add .clang-format

reformat all src
This commit is contained in:
NiccoloN
2026-02-26 19:16:42 +01:00
parent a2c31836ae
commit 810e5e75f9
32 changed files with 902 additions and 953 deletions

143
.clang-format Normal file
View File

@@ -0,0 +1,143 @@
---
AccessModifierOffset: -2
AlignAfterOpenBracket: Align
AlignArrayOfStructures: Left
AlignConsecutiveShortCaseStatements:
Enabled: true
AlignConsecutiveAssignments:
Enabled: false
AlignConsecutiveBitFields:
Enabled: false
AlignConsecutiveDeclarations:
Enabled: false
AlignConsecutiveMacros:
Enabled: false
AlignEscapedNewlines: Left
AlignOperands: AlignAfterOperator
AlignTrailingComments:
Kind: Always
OverEmptyLines: 4
AllowAllArgumentsOnNextLine: true
AllowAllParametersOfDeclarationOnNextLine: true
AllowShortBlocksOnASingleLine: Empty
AllowShortCaseExpressionOnASingleLine: true
AllowShortCaseLabelsOnASingleLine: true
AllowShortCompoundRequirementOnASingleLine: true
AllowShortEnumsOnASingleLine: false
AllowShortFunctionsOnASingleLine: All
AllowShortIfStatementsOnASingleLine: Never
AllowShortLambdasOnASingleLine: All
AllowShortLoopsOnASingleLine: false
AlwaysBreakBeforeMultilineStrings: false
BinPackArguments: false
BinPackParameters: false
BitFieldColonSpacing: Both
BraceWrapping:
BeforeElse: true
BeforeCatch: true
BeforeWhile: true
BreakBeforeBraces: Custom
BracedInitializerIndentWidth: 2
BreakAdjacentStringLiterals: true
BreakAfterAttributes: Never
BreakAfterJavaFieldAnnotations: false
BreakArrays: false
BreakBeforeBinaryOperators: NonAssignment
BreakBeforeConceptDeclarations: Always
BreakBeforeInlineASMColon: OnlyMultiline
BreakBeforeTernaryOperators: true
BreakConstructorInitializers: BeforeColon
BreakFunctionDefinitionParameters: false
BreakInheritanceList: AfterComma
BreakStringLiterals: true
BreakTemplateDeclarations: Yes
ColumnLimit: 120
CompactNamespaces: false
ConstructorInitializerIndentWidth: 0
ContinuationIndentWidth: 2
Cpp11BracedListStyle: true
DerivePointerAlignment: false
DisableFormat: false
EmptyLineAfterAccessModifier: Never
EmptyLineBeforeAccessModifier: Always
FixNamespaceComments: true
IncludeBlocks: Regroup
IncludeCategories:
- Regex: (<|")mlir(.*)(>|")
Priority: 2
- Regex: (<|")llvm(.*)(>|")
Priority: 3
- Regex: <.*>
Priority: 4
IncludeIsMainRegex: (Test)?$
IncludeIsMainSourceRegex: ""
IndentAccessModifiers: false
IndentCaseBlocks: false
IndentCaseLabels: false
IndentExternBlock: Indent
IndentGotoLabels: false
IndentPPDirectives: None
IndentRequiresClause: true
IndentWidth: 2
IndentWrappedFunctionNames: false
InsertBraces: false
InsertNewlineAtEOF: true
KeepEmptyLines:
AtEndOfFile: false
AtStartOfBlock: true
AtStartOfFile: false
LambdaBodyIndentation: Signature
Language: Cpp
LineEnding: LF
MainIncludeChar: AngleBracket
MaxEmptyLinesToKeep: 1
NamespaceIndentation: None
PackConstructorInitializers: NextLineOnly
PointerAlignment: Left
ReferenceAlignment: Pointer
RemoveBracesLLVM: true
RemoveParentheses: ReturnStatement
RemoveSemicolon: true
RequiresClausePosition: OwnLine
RequiresExpressionIndentation: OuterScope
SeparateDefinitionBlocks: Leave
ShortNamespaceLines: 4
SortIncludes: CaseSensitive
SpaceAfterCStyleCast: true
SpaceAfterLogicalNot: false
SpaceAfterTemplateKeyword: true
SpaceAroundPointerQualifiers: Before
SpaceBeforeAssignmentOperators: true
SpaceBeforeCaseColon: false
SpaceBeforeCpp11BracedList: true
SpaceBeforeCtorInitializerColon: true
SpaceBeforeInheritanceColon: true
SpaceBeforeJsonColon: false
SpaceBeforeParensOptions:
AfterControlStatements: true
AfterForeachMacros: true
AfterFunctionDeclarationName: false
AfterFunctionDefinitionName: false
AfterIfMacros: true
AfterOverloadedOperator: false
AfterPlacementOperator: false
AfterRequiresInClause: false
AfterRequiresInExpression: false
BeforeNonEmptyParentheses: false
SpaceBeforeRangeBasedForLoopColon: true
SpaceBeforeSquareBrackets: false
SpaceInEmptyBlock: false
SpacesBeforeTrailingComments: 1
SpacesInAngles: Never
SpacesInContainerLiterals: false
SpacesInLineCommentPrefix:
Minimum: 1
Maximum: 1
SpacesInParens: Never
SpacesInSquareBrackets: false
Standard: Latest
TabWidth: 4
UseTab: Never
VerilogBreakBetweenInstancePorts: true
Macros:
- LLVM_DEBUG(X)=X\n

View File

@@ -14,9 +14,9 @@ namespace onnx_mlir {
std::string getOutputDir(); std::string getOutputDir();
void createDirectory(const std::string &directory); void createDirectory(const std::string& directory);
void dumpModule(mlir::ModuleOp moduleOp, const std::string &name); void dumpModule(mlir::ModuleOp moduleOp, const std::string& name);
llvm::FailureOr<mlir::Operation*> llvm::FailureOr<mlir::Operation*>
getOtherEndOfChannel(mlir::Operation* op, bool opIsReceive, mlir::RewriterBase& rewriter); getOtherEndOfChannel(mlir::Operation* op, bool opIsReceive, mlir::RewriterBase& rewriter);

View File

@@ -11,20 +11,21 @@
#include "mlir/IR/Value.h" #include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h" #include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "llvm/Support/LogicalResult.h" #include "llvm/Support/LogicalResult.h"
#include <cstddef>
#include <cstddef>
#include <memory> #include <memory>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir; using namespace mlir;
using namespace std; using namespace std;
@@ -40,8 +41,8 @@ namespace onnx_mlir {
*/ */
class Core { class Core {
public: public:
Core(const size_t coreId, ConversionPatternRewriter &rewriter) Core(const size_t coreId, ConversionPatternRewriter& rewriter)
: coreId(coreId), rewriter(rewriter) {} : coreId(coreId), rewriter(rewriter) {}
/** /**
* @brief Add a MVM operation to the core. * @brief Add a MVM operation to the core.
@@ -52,8 +53,7 @@ public:
* @param mvmOutType The result's shape. * @param mvmOutType The result's shape.
* @return Value The result of the MVM operation. * @return Value The result of the MVM operation.
*/ */
Value addMVM( Value addMVM(Value inputTile, size_t xbarIndex, size_t outputTileId, Type mvmOutType) {
Value inputTile, size_t xbarIndex, size_t outputTileId, Type mvmOutType) {
// Use the inputTile as the reference location for the MVM operation. // Use the inputTile as the reference location for the MVM operation.
Location loc = inputTile.getLoc(); Location loc = inputTile.getLoc();
@@ -72,8 +72,7 @@ public:
// is correct. // is correct.
// Construct the MVM operation // Construct the MVM operation
Value result = rewriter.create<spatial::SpatWeightedMVMOp>( Value result = rewriter.create<spatial::SpatWeightedMVMOp>(loc, mvmOutType, xbarIndex, operand);
loc, mvmOutType, xbarIndex, operand);
// Since we are within the same core and no computation can happen in // Since we are within the same core and no computation can happen in
// paralllel, we can just apply a linear reduction in case we have multiple // paralllel, we can just apply a linear reduction in case we have multiple
@@ -84,8 +83,7 @@ public:
if (lastMVM != outputTileToMVM.end()) { if (lastMVM != outputTileToMVM.end()) {
// MVM results should have the same type for reduction. // MVM results should have the same type for reduction.
assert(lastMVM->second.getType() == result.getType()); assert(lastMVM->second.getType() == result.getType());
result = rewriter.create<spatial::SpatVAddOp>( result = rewriter.create<spatial::SpatVAddOp>(loc, mvmOutType, lastMVM->second, result);
loc, mvmOutType, lastMVM->second, result);
} }
outputTileToMVM[outputTileId] = result; outputTileToMVM[outputTileId] = result;
@@ -139,16 +137,14 @@ public:
spatial::SpatWeightedCompute createWComputeOp(Location loc) { spatial::SpatWeightedCompute createWComputeOp(Location loc) {
// Get the shape of the results. // Get the shape of the results.
SmallVector<Type> resultTypes; SmallVector<Type> resultTypes;
for (const auto &value : results) { for (const auto& value : results)
resultTypes.push_back(value.getType()); resultTypes.push_back(value.getType());
}
// Create the WComputeOp, with non-remappable operands only. // Create the WComputeOp, with non-remappable operands only.
wcomputeOp = rewriter.create<spatial::SpatWeightedCompute>( wcomputeOp = rewriter.create<spatial::SpatWeightedCompute>(loc, resultTypes, xbarWeights, operands);
loc, resultTypes, xbarWeights, operands);
// Add the body to the WComputeOp. // Add the body to the WComputeOp.
Block *releasedBlock = block.release(); Block* releasedBlock = block.release();
wcomputeOp.getBody().push_back(releasedBlock); wcomputeOp.getBody().push_back(releasedBlock);
// Add the `yieldOp` at the end, with the results. // Add the `yieldOp` at the end, with the results.
@@ -164,21 +160,18 @@ public:
void remapResults() { void remapResults() {
// Remap all the results to the WComputeOp results. // Remap all the results to the WComputeOp results.
assert(resultsToRemap.size() == wcomputeOp->getNumResults()); assert(resultsToRemap.size() == wcomputeOp->getNumResults());
for (size_t i = 0; i < resultsToRemap.size(); i++) { for (size_t i = 0; i < resultsToRemap.size(); i++)
*resultsToRemap[i] = wcomputeOp.getResult(i); *resultsToRemap[i] = wcomputeOp.getResult(i);
}
} }
void addRemappedOperands() { void addRemappedOperands() {
// Insert the remappableOperands (which were remapped in // Insert the remappableOperands (which were remapped in
// `addRemappableOperand` of another Core) // `addRemappableOperand` of another Core)
for (auto remappedValue : remappableOperands) { for (auto remappedValue : remappableOperands)
wcomputeOp->insertOperands(wcomputeOp->getNumOperands(), *remappedValue); wcomputeOp->insertOperands(wcomputeOp->getNumOperands(), *remappedValue);
}
// Update the wcomputeOp operandSegmentSize // Update the wcomputeOp operandSegmentSize
incrementWeightedComputeInputsSegmentSize( incrementWeightedComputeInputsSegmentSize(wcomputeOp, static_cast<int>(remappableOperands.size()));
wcomputeOp, static_cast<int>(remappableOperands.size()));
} }
size_t addXbarWeight(Value weight) { size_t addXbarWeight(Value weight) {
@@ -199,31 +192,27 @@ public:
llvm::outs() << "Core " << coreId << ":\n"; llvm::outs() << "Core " << coreId << ":\n";
// Print the weights // Print the weights
llvm::outs() << "Xbar Weights:\n"; llvm::outs() << "Xbar Weights:\n";
for (auto weight : xbarWeights) { for (auto weight : xbarWeights)
weight.dump(); weight.dump();
}
// Print the operands // Print the operands
llvm::outs() << "Operands:\n"; llvm::outs() << "Operands:\n";
for (auto operand : operands) { for (auto operand : operands)
llvm::outs() << operand << "\n"; llvm::outs() << operand << "\n";
}
// Dump the body block // Dump the body block
for (auto &op : block->getOperations()) { for (auto& op : block->getOperations())
op.dump(); op.dump();
}
// Print the results // Print the results
llvm::outs() << "Results:\n"; llvm::outs() << "Results:\n";
for (auto result : results) { for (auto result : results)
llvm::outs() << result << "\n"; llvm::outs() << result << "\n";
}
} }
const size_t coreId; const size_t coreId;
private: private:
ConversionPatternRewriter &rewriter; ConversionPatternRewriter& rewriter;
// Should these be set<Value> instead? But I need to keep the order // Should these be set<Value> instead? But I need to keep the order
vector<Value> operands; vector<Value> operands;
@@ -246,15 +235,16 @@ private:
}; };
struct ONNXConvOpTile : public OpConversionPattern<ONNXConvOp> { struct ONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
ONNXConvOpTile(MLIRContext *ctx) : OpConversionPattern(ctx) {} ONNXConvOpTile(MLIRContext* ctx)
: OpConversionPattern(ctx) {}
struct Producer_t { struct Producer_t {
Value value; Value value;
shared_ptr<Core> core; shared_ptr<Core> core;
}; };
LogicalResult matchAndRewrite(ONNXConvOp conv, ONNXConvOpAdaptor convAdaptor, LogicalResult
ConversionPatternRewriter &rewriter) const final { matchAndRewrite(ONNXConvOp conv, ONNXConvOpAdaptor convAdaptor, ConversionPatternRewriter& rewriter) const final {
ShapedType xShape = mlir::cast<ShapedType>(convAdaptor.getX().getType()); ShapedType xShape = mlir::cast<ShapedType>(convAdaptor.getX().getType());
ShapedType wShape = mlir::cast<ShapedType>(convAdaptor.getW().getType()); ShapedType wShape = mlir::cast<ShapedType>(convAdaptor.getW().getType());
ShapedType bShape = mlir::cast<ShapedType>(convAdaptor.getB().getType()); ShapedType bShape = mlir::cast<ShapedType>(convAdaptor.getB().getType());
@@ -264,11 +254,9 @@ struct ONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
unpackOptionalPairVector(conv.getStrides(), stride_x, stride_y); unpackOptionalPairVector(conv.getStrides(), stride_x, stride_y);
unpackOptionalPairVector(conv.getDilations(), dilation_x, dilation_y); unpackOptionalPairVector(conv.getDilations(), dilation_x, dilation_y);
auto padUnpackError = auto padUnpackError = unpackOptionalPadsVector(convAdaptor.getPads(), pad_x, pad_y);
unpackOptionalPadsVector(convAdaptor.getPads(), pad_x, pad_y); if (padUnpackError.has_value())
if (padUnpackError.has_value()) {
return rewriter.notifyMatchFailure(conv, padUnpackError.value()); return rewriter.notifyMatchFailure(conv, padUnpackError.value());
}
// TODO: Pad value at beginning and end of each dimension could be // TODO: Pad value at beginning and end of each dimension could be
// different. We should handle this case. // different. We should handle this case.
@@ -296,11 +284,9 @@ struct ONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
Location loc = conv.getLoc(); Location loc = conv.getLoc();
size_t inputTileCount = size_t inputTileCount = ceilIntegerDivide(GET_IMAGE_CHANNEL(xShape), crossbarSize.getValue());
ceilIntegerDivide(GET_IMAGE_CHANNEL(xShape), crossbarSize.getValue());
size_t inputTileRemainder = GET_IMAGE_CHANNEL(xShape) % crossbarSize; size_t inputTileRemainder = GET_IMAGE_CHANNEL(xShape) % crossbarSize;
size_t outputTileCount = size_t outputTileCount = ceilIntegerDivide(GET_IMAGE_CHANNEL(yShape), crossbarSize.getValue());
ceilIntegerDivide(GET_IMAGE_CHANNEL(yShape), crossbarSize.getValue());
size_t outputTileRemainder = GET_IMAGE_CHANNEL(yShape) % crossbarSize; size_t outputTileRemainder = GET_IMAGE_CHANNEL(yShape) % crossbarSize;
// Tile the input tensor // Tile the input tensor
@@ -310,22 +296,20 @@ struct ONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
// c. Pixel `y` position // c. Pixel `y` position
// For example: inputTiles[channelTile][x][y] // For example: inputTiles[channelTile][x][y]
// Example complete input tensor: tensor<1x3x6x6xf32> (NxCxWxH) // Example complete input tensor: tensor<1x3x6x6xf32> (NxCxWxH)
SmallVector<SmallVector<SmallVector<Value>>> inputTiles(inputTileCount, SmallVector<SmallVector<SmallVector<Value>>> inputTiles(
SmallVector<SmallVector<Value>>(input_w, SmallVector<Value>(input_h))); inputTileCount, SmallVector<SmallVector<Value>>(input_w, SmallVector<Value>(input_h)));
auto resolveErrorOpt = resolveImgInputTiles(convAdaptor.getX(), inputTiles, auto resolveErrorOpt = resolveImgInputTiles(
inputTileCount, inputTileRemainder, input_h, input_h, rewriter); convAdaptor.getX(), inputTiles, inputTileCount, inputTileRemainder, input_h, input_h, rewriter);
if (resolveErrorOpt.has_value()) { if (resolveErrorOpt.has_value())
return rewriter.notifyMatchFailure(conv, *resolveErrorOpt); return rewriter.notifyMatchFailure(conv, *resolveErrorOpt);
}
SmallVector<OpFoldResult> strides = SmallVector<OpFoldResult> strides = SmallVector<OpFoldResult>(4, rewriter.getIndexAttr(1));
SmallVector<OpFoldResult>(4, rewriter.getIndexAttr(1)); SmallVector<OpFoldResult> offsets = SmallVector<OpFoldResult>(4, rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> offsets = SmallVector<OpFoldResult> sizes = SmallVector<OpFoldResult> {rewriter.getIndexAttr(1),
SmallVector<OpFoldResult>(4, rewriter.getIndexAttr(0)); rewriter.getIndexAttr(crossbarSize),
SmallVector<OpFoldResult> sizes = SmallVector<OpFoldResult>{ rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1), rewriter.getIndexAttr(crossbarSize), rewriter.getIndexAttr(1)};
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
// Tile the weight tensor // Tile the weight tensor
// Weight tiles need to be indexed by: // Weight tiles need to be indexed by:
@@ -336,31 +320,30 @@ struct ONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
// For example: weightTiles[filterTile][channelTile][x][y] // For example: weightTiles[filterTile][channelTile][x][y]
// Example complete weight tensor: tensor<32x3x3x3xf32> (FxCxWxH) // Example complete weight tensor: tensor<32x3x3x3xf32> (FxCxWxH)
SmallVector<SmallVector<SmallVector<SmallVector<Value>>>> weightTiles( SmallVector<SmallVector<SmallVector<SmallVector<Value>>>> weightTiles(
outputTileCount, outputTileCount,
SmallVector<SmallVector<SmallVector<Value>>>(inputTileCount, SmallVector<SmallVector<SmallVector<Value>>>(inputTileCount,
SmallVector<SmallVector<Value>>(krn_w, SmallVector<Value>(krn_h)))); SmallVector<SmallVector<Value>>(krn_w, SmallVector<Value>(krn_h))));
strides = SmallVector<OpFoldResult>(4, rewriter.getIndexAttr(1)); strides = SmallVector<OpFoldResult>(4, rewriter.getIndexAttr(1));
offsets = SmallVector<OpFoldResult>(4, rewriter.getIndexAttr(0)); offsets = SmallVector<OpFoldResult>(4, rewriter.getIndexAttr(0));
sizes = {rewriter.getIndexAttr(crossbarSize), sizes = {rewriter.getIndexAttr(crossbarSize),
rewriter.getIndexAttr(crossbarSize), rewriter.getIndexAttr(1), rewriter.getIndexAttr(crossbarSize),
rewriter.getIndexAttr(1)}; rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1)};
for (size_t i = 0; i < outputTileCount; i++) { for (size_t i = 0; i < outputTileCount; i++) {
if (i == outputTileCount - 1 && outputTileRemainder != 0) { if (i == outputTileCount - 1 && outputTileRemainder != 0)
sizes[0] = rewriter.getIndexAttr(outputTileRemainder); sizes[0] = rewriter.getIndexAttr(outputTileRemainder);
}
sizes[1] = rewriter.getIndexAttr(crossbarSize); sizes[1] = rewriter.getIndexAttr(crossbarSize);
offsets[0] = rewriter.getIndexAttr(i * crossbarSize); offsets[0] = rewriter.getIndexAttr(i * crossbarSize);
for (size_t j = 0; j < inputTileCount; j++) { for (size_t j = 0; j < inputTileCount; j++) {
if (j == inputTileCount - 1 && inputTileRemainder != 0) { if (j == inputTileCount - 1 && inputTileRemainder != 0)
sizes[1] = rewriter.getIndexAttr(inputTileRemainder); sizes[1] = rewriter.getIndexAttr(inputTileRemainder);
}
for (size_t x = 0; x < krn_w; x++) { for (size_t x = 0; x < krn_w; x++) {
for (size_t y = 0; y < krn_h; y++) { for (size_t y = 0; y < krn_h; y++) {
offsets[1] = rewriter.getIndexAttr(j * crossbarSize); offsets[1] = rewriter.getIndexAttr(j * crossbarSize);
offsets[2] = rewriter.getIndexAttr(x); offsets[2] = rewriter.getIndexAttr(x);
offsets[3] = rewriter.getIndexAttr(y); offsets[3] = rewriter.getIndexAttr(y);
weightTiles[i][j][x][y] = rewriter.create<tensor::ExtractSliceOp>( weightTiles[i][j][x][y] =
loc, convAdaptor.getW(), offsets, sizes, strides); rewriter.create<tensor::ExtractSliceOp>(loc, convAdaptor.getW(), offsets, sizes, strides);
} }
} }
} }
@@ -379,56 +362,45 @@ struct ONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
// For example: outputTiles[filterTile][x][y] // For example: outputTiles[filterTile][x][y]
// Example complete output tensor: tensor<1x32x3x3xf32> (NxFxWxH) // Example complete output tensor: tensor<1x32x3x3xf32> (NxFxWxH)
SmallVector<SmallVector<SmallVector<shared_ptr<Value>>>> outputTiles( SmallVector<SmallVector<SmallVector<shared_ptr<Value>>>> outputTiles(
outputTileCount, outputTileCount,
SmallVector<SmallVector<shared_ptr<Value>>>( SmallVector<SmallVector<shared_ptr<Value>>>(output_w, SmallVector<shared_ptr<Value>>(output_h, nullptr)));
output_w, SmallVector<shared_ptr<Value>>(output_h, nullptr)));
size_t replicationFactor; size_t replicationFactor;
if (!conv->hasAttr(REPLICATION_ATTR_NAME)) { if (!conv->hasAttr(REPLICATION_ATTR_NAME))
replicationFactor = 1; replicationFactor = 1;
} else { else
replicationFactor = replicationFactor = conv->getAttrOfType<IntegerAttr>(REPLICATION_ATTR_NAME).getInt();
conv->getAttrOfType<IntegerAttr>(REPLICATION_ATTR_NAME).getInt();
}
// producers[outTile][out_x][out_y][producerIndex] // producers[outTile][out_x][out_y][producerIndex]
vector<vector<vector<vector<Producer_t>>>> producers = vector<vector<vector<vector<Producer_t>>>> producers = vector<vector<vector<vector<Producer_t>>>>(
vector<vector<vector<vector<Producer_t>>>>(outputTileCount, outputTileCount,
vector<vector<vector<Producer_t>>>(output_w, vector<vector<vector<Producer_t>>>(output_w, vector<vector<Producer_t>>(output_h, vector<Producer_t>())));
vector<vector<Producer_t>>(output_h, vector<Producer_t>())));
// Schedule in cores // Schedule in cores
size_t coreId = 0; size_t coreId = 0;
vector<shared_ptr<Core>> curCores(replicationFactor); vector<shared_ptr<Core>> curCores(replicationFactor);
for (size_t i = 0; i < replicationFactor; i++) { for (size_t i = 0; i < replicationFactor; i++)
curCores[i] = make_shared<Core>(coreId++, rewriter); curCores[i] = make_shared<Core>(coreId++, rewriter);
}
vector<shared_ptr<Core>> cores; vector<shared_ptr<Core>> cores;
const size_t replicationSliceSize = const size_t replicationSliceSize = ceilIntegerDivide(input_w, replicationFactor);
ceilIntegerDivide(input_w, replicationFactor);
for (size_t krn_x = 0; krn_x < krn_h; krn_x++) { for (size_t krn_x = 0; krn_x < krn_h; krn_x++) {
for (size_t krn_y = 0; krn_y < krn_w; krn_y++) { for (size_t krn_y = 0; krn_y < krn_w; krn_y++) {
RankedTensorType mvmOutType = RankedTensorType mvmOutType =
RankedTensorType::get({1, static_cast<long>(crossbarSize), 1, 1}, RankedTensorType::get({1, static_cast<long>(crossbarSize), 1, 1}, bShape.getElementType());
bShape.getElementType());
for (size_t outTile = 0; outTile < outputTileCount; outTile++) { for (size_t outTile = 0; outTile < outputTileCount; outTile++) {
if (outTile == outputTileCount - 1 && outputTileRemainder != 0) { if (outTile == outputTileCount - 1 && outputTileRemainder != 0)
mvmOutType = mvmOutType.clone( mvmOutType = mvmOutType.clone({1, static_cast<long>(outputTileRemainder), 1, 1});
{1, static_cast<long>(outputTileRemainder), 1, 1});
}
for (size_t inTile = 0; inTile < inputTileCount; inTile++) { for (size_t inTile = 0; inTile < inputTileCount; inTile++) {
vector<size_t> xbarIndexes(replicationFactor); vector<size_t> xbarIndexes(replicationFactor);
for (size_t i = 0; i < replicationFactor; i++) { for (size_t i = 0; i < replicationFactor; i++)
xbarIndexes[i] = curCores[i]->addXbarWeight( xbarIndexes[i] = curCores[i]->addXbarWeight(weightTiles[outTile][inTile][krn_x][krn_y]);
weightTiles[outTile][inTile][krn_x][krn_y]);
}
size_t out_x = 0; size_t out_x = 0;
for (size_t in_x = 0; in_x < input_w; in_x += stride_x) { for (size_t in_x = 0; in_x < input_w; in_x += stride_x) {
@@ -443,25 +415,20 @@ struct ONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
for (size_t in_y = 0; in_y < input_h; in_y += stride_y) { for (size_t in_y = 0; in_y < input_h; in_y += stride_y) {
// Adjust the input based on the kernel // Adjust the input based on the kernel
int actual_in_x = in_x - ((int)krn_w / 2) + krn_x * dilation_x; int actual_in_x = in_x - ((int) krn_w / 2) + krn_x * dilation_x;
int actual_in_y = in_y - ((int)krn_h / 2) + krn_y * dilation_y; int actual_in_y = in_y - ((int) krn_h / 2) + krn_y * dilation_y;
// Check if we are within the input image // Check if we are within the input image
if (verifyWithinBoundsAndPaddings(input_w, input_h, actual_in_x, if (verifyWithinBoundsAndPaddings(input_w, input_h, actual_in_x, actual_in_y, pad_x, pad_y).failed()) {
actual_in_y, pad_x, pad_y)
.failed()) {
out_y++; out_y++;
continue; continue;
} }
size_t outTileId = size_t outTileId = outTile * output_w * output_h + out_x * output_h + out_y;
outTile * output_w * output_h + out_x * output_h + out_y;
auto mvm = curCores[coreIndex]->addMVM( auto mvm = curCores[coreIndex]->addMVM(
inputTiles[inTile][actual_in_x][actual_in_y], inputTiles[inTile][actual_in_x][actual_in_y], xbarIndexes[coreIndex], outTileId, mvmOutType);
xbarIndexes[coreIndex], outTileId, mvmOutType);
producers[outTile][out_x][out_y].push_back( producers[outTile][out_x][out_y].push_back({mvm, curCores[coreIndex]});
{mvm, curCores[coreIndex]});
out_y++; out_y++;
} }
@@ -481,11 +448,9 @@ struct ONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
} }
} }
for (auto &curCore : curCores) { for (auto& curCore : curCores)
if (curCore->isCoreEmpty() == false) { if (curCore->isCoreEmpty() == false)
cores.emplace_back(std::move(curCore)); cores.emplace_back(std::move(curCore));
}
}
curCores.clear(); curCores.clear();
// Now, do the reduction of each output pixel tile // Now, do the reduction of each output pixel tile
for (size_t outTile = 0; outTile < outputTileCount; outTile++) { for (size_t outTile = 0; outTile < outputTileCount; outTile++) {
@@ -497,9 +462,8 @@ struct ONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
// core. // core.
std::unordered_map<size_t, Producer_t> withinCoreReducedProducers; std::unordered_map<size_t, Producer_t> withinCoreReducedProducers;
for (auto producer : producers[outTile][out_x][out_y]) { for (auto producer : producers[outTile][out_x][out_y])
withinCoreReducedProducers[producer.core->coreId] = producer; withinCoreReducedProducers[producer.core->coreId] = producer;
}
// Now, we need to apply inter-core reduction // Now, we need to apply inter-core reduction
@@ -509,8 +473,7 @@ struct ONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
auto singleProducer = withinCoreReducedProducers.begin()->second; auto singleProducer = withinCoreReducedProducers.begin()->second;
// Use last producer as the final result // Use last producer as the final result
auto reducedValue = auto reducedValue = singleProducer.core->makeResultRemappable(singleProducer.value);
singleProducer.core->makeResultRemappable(singleProducer.value);
outputTiles[outTile][out_x][out_y] = reducedValue; outputTiles[outTile][out_x][out_y] = reducedValue;
continue; continue;
} }
@@ -535,9 +498,9 @@ struct ONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
auto lastProducerCoreId = lastProducer.core->coreId; auto lastProducerCoreId = lastProducer.core->coreId;
auto curProducerCoreId = curProducer.core->coreId; auto curProducerCoreId = curProducer.core->coreId;
assert(lastProducerCoreId != curProducerCoreId && assert(lastProducerCoreId != curProducerCoreId
"We should have already applied within-core reduction, how " && "We should have already applied within-core reduction, how "
"could we have same cores here?"); "could we have same cores here?");
// Sort the cores by coreId // Sort the cores by coreId
if (curProducerCoreId < lastProducerCoreId) { if (curProducerCoreId < lastProducerCoreId) {
@@ -545,7 +508,8 @@ struct ONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
core1Value = curProducer.value; core1Value = curProducer.value;
core2 = lastProducer.core; core2 = lastProducer.core;
core2Value = lastProducer.value; core2Value = lastProducer.value;
} else { }
else {
core1 = lastProducer.core; core1 = lastProducer.core;
core1Value = lastProducer.value; core1Value = lastProducer.value;
core2 = curProducer.core; core2 = curProducer.core;
@@ -556,9 +520,8 @@ struct ONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
auto secondCoreBlockArg = core2->addRemappableOperand(newCoreRes); auto secondCoreBlockArg = core2->addRemappableOperand(newCoreRes);
rewriter.setInsertionPointAfterValue(core2Value); rewriter.setInsertionPointAfterValue(core2Value);
Value vaddRes = Value vaddRes = rewriter.create<spatial::SpatVAddOp>(
rewriter.create<spatial::SpatVAddOp>(core2Value.getLoc(), core2Value.getLoc(), core2Value.getType(), core2Value, secondCoreBlockArg);
core2Value.getType(), core2Value, secondCoreBlockArg);
lastProducer = {vaddRes, core2}; lastProducer = {vaddRes, core2};
@@ -568,8 +531,7 @@ struct ONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
// TODO: Add the bias and apply mapping (if present) // TODO: Add the bias and apply mapping (if present)
// Use last producer as the final result // Use last producer as the final result
auto reducedValue = auto reducedValue = lastProducer.core->makeResultRemappable(lastProducer.value);
lastProducer.core->makeResultRemappable(lastProducer.value);
outputTiles[outTile][out_x][out_y] = reducedValue; outputTiles[outTile][out_x][out_y] = reducedValue;
} }
} }
@@ -578,15 +540,14 @@ struct ONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
// Now, we need to turn the cores into a spatial::SpatWeightedCompute. // Now, we need to turn the cores into a spatial::SpatWeightedCompute.
rewriter.setInsertionPointAfter(conv); rewriter.setInsertionPointAfter(conv);
spatial::SpatWeightedCompute lastWComputeOp; spatial::SpatWeightedCompute lastWComputeOp;
for (auto &core : cores) { for (auto& core : cores) {
lastWComputeOp = core->createWComputeOp(loc); lastWComputeOp = core->createWComputeOp(loc);
core->remapResults(); core->remapResults();
rewriter.setInsertionPointAfter(lastWComputeOp); rewriter.setInsertionPointAfter(lastWComputeOp);
} }
for (auto &core : cores) { for (auto& core : cores)
core->addRemappedOperands(); core->addRemappedOperands();
}
// Set the insertion point after the last WComputeOp. // Set the insertion point after the last WComputeOp.
rewriter.setInsertionPointAfter(lastWComputeOp); rewriter.setInsertionPointAfter(lastWComputeOp);
@@ -597,8 +558,7 @@ struct ONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
for (size_t outTile = 0; outTile < outputTileCount; outTile++) for (size_t outTile = 0; outTile < outputTileCount; outTile++)
tilesToConcat.push_back(*outputTiles[outTile][outX][outY]); tilesToConcat.push_back(*outputTiles[outTile][outX][outY]);
Value outputImage = rewriter.create<spatial::SpatImgConcatOp>( Value outputImage = rewriter.create<spatial::SpatImgConcatOp>(loc, conv.getY().getType(), tilesToConcat);
loc, conv.getY().getType(), tilesToConcat);
// Value outputImage = // Value outputImage =
// createImgConcatOp(outputTiles, rewriter, loc, Y.getType()); // createImgConcatOp(outputTiles, rewriter, loc, Y.getType());
@@ -616,8 +576,7 @@ struct ONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
} }
}; };
void populateTilingConvOpPattern( void populateTilingConvOpPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
RewritePatternSet &patterns, MLIRContext *ctx) {
patterns.insert<ONNXConvOpTile>(ctx); patterns.insert<ONNXConvOpTile>(ctx);
} }

View File

@@ -1,6 +1,3 @@
#include "Compiler/PimCompilerOptions.hpp"
#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "Dialect/Spatial/SpatialOps.hpp"
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Operation.h" #include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
@@ -8,13 +5,19 @@
#include "mlir/IR/Value.h" #include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h" #include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/WeightSubdivider.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include <algorithm> #include <algorithm>
#include <cstddef> #include <cstddef>
#include <unistd.h> #include <unistd.h>
#include "Compiler/PimCompilerOptions.hpp"
#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/WeightSubdivider.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir; using namespace mlir;
using namespace std; using namespace std;
@@ -27,10 +30,11 @@ namespace onnx_mlir {
* output tensor. * output tensor.
*/ */
struct ExperimentalONNXConvOpTile : public OpConversionPattern<ONNXConvOp> { struct ExperimentalONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
ExperimentalONNXConvOpTile(MLIRContext *ctx) : OpConversionPattern(ctx) {} ExperimentalONNXConvOpTile(MLIRContext* ctx)
: OpConversionPattern(ctx) {}
LogicalResult matchAndRewrite(ONNXConvOp conv, ONNXConvOpAdaptor convAdaptor, LogicalResult
ConversionPatternRewriter &rewriter) const final { matchAndRewrite(ONNXConvOp conv, ONNXConvOpAdaptor convAdaptor, ConversionPatternRewriter& rewriter) const final {
// --------------------------------- // // --------------------------------- //
// --- READ OPERATION PARAMETERS --- // // --- READ OPERATION PARAMETERS --- //
@@ -46,12 +50,12 @@ struct ExperimentalONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
ShapedType weightsType = cast<ShapedType>(convAdaptor.getW().getType()); ShapedType weightsType = cast<ShapedType>(convAdaptor.getW().getType());
// TODO: Address bigger batches. // TODO: Address bigger batches.
assert(GET_IMAGE_N(inputType) == 1 && "Batch size must be 1" assert(GET_IMAGE_N(inputType) == 1
"for convolution."); && "Batch size must be 1"
"for convolution.");
// TODO: Address replication. // TODO: Address replication.
assert(coresCount.getValue() == -1 && assert(coresCount.getValue() == -1 && "Replication is not yet supported for convolution.");
"Replication is not yet supported for convolution.");
// TODO: Address bias addition. // TODO: Address bias addition.
@@ -97,11 +101,10 @@ struct ExperimentalONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
long inputTile = it; long inputTile = it;
// Create the slicing sizes. // Create the slicing sizes.
SmallVector<OpFoldResult> slicingSizes{ SmallVector<OpFoldResult> slicingSizes {/* 0 */ rewriter.getIndexAttr(crossbarHeight),
/* 0 */ rewriter.getIndexAttr(crossbarHeight), /* 1 */ rewriter.getIndexAttr(crossbarWidth),
/* 1 */ rewriter.getIndexAttr(crossbarWidth), /* 2 */ rewriter.getIndexAttr(1),
/* 2 */ rewriter.getIndexAttr(1), /* 3 */ rewriter.getIndexAttr(1)};
/* 3 */ rewriter.getIndexAttr(1)};
// - Slicing along the filter x position. // - Slicing along the filter x position.
// - Slicing along the filter y position. // - Slicing along the filter y position.
@@ -109,16 +112,14 @@ struct ExperimentalONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
for (size_t filterY = 0; filterY < kernelHeight; ++filterY) { for (size_t filterY = 0; filterY < kernelHeight; ++filterY) {
// Create the slicing offsets. // Create the slicing offsets.
SmallVector<OpFoldResult> slicingOffsets{ SmallVector<OpFoldResult> slicingOffsets {/* 0 */ rewriter.getIndexAttr(outputTile * crossbarSize),
/* 0 */ rewriter.getIndexAttr(outputTile * crossbarSize), /* 1 */ rewriter.getIndexAttr(inputTile * crossbarSize),
/* 1 */ rewriter.getIndexAttr(inputTile * crossbarSize), /* 2 */ rewriter.getIndexAttr(filterX),
/* 2 */ rewriter.getIndexAttr(filterX), /* 3 */ rewriter.getIndexAttr(filterY)};
/* 3 */ rewriter.getIndexAttr(filterY)};
// Create the slice extraction operation. // Create the slice extraction operation.
auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>( auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
conv.getLoc(), convAdaptor.getW(), slicingOffsets, slicingSizes, conv.getLoc(), convAdaptor.getW(), slicingOffsets, slicingSizes, slicingStrides);
slicingStrides);
// Add a note to the extractSliceOp, with the filterX and filterY. // Add a note to the extractSliceOp, with the filterX and filterY.
weightsGroups[inputTile][outputTile].push_back(extractSliceOp); weightsGroups[inputTile][outputTile].push_back(extractSliceOp);
@@ -150,8 +151,7 @@ struct ExperimentalONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
// -------------------------------- // // -------------------------------- //
// Get the next group of weights for the compute unit. // Get the next group of weights for the compute unit.
SmallVector<TaggedWeights> weightsGroups = SmallVector<TaggedWeights> weightsGroups = weightSubdivider.popGroups(crossbarCountInCore.getValue());
weightSubdivider.popGroups(crossbarCountInCore.getValue());
SmallVector<Value> computeWeights; SmallVector<Value> computeWeights;
SmallVector<Value> computeOperands; SmallVector<Value> computeOperands;
@@ -168,15 +168,13 @@ struct ExperimentalONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
// Iterate over all weights groups for this compute unit. // Iterate over all weights groups for this compute unit.
map<long, Value> localSlices; // WRT the current compute unit. map<long, Value> localSlices; // WRT the current compute unit.
for (auto group : weightsGroups) { for (auto group : weightsGroups) {
for (Value weight : group.weights) { for (Value weight : group.weights)
computeWeights.push_back(weight); computeWeights.push_back(weight);
}
// There might be multiple weight groups for the same input tile, so if // There might be multiple weight groups for the same input tile, so if
// we've already added the input tile, skip it. // we've already added the input tile, skip it.
if (localSlices.find(group.inputTile) != localSlices.end()) { if (localSlices.find(group.inputTile) != localSlices.end())
continue; continue;
}
// We might have already sliced the input tensor for some other compute // We might have already sliced the input tensor for some other compute
// unit, so if we have, reuse the slicing operation without creating a // unit, so if we have, reuse the slicing operation without creating a
@@ -188,26 +186,21 @@ struct ExperimentalONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
} }
// Create the input tensor slicing offsets. // Create the input tensor slicing offsets.
SmallVector<OpFoldResult> slicingOffsets{ SmallVector<OpFoldResult> slicingOffsets {/* 0 */ rewriter.getIndexAttr(0), // No offset along the batch axis.
/* 0 */ rewriter.getIndexAttr(0), // No offset along the batch axis. /* 1 */ rewriter.getIndexAttr(group.inputTile * crossbarSize),
/* 1 */ rewriter.getIndexAttr(group.inputTile * crossbarSize), /* 2 */ rewriter.getIndexAttr(0),
/* 2 */ rewriter.getIndexAttr(0), /* 3 */ rewriter.getIndexAttr(0)};
/* 3 */ rewriter.getIndexAttr(0)};
// Create the input tensor slicing sizes. // Create the input tensor slicing sizes.
size_t tilingSize = group.inputTile == inputTileCount.quot size_t tilingSize = group.inputTile == inputTileCount.quot ? inputTileCount.rem : crossbarSize;
? inputTileCount.rem SmallVector<OpFoldResult> slicingSizes {/* 0 */ rewriter.getIndexAttr(1), // Batch size is always 1.
: crossbarSize; /* 1 */ rewriter.getIndexAttr(tilingSize),
SmallVector<OpFoldResult> slicingSizes{ /* 2 */ rewriter.getIndexAttr(GET_IMAGE_WIDTH(inputType)),
/* 0 */ rewriter.getIndexAttr(1), // Batch size is always 1. /* 3 */ rewriter.getIndexAttr(GET_IMAGE_HEIGHT(inputType))};
/* 1 */ rewriter.getIndexAttr(tilingSize),
/* 2 */ rewriter.getIndexAttr(GET_IMAGE_WIDTH(inputType)),
/* 3 */ rewriter.getIndexAttr(GET_IMAGE_HEIGHT(inputType))};
// Create the slice extraction operation. // Create the slice extraction operation.
auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>( auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
conv.getLoc(), convAdaptor.getX(), slicingOffsets, slicingSizes, conv.getLoc(), convAdaptor.getX(), slicingOffsets, slicingSizes, slicingStrides);
slicingStrides);
computeOperands.push_back(extractSliceOp); computeOperands.push_back(extractSliceOp);
@@ -229,36 +222,28 @@ struct ExperimentalONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
// There might be multiple weight groups for the same output tile, so if // There might be multiple weight groups for the same output tile, so if
// we've already added the output tile, skip it. // we've already added the output tile, skip it.
if (outputTileIndices.find(group.outputTile) != if (outputTileIndices.find(group.outputTile) != outputTileIndices.end())
outputTileIndices.end()) {
continue; continue;
}
// Additionally, after adding the input slices as operands, also add any // Additionally, after adding the input slices as operands, also add any
// compatible partial results from previous compute units. // compatible partial results from previous compute units.
if (globalPartialResults.find(group.outputTile) != if (globalPartialResults.find(group.outputTile) != globalPartialResults.end()) {
globalPartialResults.end()) {
computeOperands.push_back(globalPartialResults[group.outputTile]); computeOperands.push_back(globalPartialResults[group.outputTile]);
reductionTileIndices[group.outputTile] = computeOperands.size() - 1; reductionTileIndices[group.outputTile] = computeOperands.size() - 1;
} }
// Define the output shape for this group. // Define the output shape for this group.
long outputTileSize = group.outputTile == outputTileCount.quot long outputTileSize = group.outputTile == outputTileCount.quot ? outputTileCount.rem : crossbarSize;
? outputTileCount.rem
: crossbarSize;
// TODO: Address non-same padding. // TODO: Address non-same padding.
SmallVector<int64_t> outputShapeArray{ SmallVector<int64_t> outputShapeArray {/* 0 */ 1, // Batch size is always 1.
/* 0 */ 1, // Batch size is always 1. /* 1 */ outputTileSize,
/* 1 */ outputTileSize, /* 2 */ GET_IMAGE_WIDTH(outputType), // Same padding assumed.
/* 2 */ GET_IMAGE_WIDTH(outputType), // Same padding assumed. /* 3 */ GET_IMAGE_HEIGHT(outputType)};
/* 3 */ GET_IMAGE_HEIGHT(outputType)};
auto elementType = auto elementType = dyn_cast<RankedTensorType>(conv.getY().getType()).getElementType();
dyn_cast<RankedTensorType>(conv.getY().getType()).getElementType();
computeOutputType.push_back( computeOutputType.push_back(RankedTensorType::get(outputShapeArray, elementType));
RankedTensorType::get(outputShapeArray, elementType));
outputTileIndices[group.outputTile] = computeOutputType.size() - 1; outputTileIndices[group.outputTile] = computeOutputType.size() - 1;
} }
@@ -268,43 +253,36 @@ struct ExperimentalONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
// ----------------------------- // // ----------------------------- //
// Create the compute unit. // Create the compute unit.
spatial::SpatWeightedCompute currentCompute = spatial::SpatWeightedCompute currentCompute = rewriter.create<spatial::SpatWeightedCompute>(
rewriter.create<spatial::SpatWeightedCompute>(conv.getLoc(), conv.getLoc(), computeOutputType, computeWeights, computeOperands);
computeOutputType, computeWeights, computeOperands);
// Create a new block for the compute unit and add the operands. // Create a new block for the compute unit and add the operands.
Block *block = rewriter.createBlock(&currentCompute.getRegion()); Block* block = rewriter.createBlock(&currentCompute.getRegion());
rewriter.setInsertionPointToStart(block); rewriter.setInsertionPointToStart(block);
for (Value operand : computeOperands) { for (Value operand : computeOperands)
block->addArgument(operand.getType(), conv->getLoc()); block->addArgument(operand.getType(), conv->getLoc());
}
// Initialize a map of local partial results. // Initialize a map of local partial results.
map<long, Value> localPartialResults; // WRT the current compute unit. map<long, Value> localPartialResults; // WRT the current compute unit.
// If we have any reduction tiles, add them to the local partial results. // If we have any reduction tiles, add them to the local partial results.
for (auto reductionTileIndex : reductionTileIndices) { for (auto reductionTileIndex : reductionTileIndices)
localPartialResults[reductionTileIndex.first] = localPartialResults[reductionTileIndex.first] = block->getArgument(reductionTileIndex.second);
block->getArgument(reductionTileIndex.second);
}
// Add all the applyFilters operations to the block. // Add all the applyFilters operations to the block.
for (TaggedWeights group : weightsGroups) { for (TaggedWeights group : weightsGroups) {
// Get the outputType for this group. // Get the outputType for this group.
Type outputType = Type outputType = computeOutputType[outputTileIndices[group.outputTile]];
computeOutputType[outputTileIndices[group.outputTile]];
// Create an apply filters operation. // Create an apply filters operation.
BlockArgument blockArgument = BlockArgument blockArgument = block->getArgument(inputTileIndices[group.inputTile]);
block->getArgument(inputTileIndices[group.inputTile]);
// The list of weight indices is group.startingCrossbarIndex + 0, 1, 2, // The list of weight indices is group.startingCrossbarIndex + 0, 1, 2,
// ... As many weights as the size of group.weights. // ... As many weights as the size of group.weights.
SmallVector<long> weightIndices; SmallVector<long> weightIndices;
for (size_t i = 0; i < group.weights.size(); ++i) { for (size_t i = 0; i < group.weights.size(); ++i)
weightIndices.push_back(group.startingCrossbarIndex + i); weightIndices.push_back(group.startingCrossbarIndex + i);
}
SmallVector<int64_t> xKerPos; SmallVector<int64_t> xKerPos;
SmallVector<int64_t> yKerPos; SmallVector<int64_t> yKerPos;
@@ -323,16 +301,14 @@ struct ExperimentalONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
ArrayAttr xKerPosAttr = rewriter.getI64ArrayAttr(xKerPos); ArrayAttr xKerPosAttr = rewriter.getI64ArrayAttr(xKerPos);
ArrayAttr yKerPosAttr = rewriter.getI64ArrayAttr(yKerPos); ArrayAttr yKerPosAttr = rewriter.getI64ArrayAttr(yKerPos);
Value result = Value result = rewriter.create<spatial::SpatApplyFiltersOp>(
rewriter.create<spatial::SpatApplyFiltersOp>(conv.getLoc(), outputType, conv.getLoc(), outputType, weightIndicesAttr, xKerPosAttr, yKerPosAttr, blockArgument);
weightIndicesAttr, xKerPosAttr, yKerPosAttr, blockArgument);
// Perform local reduction if necessary. // Perform local reduction if necessary.
if (localPartialResults.find(group.outputTile) != if (localPartialResults.find(group.outputTile) != localPartialResults.end()) {
localPartialResults.end()) {
result = rewriter.create<spatial::SpatVAddOp>(conv.getLoc(), result = rewriter.create<spatial::SpatVAddOp>(
result.getType(), localPartialResults[group.outputTile], result); conv.getLoc(), result.getType(), localPartialResults[group.outputTile], result);
} }
// Update the partial results map. // Update the partial results map.
@@ -385,22 +361,18 @@ struct ExperimentalONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
// Turn the values into a SmallVector. // Turn the values into a SmallVector.
SmallVector<Value> outputValues; SmallVector<Value> outputValues;
for (long i = 0; i < outputTileCount.quot + (outputTileCount.rem > 0); for (long i = 0; i < outputTileCount.quot + (outputTileCount.rem > 0); ++i)
++i) {
outputValues.push_back(globalPartialResults[i]); outputValues.push_back(globalPartialResults[i]);
}
// Assert that the number of output values is correct. // Assert that the number of output values is correct.
assert(outputValues.size() > 0 && assert(outputValues.size() > 0 && "No output values were generated for the convolution.");
"No output values were generated for the convolution.");
// If the conv's user is a ReLU... // If the conv's user is a ReLU...
if (conv->hasOneUse()) { if (conv->hasOneUse()) {
Operation *user = *conv->getUsers().begin(); Operation* user = *conv->getUsers().begin();
if (auto relu = dyn_cast<ONNXReluOp>(user)) { if (auto relu = dyn_cast<ONNXReluOp>(user)) {
// ...then we can just replace the ReLU with the concatenation. // ...then we can just replace the ReLU with the concatenation.
rewriter.replaceOp(relu, rewriter.replaceOp(relu, rewriter.create<tensor::ConcatOp>(conv.getLoc(), 1, outputValues));
rewriter.create<tensor::ConcatOp>(conv.getLoc(), 1, outputValues));
// And erase the convolution. // And erase the convolution.
rewriter.eraseOp(conv); rewriter.eraseOp(conv);
@@ -409,8 +381,7 @@ struct ExperimentalONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
} }
// Return the final output. // Return the final output.
rewriter.replaceOp(conv, rewriter.replaceOp(conv, rewriter.create<tensor::ConcatOp>(conv.getLoc(), 1, outputValues));
rewriter.create<tensor::ConcatOp>(conv.getLoc(), 1, outputValues));
return success(); return success();
} }
@@ -422,8 +393,7 @@ struct ExperimentalONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
* @param patterns The pattern set to populate. * @param patterns The pattern set to populate.
* @param ctx The MLIR context. * @param ctx The MLIR context.
*/ */
void populateExperimentalTilingConvOpPattern( void populateExperimentalTilingConvOpPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
RewritePatternSet &patterns, MLIRContext *ctx) {
patterns.insert<ExperimentalONNXConvOpTile>(ctx); patterns.insert<ExperimentalONNXConvOpTile>(ctx);
} }

View File

@@ -1,24 +1,25 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Transforms/DialectConversion.h"
#include <cstdlib>
#include "Compiler/PimCompilerOptions.hpp" #include "Compiler/PimCompilerOptions.hpp"
#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" #include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp" #include "Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp"
#include "Conversion/ONNXToSpatial/Utils/WeightSubdivider.hpp" #include "Conversion/ONNXToSpatial/Utils/WeightSubdivider.hpp"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Transforms/DialectConversion.h"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
#include <cstdlib>
using namespace mlir; using namespace mlir;
using namespace std; using namespace std;
namespace onnx_mlir { namespace onnx_mlir {
struct ExperimentalGemmConversionPattern struct ExperimentalGemmConversionPattern : public OpConversionPattern<ONNXGemmOp> {
: public OpConversionPattern<ONNXGemmOp> { ExperimentalGemmConversionPattern(MLIRContext* ctx)
ExperimentalGemmConversionPattern(MLIRContext *ctx) : OpConversionPattern(ctx) {}
: OpConversionPattern(ctx) {}
LogicalResult matchAndRewrite(ONNXGemmOp gemmOp, ONNXGemmOpAdaptor adaptor, LogicalResult
ConversionPatternRewriter &rewriter) const final { matchAndRewrite(ONNXGemmOp gemmOp, ONNXGemmOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final {
// --------------------------------- // // --------------------------------- //
// --- READ OPERATION PARAMETERS --- // // --- READ OPERATION PARAMETERS --- //
@@ -34,17 +35,14 @@ struct ExperimentalGemmConversionPattern
ShapedType matrixType = cast<ShapedType>(adaptor.getB().getType()); ShapedType matrixType = cast<ShapedType>(adaptor.getB().getType());
// TODO: Address bigger batches. // TODO: Address bigger batches.
assert(inputType.getShape()[0] == 1 && assert(inputType.getShape()[0] == 1 && "Only batch size of 1 is supported for GEMM.");
"Only batch size of 1 is supported for GEMM.");
// TODO: Address replication. // TODO: Address replication.
assert(coresCount.getValue() == -1 && assert(coresCount.getValue() == -1 && "Replication is not yet supported for GEMM.");
"Replication is not yet supported for GEMM.");
// TODO: Address bias addition. // TODO: Address bias addition.
assert(inputType.getShape()[1] == matrixType.getShape()[0] && assert(inputType.getShape()[1] == matrixType.getShape()[0] && "Input tile size must match the matrix's row size.");
"Input tile size must match the matrix's row size.");
ldiv_t inputTileCount = div(inputType.getShape()[1], crossbarSize); ldiv_t inputTileCount = div(inputType.getShape()[1], crossbarSize);
ldiv_t outputTileCount = div(outputType.getShape()[1], crossbarSize); ldiv_t outputTileCount = div(outputType.getShape()[1], crossbarSize);
@@ -88,11 +86,10 @@ struct ExperimentalGemmConversionPattern
long inputTile = it; long inputTile = it;
// Create the slicing sizes. // Create the slicing sizes.
SmallVector<OpFoldResult> slicingSizes{ SmallVector<OpFoldResult> slicingSizes {/* 0 */ rewriter.getIndexAttr(crossbarHeight),
/* 0 */ rewriter.getIndexAttr(crossbarHeight), /* 1 */ rewriter.getIndexAttr(crossbarWidth),
/* 1 */ rewriter.getIndexAttr(crossbarWidth), /* 2 */ /* rewriter.getIndexAttr(1), */
/* 2 */ /* rewriter.getIndexAttr(1), */ /* 3 */ /* rewriter.getIndexAttr(1) */};
/* 3 */ /* rewriter.getIndexAttr(1) */};
// - Slicing along the filter x position. // - Slicing along the filter x position.
// - Slicing along the filter y position. // - Slicing along the filter y position.
@@ -100,16 +97,14 @@ struct ExperimentalGemmConversionPattern
for (size_t filterY = 0; filterY < kernelHeight; ++filterY) { for (size_t filterY = 0; filterY < kernelHeight; ++filterY) {
// Create the slicing offsets. // Create the slicing offsets.
SmallVector<OpFoldResult> slicingOffsets{ SmallVector<OpFoldResult> slicingOffsets {/* 0 */ rewriter.getIndexAttr(outputTile * crossbarSize),
/* 0 */ rewriter.getIndexAttr(outputTile * crossbarSize), /* 1 */ rewriter.getIndexAttr(inputTile * crossbarSize),
/* 1 */ rewriter.getIndexAttr(inputTile * crossbarSize), /* 2 */ /* rewriter.getIndexAttr(filterX), */
/* 2 */ /* rewriter.getIndexAttr(filterX), */ /* 3 */ /* rewriter.getIndexAttr(filterY) */};
/* 3 */ /* rewriter.getIndexAttr(filterY) */};
// Create the slice extraction operation. // Create the slice extraction operation.
auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>( auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
gemmOp.getLoc(), adaptor.getB(), slicingOffsets, slicingSizes, gemmOp.getLoc(), adaptor.getB(), slicingOffsets, slicingSizes, slicingStrides);
slicingStrides);
// Add a note to the extractSliceOp, with the filterX and filterY. // Add a note to the extractSliceOp, with the filterX and filterY.
weightsGroups[inputTile][outputTile].push_back(extractSliceOp); weightsGroups[inputTile][outputTile].push_back(extractSliceOp);
@@ -141,8 +136,7 @@ struct ExperimentalGemmConversionPattern
// -------------------------------- // // -------------------------------- //
// Get the next group of weights for the compute unit. // Get the next group of weights for the compute unit.
SmallVector<TaggedWeights> weightsGroups = SmallVector<TaggedWeights> weightsGroups = weightSubdivider.popGroups(crossbarCountInCore.getValue());
weightSubdivider.popGroups(crossbarCountInCore.getValue());
SmallVector<Value> computeWeights; SmallVector<Value> computeWeights;
SmallVector<Value> computeOperands; SmallVector<Value> computeOperands;
@@ -159,15 +153,13 @@ struct ExperimentalGemmConversionPattern
// Iterate over all weights groups for this compute unit. // Iterate over all weights groups for this compute unit.
map<long, Value> localSlices; // WRT the current compute unit. map<long, Value> localSlices; // WRT the current compute unit.
for (auto group : weightsGroups) { for (auto group : weightsGroups) {
for (Value weight : group.weights) { for (Value weight : group.weights)
computeWeights.push_back(weight); computeWeights.push_back(weight);
}
// There might be multiple weight groups for the same input tile, so if // There might be multiple weight groups for the same input tile, so if
// we've already added the input tile, skip it. // we've already added the input tile, skip it.
if (localSlices.find(group.inputTile) != localSlices.end()) { if (localSlices.find(group.inputTile) != localSlices.end())
continue; continue;
}
// We might have already sliced the input tensor for some other compute // We might have already sliced the input tensor for some other compute
// unit, so if we have, reuse the slicing operation without creating a // unit, so if we have, reuse the slicing operation without creating a
@@ -179,26 +171,21 @@ struct ExperimentalGemmConversionPattern
} }
// Create the input tensor slicing offsets. // Create the input tensor slicing offsets.
SmallVector<OpFoldResult> slicingOffsets{ SmallVector<OpFoldResult> slicingOffsets {/* 0 */ rewriter.getIndexAttr(0), // No offset along the batch axis.
/* 0 */ rewriter.getIndexAttr(0), // No offset along the batch axis. /* 1 */ rewriter.getIndexAttr(group.inputTile * crossbarSize),
/* 1 */ rewriter.getIndexAttr(group.inputTile * crossbarSize), /* 2 */ /* rewriter.getIndexAttr(0), */
/* 2 */ /* rewriter.getIndexAttr(0), */ /* 3 */ /* rewriter.getIndexAttr(0) */};
/* 3 */ /* rewriter.getIndexAttr(0) */};
// Create the input tensor slicing sizes. // Create the input tensor slicing sizes.
size_t tilingSize = group.inputTile == inputTileCount.quot size_t tilingSize = group.inputTile == inputTileCount.quot ? inputTileCount.rem : crossbarSize;
? inputTileCount.rem SmallVector<OpFoldResult> slicingSizes {/* 0 */ rewriter.getIndexAttr(1), // Batch size is always 1.
: crossbarSize; /* 1 */ rewriter.getIndexAttr(tilingSize),
SmallVector<OpFoldResult> slicingSizes{ /* 2 */ /* rewriter.getIndexAttr(GET_IMAGE_WIDTH(inputType)), */
/* 0 */ rewriter.getIndexAttr(1), // Batch size is always 1. /* 3 */ /* rewriter.getIndexAttr(GET_IMAGE_HEIGHT(inputType)) */};
/* 1 */ rewriter.getIndexAttr(tilingSize),
/* 2 */ /* rewriter.getIndexAttr(GET_IMAGE_WIDTH(inputType)), */
/* 3 */ /* rewriter.getIndexAttr(GET_IMAGE_HEIGHT(inputType)) */};
// Create the slice extraction operation. // Create the slice extraction operation.
auto extractSliceOp = auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
rewriter.create<tensor::ExtractSliceOp>(gemmOp.getLoc(), gemmOp.getLoc(), adaptor.getA(), slicingOffsets, slicingSizes, slicingStrides);
adaptor.getA(), slicingOffsets, slicingSizes, slicingStrides);
computeOperands.push_back(extractSliceOp); computeOperands.push_back(extractSliceOp);
@@ -220,36 +207,28 @@ struct ExperimentalGemmConversionPattern
// There might be multiple weight groups for the same output tile, so if // There might be multiple weight groups for the same output tile, so if
// we've already added the output tile, skip it. // we've already added the output tile, skip it.
if (outputTileIndices.find(group.outputTile) != if (outputTileIndices.find(group.outputTile) != outputTileIndices.end())
outputTileIndices.end()) {
continue; continue;
}
// Additionally, after adding the input slices as operands, also add any // Additionally, after adding the input slices as operands, also add any
// compatible partial results from previous compute units. // compatible partial results from previous compute units.
if (globalPartialResults.find(group.outputTile) != if (globalPartialResults.find(group.outputTile) != globalPartialResults.end()) {
globalPartialResults.end()) {
computeOperands.push_back(globalPartialResults[group.outputTile]); computeOperands.push_back(globalPartialResults[group.outputTile]);
reductionTileIndices[group.outputTile] = computeOperands.size() - 1; reductionTileIndices[group.outputTile] = computeOperands.size() - 1;
} }
// Define the output shape for this group. // Define the output shape for this group.
long outputTileSize = group.outputTile == outputTileCount.quot long outputTileSize = group.outputTile == outputTileCount.quot ? outputTileCount.rem : crossbarSize;
? outputTileCount.rem
: crossbarSize;
// TODO: Address non-same padding. // TODO: Address non-same padding.
SmallVector<int64_t> outputShapeArray{ SmallVector<int64_t> outputShapeArray {/* 0 */ 1, // Batch size is always 1.
/* 0 */ 1, // Batch size is always 1. /* 1 */ outputTileSize,
/* 1 */ outputTileSize, /* 2 */ /* GET_IMAGE_WIDTH(outputType), */ // Same padding assumed.
/* 2 */ /* GET_IMAGE_WIDTH(outputType), */ // Same padding assumed. /* 3 */ /* GET_IMAGE_HEIGHT(outputType) */};
/* 3 */ /* GET_IMAGE_HEIGHT(outputType) */};
auto elementType = dyn_cast<RankedTensorType>(gemmOp.getY().getType()) auto elementType = dyn_cast<RankedTensorType>(gemmOp.getY().getType()).getElementType();
.getElementType();
computeOutputType.push_back( computeOutputType.push_back(RankedTensorType::get(outputShapeArray, elementType));
RankedTensorType::get(outputShapeArray, elementType));
outputTileIndices[group.outputTile] = computeOutputType.size() - 1; outputTileIndices[group.outputTile] = computeOutputType.size() - 1;
} }
@@ -259,43 +238,36 @@ struct ExperimentalGemmConversionPattern
// ----------------------------- // // ----------------------------- //
// Create the compute unit. // Create the compute unit.
spatial::SpatWeightedCompute currentCompute = spatial::SpatWeightedCompute currentCompute = rewriter.create<spatial::SpatWeightedCompute>(
rewriter.create<spatial::SpatWeightedCompute>(gemmOp.getLoc(), gemmOp.getLoc(), computeOutputType, computeWeights, computeOperands);
computeOutputType, computeWeights, computeOperands);
// Create a new block for the compute unit and add the operands. // Create a new block for the compute unit and add the operands.
Block *block = rewriter.createBlock(&currentCompute.getRegion()); Block* block = rewriter.createBlock(&currentCompute.getRegion());
rewriter.setInsertionPointToStart(block); rewriter.setInsertionPointToStart(block);
for (Value operand : computeOperands) { for (Value operand : computeOperands)
block->addArgument(operand.getType(), gemmOp->getLoc()); block->addArgument(operand.getType(), gemmOp->getLoc());
}
// Initialize a map of local partial results. // Initialize a map of local partial results.
map<long, Value> localPartialResults; // WRT the current compute unit. map<long, Value> localPartialResults; // WRT the current compute unit.
// If we have any reduction tiles, add them to the local partial results. // If we have any reduction tiles, add them to the local partial results.
for (auto reductionTileIndex : reductionTileIndices) { for (auto reductionTileIndex : reductionTileIndices)
localPartialResults[reductionTileIndex.first] = localPartialResults[reductionTileIndex.first] = block->getArgument(reductionTileIndex.second);
block->getArgument(reductionTileIndex.second);
}
// Add all the applyFilters operations to the block. // Add all the applyFilters operations to the block.
for (TaggedWeights group : weightsGroups) { for (TaggedWeights group : weightsGroups) {
// Get the outputType for this group. // Get the outputType for this group.
Type outputType = Type outputType = computeOutputType[outputTileIndices[group.outputTile]];
computeOutputType[outputTileIndices[group.outputTile]];
// Create an apply filters operation. // Create an apply filters operation.
BlockArgument blockArgument = BlockArgument blockArgument = block->getArgument(inputTileIndices[group.inputTile]);
block->getArgument(inputTileIndices[group.inputTile]);
// The list of weight indices is group.startingCrossbarIndex + 0, 1, 2, // The list of weight indices is group.startingCrossbarIndex + 0, 1, 2,
// ... As many weights as the size of group.weights. // ... As many weights as the size of group.weights.
SmallVector<long> weightIndices; SmallVector<long> weightIndices;
for (size_t i = 0; i < group.weights.size(); ++i) { for (size_t i = 0; i < group.weights.size(); ++i)
weightIndices.push_back(group.startingCrossbarIndex + i); weightIndices.push_back(group.startingCrossbarIndex + i);
}
SmallVector<int64_t> xKerPos; SmallVector<int64_t> xKerPos;
SmallVector<int64_t> yKerPos; SmallVector<int64_t> yKerPos;
@@ -313,16 +285,14 @@ struct ExperimentalGemmConversionPattern
ArrayAttr xKerPosAttr = rewriter.getI64ArrayAttr(xKerPos); ArrayAttr xKerPosAttr = rewriter.getI64ArrayAttr(xKerPos);
ArrayAttr yKerPosAttr = rewriter.getI64ArrayAttr(yKerPos); ArrayAttr yKerPosAttr = rewriter.getI64ArrayAttr(yKerPos);
Value result = rewriter.create<spatial::SpatApplyFiltersOp>(gemmOp.getLoc(), Value result = rewriter.create<spatial::SpatApplyFiltersOp>(
outputType, weightIndicesAttr, xKerPosAttr, yKerPosAttr, gemmOp.getLoc(), outputType, weightIndicesAttr, xKerPosAttr, yKerPosAttr, blockArgument);
blockArgument);
// Perform local reduction if necessary. // Perform local reduction if necessary.
if (localPartialResults.find(group.outputTile) != if (localPartialResults.find(group.outputTile) != localPartialResults.end()) {
localPartialResults.end()) {
result = rewriter.create<spatial::SpatVAddOp>(gemmOp.getLoc(), result = rewriter.create<spatial::SpatVAddOp>(
result.getType(), localPartialResults[group.outputTile], result); gemmOp.getLoc(), result.getType(), localPartialResults[group.outputTile], result);
} }
// Update the partial results map. // Update the partial results map.
@@ -375,25 +345,20 @@ struct ExperimentalGemmConversionPattern
// Turn the values into a SmallVector. // Turn the values into a SmallVector.
SmallVector<Value> outputValues; SmallVector<Value> outputValues;
for (long i = 0; i < outputTileCount.quot + (outputTileCount.rem > 0); for (long i = 0; i < outputTileCount.quot + (outputTileCount.rem > 0); ++i)
++i) {
outputValues.push_back(globalPartialResults[i]); outputValues.push_back(globalPartialResults[i]);
}
// Assert that the number of output values is correct. // Assert that the number of output values is correct.
assert(outputValues.size() > 0 && assert(outputValues.size() > 0 && "No output values were generated for the GEMM operation.");
"No output values were generated for the GEMM operation.");
// Return the final output. // Return the final output.
rewriter.replaceOp(gemmOp, rewriter.replaceOp(gemmOp, rewriter.create<tensor::ConcatOp>(gemmOp.getLoc(), 1, outputValues));
rewriter.create<tensor::ConcatOp>(gemmOp.getLoc(), 1, outputValues));
return success(); return success();
} }
}; };
void populateGemmToConvConversionPattern( void populateGemmToConvConversionPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
RewritePatternSet &patterns, MLIRContext *ctx) {
patterns.insert<ExperimentalGemmConversionPattern>(ctx); patterns.insert<ExperimentalGemmConversionPattern>(ctx);
} }

View File

@@ -6,20 +6,22 @@
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h" #include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h" #include "mlir/IR/ValueRange.h"
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Debug.h" #include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h" #include "llvm/Support/raw_ostream.h"
#include <cassert> #include <cassert>
#include <cmath> #include <cmath>
#include <cstddef> #include <cstddef>
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir; using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
@@ -35,71 +37,68 @@ bool hasPostProcessExperimentalPoolingWindow<ONNXAveragePoolOp>() {
} }
template <typename PoolOp> template <typename PoolOp>
Value postProcessExperimentalPoolingWindow(ConversionPatternRewriter &rewriter, Value postProcessExperimentalPoolingWindow(ConversionPatternRewriter& rewriter,
Location loc, PoolOp poolOp, Value valueToDivide, size_t krn_size, Location loc,
size_t tilesSkippedByPadding) { PoolOp poolOp,
Value valueToDivide,
size_t krn_size,
size_t tilesSkippedByPadding) {
return nullptr; return nullptr;
} }
template <> template <>
Value postProcessExperimentalPoolingWindow<ONNXAveragePoolOp>( Value postProcessExperimentalPoolingWindow<ONNXAveragePoolOp>(ConversionPatternRewriter& rewriter,
ConversionPatternRewriter &rewriter, Location loc, ONNXAveragePoolOp poolOp, Location loc,
Value valueToDivide, size_t krn_size, size_t tilesSkippedByPadding) { ONNXAveragePoolOp poolOp,
Value valueToDivide,
size_t krn_size,
size_t tilesSkippedByPadding) {
bool countIncludePad = poolOp.getCountIncludePad() == 1; bool countIncludePad = poolOp.getCountIncludePad() == 1;
size_t divisorNumber = size_t divisorNumber = countIncludePad ? krn_size : krn_size - tilesSkippedByPadding;
countIncludePad ? krn_size : krn_size - tilesSkippedByPadding;
RankedTensorType scalarTensor = RankedTensorType scalarTensor = RankedTensorType::get({1}, rewriter.getF32Type());
RankedTensorType::get({1}, rewriter.getF32Type());
// Put a spat.const before the computeOp, and use its value. We do this to be // Put a spat.const before the computeOp, and use its value. We do this to be
// compatible with the current code generation, which assumes constant to be // compatible with the current code generation, which assumes constant to be
// loaded in global memory, which is allocated by adding a spat.const OP // loaded in global memory, which is allocated by adding a spat.const OP
// directly under func.func (i.e. alongside ComputeOps) // directly under func.func (i.e. alongside ComputeOps)
auto computeOp = cast<spatial::SpatWeightedCompute>( auto computeOp = cast<spatial::SpatWeightedCompute>(valueToDivide.getDefiningOp()->getParentOp());
valueToDivide.getDefiningOp()->getParentOp());
rewriter.setInsertionPoint(computeOp); rewriter.setInsertionPoint(computeOp);
auto divisorValue = rewriter.create<spatial::SpatConstantOp>(loc, scalarTensor, auto divisorValue = rewriter.create<spatial::SpatConstantOp>(loc,
rewriter.getI64IntegerAttr(divisorNumber), scalarTensor,
/* should_allocate = */ rewriter.getBoolAttr(true)); rewriter.getI64IntegerAttr(divisorNumber),
/* should_allocate = */ rewriter.getBoolAttr(true));
rewriter.setInsertionPointAfterValue(valueToDivide); rewriter.setInsertionPointAfterValue(valueToDivide);
return rewriter.create<spatial::SpatVSDivOp>( return rewriter.create<spatial::SpatVSDivOp>(loc, valueToDivide.getType(), valueToDivide, divisorValue);
loc, valueToDivide.getType(), valueToDivide, divisorValue);
} }
template <typename ReductionOp> template <typename ReductionOp>
Value reduceInputTiles( Value reduceInputTiles(SmallVector<Value>& inputTiles, ConversionPatternRewriter& rewriter) {
SmallVector<Value> &inputTiles, ConversionPatternRewriter &rewriter) { if (inputTiles.size() == 1)
if (inputTiles.size() == 1) {
return inputTiles[0]; return inputTiles[0];
}
if (inputTiles.size() == 2) { if (inputTiles.size() == 2) {
return rewriter.create<spatial::SpatVMaxOp>(inputTiles[0].getLoc(), return rewriter.create<spatial::SpatVMaxOp>(
inputTiles[0].getType(), inputTiles[0], inputTiles[1]); inputTiles[0].getLoc(), inputTiles[0].getType(), inputTiles[0], inputTiles[1]);
} }
SmallVector<Value> left( SmallVector<Value> left(inputTiles.begin(), inputTiles.begin() + inputTiles.size() / 2);
inputTiles.begin(), inputTiles.begin() + inputTiles.size() / 2); SmallVector<Value> right(inputTiles.begin() + inputTiles.size() / 2, inputTiles.end());
SmallVector<Value> right(
inputTiles.begin() + inputTiles.size() / 2, inputTiles.end());
Value leftReduced = reduceInputTiles<ReductionOp>(left, rewriter); Value leftReduced = reduceInputTiles<ReductionOp>(left, rewriter);
Value rightReduced = reduceInputTiles<ReductionOp>(right, rewriter); Value rightReduced = reduceInputTiles<ReductionOp>(right, rewriter);
return rewriter.create<ReductionOp>( return rewriter.create<ReductionOp>(inputTiles[0].getLoc(), leftReduced.getType(), leftReduced, rightReduced);
inputTiles[0].getLoc(), leftReduced.getType(), leftReduced, rightReduced);
} }
template <typename PoolOp, typename PoolOpAdaptor, typename ReduceOp> template <typename PoolOp, typename PoolOpAdaptor, typename ReduceOp>
struct ExperimentalPoolingBaseConverter : public OpConversionPattern<PoolOp> { struct ExperimentalPoolingBaseConverter : public OpConversionPattern<PoolOp> {
ExperimentalPoolingBaseConverter(MLIRContext *ctx) ExperimentalPoolingBaseConverter(MLIRContext* ctx)
: OpConversionPattern<PoolOp>(ctx) {} : OpConversionPattern<PoolOp>(ctx) {}
LogicalResult matchAndRewrite(PoolOp poolOp, PoolOpAdaptor adaptor, LogicalResult matchAndRewrite(PoolOp poolOp, PoolOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final {
ConversionPatternRewriter &rewriter) const final {
Value X = adaptor.getX(); Value X = adaptor.getX();
ShapedType xShape = mlir::cast<ShapedType>(X.getType()); ShapedType xShape = mlir::cast<ShapedType>(X.getType());
Value Y = poolOp.getResult(); Value Y = poolOp.getResult();
@@ -110,17 +109,13 @@ struct ExperimentalPoolingBaseConverter : public OpConversionPattern<PoolOp> {
unpackOptionalPairVector(adaptor.getDilations(), dilation_x, dilation_y); unpackOptionalPairVector(adaptor.getDilations(), dilation_x, dilation_y);
unpackOptionalPairVector(adaptor.getKernelShape(), krn_w, krn_h); unpackOptionalPairVector(adaptor.getKernelShape(), krn_w, krn_h);
if (adaptor.getAutoPad() != "NOTSET") { if (adaptor.getAutoPad() != "NOTSET")
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(poolOp, "auto_pad != NOTSET is deprecated.");
poolOp, "auto_pad != NOTSET is deprecated.");
}
size_t pad_x, pad_y; size_t pad_x, pad_y;
auto padUnpackError = auto padUnpackError = unpackOptionalPadsVector(adaptor.getPads(), pad_x, pad_y);
unpackOptionalPadsVector(adaptor.getPads(), pad_x, pad_y); if (padUnpackError.has_value())
if (padUnpackError.has_value()) {
return rewriter.notifyMatchFailure(poolOp, padUnpackError.value()); return rewriter.notifyMatchFailure(poolOp, padUnpackError.value());
}
Location loc = poolOp.getLoc(); Location loc = poolOp.getLoc();
@@ -133,10 +128,8 @@ struct ExperimentalPoolingBaseConverter : public OpConversionPattern<PoolOp> {
// Assert that the input is a tensor.ConcatOp. // Assert that the input is a tensor.ConcatOp.
auto concat = X.getDefiningOp<tensor::ConcatOp>(); auto concat = X.getDefiningOp<tensor::ConcatOp>();
if (!concat) { if (!concat)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(poolOp, "Expected input to be a tensor.ConcatOp");
poolOp, "Expected input to be a tensor.ConcatOp");
}
// Create a [channel_tile][x][y] array to store the input tiles. // Create a [channel_tile][x][y] array to store the input tiles.
std::map<long, std::map<long, std::map<long, Value>>> inputTiles; std::map<long, std::map<long, std::map<long, Value>>> inputTiles;
@@ -145,24 +138,21 @@ struct ExperimentalPoolingBaseConverter : public OpConversionPattern<PoolOp> {
for (size_t y = 0; y < input_h; ++y) { for (size_t y = 0; y < input_h; ++y) {
for (size_t x = 0; x < input_w; ++x) { for (size_t x = 0; x < input_w; ++x) {
for (long it = 0; it < tileCount.quot + (tileCount.rem > 0); ++it) { for (long it = 0; it < tileCount.quot + (tileCount.rem > 0); ++it) {
size_t tilingSize = size_t tilingSize = it == tileCount.quot ? tileCount.rem : crossbarSize;
it == tileCount.quot ? tileCount.rem : crossbarSize;
SmallVector<OpFoldResult> strides(4, rewriter.getIndexAttr(1)); SmallVector<OpFoldResult> strides(4, rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> offsets = {/* 0 */ rewriter.getIndexAttr(0), SmallVector<OpFoldResult> offsets = {/* 0 */ rewriter.getIndexAttr(0),
/* 1 */ rewriter.getIndexAttr(0), /* 1 */ rewriter.getIndexAttr(0),
/* 2 */ rewriter.getIndexAttr(x), /* 2 */ rewriter.getIndexAttr(x),
/* 3 */ rewriter.getIndexAttr(y)}; /* 3 */ rewriter.getIndexAttr(y)};
SmallVector<OpFoldResult> sizes = { SmallVector<OpFoldResult> sizes = {/* 0 */ rewriter.getIndexAttr(1), // Batch size is always 1.
/* 0 */ rewriter.getIndexAttr(1), // Batch size is always 1. /* 1 */ rewriter.getIndexAttr(tilingSize),
/* 1 */ rewriter.getIndexAttr(tilingSize), /* 2 */ rewriter.getIndexAttr(1),
/* 2 */ rewriter.getIndexAttr(1), /* 3 */ rewriter.getIndexAttr(1)};
/* 3 */ rewriter.getIndexAttr(1)};
// Get the concat's operand that we want to slice. // Get the concat's operand that we want to slice.
Value concatInput = concat.getOperand(it); Value concatInput = concat.getOperand(it);
Value slicedTile = rewriter.create<tensor::ExtractSliceOp>( Value slicedTile = rewriter.create<tensor::ExtractSliceOp>(loc, concatInput, offsets, sizes, strides);
loc, concatInput, offsets, sizes, strides);
inputTiles[it][x][y] = slicedTile; inputTiles[it][x][y] = slicedTile;
} }
@@ -175,19 +165,15 @@ struct ExperimentalPoolingBaseConverter : public OpConversionPattern<PoolOp> {
for (size_t y = 0; y < output_h; ++y) { for (size_t y = 0; y < output_h; ++y) {
for (size_t x = 0; x < output_w; ++x) { for (size_t x = 0; x < output_w; ++x) {
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) { for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) {
SmallVector<int64_t> outputShapeArray{ SmallVector<int64_t> outputShapeArray {/* 0 */ 1, // Batch size is always 1.
/* 0 */ 1, // Batch size is always 1. /* 1 */
/* 1 */ cast<RankedTensorType>(inputTiles[it][0][0].getType()).getShape()[1],
cast<RankedTensorType>(inputTiles[it][0][0].getType()) /* 2 */ 1,
.getShape()[1], /* 3 */ 1};
/* 2 */ 1,
/* 3 */ 1};
auto elementType = auto elementType = dyn_cast<RankedTensorType>(xShape).getElementType();
dyn_cast<RankedTensorType>(xShape).getElementType();
outputTileTypes.push_back( outputTileTypes.push_back(RankedTensorType::get(outputShapeArray, elementType));
RankedTensorType::get(outputShapeArray, elementType));
} }
} }
} }
@@ -195,29 +181,25 @@ struct ExperimentalPoolingBaseConverter : public OpConversionPattern<PoolOp> {
// Create a plain value list of the input tiles. // Create a plain value list of the input tiles.
SmallVector<Value> inputTilesList; SmallVector<Value> inputTilesList;
for (size_t y = 0; y < input_h; ++y) { for (size_t y = 0; y < input_h; ++y) {
for (size_t x = 0; x < input_w; ++x) { for (size_t x = 0; x < input_w; ++x)
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) { for (long it = 0; it < itc.quot + (itc.rem > 0); ++it)
inputTilesList.push_back(inputTiles[it][y][x]); inputTilesList.push_back(inputTiles[it][y][x]);
}
}
} }
// Create a single compute to calculate the output. // Create a single compute to calculate the output.
auto computeOp = rewriter.create<spatial::SpatWeightedCompute>( auto computeOp =
loc, outputTileTypes, SmallVector<Value>(), inputTilesList); rewriter.create<spatial::SpatWeightedCompute>(loc, outputTileTypes, SmallVector<Value>(), inputTilesList);
// Create a new block for the compute unit and add the operands. // Create a new block for the compute unit and add the operands.
Block *block = rewriter.createBlock(&computeOp.getRegion()); Block* block = rewriter.createBlock(&computeOp.getRegion());
// Fill the block arguments and keep a reference to them. // Fill the block arguments and keep a reference to them.
std::map<size_t, std::map<size_t, std::map<size_t, Value>>> inputTilesArgs; std::map<size_t, std::map<size_t, std::map<size_t, Value>>> inputTilesArgs;
for (size_t y = 0; y < input_h; ++y) { for (size_t y = 0; y < input_h; ++y) {
for (size_t x = 0; x < input_w; ++x) { for (size_t x = 0; x < input_w; ++x) {
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) { for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) {
auto tileIndex = y * input_w * (itc.quot + (itc.rem > 0)) + auto tileIndex = y * input_w * (itc.quot + (itc.rem > 0)) + x * (itc.quot + (itc.rem > 0)) + it;
x * (itc.quot + (itc.rem > 0)) + it; inputTilesArgs[it][y][x] = block->addArgument(computeOp->getOperand(tileIndex).getType(), loc);
inputTilesArgs[it][y][x] = block->addArgument(
computeOp->getOperand(tileIndex).getType(), loc);
} }
} }
} }
@@ -236,28 +218,26 @@ struct ExperimentalPoolingBaseConverter : public OpConversionPattern<PoolOp> {
size_t end_y = std::min(start_y + krn_h, input_h); size_t end_y = std::min(start_y + krn_h, input_h);
SmallVector<Value> inputTilesToReduce; SmallVector<Value> inputTilesToReduce;
for (size_t ky = start_y; ky < end_y; ++ky) { for (size_t ky = start_y; ky < end_y; ++ky)
for (size_t kx = start_x; kx < end_x; ++kx) { for (size_t kx = start_x; kx < end_x; ++kx)
inputTilesToReduce.push_back(inputTilesArgs[it][ky][kx]); inputTilesToReduce.push_back(inputTilesArgs[it][ky][kx]);
}
}
auto reduceResult = auto reduceResult = reduceInputTiles<ReduceOp>(inputTilesToReduce, rewriter);
reduceInputTiles<ReduceOp>(inputTilesToReduce, rewriter);
// If the reduce op is add, we need to divide the result by the // If the reduce op is add, we need to divide the result by the
// number of elements in the pooling window. // number of elements in the pooling window.
if (hasPostProcessExperimentalPoolingWindow<PoolOp>()) { if (hasPostProcessExperimentalPoolingWindow<PoolOp>()) {
// Add a spat.const before the computeOp. // Add a spat.const before the computeOp.
rewriter.setInsertionPoint(computeOp); rewriter.setInsertionPoint(computeOp);
auto divisorValue = rewriter.create<spatial::SpatConstantOp>(loc, auto divisorValue =
RankedTensorType::get({1}, rewriter.getF32Type()), rewriter.create<spatial::SpatConstantOp>(loc,
rewriter.getI64IntegerAttr(krn_w * krn_h), RankedTensorType::get({1}, rewriter.getF32Type()),
rewriter.getBoolAttr(true)); rewriter.getI64IntegerAttr(krn_w * krn_h),
rewriter.getBoolAttr(true));
rewriter.setInsertionPointAfter(reduceResult.getDefiningOp()); rewriter.setInsertionPointAfter(reduceResult.getDefiningOp());
reduceResult = rewriter.create<spatial::SpatVSDivOp>( reduceResult =
loc, reduceResult.getType(), reduceResult, divisorValue); rewriter.create<spatial::SpatVSDivOp>(loc, reduceResult.getType(), reduceResult, divisorValue);
} }
outputTiles.push_back(reduceResult); outputTiles.push_back(reduceResult);
} }
@@ -274,8 +254,7 @@ struct ExperimentalPoolingBaseConverter : public OpConversionPattern<PoolOp> {
for (size_t y = 0; y < output_h; ++y) { for (size_t y = 0; y < output_h; ++y) {
for (size_t x = 0; x < output_w; ++x) { for (size_t x = 0; x < output_w; ++x) {
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) { for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) {
auto tileIndex = y * output_w * (itc.quot + (itc.rem > 0)) + auto tileIndex = y * output_w * (itc.quot + (itc.rem > 0)) + x * (itc.quot + (itc.rem > 0)) + it;
x * (itc.quot + (itc.rem > 0)) + it;
computeOutput[it][y][x] = computeOp.getResult(tileIndex); computeOutput[it][y][x] = computeOp.getResult(tileIndex);
} }
} }
@@ -285,30 +264,25 @@ struct ExperimentalPoolingBaseConverter : public OpConversionPattern<PoolOp> {
SmallVector<Value> outputTilesList; SmallVector<Value> outputTilesList;
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) { for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) {
SmallVector<Value> imgConcatTiles; SmallVector<Value> imgConcatTiles;
for (size_t y = 0; y < output_h; ++y) { for (size_t y = 0; y < output_h; ++y)
for (size_t x = 0; x < output_w; ++x) { for (size_t x = 0; x < output_w; ++x)
imgConcatTiles.push_back(computeOutput[it][y][x]); imgConcatTiles.push_back(computeOutput[it][y][x]);
}
}
size_t tilingSize = it == tileCount.quot ? tileCount.rem : crossbarSize; size_t tilingSize = it == tileCount.quot ? tileCount.rem : crossbarSize;
SmallVector<int64_t> outputShapeArray{ SmallVector<int64_t> outputShapeArray {/* 0 */ 1, // Batch size is always 1.
/* 0 */ 1, // Batch size is always 1. /* 1 */ (long) tilingSize,
/* 1 */ (long)tilingSize, /* 2 */ (long) output_w,
/* 2 */ (long)output_w, /* 3 */ (long) output_h};
/* 3 */ (long)output_h};
auto elementType = dyn_cast<RankedTensorType>(xShape).getElementType(); auto elementType = dyn_cast<RankedTensorType>(xShape).getElementType();
outputTilesList.push_back(rewriter.create<spatial::SpatImgConcatOp>(loc, outputTilesList.push_back(rewriter.create<spatial::SpatImgConcatOp>(
RankedTensorType::get(outputShapeArray, elementType), loc, RankedTensorType::get(outputShapeArray, elementType), imgConcatTiles));
imgConcatTiles));
} }
// Create a new tensor.ConcatOp to concatenate the output tiles. // Create a new tensor.ConcatOp to concatenate the output tiles.
Value outputTensor = Value outputTensor = rewriter.create<tensor::ConcatOp>(loc, 1, outputTilesList);
rewriter.create<tensor::ConcatOp>(loc, 1, outputTilesList);
rewriter.replaceOp(poolOp, outputTensor); rewriter.replaceOp(poolOp, outputTensor);
@@ -316,12 +290,11 @@ struct ExperimentalPoolingBaseConverter : public OpConversionPattern<PoolOp> {
} }
}; };
void populateExperimentalPoolingTilingPattern( void populateExperimentalPoolingTilingPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
RewritePatternSet &patterns, MLIRContext *ctx) { patterns.insert<
patterns.insert<ExperimentalPoolingBaseConverter<ONNXMaxPoolSingleOutOp, ExperimentalPoolingBaseConverter<ONNXMaxPoolSingleOutOp, ONNXMaxPoolSingleOutOpAdaptor, spatial::SpatVMaxOp>>(ctx);
ONNXMaxPoolSingleOutOpAdaptor, spatial::SpatVMaxOp>>(ctx); patterns.insert<ExperimentalPoolingBaseConverter<ONNXAveragePoolOp, ONNXAveragePoolOpAdaptor, spatial::SpatVAddOp>>(
patterns.insert<ExperimentalPoolingBaseConverter<ONNXAveragePoolOp, ctx);
ONNXAveragePoolOpAdaptor, spatial::SpatVAddOp>>(ctx);
} }
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -6,38 +6,39 @@
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h" #include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h" #include "mlir/IR/ValueRange.h"
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Debug.h" #include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h" #include "llvm/Support/raw_ostream.h"
#include <cassert> #include <cassert>
#include <cmath> #include <cmath>
#include <cstddef> #include <cstddef>
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir; using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
llvm::SmallPtrSet<Operation *, 16> oldComputeOpsReplaced; llvm::SmallPtrSet<Operation*, 16> oldComputeOpsReplaced;
Value applyReducePatternNew(SmallVector<Value> &valuesToReduce, Value applyReducePatternNew(SmallVector<Value>& valuesToReduce,
ConversionPatternRewriter &rewriter, ConversionPatternRewriter& rewriter,
std::function<Value(const Value &, const Value &)> reduce, std::function<Value(const Value&, const Value&)> reduce,
std::function<Value(const Value &)> preprocess, std::function<Value(const Value&)> preprocess,
std::function<Value(const Value &)> postprocess) { std::function<Value(const Value&)> postprocess) {
// Simple case: if we have only one input, just return it // Simple case: if we have only one input, just return it
if (valuesToReduce.size() == 1) { if (valuesToReduce.size() == 1)
return valuesToReduce[0]; return valuesToReduce[0];
}
if (preprocess) { if (preprocess) {
for (auto &valToReduce : valuesToReduce) { for (auto& valToReduce : valuesToReduce) {
rewriter.setInsertionPointAfterValue(valToReduce); rewriter.setInsertionPointAfterValue(valToReduce);
valToReduce = preprocess(valToReduce); valToReduce = preprocess(valToReduce);
} }
@@ -47,9 +48,9 @@ Value applyReducePatternNew(SmallVector<Value> &valuesToReduce,
// computeOp. In this case, we need to apply the reduction within-computef // computeOp. In this case, we need to apply the reduction within-computef
// Keep a map between a computeOp and the last Value for this reduction // Keep a map between a computeOp and the last Value for this reduction
std::unordered_map<Operation *, Value> lastValueForCompute; std::unordered_map<Operation*, Value> lastValueForCompute;
for (auto &valToReduce : valuesToReduce) { for (auto& valToReduce : valuesToReduce) {
Operation *computeOp = valToReduce.getParentBlock()->getParentOp(); Operation* computeOp = valToReduce.getParentBlock()->getParentOp();
// if (valToReduce.getDefiningOp()) { // if (valToReduce.getDefiningOp()) {
// // If the value is defined by an operation, we take the parent // // If the value is defined by an operation, we take the parent
// operation computeOp = valToReduce.getDefiningOp()->getParentOp(); // operation computeOp = valToReduce.getDefiningOp()->getParentOp();
@@ -67,12 +68,10 @@ Value applyReducePatternNew(SmallVector<Value> &valuesToReduce,
// within-compute // within-compute
Value lastWithinComputeValue = it->second; Value lastWithinComputeValue = it->second;
if (valToReduce.getDefiningOp()->isBeforeInBlock( if (valToReduce.getDefiningOp()->isBeforeInBlock(lastWithinComputeValue.getDefiningOp()))
lastWithinComputeValue.getDefiningOp())) {
rewriter.setInsertionPointAfterValue(lastWithinComputeValue); rewriter.setInsertionPointAfterValue(lastWithinComputeValue);
} else { else
rewriter.setInsertionPointAfterValue(valToReduce); rewriter.setInsertionPointAfterValue(valToReduce);
}
valToReduce = reduce(lastWithinComputeValue, valToReduce); valToReduce = reduce(lastWithinComputeValue, valToReduce);
lastValueForCompute[computeOp] = valToReduce; lastValueForCompute[computeOp] = valToReduce;
} }
@@ -83,9 +82,8 @@ Value applyReducePatternNew(SmallVector<Value> &valuesToReduce,
// Now, reconstruct from the map the valuesToReduce list // Now, reconstruct from the map the valuesToReduce list
valuesToReduce.clear(); valuesToReduce.clear();
valuesToReduce.reserve(lastValueForCompute.size()); valuesToReduce.reserve(lastValueForCompute.size());
for (auto &entry : lastValueForCompute) { for (auto& entry : lastValueForCompute)
valuesToReduce.push_back(entry.second); valuesToReduce.push_back(entry.second);
}
Location loc = valuesToReduce[0].getLoc(); Location loc = valuesToReduce[0].getLoc();
auto channelType = spatial::SpatChannelType::get(rewriter.getContext()); auto channelType = spatial::SpatChannelType::get(rewriter.getContext());
@@ -123,8 +121,7 @@ Value applyReducePatternNew(SmallVector<Value> &valuesToReduce,
// 3. Add a receiveOp after the second value // 3. Add a receiveOp after the second value
rewriter.setInsertionPointAfterValue(secondValue); rewriter.setInsertionPointAfterValue(secondValue);
auto receivedValue = rewriter.create<spatial::SpatChannelReceiveOp>( auto receivedValue = rewriter.create<spatial::SpatChannelReceiveOp>(loc, secondValue.getType(), channel);
loc, secondValue.getType(), channel);
// 4. Apply reduction between second value and received value // 4. Apply reduction between second value and received value
rewriter.setInsertionPointAfterValue(receivedValue); rewriter.setInsertionPointAfterValue(receivedValue);
@@ -135,17 +132,14 @@ Value applyReducePatternNew(SmallVector<Value> &valuesToReduce,
// If we have an odd number of inputs, we need to add the last one to the // If we have an odd number of inputs, we need to add the last one to the
// newInputs list. // newInputs list.
if (valuesToReduceRef.size() % 2 == 1) { if (valuesToReduceRef.size() % 2 == 1)
nextValuesToReduce.push_back(valuesToReduceRef.back()); nextValuesToReduce.push_back(valuesToReduceRef.back());
}
// Replace the inputOps list with the new one. // Replace the inputOps list with the new one.
valuesToReduceRef = valuesToReduceRef = llvm::OwningArrayRef<Value>(std::move(nextValuesToReduce));
llvm::OwningArrayRef<Value>(std::move(nextValuesToReduce));
} }
assert(valuesToReduceRef.size() == 1 && assert(valuesToReduceRef.size() == 1 && "Internal error: expected a single input at this point.");
"Internal error: expected a single input at this point.");
auto finalValue = valuesToReduceRef[0]; auto finalValue = valuesToReduceRef[0];
@@ -168,46 +162,49 @@ bool hasPostProcessPoolingWindow<ONNXAveragePoolOp>() {
} }
template <typename PoolOp> template <typename PoolOp>
Value postProcessPoolingWindow(ConversionPatternRewriter &rewriter, Value postProcessPoolingWindow(ConversionPatternRewriter& rewriter,
Location loc, PoolOp poolOp, Value valueToDivide, size_t krn_size, Location loc,
size_t tilesSkippedByPadding) { PoolOp poolOp,
Value valueToDivide,
size_t krn_size,
size_t tilesSkippedByPadding) {
return nullptr; return nullptr;
} }
template <> template <>
Value postProcessPoolingWindow<ONNXAveragePoolOp>( Value postProcessPoolingWindow<ONNXAveragePoolOp>(ConversionPatternRewriter& rewriter,
ConversionPatternRewriter &rewriter, Location loc, ONNXAveragePoolOp poolOp, Location loc,
Value valueToDivide, size_t krn_size, size_t tilesSkippedByPadding) { ONNXAveragePoolOp poolOp,
Value valueToDivide,
size_t krn_size,
size_t tilesSkippedByPadding) {
bool countIncludePad = poolOp.getCountIncludePad() == 1; bool countIncludePad = poolOp.getCountIncludePad() == 1;
size_t divisorNumber = size_t divisorNumber = countIncludePad ? krn_size : krn_size - tilesSkippedByPadding;
countIncludePad ? krn_size : krn_size - tilesSkippedByPadding;
RankedTensorType scalarTensor = RankedTensorType scalarTensor = RankedTensorType::get({1}, rewriter.getF32Type());
RankedTensorType::get({1}, rewriter.getF32Type());
// Put a spat.const before the computeOp, and use its value. We do this to be // Put a spat.const before the computeOp, and use its value. We do this to be
// compatible with the current code generation, which assumes constant to be // compatible with the current code generation, which assumes constant to be
// loaded in global memory, which is allocated by adding a spat.const OP // loaded in global memory, which is allocated by adding a spat.const OP
// directly under func.func (i.e. alongside ComputeOps) // directly under func.func (i.e. alongside ComputeOps)
auto computeOp = cast<spatial::SpatWeightedCompute>( auto computeOp = cast<spatial::SpatWeightedCompute>(valueToDivide.getDefiningOp()->getParentOp());
valueToDivide.getDefiningOp()->getParentOp());
rewriter.setInsertionPoint(computeOp); rewriter.setInsertionPoint(computeOp);
auto divisorValue = rewriter.create<spatial::SpatConstantOp>(loc, scalarTensor, auto divisorValue = rewriter.create<spatial::SpatConstantOp>(loc,
rewriter.getI64IntegerAttr(divisorNumber), scalarTensor,
/* should_allocate = */ rewriter.getBoolAttr(true)); rewriter.getI64IntegerAttr(divisorNumber),
/* should_allocate = */ rewriter.getBoolAttr(true));
rewriter.setInsertionPointAfterValue(valueToDivide); rewriter.setInsertionPointAfterValue(valueToDivide);
return rewriter.create<spatial::SpatVSDivOp>( return rewriter.create<spatial::SpatVSDivOp>(loc, valueToDivide.getType(), valueToDivide, divisorValue);
loc, valueToDivide.getType(), valueToDivide, divisorValue);
} }
template <typename PoolOp, typename PoolOpAdaptor, typename ReduceOp> template <typename PoolOp, typename PoolOpAdaptor, typename ReduceOp>
struct PoolingBaseConverter : public OpConversionPattern<PoolOp> { struct PoolingBaseConverter : public OpConversionPattern<PoolOp> {
PoolingBaseConverter(MLIRContext *ctx) : OpConversionPattern<PoolOp>(ctx) {} PoolingBaseConverter(MLIRContext* ctx)
: OpConversionPattern<PoolOp>(ctx) {}
LogicalResult matchAndRewrite(PoolOp poolOp, PoolOpAdaptor adaptor, LogicalResult matchAndRewrite(PoolOp poolOp, PoolOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final {
ConversionPatternRewriter &rewriter) const final {
Value X = adaptor.getX(); Value X = adaptor.getX();
ShapedType xShape = mlir::cast<ShapedType>(X.getType()); ShapedType xShape = mlir::cast<ShapedType>(X.getType());
Value Y = poolOp.getResult(); Value Y = poolOp.getResult();
@@ -218,17 +215,13 @@ struct PoolingBaseConverter : public OpConversionPattern<PoolOp> {
unpackOptionalPairVector(adaptor.getDilations(), dilation_x, dilation_y); unpackOptionalPairVector(adaptor.getDilations(), dilation_x, dilation_y);
unpackOptionalPairVector(adaptor.getKernelShape(), krn_w, krn_h); unpackOptionalPairVector(adaptor.getKernelShape(), krn_w, krn_h);
if (adaptor.getAutoPad() != "NOTSET") { if (adaptor.getAutoPad() != "NOTSET")
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(poolOp, "auto_pad != NOTSET is deprecated.");
poolOp, "auto_pad != NOTSET is deprecated.");
}
size_t pad_x, pad_y; size_t pad_x, pad_y;
auto padUnpackError = auto padUnpackError = unpackOptionalPadsVector(adaptor.getPads(), pad_x, pad_y);
unpackOptionalPadsVector(adaptor.getPads(), pad_x, pad_y); if (padUnpackError.has_value())
if (padUnpackError.has_value()) {
return rewriter.notifyMatchFailure(poolOp, padUnpackError.value()); return rewriter.notifyMatchFailure(poolOp, padUnpackError.value());
}
Location loc = poolOp.getLoc(); Location loc = poolOp.getLoc();
@@ -236,8 +229,7 @@ struct PoolingBaseConverter : public OpConversionPattern<PoolOp> {
size_t input_w = GET_IMAGE_WIDTH(xShape); size_t input_w = GET_IMAGE_WIDTH(xShape);
size_t output_h = GET_IMAGE_HEIGHT(yShape); size_t output_h = GET_IMAGE_HEIGHT(yShape);
size_t output_w = GET_IMAGE_WIDTH(yShape); size_t output_w = GET_IMAGE_WIDTH(yShape);
size_t channelTileCount = size_t channelTileCount = ceilIntegerDivide(GET_IMAGE_CHANNEL(xShape), crossbarSize.getValue());
ceilIntegerDivide(GET_IMAGE_CHANNEL(xShape), crossbarSize.getValue());
size_t channelTileRest = GET_IMAGE_CHANNEL(xShape) % crossbarSize; size_t channelTileRest = GET_IMAGE_CHANNEL(xShape) % crossbarSize;
// 1: Tile the input tensor // 1: Tile the input tensor
@@ -249,14 +241,13 @@ struct PoolingBaseConverter : public OpConversionPattern<PoolOp> {
// Example complete input tensor: tensor<1x3x12x12xf32> (NxCxWxH) // Example complete input tensor: tensor<1x3x12x12xf32> (NxCxWxH)
// Suppose that the input tensor is produced by concatenating the results of // Suppose that the input tensor is produced by concatenating the results of
// many ComputeOps. Get the result tiles from these ComputeOps. // many ComputeOps. Get the result tiles from these ComputeOps.
SmallVector<SmallVector<SmallVector<Value>>> inputTiles(channelTileCount, SmallVector<SmallVector<SmallVector<Value>>> inputTiles(
SmallVector<SmallVector<Value>>(input_w, SmallVector<Value>(input_h))); channelTileCount, SmallVector<SmallVector<Value>>(input_w, SmallVector<Value>(input_h)));
auto resolveErrorOpt = resolveImgInputTiles(X, inputTiles, channelTileCount, auto resolveErrorOpt =
channelTileRest, input_w, input_h, rewriter); resolveImgInputTiles(X, inputTiles, channelTileCount, channelTileRest, input_w, input_h, rewriter);
if (resolveErrorOpt.has_value()) { if (resolveErrorOpt.has_value())
return rewriter.notifyMatchFailure(poolOp, *resolveErrorOpt); return rewriter.notifyMatchFailure(poolOp, *resolveErrorOpt);
}
// TODO: This requires a core for each input tile, which is not ideal. We // TODO: This requires a core for each input tile, which is not ideal. We
// can do better. // can do better.
@@ -265,18 +256,17 @@ struct PoolingBaseConverter : public OpConversionPattern<PoolOp> {
for (size_t t = 0; t < channelTileCount; t++) { for (size_t t = 0; t < channelTileCount; t++) {
for (size_t x = 0; x < input_w; x++) { for (size_t x = 0; x < input_w; x++) {
for (size_t y = 0; y < input_h; y++) { for (size_t y = 0; y < input_h; y++) {
if (auto extractSliceOp = if (auto extractSliceOp = inputTiles[t][x][y].getDefiningOp<tensor::ExtractSliceOp>()) {
inputTiles[t][x][y].getDefiningOp<tensor::ExtractSliceOp>()) {
Location tileLoc = extractSliceOp.getLoc(); Location tileLoc = extractSliceOp.getLoc();
auto tempComputeOp = rewriter.create<spatial::SpatWeightedCompute>( auto tempComputeOp = rewriter.create<spatial::SpatWeightedCompute>(tileLoc,
tileLoc, extractSliceOp.getResultType(), extractSliceOp.getResultType(),
/* xbarWeights =*/ValueRange(), extractSliceOp.getResult()); /* xbarWeights =*/ValueRange(),
extractSliceOp.getResult());
Block *tempComputeOpBlock = new Block(); Block* tempComputeOpBlock = new Block();
tempComputeOp.getBody().push_back(tempComputeOpBlock); tempComputeOp.getBody().push_back(tempComputeOpBlock);
auto tempComputeOpBlockArg = tempComputeOpBlock->addArgument( auto tempComputeOpBlockArg = tempComputeOpBlock->addArgument(extractSliceOp.getType(), tileLoc);
extractSliceOp.getType(), tileLoc);
rewriter.setInsertionPointToStart(tempComputeOpBlock); rewriter.setInsertionPointToStart(tempComputeOpBlock);
rewriter.create<spatial::SpatYieldOp>(tileLoc, tempComputeOpBlockArg); rewriter.create<spatial::SpatYieldOp>(tileLoc, tempComputeOpBlockArg);
@@ -295,8 +285,7 @@ struct PoolingBaseConverter : public OpConversionPattern<PoolOp> {
// For example: outputTiles[channelTile][x][y] // For example: outputTiles[channelTile][x][y]
// Example complete output tensor: tensor<1x3x6x6xf32> (NxCxWxH) // Example complete output tensor: tensor<1x3x6x6xf32> (NxCxWxH)
SmallVector<SmallVector<SmallVector<Value>>> outputTiles( SmallVector<SmallVector<SmallVector<Value>>> outputTiles(
channelTileCount, SmallVector<SmallVector<Value>>( channelTileCount, SmallVector<SmallVector<Value>>(output_w, SmallVector<Value>(output_h, nullptr)));
output_w, SmallVector<Value>(output_h, nullptr)));
// List of values to pool for each output pixel // List of values to pool for each output pixel
SmallVector<Value> valuesToPool; SmallVector<Value> valuesToPool;
@@ -312,15 +301,12 @@ struct PoolingBaseConverter : public OpConversionPattern<PoolOp> {
valuesToPool.clear(); valuesToPool.clear();
size_t tilesSkippedByPadding = 0; size_t tilesSkippedByPadding = 0;
auto [start_x, end_x] = kernel_get_start_and_end( auto [start_x, end_x] = kernel_get_start_and_end(outX, input_w, krn_w, stride_x, dilation_x, pad_x);
outX, input_w, krn_w, stride_x, dilation_x, pad_x); auto [start_y, end_y] = kernel_get_start_and_end(outY, input_h, krn_h, stride_y, dilation_y, pad_y);
auto [start_y, end_y] = kernel_get_start_and_end(
outY, input_h, krn_h, stride_y, dilation_y, pad_y);
for (size_t inX = start_x; inX < end_x; inX += dilation_x) { for (size_t inX = start_x; inX < end_x; inX += dilation_x) {
for (size_t inY = start_y; inY < end_y; inY += dilation_y) { for (size_t inY = start_y; inY < end_y; inY += dilation_y) {
if (failed(verifyWithinBoundsAndPaddings( if (failed(verifyWithinBoundsAndPaddings(input_w, input_h, inX, inY, pad_x, pad_y))) {
input_w, input_h, inX, inY, pad_x, pad_y))) {
tilesSkippedByPadding++; tilesSkippedByPadding++;
continue; continue;
} }
@@ -328,78 +314,73 @@ struct PoolingBaseConverter : public OpConversionPattern<PoolOp> {
Value inputTile = inputTiles[outTile][inX][inY]; Value inputTile = inputTiles[outTile][inX][inY];
Value valueToPool; Value valueToPool;
if (auto computeProducer = if (auto computeProducer = inputTile.getDefiningOp<spatial::SpatWeightedCompute>()) {
inputTile.getDefiningOp<spatial::SpatWeightedCompute>()) {
int resultNumber = getResultIndex(computeProducer, inputTile); int resultNumber = getResultIndex(computeProducer, inputTile);
auto yieldInComputeOp = cast<spatial::SpatYieldOp>( auto yieldInComputeOp = cast<spatial::SpatYieldOp>(computeProducer.getBody().front().getTerminator());
computeProducer.getBody().front().getTerminator());
valueToPool = yieldInComputeOp.getOperand(resultNumber); valueToPool = yieldInComputeOp.getOperand(resultNumber);
} else if (auto receiveProducer = }
inputTile else if (auto receiveProducer = inputTile.getDefiningOp<spatial::SpatChannelReceiveOp>()) {
.getDefiningOp<spatial::SpatChannelReceiveOp>()) { auto sendOpOpt = getOtherEndOfChannel(receiveProducer, true, rewriter);
auto sendOpOpt =
getOtherEndOfChannel(receiveProducer, true, rewriter);
if (failed(sendOpOpt)) { if (failed(sendOpOpt)) {
return rewriter.notifyMatchFailure(poolOp, return rewriter.notifyMatchFailure(poolOp,
"ChannelReceiveOp does not have a matching " "ChannelReceiveOp does not have a matching "
"ChannelSendOp."); "ChannelSendOp.");
} }
auto sendOp = cast<spatial::SpatChannelSendOp>(*sendOpOpt); auto sendOp = cast<spatial::SpatChannelSendOp>(*sendOpOpt);
valueToPool = sendOp.getData(); valueToPool = sendOp.getData();
} else { }
else {
return rewriter.notifyMatchFailure(poolOp, return rewriter.notifyMatchFailure(poolOp,
"Input tile for Pooling is not produced by a " "Input tile for Pooling is not produced by a "
"WeightedComputeOp nor a receiveOp"); "WeightedComputeOp nor a receiveOp");
} }
valuesToPool.push_back(valueToPool); valuesToPool.push_back(valueToPool);
} }
} }
assert(valuesToPool.size() != 0 && assert(valuesToPool.size() != 0 && "Pooling computed on zero tiles make no sense.");
"Pooling computed on zero tiles make no sense.");
// assert(computeOpsForPooling.size() != 1 && // assert(computeOpsForPooling.size() != 1 &&
// "Pooling computed on one tiles make no sense??? Or maybe // "Pooling computed on one tiles make no sense??? Or maybe
// this " "should have been simplified earlier???"); // this " "should have been simplified earlier???");
std::function<Value(const Value &)> postProcessFn = nullptr; std::function<Value(const Value&)> postProcessFn = nullptr;
if (hasPostProcessPoolingWindow<PoolOp>()) { if (hasPostProcessPoolingWindow<PoolOp>()) {
postProcessFn = [&](const Value prevFinalRes) { postProcessFn = [&](const Value prevFinalRes) {
return postProcessPoolingWindow(rewriter, loc, poolOp, return postProcessPoolingWindow(
prevFinalRes, krn_h * krn_w, tilesSkippedByPadding); rewriter, loc, poolOp, prevFinalRes, krn_h * krn_w, tilesSkippedByPadding);
}; };
} }
Value reducedWithinCompute = applyReducePatternNew( Value reducedWithinCompute = applyReducePatternNew(
valuesToPool, rewriter, valuesToPool,
[&](const Value lhs, const Value rhs) { rewriter,
return rewriter.create<ReduceOp>(loc, lhs.getType(), lhs, rhs); [&](const Value lhs, const Value rhs) { return rewriter.create<ReduceOp>(loc, lhs.getType(), lhs, rhs); },
}, nullptr,
nullptr, postProcessFn); postProcessFn);
// Send this value through a channel, and receive it in the // Send this value through a channel, and receive it in the
// `func.func`. During lowering, we will need to "move it" into the // `func.func`. During lowering, we will need to "move it" into the
// users computeOps // users computeOps
auto computeOpOfReduced = cast<spatial::SpatWeightedCompute>( auto computeOpOfReduced =
reducedWithinCompute.getDefiningOp()->getParentOp()); cast<spatial::SpatWeightedCompute>(reducedWithinCompute.getDefiningOp()->getParentOp());
// Create a new channel before the computeOp // Create a new channel before the computeOp
rewriter.setInsertionPoint(computeOpOfReduced); rewriter.setInsertionPoint(computeOpOfReduced);
auto reduceChannel = rewriter.create<spatial::SpatChannelNewOp>( auto reduceChannel =
loc, spatial::SpatChannelType::get(rewriter.getContext())); rewriter.create<spatial::SpatChannelNewOp>(loc, spatial::SpatChannelType::get(rewriter.getContext()));
// Send value through the channel // Send value through the channel
rewriter.setInsertionPointAfterValue(reducedWithinCompute); rewriter.setInsertionPointAfterValue(reducedWithinCompute);
rewriter.create<spatial::SpatChannelSendOp>( rewriter.create<spatial::SpatChannelSendOp>(loc, reduceChannel, reducedWithinCompute);
loc, reduceChannel, reducedWithinCompute);
// Receive after the computeOp // Receive after the computeOp
rewriter.setInsertionPointAfter(computeOpOfReduced); rewriter.setInsertionPointAfter(computeOpOfReduced);
auto receivedValue = rewriter.create<spatial::SpatChannelReceiveOp>( auto receivedValue =
loc, reducedWithinCompute.getType(), reduceChannel); rewriter.create<spatial::SpatChannelReceiveOp>(loc, reducedWithinCompute.getType(), reduceChannel);
outputTiles[outTile][outX][outY] = receivedValue; outputTiles[outTile][outX][outY] = receivedValue;
} }
@@ -409,9 +390,7 @@ struct PoolingBaseConverter : public OpConversionPattern<PoolOp> {
// TODO: outputTiles are not the results of the computeOps! We need to add // TODO: outputTiles are not the results of the computeOps! We need to add
// them! // them!
std::unordered_map<Operation *, std::unordered_map<Operation*, SmallVector<std::tuple<size_t, size_t, size_t, Value>>> computeOpNeedingResults;
SmallVector<std::tuple<size_t, size_t, size_t, Value>>>
computeOpNeedingResults;
// Iterate each output tile // Iterate each output tile
for (size_t outTile = 0; outTile < channelTileCount; outTile++) { for (size_t outTile = 0; outTile < channelTileCount; outTile++) {
@@ -422,18 +401,16 @@ struct PoolingBaseConverter : public OpConversionPattern<PoolOp> {
auto outputTileProducer = outputTile.getDefiningOp()->getParentOp(); auto outputTileProducer = outputTile.getDefiningOp()->getParentOp();
if (!outputTileProducer) { if (!outputTileProducer) {
return rewriter.notifyMatchFailure(poolOp, return rewriter.notifyMatchFailure(poolOp,
"Output tile for Pooling is not produced by a " "Output tile for Pooling is not produced by a "
"WeightedComputeOp."); "WeightedComputeOp.");
} }
computeOpNeedingResults[outputTileProducer].push_back( computeOpNeedingResults[outputTileProducer].push_back(std::make_tuple(outTile, outX, outY, outputTile));
std::make_tuple(outTile, outX, outY, outputTile));
} }
} }
} }
Value outputImage = Value outputImage = createImgConcatOp(outputTiles, rewriter, loc, poolOp.getType());
createImgConcatOp(outputTiles, rewriter, loc, poolOp.getType());
rewriter.replaceOp(poolOp, outputImage); rewriter.replaceOp(poolOp, outputImage);
@@ -441,12 +418,10 @@ struct PoolingBaseConverter : public OpConversionPattern<PoolOp> {
} }
}; };
void populatePoolingTilingPattern( void populatePoolingTilingPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
RewritePatternSet &patterns, MLIRContext *ctx) { patterns.insert<PoolingBaseConverter<ONNXMaxPoolSingleOutOp, ONNXMaxPoolSingleOutOpAdaptor, spatial::SpatVMaxOp>>(
patterns.insert<PoolingBaseConverter<ONNXMaxPoolSingleOutOp, ctx);
ONNXMaxPoolSingleOutOpAdaptor, spatial::SpatVMaxOp>>(ctx); patterns.insert<PoolingBaseConverter<ONNXAveragePoolOp, ONNXAveragePoolOpAdaptor, spatial::SpatVAddOp>>(ctx);
patterns.insert<PoolingBaseConverter<ONNXAveragePoolOp,
ONNXAveragePoolOpAdaptor, spatial::SpatVAddOp>>(ctx);
} }
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -1,20 +1,19 @@
#include "mlir/Transforms/DialectConversion.h"
#include "Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp" #include "Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp"
#include "mlir/Transforms/DialectConversion.h"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir; using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
struct ReduceMeanConversionPattern struct ReduceMeanConversionPattern : public OpConversionPattern<ONNXReduceMeanV13Op> {
: public OpConversionPattern<ONNXReduceMeanV13Op> {
ReduceMeanConversionPattern(MLIRContext *ctx) : OpConversionPattern(ctx) {} ReduceMeanConversionPattern(MLIRContext* ctx)
: OpConversionPattern(ctx) {}
LogicalResult matchAndRewrite(ONNXReduceMeanV13Op reduceMean, LogicalResult matchAndRewrite(ONNXReduceMeanV13Op reduceMean,
ONNXReduceMeanV13OpAdaptor adaptor, ONNXReduceMeanV13OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final { ConversionPatternRewriter& rewriter) const final {
// Get the input tensor. // Get the input tensor.
Value inputTensor = adaptor.getData(); Value inputTensor = adaptor.getData();
@@ -38,42 +37,42 @@ struct ReduceMeanConversionPattern
SmallVector<int64_t> padsVals = {0, 0, 0, 0}; SmallVector<int64_t> padsVals = {0, 0, 0, 0};
// Create the ArrayAttrs // Create the ArrayAttrs
auto kernelShape = mlir::ArrayAttr::get(rewriter.getContext(), auto kernelShape = mlir::ArrayAttr::get(
llvm::to_vector( rewriter.getContext(), llvm::to_vector(llvm::map_range(kernelShapeVals, [&](int64_t v) -> mlir::Attribute {
llvm::map_range(kernelShapeVals, [&](int64_t v) -> mlir::Attribute { return rewriter.getI64IntegerAttr(v);
return rewriter.getI64IntegerAttr(v); })));
})));
auto strides = mlir::ArrayAttr::get(rewriter.getContext(), auto strides = mlir::ArrayAttr::get(rewriter.getContext(),
llvm::to_vector( llvm::to_vector(llvm::map_range(stridesVals, [&](int64_t v) -> mlir::Attribute {
llvm::map_range(stridesVals, [&](int64_t v) -> mlir::Attribute { return rewriter.getI64IntegerAttr(v);
return rewriter.getI64IntegerAttr(v); })));
})));
auto dilations = mlir::ArrayAttr::get(rewriter.getContext(), auto dilations = mlir::ArrayAttr::get(
llvm::to_vector( rewriter.getContext(), llvm::to_vector(llvm::map_range(dilationsVals, [&](int64_t v) -> mlir::Attribute {
llvm::map_range(dilationsVals, [&](int64_t v) -> mlir::Attribute { return rewriter.getI64IntegerAttr(v);
return rewriter.getI64IntegerAttr(v); })));
})));
auto pads = mlir::ArrayAttr::get(rewriter.getContext(), auto pads = mlir::ArrayAttr::get(rewriter.getContext(),
llvm::to_vector( llvm::to_vector(llvm::map_range(padsVals, [&](int64_t v) -> mlir::Attribute {
llvm::map_range(padsVals, [&](int64_t v) -> mlir::Attribute { return rewriter.getI64IntegerAttr(v);
return rewriter.getI64IntegerAttr(v); })));
})));
// Create the resulting tensor type. // Create the resulting tensor type.
auto resultType = RankedTensorType::get( auto resultType = RankedTensorType::get(
/*shape=*/{inputTensorType.getShape()[0], inputTensorType.getShape()[1], /*shape=*/ {inputTensorType.getShape()[0], inputTensorType.getShape()[1], 1, 1},
1, 1}, /*elementType=*/inputTensorType.getElementType());
/*elementType=*/inputTensorType.getElementType());
// Create the ONNXAveragePoolOp. // Create the ONNXAveragePoolOp.
auto averagePool = rewriter.create<ONNXAveragePoolOp>(reduceMean.getLoc(), auto averagePool = rewriter.create<ONNXAveragePoolOp>(reduceMean.getLoc(),
resultType, inputTensor, /*auto_pad=*/"NOTSET", resultType,
/*ceil_mode=*/0, /*count_include_pad=*/1, dilations, inputTensor,
/*kernel_shape=*/kernelShape, /*auto_pad=*/"NOTSET",
/*pads=*/pads, /*strides=*/strides); /*ceil_mode=*/0,
/*count_include_pad=*/1,
dilations,
/*kernel_shape=*/kernelShape,
/*pads=*/pads,
/*strides=*/strides);
// Replace the ONNXReduceMeanV13Op with the ONNXAveragePoolOp. // Replace the ONNXReduceMeanV13Op with the ONNXAveragePoolOp.
rewriter.replaceOp(reduceMean, averagePool.getResult()); rewriter.replaceOp(reduceMean, averagePool.getResult());
@@ -82,8 +81,7 @@ struct ReduceMeanConversionPattern
} }
}; };
void populateReduceMeanConversionPattern( void populateReduceMeanConversionPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
RewritePatternSet &patterns, MLIRContext *ctx) {
patterns.insert<ReduceMeanConversionPattern>(ctx); patterns.insert<ReduceMeanConversionPattern>(ctx);
} }

View File

@@ -6,11 +6,12 @@
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h" #include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "llvm/Support/LogicalResult.h" #include "llvm/Support/LogicalResult.h"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#define DEFINE_MAP_OP(opname) opname, #define DEFINE_MAP_OP(opname) opname,
#define GET_IMAGE_WIDTH(shapedType) shapedType.getDimSize(2) #define GET_IMAGE_WIDTH(shapedType) shapedType.getDimSize(2)

View File

@@ -6,7 +6,6 @@
#include "llvm/Support/Debug.h" #include "llvm/Support/Debug.h"
#include "llvm/Support/raw_os_ostream.h" #include "llvm/Support/raw_os_ostream.h"
#include <filesystem>
#include <fstream> #include <fstream>
#include "Common/PIMCommon.hpp" #include "Common/PIMCommon.hpp"
@@ -16,7 +15,6 @@
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
#include "src/Compiler/CompilerOptions.hpp" #include "src/Compiler/CompilerOptions.hpp"
using namespace mlir; using namespace mlir;
@@ -39,7 +37,7 @@ void ONNXToSpatialPass::runOnOperation() {
mergeActivationPatterns.add<matMulToGemmPattern>(ctx); mergeActivationPatterns.add<matMulToGemmPattern>(ctx);
mergeActivationPatterns.add<removeFlattenSameShapePattern>(ctx); mergeActivationPatterns.add<removeFlattenSameShapePattern>(ctx);
if (failed(applyPatternsAndFoldGreedily(moduleOp, std::move(mergeActivationPatterns)))) if (failed(applyPatternsGreedily(moduleOp, std::move(mergeActivationPatterns))))
llvm::dbgs() << "Failed to merge activation patterns, continuing...\n"; llvm::dbgs() << "Failed to merge activation patterns, continuing...\n";
IRRewriter rewriter(moduleOp); IRRewriter rewriter(moduleOp);
@@ -88,7 +86,7 @@ void ONNXToSpatialPass::runOnOperation() {
if (coresCount != -1) { if (coresCount != -1) {
int computeOpsCount = 0; int computeOpsCount = 0;
for (auto& op : funcOp.getFunctionBody().front().getOperations()) for (auto& op : funcOp.getFunctionBody().front().getOperations())
if (isa<spatial::SpatWeightedCompute>(op)) if (isa<SpatWeightedCompute>(op))
computeOpsCount++; computeOpsCount++;
if (computeOpsCount > coresCount) { if (computeOpsCount > coresCount) {

View File

@@ -3,38 +3,26 @@
namespace onnx_mlir { namespace onnx_mlir {
void populateLoweringONNXMatMulOpToSpatialPattern( void populateLoweringONNXMatMulOpToSpatialPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx);
void populateTilingGemmOpPattern( void populateTilingGemmOpPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx); void populateTilingConvOpPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateTilingConvOpPattern(
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx);
void populatePoolingTilingPattern( void populatePoolingTilingPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx);
void populateDistributeReducePattern( void populateDistributeReducePattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx);
void populateFoldComputePattern( void populateFoldComputePattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx);
void populateONNXConcatToTensorConcatPattern( void populateONNXConcatToTensorConcatPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx);
void populateRemoveUnusedHelperOpsPatterns( void populateRemoveUnusedHelperOpsPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx);
void populateReduceMeanConversionPattern( void populateReduceMeanConversionPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx);
// Experimental patterns. // Experimental patterns.
void populateExperimentalTilingConvOpPattern( void populateExperimentalTilingConvOpPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx); void populateGemmToConvConversionPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateGemmToConvConversionPattern( void populateExperimentalPoolingTilingPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx);
void populateExperimentalPoolingTilingPattern(
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx);
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -1,19 +1,20 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir; using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
struct ONNXConcatToTensorConcat : public OpConversionPattern<ONNXConcatOp> { struct ONNXConcatToTensorConcat : public OpConversionPattern<ONNXConcatOp> {
ONNXConcatToTensorConcat(MLIRContext *ctx) : OpConversionPattern(ctx) {} ONNXConcatToTensorConcat(MLIRContext* ctx)
: OpConversionPattern(ctx) {}
LogicalResult matchAndRewrite(ONNXConcatOp maxpoolOp, LogicalResult matchAndRewrite(ONNXConcatOp maxpoolOp,
ONNXConcatOpAdaptor adaptor, ONNXConcatOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final { ConversionPatternRewriter& rewriter) const final {
auto inputs = adaptor.getInputs(); auto inputs = adaptor.getInputs();
int64_t axis = adaptor.getAxis(); int64_t axis = adaptor.getAxis();
@@ -23,8 +24,7 @@ struct ONNXConcatToTensorConcat : public OpConversionPattern<ONNXConcatOp> {
} }
}; };
void populateONNXConcatToTensorConcatPattern( void populateONNXConcatToTensorConcatPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
RewritePatternSet &patterns, MLIRContext *ctx) {
patterns.insert<ONNXConcatToTensorConcat>(ctx); patterns.insert<ONNXConcatToTensorConcat>(ctx);
} }

View File

@@ -1,5 +1,6 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"

View File

@@ -1,10 +1,10 @@
#include <queue>
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/AnnotateReplication.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/AnnotateReplication.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
#include <queue>
using namespace mlir; using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {

View File

@@ -5,7 +5,6 @@
namespace onnx_mlir { namespace onnx_mlir {
mlir::LogicalResult annotateReplication( mlir::LogicalResult annotateReplication(mlir::func::FuncOp funcOp, mlir::IRRewriter& rewriter);
mlir::func::FuncOp funcOp, mlir::IRRewriter &rewriter);
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -1,32 +1,31 @@
#include "SpatialReducer.hpp"
#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Value.h" #include "mlir/IR/Value.h"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "llvm/Support/raw_ostream.h" #include "llvm/Support/raw_ostream.h"
#include <cassert> #include <cassert>
#include <unordered_map> #include <unordered_map>
#include <utility> #include <utility>
#include "SpatialReducer.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#define GET_COMP(computeOpAndResNum) std::get<0>(computeOpAndResNum) #define GET_COMP(computeOpAndResNum) std::get<0>(computeOpAndResNum)
#define GET_RES_NUM(computeOpAndResNum) std::get<1>(computeOpAndResNum) #define GET_RES_NUM(computeOpAndResNum) std::get<1>(computeOpAndResNum)
namespace onnx_mlir { namespace onnx_mlir {
llvm::SmallPtrSet<Operation *, 16> llvm::SmallPtrSet<Operation*, 16> onnx_mlir::SpatialReducer::oldComputeOpsReplaced;
onnx_mlir::SpatialReducer::oldComputeOpsReplaced;
ResNum SpatialReducer::applyResultProcessing( ResNum SpatialReducer::applyResultProcessing(ComputeAndResNum computeOpAndResNum,
ComputeAndResNum computeOpAndResNum, std::function<Value(const Value&)> processFun,
std::function<Value(const Value &)> processFun, ConversionPatternRewriter& rewriter) {
ConversionPatternRewriter &rewriter) {
assert(processFun); assert(processFun);
auto computeOp = GET_COMP(computeOpAndResNum); auto computeOp = GET_COMP(computeOpAndResNum);
auto resultNum = GET_RES_NUM(computeOpAndResNum); auto resultNum = GET_RES_NUM(computeOpAndResNum);
spatial::SpatYieldOp yieldOp = spatial::SpatYieldOp yieldOp = cast<spatial::SpatYieldOp>(computeOp.getBody().front().getTerminator());
cast<spatial::SpatYieldOp>(computeOp.getBody().front().getTerminator());
Value result = yieldOp->getOperand(resultNum); Value result = yieldOp->getOperand(resultNum);
rewriter.setInsertionPointAfterValue(result); rewriter.setInsertionPointAfterValue(result);
@@ -43,30 +42,24 @@ ResNum SpatialReducer::applyResultProcessing(
return yieldOp.getNumOperands() - 1; return yieldOp.getNumOperands() - 1;
} }
OpAndResNum SpatialReducer::applyReducePattern( OpAndResNum SpatialReducer::applyReducePattern(SmallVector<ComputeAndResNum>& computeOpsAndResNum,
SmallVector<ComputeAndResNum> &computeOpsAndResNum, std::function<Value(const Value&, const Value&)> reduce,
std::function<Value(const Value &, const Value &)> reduce, std::function<Value(const Value&)> preprocess,
std::function<Value(const Value &)> preprocess, std::function<Value(const Value&)> postprocess) {
std::function<Value(const Value &)> postprocess) {
if (preprocess) { if (preprocess)
for (auto &computeOpAndResNum : computeOpsAndResNum) { for (auto& computeOpAndResNum : computeOpsAndResNum)
GET_RES_NUM(computeOpAndResNum) = GET_RES_NUM(computeOpAndResNum) = applyResultProcessing(computeOpAndResNum, preprocess, rewriter);
applyResultProcessing(computeOpAndResNum, preprocess, rewriter);
}
}
// It is possible that `computeOpsAndResNum` contains two entries for the same // It is possible that `computeOpsAndResNum` contains two entries for the same
// computeOp. In this case, we need to apply the reduction within-computef // computeOp. In this case, we need to apply the reduction within-computef
// Keep a map between a computeOp and the last Value for this reduction // Keep a map between a computeOp and the last Value for this reduction
std::unordered_map<Operation *, Value> lastValueForCompute; std::unordered_map<Operation*, Value> lastValueForCompute;
for (auto &computeOpAndResNum : computeOpsAndResNum) { for (auto& computeOpAndResNum : computeOpsAndResNum) {
auto computeOp = GET_COMP(computeOpAndResNum); auto computeOp = GET_COMP(computeOpAndResNum);
auto yieldOp = auto yieldOp = cast<spatial::SpatYieldOp>(computeOp.getBody().front().getTerminator());
cast<spatial::SpatYieldOp>(computeOp.getBody().front().getTerminator()); Value valueWithinCompute = yieldOp->getOperand(GET_RES_NUM(computeOpAndResNum));
Value valueWithinCompute =
yieldOp->getOperand(GET_RES_NUM(computeOpAndResNum));
auto it = lastValueForCompute.find(computeOp.getOperation()); auto it = lastValueForCompute.find(computeOp.getOperation());
@@ -75,15 +68,12 @@ OpAndResNum SpatialReducer::applyReducePattern(
// within-compute // within-compute
Value lastWithinComputeValue = it->second; Value lastWithinComputeValue = it->second;
assert(valueWithinCompute.getDefiningOp() && assert(valueWithinCompute.getDefiningOp() && lastWithinComputeValue.getDefiningOp());
lastWithinComputeValue.getDefiningOp());
if (valueWithinCompute.getDefiningOp()->isBeforeInBlock( if (valueWithinCompute.getDefiningOp()->isBeforeInBlock(lastWithinComputeValue.getDefiningOp()))
lastWithinComputeValue.getDefiningOp())) {
rewriter.setInsertionPointAfterValue(lastWithinComputeValue); rewriter.setInsertionPointAfterValue(lastWithinComputeValue);
} else { else
rewriter.setInsertionPointAfterValue(valueWithinCompute); rewriter.setInsertionPointAfterValue(valueWithinCompute);
}
valueWithinCompute = reduce(lastWithinComputeValue, valueWithinCompute); valueWithinCompute = reduce(lastWithinComputeValue, valueWithinCompute);
lastValueForCompute[computeOp.getOperation()] = valueWithinCompute; lastValueForCompute[computeOp.getOperation()] = valueWithinCompute;
} }
@@ -94,16 +84,15 @@ OpAndResNum SpatialReducer::applyReducePattern(
// Now, reconstruct from the map the computeOpsAndResNum list // Now, reconstruct from the map the computeOpsAndResNum list
computeOpsAndResNum.clear(); computeOpsAndResNum.clear();
computeOpsAndResNum.reserve(lastValueForCompute.size()); computeOpsAndResNum.reserve(lastValueForCompute.size());
for (auto &entry : lastValueForCompute) { for (auto& entry : lastValueForCompute) {
auto computeOp = cast<spatial::SpatWeightedCompute>(entry.first); auto computeOp = cast<spatial::SpatWeightedCompute>(entry.first);
auto valueWithinCompute = entry.second; auto valueWithinCompute = entry.second;
// We check if `valueWithinCompute` is already used by the yieldOp, in that // We check if `valueWithinCompute` is already used by the yieldOp, in that
// case no need to add it // case no need to add it
auto yieldOp = auto yieldOp = cast<spatial::SpatYieldOp>(computeOp.getBody().front().getTerminator());
cast<spatial::SpatYieldOp>(computeOp.getBody().front().getTerminator());
bool yieldOpUseFound = false; bool yieldOpUseFound = false;
for (auto &use : valueWithinCompute.getUses()) { for (auto& use : valueWithinCompute.getUses()) {
if (use.getOwner() == yieldOp.getOperation()) { if (use.getOwner() == yieldOp.getOperation()) {
// If the value is already used by the yieldOp, we can just use it // If the value is already used by the yieldOp, we can just use it
computeOpsAndResNum.push_back({computeOp, use.getOperandNumber()}); computeOpsAndResNum.push_back({computeOp, use.getOperandNumber()});
@@ -111,9 +100,8 @@ OpAndResNum SpatialReducer::applyReducePattern(
break; break;
} }
} }
if (yieldOpUseFound) { if (yieldOpUseFound)
continue; continue;
}
// If this result is not used within a yieldOp, then add it // If this result is not used within a yieldOp, then add it
auto resultNum = yieldOp->getNumOperands(); auto resultNum = yieldOp->getNumOperands();
@@ -147,23 +135,18 @@ OpAndResNum SpatialReducer::applyReducePattern(
// the number of results) // the number of results)
// See below `reducerChanges.push_back` and `finalizeReduceUpdates` // See below `reducerChanges.push_back` and `finalizeReduceUpdates`
auto yieldOpFirstCompute = cast<spatial::SpatYieldOp>( auto yieldOpFirstCompute = cast<spatial::SpatYieldOp>(firstCompute.getBody().front().getTerminator());
firstCompute.getBody().front().getTerminator());
// Add a new operand to the block of the second computeOp // Add a new operand to the block of the second computeOp
Block &secondBlock = secondCompute.getBody().front(); Block& secondBlock = secondCompute.getBody().front();
Value formerRes1 = secondBlock.addArgument( Value formerRes1 = secondBlock.addArgument(yieldOpFirstCompute->getOperand(firstResultNum).getType(), loc);
yieldOpFirstCompute->getOperand(firstResultNum).getType(), loc);
auto secondComputeWeightsNum = auto secondComputeWeightsNum =
secondCompute->getAttrOfType<DenseI32ArrayAttr>( secondCompute->getAttrOfType<DenseI32ArrayAttr>(secondCompute.getOperandSegmentSizesAttrName())[0];
secondCompute.getOperandSegmentSizesAttrName())[0]; auto secondComputeOperandNum = secondComputeWeightsNum + secondBlock.getNumArguments() - 1;
auto secondComputeOperandNum =
secondComputeWeightsNum + secondBlock.getNumArguments() - 1;
// Take the "former-result" from the second computeOp // Take the "former-result" from the second computeOp
spatial::SpatYieldOp secondYield = spatial::SpatYieldOp secondYield = cast<spatial::SpatYieldOp>(secondBlock.getTerminator());
cast<spatial::SpatYieldOp>(secondBlock.getTerminator());
Value formerRes2 = secondYield.getOperand(secondResultNum); Value formerRes2 = secondYield.getOperand(secondResultNum);
// Apply reduction operation // Apply reduction operation
@@ -184,37 +167,31 @@ OpAndResNum SpatialReducer::applyReducePattern(
// We should also add an entry for updating the results of the last // We should also add an entry for updating the results of the last
// operation (the one which never becomes a `firstCompute`): because it is // operation (the one which never becomes a `firstCompute`): because it is
// not tracked by reducerChanges as `fromOp` // not tracked by reducerChanges as `fromOp`
reducerChanges.push_back({firstCompute.getOperation(), firstResultNum, reducerChanges.push_back(
secondCompute.getOperation(), secondComputeOperandNum}); {firstCompute.getOperation(), firstResultNum, secondCompute.getOperation(), secondComputeOperandNum});
nextComputeOps.push_back(std::make_pair(secondCompute, secondResultNum)); nextComputeOps.push_back(std::make_pair(secondCompute, secondResultNum));
} }
// If we have an odd number of inputs, we need to add the last one to the // If we have an odd number of inputs, we need to add the last one to the
// newInputs list. // newInputs list.
if (computeOpsRef.size() % 2 == 1) { if (computeOpsRef.size() % 2 == 1)
nextComputeOps.push_back(computeOpsRef.back()); nextComputeOps.push_back(computeOpsRef.back());
}
// Replace the inputOps list with the new one. // Replace the inputOps list with the new one.
computeOpsRef = computeOpsRef = llvm::OwningArrayRef<ComputeAndResNum>(std::move(nextComputeOps));
llvm::OwningArrayRef<ComputeAndResNum>(std::move(nextComputeOps));
} }
assert(computeOpsRef.size() == 1 && assert(computeOpsRef.size() == 1 && "Internal error: expected a single input at this point.");
"Internal error: expected a single input at this point.");
auto finalComputeAndResNum = computeOpsRef[0]; auto finalComputeAndResNum = computeOpsRef[0];
// Force the update of the results of this computeOp, when finalizing // Force the update of the results of this computeOp, when finalizing
computeOpNeedingResUpdate.push_back(GET_COMP(finalComputeAndResNum)); computeOpNeedingResUpdate.push_back(GET_COMP(finalComputeAndResNum));
if (postprocess) { if (postprocess)
GET_RES_NUM(finalComputeAndResNum) = GET_RES_NUM(finalComputeAndResNum) = applyResultProcessing(finalComputeAndResNum, postprocess, rewriter);
applyResultProcessing(finalComputeAndResNum, postprocess, rewriter);
}
return std::make_pair(GET_COMP(finalComputeAndResNum).getOperation(), return std::make_pair(GET_COMP(finalComputeAndResNum).getOperation(), GET_RES_NUM(finalComputeAndResNum));
GET_RES_NUM(finalComputeAndResNum));
} }
void SpatialReducer::finalizeReduceUpdates() { void SpatialReducer::finalizeReduceUpdates() {
@@ -223,15 +200,13 @@ void SpatialReducer::finalizeReduceUpdates() {
reducesFinalized = true; reducesFinalized = true;
// First, add the results to the computeOps // First, add the results to the computeOps
for (auto &reduceChange : reducerChanges) { for (auto& reduceChange : reducerChanges)
updateResultsOfCompute(reduceChange.fromOp); updateResultsOfCompute(reduceChange.fromOp);
}
for (auto &c : computeOpNeedingResUpdate) { for (auto& c : computeOpNeedingResUpdate)
updateResultsOfCompute(c.getOperation()); updateResultsOfCompute(c.getOperation());
}
for (auto &reducerChange : this->reducerChanges) { for (auto& reducerChange : this->reducerChanges) {
auto fromOp = reducerChange.fromOp; auto fromOp = reducerChange.fromOp;
auto toOp = reducerChange.toOp; auto toOp = reducerChange.toOp;
auto fromOpResNum = reducerChange.fromOpResNum; auto fromOpResNum = reducerChange.fromOpResNum;
@@ -243,16 +218,14 @@ void SpatialReducer::finalizeReduceUpdates() {
// toComputeOp could be the existing pointer, or we have to remap it with // toComputeOp could be the existing pointer, or we have to remap it with
// `opToReplacedCompute` // `opToReplacedCompute`
auto toComputeOp = opToReplacedCompute[toOp]; auto toComputeOp = opToReplacedCompute[toOp];
if (!toComputeOp) { if (!toComputeOp)
toComputeOp = cast<spatial::SpatWeightedCompute>(toOp); toComputeOp = cast<spatial::SpatWeightedCompute>(toOp);
}
assert(toComputeOp != fromComputeOp && assert(toComputeOp != fromComputeOp && "Oops should have caught this earlier!");
"Oops should have caught this earlier!");
assert(toComputeOp->getNumOperands() == toOpOperandNum && assert(toComputeOp->getNumOperands() == toOpOperandNum
"toOpOperandNum should be the last operand of toComputeOp, are the " && "toOpOperandNum should be the last operand of toComputeOp, are the "
"operations in the right order?"); "operations in the right order?");
// Add the new operand to `toComputeOp` // Add the new operand to `toComputeOp`
auto fromResult = fromComputeOp.getResult(fromOpResNum); auto fromResult = fromComputeOp.getResult(fromOpResNum);
@@ -261,24 +234,22 @@ void SpatialReducer::finalizeReduceUpdates() {
} }
} }
Value SpatialReducer::resolveValueFromOpAndResNum(OpAndResNum &opAndResNum) { Value SpatialReducer::resolveValueFromOpAndResNum(OpAndResNum& opAndResNum) {
assert(reducesFinalized && assert(reducesFinalized && "Cannot create resolve values before finalizing the reduce updates.");
"Cannot create resolve values before finalizing the reduce updates.");
Operation *opToCast; Operation* opToCast;
auto it = opToReplacedCompute.find(opAndResNum.first); auto it = opToReplacedCompute.find(opAndResNum.first);
if (it != opToReplacedCompute.end()) { if (it != opToReplacedCompute.end())
opToCast = it->second; opToCast = it->second;
} else { else
opToCast = opAndResNum.first; opToCast = opAndResNum.first;
}
auto computeOp = cast<spatial::SpatWeightedCompute>(opToCast); auto computeOp = cast<spatial::SpatWeightedCompute>(opToCast);
return computeOp.getResult(opAndResNum.second); return computeOp.getResult(opAndResNum.second);
} }
void SpatialReducer::updateResultsOfCompute(Operation *computeOp) { void SpatialReducer::updateResultsOfCompute(Operation* computeOp) {
if (opToReplacedCompute.find(computeOp) != opToReplacedCompute.end()) { if (opToReplacedCompute.find(computeOp) != opToReplacedCompute.end()) {
// If we have already replaced the fromOp, we do not need to do it again // If we have already replaced the fromOp, we do not need to do it again
return; return;
@@ -287,8 +258,7 @@ void SpatialReducer::updateResultsOfCompute(Operation *computeOp) {
auto oldComputeOpNum = oldComputeOp->getNumOperands(); auto oldComputeOpNum = oldComputeOp->getNumOperands();
auto yieldOp = auto yieldOp = cast<spatial::SpatYieldOp>(oldComputeOp.getBody().front().getTerminator());
cast<spatial::SpatYieldOp>(oldComputeOp.getBody().front().getTerminator());
if (yieldOp.getNumOperands() == oldComputeOp->getNumResults()) { if (yieldOp.getNumOperands() == oldComputeOp->getNumResults()) {
// No result was added, just add itself to the map // No result was added, just add itself to the map
@@ -301,9 +271,8 @@ void SpatialReducer::updateResultsOfCompute(Operation *computeOp) {
// Create a new ComputeOp with the new result type, but same operands // Create a new ComputeOp with the new result type, but same operands
rewriter.setInsertionPoint(oldComputeOp); rewriter.setInsertionPoint(oldComputeOp);
auto newComputeOp = auto newComputeOp = rewriter.create<spatial::SpatWeightedCompute>(
rewriter.create<spatial::SpatWeightedCompute>(oldComputeOp->getLoc(), oldComputeOp->getLoc(), newResultTypes, oldComputeOp.getWeights(), oldComputeOp.getInputs());
newResultTypes, oldComputeOp.getWeights(), oldComputeOp.getInputs());
newComputeOp.getBody().takeBody(oldComputeOp.getBody()); newComputeOp.getBody().takeBody(oldComputeOp.getBody());
@@ -329,54 +298,49 @@ void SpatialReducer::updateResultsOfCompute(Operation *computeOp) {
rewriter.eraseOp(oldComputeOp); rewriter.eraseOp(oldComputeOp);
} }
Value SpatialReducer::createImgConcatOp( Value SpatialReducer::createImgConcatOp(SmallVector<SmallVector<SmallVector<OpAndResNum>>>& outputTiles,
SmallVector<SmallVector<SmallVector<OpAndResNum>>> &outputTiles, Location& loc,
Location &loc, Type outputType) { Type outputType) {
assert(reducesFinalized && assert(reducesFinalized && "Cannot create ImgConcatOp before finalizing the reduce updates.");
"Cannot create ImgConcatOp before finalizing the reduce updates.");
// outputTiles are indexed like this: [channelTile][x][y] // outputTiles are indexed like this: [channelTile][x][y]
auto tilesCount = outputTiles.size(); auto tilesCount = outputTiles.size();
auto width = outputTiles[0].size(); auto width = outputTiles[0].size();
auto height = outputTiles[0][0].size(); auto height = outputTiles[0][0].size();
SmallVector<SmallVector<SmallVector<Value>>> remappedOutputTiles(tilesCount, SmallVector<SmallVector<SmallVector<Value>>> remappedOutputTiles(
SmallVector<SmallVector<Value>>(width, SmallVector<Value>(height))); tilesCount, SmallVector<SmallVector<Value>>(width, SmallVector<Value>(height)));
for (size_t t = 0; t < tilesCount; t++) for (size_t t = 0; t < tilesCount; t++)
for (size_t x = 0; x < width; x++) for (size_t x = 0; x < width; x++)
for (size_t y = 0; y < height; y++) for (size_t y = 0; y < height; y++)
remappedOutputTiles[t][x][y] = remappedOutputTiles[t][x][y] = resolveValueFromOpAndResNum(outputTiles[t][x][y]);
resolveValueFromOpAndResNum(outputTiles[t][x][y]);
return ::onnx_mlir::createImgConcatOp( return ::onnx_mlir::createImgConcatOp(remappedOutputTiles, rewriter, loc, outputType);
remappedOutputTiles, rewriter, loc, outputType);
} }
OpAndResNum SpatialReducer::applyAddMapReduction( OpAndResNum SpatialReducer::applyAddMapReduction(SmallVector<ComputeAndResNum>& computeOps,
SmallVector<ComputeAndResNum> &computeOps, ConversionPatternRewriter& rewriter,
ConversionPatternRewriter &rewriter, Value biasTile, MapOperations mapOp) { Value biasTile,
MapOperations mapOp) {
std::function<Value(const Value &)> postprocessing = nullptr; std::function<Value(const Value&)> postprocessing = nullptr;
if (mapOp != MapOperations::None) { if (mapOp != MapOperations::None) {
postprocessing = [&](const Value a) { postprocessing = [&](const Value a) {
Value mapOperand = a; Value mapOperand = a;
if (biasTile) { if (biasTile)
mapOperand = rewriter.create<spatial::SpatVAddOp>( mapOperand = rewriter.create<spatial::SpatVAddOp>(a.getLoc(), a.getType(), a, biasTile);
a.getLoc(), a.getType(), a, biasTile);
}
return createMapOperation(rewriter, mapOp, mapOperand); return createMapOperation(rewriter, mapOp, mapOperand);
}; };
} }
return this->applyReducePattern( return this->applyReducePattern(
computeOps, computeOps,
[&](Value a, Value b) { [&](Value a, Value b) { return rewriter.create<spatial::SpatVAddOp>(a.getLoc(), a.getType(), a, b); },
return rewriter.create<spatial::SpatVAddOp>(a.getLoc(), a.getType(), a, b); /* preprocess = */ nullptr,
}, postprocessing);
/* preprocess = */ nullptr, postprocessing);
} }
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -1,9 +1,10 @@
#pragma once #pragma once
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/Casting.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/Casting.h"
namespace onnx_mlir { namespace onnx_mlir {
@@ -12,48 +13,48 @@ using ResNum = unsigned int;
using ComputeAndResNum = std::pair<spatial::SpatWeightedCompute, ResNum>; using ComputeAndResNum = std::pair<spatial::SpatWeightedCompute, ResNum>;
struct SpatialReducerChange { struct SpatialReducerChange {
Operation *fromOp; Operation* fromOp;
unsigned int fromOpResNum; unsigned int fromOpResNum;
Operation *toOp; Operation* toOp;
unsigned int toOpOperandNum; unsigned int toOpOperandNum;
}; };
using OpAndResNum = std::pair<Operation *, ResNum>; using OpAndResNum = std::pair<Operation*, ResNum>;
class SpatialReducer { class SpatialReducer {
public: public:
SpatialReducer(ConversionPatternRewriter &rewriter) : rewriter(rewriter) {} SpatialReducer(ConversionPatternRewriter& rewriter)
: rewriter(rewriter) {}
OpAndResNum applyReducePattern( OpAndResNum applyReducePattern(SmallVector<ComputeAndResNum>& computeOpsAndResNum,
SmallVector<ComputeAndResNum> &computeOpsAndResNum, std::function<Value(const Value&, const Value&)> reduce,
std::function<Value(const Value &, const Value &)> reduce, std::function<Value(const Value&)> preprocess,
std::function<Value(const Value &)> preprocess, std::function<Value(const Value&)> postprocess);
std::function<Value(const Value &)> postprocess);
OpAndResNum applyAddMapReduction(SmallVector<ComputeAndResNum> &computeOps, OpAndResNum applyAddMapReduction(SmallVector<ComputeAndResNum>& computeOps,
ConversionPatternRewriter &rewriter, Value biasTile, MapOperations mapOp); ConversionPatternRewriter& rewriter,
Value biasTile,
MapOperations mapOp);
void finalizeReduceUpdates(); void finalizeReduceUpdates();
~SpatialReducer() { ~SpatialReducer() {
if (!reducesFinalized) { if (!reducesFinalized)
finalizeReduceUpdates(); finalizeReduceUpdates();
}
} }
Value createImgConcatOp( Value createImgConcatOp(llvm::SmallVector<llvm::SmallVector<llvm::SmallVector<OpAndResNum>>>& outputTiles,
llvm::SmallVector<llvm::SmallVector<llvm::SmallVector<OpAndResNum>>> Location& loc,
&outputTiles, Type outputType);
Location &loc, Type outputType);
Value resolveValueFromOpAndResNum(OpAndResNum &opAndResNum); Value resolveValueFromOpAndResNum(OpAndResNum& opAndResNum);
private: private:
[[nodiscard("computeOp result number gets updated")]] ResNum [[nodiscard("computeOp result number gets updated")]] ResNum
applyResultProcessing(ComputeAndResNum computeOpAndResNum, applyResultProcessing(ComputeAndResNum computeOpAndResNum,
std::function<Value(const Value &)> processFun, std::function<Value(const Value&)> processFun,
ConversionPatternRewriter &rewriter); ConversionPatternRewriter& rewriter);
/** /**
* @brief Update the results of a ComputeOp. * @brief Update the results of a ComputeOp.
@@ -65,9 +66,9 @@ private:
* *
* @param computeOp The ComputeOp to update the results of. * @param computeOp The ComputeOp to update the results of.
*/ */
void updateResultsOfCompute(Operation *computeOp); void updateResultsOfCompute(Operation* computeOp);
ConversionPatternRewriter &rewriter; ConversionPatternRewriter& rewriter;
bool reducesFinalized = false; bool reducesFinalized = false;
// List of changes to be applied after the reduction is finalized // List of changes to be applied after the reduction is finalized
@@ -75,9 +76,9 @@ private:
// List of computeOps that need to be replaced with new results // List of computeOps that need to be replaced with new results
SmallVector<spatial::SpatWeightedCompute> computeOpNeedingResUpdate; SmallVector<spatial::SpatWeightedCompute> computeOpNeedingResUpdate;
std::unordered_map<Operation *, spatial::SpatWeightedCompute> opToReplacedCompute; std::unordered_map<Operation*, spatial::SpatWeightedCompute> opToReplacedCompute;
static llvm::SmallPtrSet<Operation *, 16> oldComputeOpsReplaced; static llvm::SmallPtrSet<Operation*, 16> oldComputeOpsReplaced;
}; };
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -1,11 +1,11 @@
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/WeightSubdivider.hpp"
#include <cassert> #include <cassert>
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/WeightSubdivider.hpp"
namespace onnx_mlir { namespace onnx_mlir {
WeightSubdivider::WeightSubdivider( WeightSubdivider::WeightSubdivider(map<long, map<long, SmallVector<Value>>> weights)
map<long, map<long, SmallVector<Value>>> weights) : weights(std::move(weights)) {}
: weights(std::move(weights)) {}
bool WeightSubdivider::isEmpty() const { return weights.empty(); } bool WeightSubdivider::isEmpty() const { return weights.empty(); }
@@ -13,7 +13,7 @@ TaggedWeights WeightSubdivider::popGroup(size_t amount) {
assert(!weights.empty() && "No weights to extract."); assert(!weights.empty() && "No weights to extract.");
auto it = weights.begin(); auto it = weights.begin();
SmallVector<Value> &values = it->second.begin()->second; SmallVector<Value>& values = it->second.begin()->second;
long inputTile = it->first; long inputTile = it->first;
long outputTile = it->second.begin()->first; long outputTile = it->second.begin()->first;
@@ -26,11 +26,11 @@ TaggedWeights WeightSubdivider::popGroup(size_t amount) {
if (n < values.size()) { if (n < values.size()) {
values.erase(values.begin(), values.begin() + n); values.erase(values.begin(), values.begin() + n);
} else { }
else {
it->second.erase(outputTile); it->second.erase(outputTile);
if (it->second.empty()) { if (it->second.empty())
weights.erase(inputTile); weights.erase(inputTile);
}
} }
return {inputTile, outputTile, crossbarsUsed - n, result}; return {inputTile, outputTile, crossbarsUsed - n, result};

View File

@@ -1,7 +1,9 @@
#pragma once #pragma once
#include "mlir/IR/Value.h" #include "mlir/IR/Value.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include <map> #include <map>
using namespace mlir; using namespace mlir;

View File

@@ -1,23 +1,21 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/Block.h" #include "mlir/IR/Block.h"
#include "mlir/IR/Diagnostics.h" #include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Value.h" #include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/Casting.h" #include "llvm/Support/Casting.h"
#include "llvm/Support/Format.h" #include "llvm/Support/Format.h"
#define FORMAT_OPERATION(op) \ #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
'x' << llvm::format_hex_no_prefix(reinterpret_cast<size_t>(op), 0) #include "src/Accelerators/PIM/Pass/PimPasses.hpp"
#define FORMAT_ARGUMENT(computeOpPointer, argumentNum) \ #include "src/Dialect/ONNX/ONNXOps.hpp"
llvm::format("Arg_%p_%u", computeOpPointer, argumentNum)
#define FORMAT_OPERATION(op) 'x' << llvm::format_hex_no_prefix(reinterpret_cast<size_t>(op), 0)
#define FORMAT_ARGUMENT(computeOpPointer, argumentNum) llvm::format("Arg_%p_%u", computeOpPointer, argumentNum)
using namespace mlir; using namespace mlir;
@@ -25,26 +23,22 @@ namespace onnx_mlir {
namespace { namespace {
struct SpatialToGraphvizPass struct SpatialToGraphvizPass : public PassWrapper<SpatialToGraphvizPass, OperationPass<ModuleOp>> {
: public PassWrapper<SpatialToGraphvizPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialToGraphvizPass) MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialToGraphvizPass)
StringRef getArgument() const override { StringRef getArgument() const override { return "convert-spatial-to-graphviz"; }
return "convert-spatial-to-graphviz";
}
StringRef getDescription() const override { StringRef getDescription() const override { return "Lower ONNX ops to Spatial ops."; }
return "Lower ONNX ops to Spatial ops.";
}
SpatialToGraphvizPass(raw_ostream &os = llvm::errs()) : os(os) {} SpatialToGraphvizPass(raw_ostream& os = llvm::errs())
SpatialToGraphvizPass(const SpatialToGraphvizPass &pass) : os(os) {}
: SpatialToGraphvizPass(pass.os) {} SpatialToGraphvizPass(const SpatialToGraphvizPass& pass)
: SpatialToGraphvizPass(pass.os) {}
void runOnOperation() final; void runOnOperation() final;
private: private:
raw_ostream &os; raw_ostream& os;
/** /**
* Draws the subgraph for a given spatial::SpatWeightedCompute, including: * Draws the subgraph for a given spatial::SpatWeightedCompute, including:
@@ -56,31 +50,27 @@ private:
* @param computeNum The number of the compute operation. * @param computeNum The number of the compute operation.
*/ */
void drawComputeOpSubgraph(spatial::SpatWeightedCompute op, size_t computeNum) { void drawComputeOpSubgraph(spatial::SpatWeightedCompute op, size_t computeNum) {
os << "\tsubgraph cluster" << computeNum << " {\n\t\tlabel=\"Compute" os << "\tsubgraph cluster" << computeNum << " {\n\t\tlabel=\"Compute" << computeNum << "\";\n"
<< computeNum << "\";\n"
<< "\t\tstyle=filled;\n" << "\t\tstyle=filled;\n"
<< "\t\tcolor=lightblue;\n"; << "\t\tcolor=lightblue;\n";
Block &block = op.getBody().front(); Block& block = op.getBody().front();
// Inputs // Inputs
size_t inputNum = 0; size_t inputNum = 0;
for (BlockArgument &input : block.getArguments()) { for (BlockArgument& input : block.getArguments()) {
auto fromOp = FORMAT_ARGUMENT(op.getOperation(), inputNum); auto fromOp = FORMAT_ARGUMENT(op.getOperation(), inputNum);
os << "\t\t" << fromOp << " [label=\"Arg" << inputNum os << "\t\t" << fromOp << " [label=\"Arg" << inputNum << "\",shape=box];\n";
<< "\",shape=box];\n"; for (auto userOp : input.getUsers())
for (auto userOp : input.getUsers()) {
os << "\t\t" << fromOp << " -> " << FORMAT_OPERATION(userOp) << ";\n"; os << "\t\t" << fromOp << " -> " << FORMAT_OPERATION(userOp) << ";\n";
}
inputNum++; inputNum++;
} }
// Iterate operations // Iterate operations
for (auto &childOp : block.getOperations()) { for (auto& childOp : block.getOperations()) {
os << "\t\t" << FORMAT_OPERATION(&childOp) << " [label=\"" os << "\t\t" << FORMAT_OPERATION(&childOp) << " [label=\"" << childOp.getName() << "\"];\n";
<< childOp.getName() << "\"];\n";
drawEdgesFromOpToItsUsers(&childOp); drawEdgesFromOpToItsUsers(&childOp);
} }
@@ -88,7 +78,7 @@ private:
os << "\t}\n"; os << "\t}\n";
// Draw edges from the yield to the users of this computeOp // Draw edges from the yield to the users of this computeOp
Operation *yieldOp = block.getTerminator(); Operation* yieldOp = block.getTerminator();
if (!isa<spatial::SpatYieldOp>(yieldOp)) { if (!isa<spatial::SpatYieldOp>(yieldOp)) {
yieldOp->emitError("Terminator of block must be YieldOp ???"); yieldOp->emitError("Terminator of block must be YieldOp ???");
signalPassFailure(); signalPassFailure();
@@ -96,9 +86,8 @@ private:
} }
for (auto computeOpResult : op->getResults()) { for (auto computeOpResult : op->getResults()) {
for (auto &computeOpUse : computeOpResult.getUses()) { for (auto& computeOpUse : computeOpResult.getUses()) {
auto toOp = FORMAT_ARGUMENT( auto toOp = FORMAT_ARGUMENT(computeOpUse.getOwner(), computeOpUse.getOperandNumber());
computeOpUse.getOwner(), computeOpUse.getOperandNumber());
os << "\t" << FORMAT_OPERATION(yieldOp) << " -> " << toOp << ";\n"; os << "\t" << FORMAT_OPERATION(yieldOp) << " -> " << toOp << ";\n";
} }
} }
@@ -114,9 +103,8 @@ private:
* @param concatOp The concatOp for which the subgraph is drawn. * @param concatOp The concatOp for which the subgraph is drawn.
* @param concatOpNum The number of the concatOp. * @param concatOpNum The number of the concatOp.
*/ */
void drawConcatOpSubgraph(Operation *concatOp, size_t concatOpNum) { void drawConcatOpSubgraph(Operation* concatOp, size_t concatOpNum) {
os << "\tsubgraph clusterconcat" << concatOpNum os << "\tsubgraph clusterconcat" << concatOpNum << " {\n\t\tlabel=\"ConcatOp" << concatOpNum << "\";\n"
<< " {\n\t\tlabel=\"ConcatOp" << concatOpNum << "\";\n"
<< "\t\tstyle=filled;\n" << "\t\tstyle=filled;\n"
<< "\t\tcolor=orange;\n"; << "\t\tcolor=orange;\n";
@@ -126,9 +114,8 @@ private:
auto fromOp = FORMAT_ARGUMENT(concatOp, inputNum); auto fromOp = FORMAT_ARGUMENT(concatOp, inputNum);
os << "\t\t" << fromOp << " [label=\"Input" << inputNum << "\"];\n"; os << "\t\t" << fromOp << " [label=\"Input" << inputNum << "\"];\n";
for (auto userOp : input.getUsers()) { for (auto userOp : input.getUsers())
os << "\t\t" << fromOp << " -> " << FORMAT_OPERATION(userOp) << ";\n"; os << "\t\t" << fromOp << " -> " << FORMAT_OPERATION(userOp) << ";\n";
}
inputNum++; inputNum++;
} }
@@ -139,11 +126,9 @@ private:
// Edges from output to users // Edges from output to users
for (auto &computeOpUse : concatOp->getResult(0).getUses()) { for (auto& computeOpUse : concatOp->getResult(0).getUses()) {
os << "\t" << FORMAT_OPERATION(concatOp) << " -> " os << "\t" << FORMAT_OPERATION(concatOp) << " -> "
<< FORMAT_ARGUMENT( << FORMAT_ARGUMENT(computeOpUse.getOwner(), computeOpUse.getOperandNumber()) << ";\n";
computeOpUse.getOwner(), computeOpUse.getOperandNumber())
<< ";\n";
} }
} }
@@ -164,10 +149,8 @@ private:
sliceOp.getStaticOffsetsAttr().print(os); sliceOp.getStaticOffsetsAttr().print(os);
os << "\",color=lawngreen];\n"; os << "\",color=lawngreen];\n";
for (auto &computeOpUse : sliceOp.getResult().getUses()) { for (auto& computeOpUse : sliceOp.getResult().getUses()) {
os << "\t" << nodeId << " -> " os << "\t" << nodeId << " -> " << FORMAT_ARGUMENT(computeOpUse.getOwner(), computeOpUse.getOperandNumber())
<< FORMAT_ARGUMENT(
computeOpUse.getOwner(), computeOpUse.getOperandNumber())
<< ";\n"; << ";\n";
} }
} }
@@ -178,9 +161,8 @@ private:
sliceOp.getStaticOffsetsAttr().print(os); sliceOp.getStaticOffsetsAttr().print(os);
os << "\",color=lightpink];\n"; os << "\",color=lightpink];\n";
for (auto user : sliceOp.getResult().getUsers()) { for (auto user : sliceOp.getResult().getUsers())
os << "\t" << nodeId << " -> " << FORMAT_OPERATION(user) << ";\n"; os << "\t" << nodeId << " -> " << FORMAT_OPERATION(user) << ";\n";
}
} }
/** /**
@@ -188,13 +170,10 @@ private:
* *
* @param fromOp The operation from which the edges are drawn. * @param fromOp The operation from which the edges are drawn.
*/ */
void drawEdgesFromOpToItsUsers(mlir::Operation *fromOp) { void drawEdgesFromOpToItsUsers(mlir::Operation* fromOp) {
for (auto result : fromOp->getResults()) { for (auto result : fromOp->getResults())
for (auto userOp : result.getUsers()) { for (auto userOp : result.getUsers())
os << "\t\t" << FORMAT_OPERATION(fromOp) << " -> " os << "\t\t" << FORMAT_OPERATION(fromOp) << " -> " << FORMAT_OPERATION(userOp) << ";\n";
<< FORMAT_OPERATION(userOp) << ";\n";
}
}
} }
/** /**
@@ -202,16 +181,15 @@ private:
* *
* @param funcOp The `funcOp` for which to draw input nodes and edges. * @param funcOp The `funcOp` for which to draw input nodes and edges.
*/ */
void drawInputNodesAndEdges(func::FuncOp &funcOp) { void drawInputNodesAndEdges(func::FuncOp& funcOp) {
os << "\tinput [label=\"Module Input\",color=green];\n"; os << "\tinput [label=\"Module Input\",color=green];\n";
size_t funcOpArgNum = 0; size_t funcOpArgNum = 0;
for (BlockArgument &arg : funcOp.getArguments()) { for (BlockArgument& arg : funcOp.getArguments()) {
for (auto &useOp : arg.getUses()) { for (auto& useOp : arg.getUses()) {
os << "\tinput -> " os << "\tinput -> " << FORMAT_ARGUMENT(useOp.getOwner(), useOp.getOperandNumber()) << "[label=" << funcOpArgNum
<< FORMAT_ARGUMENT(useOp.getOwner(), useOp.getOperandNumber()) << "];\n";
<< "[label=" << funcOpArgNum << "];\n";
} }
funcOpArgNum++; funcOpArgNum++;
} }
@@ -237,20 +215,22 @@ void SpatialToGraphvizPass::runOnOperation() {
// Iterate over the ComputeOps within FuncOp: // Iterate over the ComputeOps within FuncOp:
// 1. Print their subgraph // 1. Print their subgraph
// 2. Print the edges from its inputs to its outputs // 2. Print the edges from its inputs to its outputs
for (Operation &op : func.getOps()) { for (Operation& op : func.getOps()) {
if (auto computeOp = dyn_cast<spatial::SpatWeightedCompute>(op)) { if (auto computeOp = dyn_cast<spatial::SpatWeightedCompute>(op)) {
drawComputeOpSubgraph(computeOp, computeNum++); drawComputeOpSubgraph(computeOp, computeNum++);
} else if (auto concatOp = dyn_cast<tensor::ConcatOp>(op)) { }
else if (auto concatOp = dyn_cast<tensor::ConcatOp>(op)) {
drawConcatOpSubgraph(concatOp, concatNum++); drawConcatOpSubgraph(concatOp, concatNum++);
} else if (auto imgConcatOp = dyn_cast<spatial::SpatImgConcatOp>(op)) { }
else if (auto imgConcatOp = dyn_cast<spatial::SpatImgConcatOp>(op)) {
drawConcatOpSubgraph(imgConcatOp, concatNum++); drawConcatOpSubgraph(imgConcatOp, concatNum++);
} else if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op)) { }
else if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op)) {
auto producerOp = extractSliceOp->getOperand(0).getDefiningOp(); auto producerOp = extractSliceOp->getOperand(0).getDefiningOp();
if (producerOp) { if (producerOp) {
// Skip extractSliceOp if producer is constant weights (ONNXConstantOp) // Skip extractSliceOp if producer is constant weights (ONNXConstantOp)
if (llvm::isa<ONNXConstantOp>(producerOp)) { if (llvm::isa<ONNXConstantOp>(producerOp))
continue; continue;
}
// If produced by tosa::ReshapeOp (i.e. it is a bias tile) connect // If produced by tosa::ReshapeOp (i.e. it is a bias tile) connect
// directly to its user, which is not a ComputeOp argument. // directly to its user, which is not a ComputeOp argument.
if (llvm::isa<tosa::ReshapeOp>(producerOp)) { if (llvm::isa<tosa::ReshapeOp>(producerOp)) {
@@ -268,16 +248,13 @@ void SpatialToGraphvizPass::runOnOperation() {
// Draw output node (use the return Operation - argument number=0 - as nodeId) // Draw output node (use the return Operation - argument number=0 - as nodeId)
auto returnOp = func.getBody().front().getTerminator(); auto returnOp = func.getBody().front().getTerminator();
os << '\t' << FORMAT_ARGUMENT(returnOp, 0) os << '\t' << FORMAT_ARGUMENT(returnOp, 0) << " [label=\"Module Output\",color=green];\n";
<< " [label=\"Module Output\",color=green];\n";
os << "}\n"; os << "}\n";
} }
} // namespace } // namespace
std::unique_ptr<Pass> createSpatialToGraphvizPass() { std::unique_ptr<Pass> createSpatialToGraphvizPass() { return std::make_unique<SpatialToGraphvizPass>(); }
return std::make_unique<SpatialToGraphvizPass>();
}
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -148,7 +148,8 @@ void SpatialToPIMPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IR
size_t resultIndexInConcat = resultUses.begin()->getOperandNumber(); size_t resultIndexInConcat = resultUses.begin()->getOperandNumber();
size_t offset = 0; size_t offset = 0;
for (auto operand : concatOp->getOperands().take_front(resultIndexInConcat)) for (auto operand : concatOp->getOperands().take_front(resultIndexInConcat))
offset += cast<ShapedType>(operand.getType()).getNumElements() * cast<ShapedType>(operand.getType()).getElementTypeBitWidth() / 8; offset += cast<ShapedType>(operand.getType()).getNumElements()
* cast<ShapedType>(operand.getType()).getElementTypeBitWidth() / 8;
size_t elementSize = yieldType.getElementTypeBitWidth() / 8; size_t elementSize = yieldType.getElementTypeBitWidth() / 8;

View File

@@ -1,8 +1,5 @@
#pragma once #pragma once
#include <map>
#include <string>
#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h" #include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -11,6 +8,9 @@
#include "mlir/IR/Dialect.h" #include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpDefinition.h"
#include <map>
#include <string>
/// Include the auto-generated header files containing the declarations /// Include the auto-generated header files containing the declarations
#include "src/Accelerators/PIM/Dialect/PIM/PimDialect.hpp.inc" #include "src/Accelerators/PIM/Dialect/PIM/PimDialect.hpp.inc"

View File

@@ -15,7 +15,10 @@ namespace pim {
struct MemCopyHostToDevOpInterface struct MemCopyHostToDevOpInterface
: DstBufferizableOpInterfaceExternalModel<MemCopyHostToDevOpInterface, PimMemCopyHostToDevOp> { : DstBufferizableOpInterfaceExternalModel<MemCopyHostToDevOpInterface, PimMemCopyHostToDevOp> {
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const { LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto memCopyHostToDevOp = cast<PimMemCopyHostToDevOp>(op); auto memCopyHostToDevOp = cast<PimMemCopyHostToDevOp>(op);
auto deviceDst = memCopyHostToDevOp.getDeviceDst(); auto deviceDst = memCopyHostToDevOp.getDeviceDst();
auto hostSrc = memCopyHostToDevOp.getHostSrc(); auto hostSrc = memCopyHostToDevOp.getHostSrc();
@@ -44,7 +47,10 @@ struct MemCopyHostToDevOpInterface
struct MemCopyDevToHostOpInterface struct MemCopyDevToHostOpInterface
: DstBufferizableOpInterfaceExternalModel<MemCopyDevToHostOpInterface, PimMemCopyDevToHostOp> { : DstBufferizableOpInterfaceExternalModel<MemCopyDevToHostOpInterface, PimMemCopyDevToHostOp> {
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const { LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto memCopyDevToHostOp = cast<PimMemCopyDevToHostOp>(op); auto memCopyDevToHostOp = cast<PimMemCopyDevToHostOp>(op);
auto globalDst = memCopyDevToHostOp.getHostDst(); auto globalDst = memCopyDevToHostOp.getHostDst();
@@ -88,7 +94,10 @@ struct VMMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<VMMOpBu
return false; return false;
} }
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const { LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto vmmOp = cast<PimVMMOp>(op); auto vmmOp = cast<PimVMMOp>(op);
auto vectorInputOpt = getBuffer(rewriter, vmmOp.getVectorInput(), options, state); auto vectorInputOpt = getBuffer(rewriter, vmmOp.getVectorInput(), options, state);
@@ -110,7 +119,10 @@ struct MVMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<MVMOpBu
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand); return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
} }
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const { LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto mvmOp = cast<PimMVMOp>(op); auto mvmOp = cast<PimMVMOp>(op);
auto vectorInputOpt = getBuffer(rewriter, mvmOp.getVectorInput(), options, state); auto vectorInputOpt = getBuffer(rewriter, mvmOp.getVectorInput(), options, state);
@@ -138,7 +150,10 @@ struct VAddOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<VAddOp
return true; return true;
} }
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const { LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto vaddOp = cast<PimVAddOp>(op); auto vaddOp = cast<PimVAddOp>(op);
auto aOpt = getBuffer(rewriter, vaddOp.getA(), options, state); auto aOpt = getBuffer(rewriter, vaddOp.getA(), options, state);

View File

@@ -1,6 +1,7 @@
#pragma once #pragma once
#include "mlir/IR/DialectRegistry.h" #include "mlir/IR/DialectRegistry.h"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
using namespace mlir; using namespace mlir;
@@ -8,7 +9,7 @@ using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
namespace pim { namespace pim {
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry); void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry);
} // namespace pim } // namespace pim
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -1,8 +1,5 @@
#pragma once #pragma once
#include <map>
#include <string>
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
@@ -10,6 +7,9 @@
#include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpDefinition.h"
#include "mlir/IR/Types.h" #include "mlir/IR/Types.h"
#include <map>
#include <string>
/// Include the auto-generated header files containing the declarations /// Include the auto-generated header files containing the declarations
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialDialect.hpp.inc" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialDialect.hpp.inc"

View File

@@ -75,7 +75,10 @@ struct WComputeOpInterface : BufferizableOpInterface::ExternalModel<WComputeOpIn
return {}; return {};
} }
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const { LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
// Bufferize its block // Bufferize its block
auto& block = op->getRegion(0).front(); auto& block = op->getRegion(0).front();
@@ -104,7 +107,10 @@ struct VariadicArgumentElementWiseOpInterface : BufferizableOpInterface::Externa
} }
// Cast tensor values into memref values // Cast tensor values into memref values
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const { LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
// Turn Tensor Operands into Memref Operands // Turn Tensor Operands into Memref Operands
SmallVector<Value> memrefOperands; SmallVector<Value> memrefOperands;
@@ -151,7 +157,10 @@ struct WeightedMultiplicationsOpInterface : BufferizableOpInterface::ExternalMod
} }
// Cast tensor value into memref value // Cast tensor value into memref value
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const { LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto memrefOperandOpt = getBuffer(rewriter, op->getOperand(0), options, state); auto memrefOperandOpt = getBuffer(rewriter, op->getOperand(0), options, state);
if (failed(memrefOperandOpt)) if (failed(memrefOperandOpt))
return failure(); return failure();
@@ -190,7 +199,10 @@ struct ChannelReceiveOpInterface
/* /*
* Turn the channel receive to pim.recv * Turn the channel receive to pim.recv
*/ */
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const { LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter); auto outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
@@ -235,7 +247,10 @@ struct ChannelSendOpInterface : BufferizableOpInterface::ExternalModel<ChannelSe
/* /*
* Turn the channel send to pim.send * Turn the channel send to pim.send
*/ */
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const { LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto srcTensor = op->getOperand(1); auto srcTensor = op->getOperand(1);
auto srcTensorOpt = getBuffer(rewriter, srcTensor, options, state); auto srcTensorOpt = getBuffer(rewriter, srcTensor, options, state);
@@ -278,7 +293,10 @@ struct ChannelBroadcastReceiveOpInterface
/* /*
* Turn the channel receive to pim.load using by creating a new global buffer * Turn the channel receive to pim.load using by creating a new global buffer
*/ */
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const { LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter); auto outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
@@ -340,7 +358,10 @@ struct ChannelBroadcastSendOpInterface
/* /*
* Turn the channel send to pim.send * Turn the channel send to pim.send
*/ */
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const { LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto srcTensor = op->getOperand(1); auto srcTensor = op->getOperand(1);
auto srcTensorOpt = getBuffer(rewriter, srcTensor, options, state); auto srcTensorOpt = getBuffer(rewriter, srcTensor, options, state);
@@ -414,7 +435,10 @@ struct ApplyFiltersOpInterface : BufferizableOpInterface::ExternalModel<ApplyFil
} }
// Bufferize the operation. // Bufferize the operation.
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const { LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
// Get the input tensor buffer. // Get the input tensor buffer.
auto inputBuffer = getBuffer(rewriter, op->getOperand(0), options, state); auto inputBuffer = getBuffer(rewriter, op->getOperand(0), options, state);

View File

@@ -1,6 +1,7 @@
#pragma once #pragma once
#include "mlir/IR/DialectRegistry.h" #include "mlir/IR/DialectRegistry.h"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir; using namespace mlir;

View File

@@ -1,5 +1,6 @@
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Compiler/CompilerUtils.hpp" #include "src/Compiler/CompilerUtils.hpp"
@@ -10,22 +11,19 @@ namespace onnx_mlir {
namespace { namespace {
struct CountInstructionPass struct CountInstructionPass : public PassWrapper<CountInstructionPass, OperationPass<ModuleOp>> {
: public PassWrapper<CountInstructionPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CountInstructionPass) MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CountInstructionPass)
StringRef getArgument() const override { return "count-instruction-pass"; } StringRef getArgument() const override { return "count-instruction-pass"; }
StringRef getDescription() const override { StringRef getDescription() const override { return "Count instructions for each core/compute in the module"; }
return "Count instructions for each core/compute in the module";
}
// Make sure that we have a valid default constructor and copy // Make sure that we have a valid default constructor and copy
// constructor to make sure that the options are initialized properly. // constructor to make sure that the options are initialized properly.
CountInstructionPass() {} CountInstructionPass() {}
CountInstructionPass(const CountInstructionPass &pass) CountInstructionPass(const CountInstructionPass& pass)
: PassWrapper<CountInstructionPass, OperationPass<ModuleOp>>() {} : PassWrapper<CountInstructionPass, OperationPass<ModuleOp>>() {}
void runOnOperation() final { void runOnOperation() final {
ModuleOp module = getOperation(); ModuleOp module = getOperation();
@@ -37,8 +35,7 @@ struct CountInstructionPass
for (auto computeOp : func.getOps<spatial::SpatWeightedCompute>()) { for (auto computeOp : func.getOps<spatial::SpatWeightedCompute>()) {
unsigned instructionCount = 0; unsigned instructionCount = 0;
instructionCount += computeOp.getBody().front().getOperations().size(); instructionCount += computeOp.getBody().front().getOperations().size();
llvm::outs() << "Compute " << computeId << ": " << instructionCount llvm::outs() << "Compute " << computeId << ": " << instructionCount << " instructions\n";
<< " instructions\n";
totalInstructionCount += instructionCount; totalInstructionCount += instructionCount;
computeId++; computeId++;
} }
@@ -47,21 +44,17 @@ struct CountInstructionPass
for (auto coreOp : func.getOps<pim::PimCoreOp>()) { for (auto coreOp : func.getOps<pim::PimCoreOp>()) {
unsigned instructionCount = 0; unsigned instructionCount = 0;
instructionCount += coreOp.getBody().front().getOperations().size(); instructionCount += coreOp.getBody().front().getOperations().size();
llvm::outs() << "Core " << coreId << ": " << instructionCount llvm::outs() << "Core " << coreId << ": " << instructionCount << " instructions\n";
<< " instructions\n";
totalInstructionCount += instructionCount; totalInstructionCount += instructionCount;
coreId++; coreId++;
} }
llvm::outs() << "Total instruction count: " << totalInstructionCount llvm::outs() << "Total instruction count: " << totalInstructionCount << "\n";
<< "\n";
} }
}; };
} // namespace } // namespace
std::unique_ptr<Pass> createCountInstructionPass() { std::unique_ptr<Pass> createCountInstructionPass() { return std::make_unique<CountInstructionPass>(); }
return std::make_unique<CountInstructionPass>();
}
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -1,6 +1,7 @@
#pragma once #pragma once
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include <memory> #include <memory>
using namespace mlir; using namespace mlir;

View File

@@ -1,6 +1,7 @@
#pragma once #pragma once
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
#include "src/Accelerators/Accelerator.hpp" #include "src/Accelerators/Accelerator.hpp"
namespace onnx_mlir { namespace onnx_mlir {
@@ -9,38 +10,37 @@ namespace accel {
/// Singleton class to construct PIM accelerator. /// Singleton class to construct PIM accelerator.
class PimAccelerator final : public Accelerator { class PimAccelerator final : public Accelerator {
private: private:
static PimAccelerator *instance; static PimAccelerator* instance;
PimAccelerator(); PimAccelerator();
public: public:
/// Singleton should not be clonable or assignable. /// Singleton should not be clonable or assignable.
PimAccelerator(PimAccelerator &) = delete; PimAccelerator(PimAccelerator&) = delete;
void operator=(const PimAccelerator &) = delete; void operator=(const PimAccelerator&) = delete;
~PimAccelerator(); ~PimAccelerator();
/// Creates an instance on the first invocation. Subsequent invocations /// Creates an instance on the first invocation. Subsequent invocations
/// return the existing instance. /// return the existing instance.
static PimAccelerator *getInstance(); static PimAccelerator* getInstance();
/// Define classof to be able to use isa<>, cast<>, dyn_cast<>, etc. /// Define classof to be able to use isa<>, cast<>, dyn_cast<>, etc.
static bool classof(const Accelerator *accel) { static bool classof(const Accelerator* accel) { return accel->getKind() == Accelerator::Kind::PIM; }
return accel->getKind() == Accelerator::Kind::PIM; static bool classof(const PimAccelerator*) { return true; }
}
static bool classof(const PimAccelerator *) { return true; }
uint64_t getVersionNumber() const final; uint64_t getVersionNumber() const final;
//===--------------------------------------------------------------------===// //===--------------------------------------------------------------------===//
// Hooks for onnx-mlir-opt driver // Hooks for onnx-mlir-opt driver
//===--------------------------------------------------------------------===// //===--------------------------------------------------------------------===//
virtual void addPasses(mlir::OwningOpRef<mlir::ModuleOp> &module, virtual void addPasses(mlir::OwningOpRef<mlir::ModuleOp>& module,
mlir::PassManager &pm, onnx_mlir::EmissionTargetType &emissionTarget, mlir::PassManager& pm,
std::string outputNameNoExt) const final; onnx_mlir::EmissionTargetType& emissionTarget,
std::string outputNameNoExt) const final;
//===--------------------------------------------------------------------===// //===--------------------------------------------------------------------===//
// Hooks for onnx-mlir-opt driver // Hooks for onnx-mlir-opt driver
//===--------------------------------------------------------------------===// //===--------------------------------------------------------------------===//
virtual void registerDialects(mlir::DialectRegistry &registry) const final; virtual void registerDialects(mlir::DialectRegistry& registry) const final;
virtual void registerPasses(int optLevel) const final; virtual void registerPasses(int optLevel) const final;
//===--------------------------------------------------------------------===// //===--------------------------------------------------------------------===//
// Hooks for both onnx-mlir and onnx-mlir-opt drivers // Hooks for both onnx-mlir and onnx-mlir-opt drivers
@@ -49,21 +49,19 @@ public:
//===--------------------------------------------------------------------===// //===--------------------------------------------------------------------===//
// Hooks for onnx-to-krnl pass // Hooks for onnx-to-krnl pass
//===--------------------------------------------------------------------===// //===--------------------------------------------------------------------===//
virtual mlir::MemRefType convertTensorTypeToMemRefType( virtual mlir::MemRefType convertTensorTypeToMemRefType(const mlir::TensorType tensorType) const final;
const mlir::TensorType tensorType) const final; virtual void conversionTargetONNXToKrnl(mlir::ConversionTarget& target) const final;
virtual void conversionTargetONNXToKrnl( virtual void rewritePatternONNXToKrnl(mlir::RewritePatternSet& patterns,
mlir::ConversionTarget &target) const final; mlir::TypeConverter& typeConverter,
virtual void rewritePatternONNXToKrnl(mlir::RewritePatternSet &patterns, mlir::MLIRContext* ctx) const final;
mlir::TypeConverter &typeConverter, mlir::MLIRContext *ctx) const final;
//===--------------------------------------------------------------------===// //===--------------------------------------------------------------------===//
// Hooks for krnl-to-llvm pass // Hooks for krnl-to-llvm pass
//===--------------------------------------------------------------------===// //===--------------------------------------------------------------------===//
virtual void conversionTargetKrnlToLLVM( virtual void conversionTargetKrnlToLLVM(mlir::ConversionTarget& target) const final;
mlir::ConversionTarget &target) const final; virtual void rewritePatternKrnlToLLVM(mlir::RewritePatternSet& patterns,
virtual void rewritePatternKrnlToLLVM(mlir::RewritePatternSet &patterns, mlir::LLVMTypeConverter& typeConverter,
mlir::LLVMTypeConverter &typeConverter, mlir::MLIRContext* ctx) const final;
mlir::MLIRContext *ctx) const final;
}; };
} // namespace accel } // namespace accel

View File

@@ -63,7 +63,8 @@ void PimBufferizationPass::runOnOperation() {
void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const { void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const {
MLIRContext* ctx = funcOp.getContext(); MLIRContext* ctx = funcOp.getContext();
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) { funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
bool isAlwaysWeight = llvm::all_of(getGlobalOp->getUsers(), [](auto user) -> bool { return isa<PimCoreOp>(user); }) && !getGlobalOp->getUsers().empty() ; bool isAlwaysWeight = llvm::all_of(getGlobalOp->getUsers(), [](auto user) -> bool { return isa<PimCoreOp>(user); })
&& !getGlobalOp->getUsers().empty();
if (isAlwaysWeight) { if (isAlwaysWeight) {
auto globalMemrefOp = moduleOp.lookupSymbol<memref::GlobalOp>(getGlobalOp.getName()); auto globalMemrefOp = moduleOp.lookupSymbol<memref::GlobalOp>(getGlobalOp.getName());
assert("Weights must be constants" && globalMemrefOp.getConstant()); assert("Weights must be constants" && globalMemrefOp.getConstant());