add .clang-format
reformat all src
This commit is contained in:
143
.clang-format
Normal file
143
.clang-format
Normal file
@@ -0,0 +1,143 @@
|
||||
---
|
||||
AccessModifierOffset: -2
|
||||
AlignAfterOpenBracket: Align
|
||||
AlignArrayOfStructures: Left
|
||||
AlignConsecutiveShortCaseStatements:
|
||||
Enabled: true
|
||||
AlignConsecutiveAssignments:
|
||||
Enabled: false
|
||||
AlignConsecutiveBitFields:
|
||||
Enabled: false
|
||||
AlignConsecutiveDeclarations:
|
||||
Enabled: false
|
||||
AlignConsecutiveMacros:
|
||||
Enabled: false
|
||||
AlignEscapedNewlines: Left
|
||||
AlignOperands: AlignAfterOperator
|
||||
AlignTrailingComments:
|
||||
Kind: Always
|
||||
OverEmptyLines: 4
|
||||
AllowAllArgumentsOnNextLine: true
|
||||
AllowAllParametersOfDeclarationOnNextLine: true
|
||||
AllowShortBlocksOnASingleLine: Empty
|
||||
AllowShortCaseExpressionOnASingleLine: true
|
||||
AllowShortCaseLabelsOnASingleLine: true
|
||||
AllowShortCompoundRequirementOnASingleLine: true
|
||||
AllowShortEnumsOnASingleLine: false
|
||||
AllowShortFunctionsOnASingleLine: All
|
||||
AllowShortIfStatementsOnASingleLine: Never
|
||||
AllowShortLambdasOnASingleLine: All
|
||||
AllowShortLoopsOnASingleLine: false
|
||||
AlwaysBreakBeforeMultilineStrings: false
|
||||
BinPackArguments: false
|
||||
BinPackParameters: false
|
||||
BitFieldColonSpacing: Both
|
||||
BraceWrapping:
|
||||
BeforeElse: true
|
||||
BeforeCatch: true
|
||||
BeforeWhile: true
|
||||
BreakBeforeBraces: Custom
|
||||
BracedInitializerIndentWidth: 2
|
||||
BreakAdjacentStringLiterals: true
|
||||
BreakAfterAttributes: Never
|
||||
BreakAfterJavaFieldAnnotations: false
|
||||
BreakArrays: false
|
||||
BreakBeforeBinaryOperators: NonAssignment
|
||||
BreakBeforeConceptDeclarations: Always
|
||||
BreakBeforeInlineASMColon: OnlyMultiline
|
||||
BreakBeforeTernaryOperators: true
|
||||
BreakConstructorInitializers: BeforeColon
|
||||
BreakFunctionDefinitionParameters: false
|
||||
BreakInheritanceList: AfterComma
|
||||
BreakStringLiterals: true
|
||||
BreakTemplateDeclarations: Yes
|
||||
ColumnLimit: 120
|
||||
CompactNamespaces: false
|
||||
ConstructorInitializerIndentWidth: 0
|
||||
ContinuationIndentWidth: 2
|
||||
Cpp11BracedListStyle: true
|
||||
DerivePointerAlignment: false
|
||||
DisableFormat: false
|
||||
EmptyLineAfterAccessModifier: Never
|
||||
EmptyLineBeforeAccessModifier: Always
|
||||
FixNamespaceComments: true
|
||||
IncludeBlocks: Regroup
|
||||
IncludeCategories:
|
||||
- Regex: (<|")mlir(.*)(>|")
|
||||
Priority: 2
|
||||
- Regex: (<|")llvm(.*)(>|")
|
||||
Priority: 3
|
||||
- Regex: <.*>
|
||||
Priority: 4
|
||||
IncludeIsMainRegex: (Test)?$
|
||||
IncludeIsMainSourceRegex: ""
|
||||
IndentAccessModifiers: false
|
||||
IndentCaseBlocks: false
|
||||
IndentCaseLabels: false
|
||||
IndentExternBlock: Indent
|
||||
IndentGotoLabels: false
|
||||
IndentPPDirectives: None
|
||||
IndentRequiresClause: true
|
||||
IndentWidth: 2
|
||||
IndentWrappedFunctionNames: false
|
||||
InsertBraces: false
|
||||
InsertNewlineAtEOF: true
|
||||
KeepEmptyLines:
|
||||
AtEndOfFile: false
|
||||
AtStartOfBlock: true
|
||||
AtStartOfFile: false
|
||||
LambdaBodyIndentation: Signature
|
||||
Language: Cpp
|
||||
LineEnding: LF
|
||||
MainIncludeChar: AngleBracket
|
||||
MaxEmptyLinesToKeep: 1
|
||||
NamespaceIndentation: None
|
||||
PackConstructorInitializers: NextLineOnly
|
||||
PointerAlignment: Left
|
||||
ReferenceAlignment: Pointer
|
||||
RemoveBracesLLVM: true
|
||||
RemoveParentheses: ReturnStatement
|
||||
RemoveSemicolon: true
|
||||
RequiresClausePosition: OwnLine
|
||||
RequiresExpressionIndentation: OuterScope
|
||||
SeparateDefinitionBlocks: Leave
|
||||
ShortNamespaceLines: 4
|
||||
SortIncludes: CaseSensitive
|
||||
SpaceAfterCStyleCast: true
|
||||
SpaceAfterLogicalNot: false
|
||||
SpaceAfterTemplateKeyword: true
|
||||
SpaceAroundPointerQualifiers: Before
|
||||
SpaceBeforeAssignmentOperators: true
|
||||
SpaceBeforeCaseColon: false
|
||||
SpaceBeforeCpp11BracedList: true
|
||||
SpaceBeforeCtorInitializerColon: true
|
||||
SpaceBeforeInheritanceColon: true
|
||||
SpaceBeforeJsonColon: false
|
||||
SpaceBeforeParensOptions:
|
||||
AfterControlStatements: true
|
||||
AfterForeachMacros: true
|
||||
AfterFunctionDeclarationName: false
|
||||
AfterFunctionDefinitionName: false
|
||||
AfterIfMacros: true
|
||||
AfterOverloadedOperator: false
|
||||
AfterPlacementOperator: false
|
||||
AfterRequiresInClause: false
|
||||
AfterRequiresInExpression: false
|
||||
BeforeNonEmptyParentheses: false
|
||||
SpaceBeforeRangeBasedForLoopColon: true
|
||||
SpaceBeforeSquareBrackets: false
|
||||
SpaceInEmptyBlock: false
|
||||
SpacesBeforeTrailingComments: 1
|
||||
SpacesInAngles: Never
|
||||
SpacesInContainerLiterals: false
|
||||
SpacesInLineCommentPrefix:
|
||||
Minimum: 1
|
||||
Maximum: 1
|
||||
SpacesInParens: Never
|
||||
SpacesInSquareBrackets: false
|
||||
Standard: Latest
|
||||
TabWidth: 4
|
||||
UseTab: Never
|
||||
VerilogBreakBetweenInstancePorts: true
|
||||
Macros:
|
||||
- LLVM_DEBUG(X)=X\n
|
||||
@@ -14,9 +14,9 @@ namespace onnx_mlir {
|
||||
|
||||
std::string getOutputDir();
|
||||
|
||||
void createDirectory(const std::string &directory);
|
||||
void createDirectory(const std::string& directory);
|
||||
|
||||
void dumpModule(mlir::ModuleOp moduleOp, const std::string &name);
|
||||
void dumpModule(mlir::ModuleOp moduleOp, const std::string& name);
|
||||
|
||||
llvm::FailureOr<mlir::Operation*>
|
||||
getOtherEndOfChannel(mlir::Operation* op, bool opIsReceive, mlir::RewriterBase& rewriter);
|
||||
|
||||
@@ -11,20 +11,21 @@
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/LogicalResult.h"
|
||||
#include <cstddef>
|
||||
|
||||
#include <cstddef>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace std;
|
||||
|
||||
@@ -40,8 +41,8 @@ namespace onnx_mlir {
|
||||
*/
|
||||
class Core {
|
||||
public:
|
||||
Core(const size_t coreId, ConversionPatternRewriter &rewriter)
|
||||
: coreId(coreId), rewriter(rewriter) {}
|
||||
Core(const size_t coreId, ConversionPatternRewriter& rewriter)
|
||||
: coreId(coreId), rewriter(rewriter) {}
|
||||
|
||||
/**
|
||||
* @brief Add a MVM operation to the core.
|
||||
@@ -52,8 +53,7 @@ public:
|
||||
* @param mvmOutType The result's shape.
|
||||
* @return Value The result of the MVM operation.
|
||||
*/
|
||||
Value addMVM(
|
||||
Value inputTile, size_t xbarIndex, size_t outputTileId, Type mvmOutType) {
|
||||
Value addMVM(Value inputTile, size_t xbarIndex, size_t outputTileId, Type mvmOutType) {
|
||||
// Use the inputTile as the reference location for the MVM operation.
|
||||
Location loc = inputTile.getLoc();
|
||||
|
||||
@@ -72,8 +72,7 @@ public:
|
||||
// is correct.
|
||||
|
||||
// Construct the MVM operation
|
||||
Value result = rewriter.create<spatial::SpatWeightedMVMOp>(
|
||||
loc, mvmOutType, xbarIndex, operand);
|
||||
Value result = rewriter.create<spatial::SpatWeightedMVMOp>(loc, mvmOutType, xbarIndex, operand);
|
||||
|
||||
// Since we are within the same core and no computation can happen in
|
||||
// paralllel, we can just apply a linear reduction in case we have multiple
|
||||
@@ -84,8 +83,7 @@ public:
|
||||
if (lastMVM != outputTileToMVM.end()) {
|
||||
// MVM results should have the same type for reduction.
|
||||
assert(lastMVM->second.getType() == result.getType());
|
||||
result = rewriter.create<spatial::SpatVAddOp>(
|
||||
loc, mvmOutType, lastMVM->second, result);
|
||||
result = rewriter.create<spatial::SpatVAddOp>(loc, mvmOutType, lastMVM->second, result);
|
||||
}
|
||||
|
||||
outputTileToMVM[outputTileId] = result;
|
||||
@@ -139,16 +137,14 @@ public:
|
||||
spatial::SpatWeightedCompute createWComputeOp(Location loc) {
|
||||
// Get the shape of the results.
|
||||
SmallVector<Type> resultTypes;
|
||||
for (const auto &value : results) {
|
||||
for (const auto& value : results)
|
||||
resultTypes.push_back(value.getType());
|
||||
}
|
||||
|
||||
// Create the WComputeOp, with non-remappable operands only.
|
||||
wcomputeOp = rewriter.create<spatial::SpatWeightedCompute>(
|
||||
loc, resultTypes, xbarWeights, operands);
|
||||
wcomputeOp = rewriter.create<spatial::SpatWeightedCompute>(loc, resultTypes, xbarWeights, operands);
|
||||
|
||||
// Add the body to the WComputeOp.
|
||||
Block *releasedBlock = block.release();
|
||||
Block* releasedBlock = block.release();
|
||||
wcomputeOp.getBody().push_back(releasedBlock);
|
||||
|
||||
// Add the `yieldOp` at the end, with the results.
|
||||
@@ -164,21 +160,18 @@ public:
|
||||
void remapResults() {
|
||||
// Remap all the results to the WComputeOp results.
|
||||
assert(resultsToRemap.size() == wcomputeOp->getNumResults());
|
||||
for (size_t i = 0; i < resultsToRemap.size(); i++) {
|
||||
for (size_t i = 0; i < resultsToRemap.size(); i++)
|
||||
*resultsToRemap[i] = wcomputeOp.getResult(i);
|
||||
}
|
||||
}
|
||||
|
||||
void addRemappedOperands() {
|
||||
// Insert the remappableOperands (which were remapped in
|
||||
// `addRemappableOperand` of another Core)
|
||||
for (auto remappedValue : remappableOperands) {
|
||||
for (auto remappedValue : remappableOperands)
|
||||
wcomputeOp->insertOperands(wcomputeOp->getNumOperands(), *remappedValue);
|
||||
}
|
||||
|
||||
// Update the wcomputeOp operandSegmentSize
|
||||
incrementWeightedComputeInputsSegmentSize(
|
||||
wcomputeOp, static_cast<int>(remappableOperands.size()));
|
||||
incrementWeightedComputeInputsSegmentSize(wcomputeOp, static_cast<int>(remappableOperands.size()));
|
||||
}
|
||||
|
||||
size_t addXbarWeight(Value weight) {
|
||||
@@ -199,31 +192,27 @@ public:
|
||||
llvm::outs() << "Core " << coreId << ":\n";
|
||||
// Print the weights
|
||||
llvm::outs() << "Xbar Weights:\n";
|
||||
for (auto weight : xbarWeights) {
|
||||
for (auto weight : xbarWeights)
|
||||
weight.dump();
|
||||
}
|
||||
// Print the operands
|
||||
llvm::outs() << "Operands:\n";
|
||||
for (auto operand : operands) {
|
||||
for (auto operand : operands)
|
||||
llvm::outs() << operand << "\n";
|
||||
}
|
||||
|
||||
// Dump the body block
|
||||
for (auto &op : block->getOperations()) {
|
||||
for (auto& op : block->getOperations())
|
||||
op.dump();
|
||||
}
|
||||
|
||||
// Print the results
|
||||
llvm::outs() << "Results:\n";
|
||||
for (auto result : results) {
|
||||
for (auto result : results)
|
||||
llvm::outs() << result << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
const size_t coreId;
|
||||
|
||||
private:
|
||||
ConversionPatternRewriter &rewriter;
|
||||
ConversionPatternRewriter& rewriter;
|
||||
|
||||
// Should these be set<Value> instead? But I need to keep the order
|
||||
vector<Value> operands;
|
||||
@@ -246,15 +235,16 @@ private:
|
||||
};
|
||||
|
||||
struct ONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
|
||||
ONNXConvOpTile(MLIRContext *ctx) : OpConversionPattern(ctx) {}
|
||||
ONNXConvOpTile(MLIRContext* ctx)
|
||||
: OpConversionPattern(ctx) {}
|
||||
|
||||
struct Producer_t {
|
||||
Value value;
|
||||
shared_ptr<Core> core;
|
||||
};
|
||||
|
||||
LogicalResult matchAndRewrite(ONNXConvOp conv, ONNXConvOpAdaptor convAdaptor,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
LogicalResult
|
||||
matchAndRewrite(ONNXConvOp conv, ONNXConvOpAdaptor convAdaptor, ConversionPatternRewriter& rewriter) const final {
|
||||
ShapedType xShape = mlir::cast<ShapedType>(convAdaptor.getX().getType());
|
||||
ShapedType wShape = mlir::cast<ShapedType>(convAdaptor.getW().getType());
|
||||
ShapedType bShape = mlir::cast<ShapedType>(convAdaptor.getB().getType());
|
||||
@@ -264,11 +254,9 @@ struct ONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
|
||||
unpackOptionalPairVector(conv.getStrides(), stride_x, stride_y);
|
||||
unpackOptionalPairVector(conv.getDilations(), dilation_x, dilation_y);
|
||||
|
||||
auto padUnpackError =
|
||||
unpackOptionalPadsVector(convAdaptor.getPads(), pad_x, pad_y);
|
||||
if (padUnpackError.has_value()) {
|
||||
auto padUnpackError = unpackOptionalPadsVector(convAdaptor.getPads(), pad_x, pad_y);
|
||||
if (padUnpackError.has_value())
|
||||
return rewriter.notifyMatchFailure(conv, padUnpackError.value());
|
||||
}
|
||||
|
||||
// TODO: Pad value at beginning and end of each dimension could be
|
||||
// different. We should handle this case.
|
||||
@@ -296,11 +284,9 @@ struct ONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
|
||||
|
||||
Location loc = conv.getLoc();
|
||||
|
||||
size_t inputTileCount =
|
||||
ceilIntegerDivide(GET_IMAGE_CHANNEL(xShape), crossbarSize.getValue());
|
||||
size_t inputTileCount = ceilIntegerDivide(GET_IMAGE_CHANNEL(xShape), crossbarSize.getValue());
|
||||
size_t inputTileRemainder = GET_IMAGE_CHANNEL(xShape) % crossbarSize;
|
||||
size_t outputTileCount =
|
||||
ceilIntegerDivide(GET_IMAGE_CHANNEL(yShape), crossbarSize.getValue());
|
||||
size_t outputTileCount = ceilIntegerDivide(GET_IMAGE_CHANNEL(yShape), crossbarSize.getValue());
|
||||
size_t outputTileRemainder = GET_IMAGE_CHANNEL(yShape) % crossbarSize;
|
||||
|
||||
// Tile the input tensor
|
||||
@@ -310,22 +296,20 @@ struct ONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
|
||||
// c. Pixel `y` position
|
||||
// For example: inputTiles[channelTile][x][y]
|
||||
// Example complete input tensor: tensor<1x3x6x6xf32> (NxCxWxH)
|
||||
SmallVector<SmallVector<SmallVector<Value>>> inputTiles(inputTileCount,
|
||||
SmallVector<SmallVector<Value>>(input_w, SmallVector<Value>(input_h)));
|
||||
SmallVector<SmallVector<SmallVector<Value>>> inputTiles(
|
||||
inputTileCount, SmallVector<SmallVector<Value>>(input_w, SmallVector<Value>(input_h)));
|
||||
|
||||
auto resolveErrorOpt = resolveImgInputTiles(convAdaptor.getX(), inputTiles,
|
||||
inputTileCount, inputTileRemainder, input_h, input_h, rewriter);
|
||||
if (resolveErrorOpt.has_value()) {
|
||||
auto resolveErrorOpt = resolveImgInputTiles(
|
||||
convAdaptor.getX(), inputTiles, inputTileCount, inputTileRemainder, input_h, input_h, rewriter);
|
||||
if (resolveErrorOpt.has_value())
|
||||
return rewriter.notifyMatchFailure(conv, *resolveErrorOpt);
|
||||
}
|
||||
|
||||
SmallVector<OpFoldResult> strides =
|
||||
SmallVector<OpFoldResult>(4, rewriter.getIndexAttr(1));
|
||||
SmallVector<OpFoldResult> offsets =
|
||||
SmallVector<OpFoldResult>(4, rewriter.getIndexAttr(0));
|
||||
SmallVector<OpFoldResult> sizes = SmallVector<OpFoldResult>{
|
||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(crossbarSize),
|
||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
SmallVector<OpFoldResult> strides = SmallVector<OpFoldResult>(4, rewriter.getIndexAttr(1));
|
||||
SmallVector<OpFoldResult> offsets = SmallVector<OpFoldResult>(4, rewriter.getIndexAttr(0));
|
||||
SmallVector<OpFoldResult> sizes = SmallVector<OpFoldResult> {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(crossbarSize),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1)};
|
||||
|
||||
// Tile the weight tensor
|
||||
// Weight tiles need to be indexed by:
|
||||
@@ -336,31 +320,30 @@ struct ONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
|
||||
// For example: weightTiles[filterTile][channelTile][x][y]
|
||||
// Example complete weight tensor: tensor<32x3x3x3xf32> (FxCxWxH)
|
||||
SmallVector<SmallVector<SmallVector<SmallVector<Value>>>> weightTiles(
|
||||
outputTileCount,
|
||||
SmallVector<SmallVector<SmallVector<Value>>>(inputTileCount,
|
||||
SmallVector<SmallVector<Value>>(krn_w, SmallVector<Value>(krn_h))));
|
||||
outputTileCount,
|
||||
SmallVector<SmallVector<SmallVector<Value>>>(inputTileCount,
|
||||
SmallVector<SmallVector<Value>>(krn_w, SmallVector<Value>(krn_h))));
|
||||
strides = SmallVector<OpFoldResult>(4, rewriter.getIndexAttr(1));
|
||||
offsets = SmallVector<OpFoldResult>(4, rewriter.getIndexAttr(0));
|
||||
sizes = {rewriter.getIndexAttr(crossbarSize),
|
||||
rewriter.getIndexAttr(crossbarSize), rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1)};
|
||||
rewriter.getIndexAttr(crossbarSize),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1)};
|
||||
for (size_t i = 0; i < outputTileCount; i++) {
|
||||
if (i == outputTileCount - 1 && outputTileRemainder != 0) {
|
||||
if (i == outputTileCount - 1 && outputTileRemainder != 0)
|
||||
sizes[0] = rewriter.getIndexAttr(outputTileRemainder);
|
||||
}
|
||||
sizes[1] = rewriter.getIndexAttr(crossbarSize);
|
||||
offsets[0] = rewriter.getIndexAttr(i * crossbarSize);
|
||||
for (size_t j = 0; j < inputTileCount; j++) {
|
||||
if (j == inputTileCount - 1 && inputTileRemainder != 0) {
|
||||
if (j == inputTileCount - 1 && inputTileRemainder != 0)
|
||||
sizes[1] = rewriter.getIndexAttr(inputTileRemainder);
|
||||
}
|
||||
for (size_t x = 0; x < krn_w; x++) {
|
||||
for (size_t y = 0; y < krn_h; y++) {
|
||||
offsets[1] = rewriter.getIndexAttr(j * crossbarSize);
|
||||
offsets[2] = rewriter.getIndexAttr(x);
|
||||
offsets[3] = rewriter.getIndexAttr(y);
|
||||
weightTiles[i][j][x][y] = rewriter.create<tensor::ExtractSliceOp>(
|
||||
loc, convAdaptor.getW(), offsets, sizes, strides);
|
||||
weightTiles[i][j][x][y] =
|
||||
rewriter.create<tensor::ExtractSliceOp>(loc, convAdaptor.getW(), offsets, sizes, strides);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -379,56 +362,45 @@ struct ONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
|
||||
// For example: outputTiles[filterTile][x][y]
|
||||
// Example complete output tensor: tensor<1x32x3x3xf32> (NxFxWxH)
|
||||
SmallVector<SmallVector<SmallVector<shared_ptr<Value>>>> outputTiles(
|
||||
outputTileCount,
|
||||
SmallVector<SmallVector<shared_ptr<Value>>>(
|
||||
output_w, SmallVector<shared_ptr<Value>>(output_h, nullptr)));
|
||||
outputTileCount,
|
||||
SmallVector<SmallVector<shared_ptr<Value>>>(output_w, SmallVector<shared_ptr<Value>>(output_h, nullptr)));
|
||||
|
||||
size_t replicationFactor;
|
||||
if (!conv->hasAttr(REPLICATION_ATTR_NAME)) {
|
||||
if (!conv->hasAttr(REPLICATION_ATTR_NAME))
|
||||
replicationFactor = 1;
|
||||
} else {
|
||||
replicationFactor =
|
||||
conv->getAttrOfType<IntegerAttr>(REPLICATION_ATTR_NAME).getInt();
|
||||
}
|
||||
else
|
||||
replicationFactor = conv->getAttrOfType<IntegerAttr>(REPLICATION_ATTR_NAME).getInt();
|
||||
// producers[outTile][out_x][out_y][producerIndex]
|
||||
vector<vector<vector<vector<Producer_t>>>> producers =
|
||||
vector<vector<vector<vector<Producer_t>>>>(outputTileCount,
|
||||
vector<vector<vector<Producer_t>>>(output_w,
|
||||
vector<vector<Producer_t>>(output_h, vector<Producer_t>())));
|
||||
vector<vector<vector<vector<Producer_t>>>> producers = vector<vector<vector<vector<Producer_t>>>>(
|
||||
outputTileCount,
|
||||
vector<vector<vector<Producer_t>>>(output_w, vector<vector<Producer_t>>(output_h, vector<Producer_t>())));
|
||||
|
||||
// Schedule in cores
|
||||
size_t coreId = 0;
|
||||
vector<shared_ptr<Core>> curCores(replicationFactor);
|
||||
for (size_t i = 0; i < replicationFactor; i++) {
|
||||
for (size_t i = 0; i < replicationFactor; i++)
|
||||
curCores[i] = make_shared<Core>(coreId++, rewriter);
|
||||
}
|
||||
|
||||
vector<shared_ptr<Core>> cores;
|
||||
|
||||
const size_t replicationSliceSize =
|
||||
ceilIntegerDivide(input_w, replicationFactor);
|
||||
const size_t replicationSliceSize = ceilIntegerDivide(input_w, replicationFactor);
|
||||
|
||||
for (size_t krn_x = 0; krn_x < krn_h; krn_x++) {
|
||||
for (size_t krn_y = 0; krn_y < krn_w; krn_y++) {
|
||||
|
||||
RankedTensorType mvmOutType =
|
||||
RankedTensorType::get({1, static_cast<long>(crossbarSize), 1, 1},
|
||||
bShape.getElementType());
|
||||
RankedTensorType::get({1, static_cast<long>(crossbarSize), 1, 1}, bShape.getElementType());
|
||||
|
||||
for (size_t outTile = 0; outTile < outputTileCount; outTile++) {
|
||||
|
||||
if (outTile == outputTileCount - 1 && outputTileRemainder != 0) {
|
||||
mvmOutType = mvmOutType.clone(
|
||||
{1, static_cast<long>(outputTileRemainder), 1, 1});
|
||||
}
|
||||
if (outTile == outputTileCount - 1 && outputTileRemainder != 0)
|
||||
mvmOutType = mvmOutType.clone({1, static_cast<long>(outputTileRemainder), 1, 1});
|
||||
|
||||
for (size_t inTile = 0; inTile < inputTileCount; inTile++) {
|
||||
|
||||
vector<size_t> xbarIndexes(replicationFactor);
|
||||
for (size_t i = 0; i < replicationFactor; i++) {
|
||||
xbarIndexes[i] = curCores[i]->addXbarWeight(
|
||||
weightTiles[outTile][inTile][krn_x][krn_y]);
|
||||
}
|
||||
for (size_t i = 0; i < replicationFactor; i++)
|
||||
xbarIndexes[i] = curCores[i]->addXbarWeight(weightTiles[outTile][inTile][krn_x][krn_y]);
|
||||
|
||||
size_t out_x = 0;
|
||||
for (size_t in_x = 0; in_x < input_w; in_x += stride_x) {
|
||||
@@ -443,25 +415,20 @@ struct ONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
|
||||
|
||||
for (size_t in_y = 0; in_y < input_h; in_y += stride_y) {
|
||||
// Adjust the input based on the kernel
|
||||
int actual_in_x = in_x - ((int)krn_w / 2) + krn_x * dilation_x;
|
||||
int actual_in_y = in_y - ((int)krn_h / 2) + krn_y * dilation_y;
|
||||
int actual_in_x = in_x - ((int) krn_w / 2) + krn_x * dilation_x;
|
||||
int actual_in_y = in_y - ((int) krn_h / 2) + krn_y * dilation_y;
|
||||
|
||||
// Check if we are within the input image
|
||||
if (verifyWithinBoundsAndPaddings(input_w, input_h, actual_in_x,
|
||||
actual_in_y, pad_x, pad_y)
|
||||
.failed()) {
|
||||
if (verifyWithinBoundsAndPaddings(input_w, input_h, actual_in_x, actual_in_y, pad_x, pad_y).failed()) {
|
||||
out_y++;
|
||||
continue;
|
||||
}
|
||||
|
||||
size_t outTileId =
|
||||
outTile * output_w * output_h + out_x * output_h + out_y;
|
||||
size_t outTileId = outTile * output_w * output_h + out_x * output_h + out_y;
|
||||
auto mvm = curCores[coreIndex]->addMVM(
|
||||
inputTiles[inTile][actual_in_x][actual_in_y],
|
||||
xbarIndexes[coreIndex], outTileId, mvmOutType);
|
||||
inputTiles[inTile][actual_in_x][actual_in_y], xbarIndexes[coreIndex], outTileId, mvmOutType);
|
||||
|
||||
producers[outTile][out_x][out_y].push_back(
|
||||
{mvm, curCores[coreIndex]});
|
||||
producers[outTile][out_x][out_y].push_back({mvm, curCores[coreIndex]});
|
||||
|
||||
out_y++;
|
||||
}
|
||||
@@ -481,11 +448,9 @@ struct ONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
|
||||
}
|
||||
}
|
||||
|
||||
for (auto &curCore : curCores) {
|
||||
if (curCore->isCoreEmpty() == false) {
|
||||
for (auto& curCore : curCores)
|
||||
if (curCore->isCoreEmpty() == false)
|
||||
cores.emplace_back(std::move(curCore));
|
||||
}
|
||||
}
|
||||
curCores.clear();
|
||||
// Now, do the reduction of each output pixel tile
|
||||
for (size_t outTile = 0; outTile < outputTileCount; outTile++) {
|
||||
@@ -497,9 +462,8 @@ struct ONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
|
||||
// core.
|
||||
|
||||
std::unordered_map<size_t, Producer_t> withinCoreReducedProducers;
|
||||
for (auto producer : producers[outTile][out_x][out_y]) {
|
||||
for (auto producer : producers[outTile][out_x][out_y])
|
||||
withinCoreReducedProducers[producer.core->coreId] = producer;
|
||||
}
|
||||
|
||||
// Now, we need to apply inter-core reduction
|
||||
|
||||
@@ -509,8 +473,7 @@ struct ONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
|
||||
|
||||
auto singleProducer = withinCoreReducedProducers.begin()->second;
|
||||
// Use last producer as the final result
|
||||
auto reducedValue =
|
||||
singleProducer.core->makeResultRemappable(singleProducer.value);
|
||||
auto reducedValue = singleProducer.core->makeResultRemappable(singleProducer.value);
|
||||
outputTiles[outTile][out_x][out_y] = reducedValue;
|
||||
continue;
|
||||
}
|
||||
@@ -535,9 +498,9 @@ struct ONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
|
||||
auto lastProducerCoreId = lastProducer.core->coreId;
|
||||
auto curProducerCoreId = curProducer.core->coreId;
|
||||
|
||||
assert(lastProducerCoreId != curProducerCoreId &&
|
||||
"We should have already applied within-core reduction, how "
|
||||
"could we have same cores here?");
|
||||
assert(lastProducerCoreId != curProducerCoreId
|
||||
&& "We should have already applied within-core reduction, how "
|
||||
"could we have same cores here?");
|
||||
|
||||
// Sort the cores by coreId
|
||||
if (curProducerCoreId < lastProducerCoreId) {
|
||||
@@ -545,7 +508,8 @@ struct ONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
|
||||
core1Value = curProducer.value;
|
||||
core2 = lastProducer.core;
|
||||
core2Value = lastProducer.value;
|
||||
} else {
|
||||
}
|
||||
else {
|
||||
core1 = lastProducer.core;
|
||||
core1Value = lastProducer.value;
|
||||
core2 = curProducer.core;
|
||||
@@ -556,9 +520,8 @@ struct ONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
|
||||
auto secondCoreBlockArg = core2->addRemappableOperand(newCoreRes);
|
||||
|
||||
rewriter.setInsertionPointAfterValue(core2Value);
|
||||
Value vaddRes =
|
||||
rewriter.create<spatial::SpatVAddOp>(core2Value.getLoc(),
|
||||
core2Value.getType(), core2Value, secondCoreBlockArg);
|
||||
Value vaddRes = rewriter.create<spatial::SpatVAddOp>(
|
||||
core2Value.getLoc(), core2Value.getType(), core2Value, secondCoreBlockArg);
|
||||
|
||||
lastProducer = {vaddRes, core2};
|
||||
|
||||
@@ -568,8 +531,7 @@ struct ONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
|
||||
// TODO: Add the bias and apply mapping (if present)
|
||||
|
||||
// Use last producer as the final result
|
||||
auto reducedValue =
|
||||
lastProducer.core->makeResultRemappable(lastProducer.value);
|
||||
auto reducedValue = lastProducer.core->makeResultRemappable(lastProducer.value);
|
||||
outputTiles[outTile][out_x][out_y] = reducedValue;
|
||||
}
|
||||
}
|
||||
@@ -578,15 +540,14 @@ struct ONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
|
||||
// Now, we need to turn the cores into a spatial::SpatWeightedCompute.
|
||||
rewriter.setInsertionPointAfter(conv);
|
||||
spatial::SpatWeightedCompute lastWComputeOp;
|
||||
for (auto &core : cores) {
|
||||
for (auto& core : cores) {
|
||||
lastWComputeOp = core->createWComputeOp(loc);
|
||||
core->remapResults();
|
||||
rewriter.setInsertionPointAfter(lastWComputeOp);
|
||||
}
|
||||
|
||||
for (auto &core : cores) {
|
||||
for (auto& core : cores)
|
||||
core->addRemappedOperands();
|
||||
}
|
||||
|
||||
// Set the insertion point after the last WComputeOp.
|
||||
rewriter.setInsertionPointAfter(lastWComputeOp);
|
||||
@@ -597,8 +558,7 @@ struct ONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
|
||||
for (size_t outTile = 0; outTile < outputTileCount; outTile++)
|
||||
tilesToConcat.push_back(*outputTiles[outTile][outX][outY]);
|
||||
|
||||
Value outputImage = rewriter.create<spatial::SpatImgConcatOp>(
|
||||
loc, conv.getY().getType(), tilesToConcat);
|
||||
Value outputImage = rewriter.create<spatial::SpatImgConcatOp>(loc, conv.getY().getType(), tilesToConcat);
|
||||
|
||||
// Value outputImage =
|
||||
// createImgConcatOp(outputTiles, rewriter, loc, Y.getType());
|
||||
@@ -616,9 +576,8 @@ struct ONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
|
||||
}
|
||||
};
|
||||
|
||||
void populateTilingConvOpPattern(
|
||||
RewritePatternSet &patterns, MLIRContext *ctx) {
|
||||
void populateTilingConvOpPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.insert<ONNXConvOpTile>(ctx);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -1,6 +1,3 @@
|
||||
#include "Compiler/PimCompilerOptions.hpp"
|
||||
#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
||||
#include "Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
@@ -8,13 +5,19 @@
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/WeightSubdivider.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstddef>
|
||||
#include <unistd.h>
|
||||
|
||||
#include "Compiler/PimCompilerOptions.hpp"
|
||||
#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
||||
#include "Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/WeightSubdivider.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace std;
|
||||
|
||||
@@ -27,10 +30,11 @@ namespace onnx_mlir {
|
||||
* output tensor.
|
||||
*/
|
||||
struct ExperimentalONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
|
||||
ExperimentalONNXConvOpTile(MLIRContext *ctx) : OpConversionPattern(ctx) {}
|
||||
ExperimentalONNXConvOpTile(MLIRContext* ctx)
|
||||
: OpConversionPattern(ctx) {}
|
||||
|
||||
LogicalResult matchAndRewrite(ONNXConvOp conv, ONNXConvOpAdaptor convAdaptor,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
LogicalResult
|
||||
matchAndRewrite(ONNXConvOp conv, ONNXConvOpAdaptor convAdaptor, ConversionPatternRewriter& rewriter) const final {
|
||||
|
||||
// --------------------------------- //
|
||||
// --- READ OPERATION PARAMETERS --- //
|
||||
@@ -46,12 +50,12 @@ struct ExperimentalONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
|
||||
ShapedType weightsType = cast<ShapedType>(convAdaptor.getW().getType());
|
||||
|
||||
// TODO: Address bigger batches.
|
||||
assert(GET_IMAGE_N(inputType) == 1 && "Batch size must be 1"
|
||||
"for convolution.");
|
||||
assert(GET_IMAGE_N(inputType) == 1
|
||||
&& "Batch size must be 1"
|
||||
"for convolution.");
|
||||
|
||||
// TODO: Address replication.
|
||||
assert(coresCount.getValue() == -1 &&
|
||||
"Replication is not yet supported for convolution.");
|
||||
assert(coresCount.getValue() == -1 && "Replication is not yet supported for convolution.");
|
||||
|
||||
// TODO: Address bias addition.
|
||||
|
||||
@@ -97,11 +101,10 @@ struct ExperimentalONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
|
||||
long inputTile = it;
|
||||
|
||||
// Create the slicing sizes.
|
||||
SmallVector<OpFoldResult> slicingSizes{
|
||||
/* 0 */ rewriter.getIndexAttr(crossbarHeight),
|
||||
/* 1 */ rewriter.getIndexAttr(crossbarWidth),
|
||||
/* 2 */ rewriter.getIndexAttr(1),
|
||||
/* 3 */ rewriter.getIndexAttr(1)};
|
||||
SmallVector<OpFoldResult> slicingSizes {/* 0 */ rewriter.getIndexAttr(crossbarHeight),
|
||||
/* 1 */ rewriter.getIndexAttr(crossbarWidth),
|
||||
/* 2 */ rewriter.getIndexAttr(1),
|
||||
/* 3 */ rewriter.getIndexAttr(1)};
|
||||
|
||||
// - Slicing along the filter x position.
|
||||
// - Slicing along the filter y position.
|
||||
@@ -109,16 +112,14 @@ struct ExperimentalONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
|
||||
for (size_t filterY = 0; filterY < kernelHeight; ++filterY) {
|
||||
|
||||
// Create the slicing offsets.
|
||||
SmallVector<OpFoldResult> slicingOffsets{
|
||||
/* 0 */ rewriter.getIndexAttr(outputTile * crossbarSize),
|
||||
/* 1 */ rewriter.getIndexAttr(inputTile * crossbarSize),
|
||||
/* 2 */ rewriter.getIndexAttr(filterX),
|
||||
/* 3 */ rewriter.getIndexAttr(filterY)};
|
||||
SmallVector<OpFoldResult> slicingOffsets {/* 0 */ rewriter.getIndexAttr(outputTile * crossbarSize),
|
||||
/* 1 */ rewriter.getIndexAttr(inputTile * crossbarSize),
|
||||
/* 2 */ rewriter.getIndexAttr(filterX),
|
||||
/* 3 */ rewriter.getIndexAttr(filterY)};
|
||||
|
||||
// Create the slice extraction operation.
|
||||
auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
|
||||
conv.getLoc(), convAdaptor.getW(), slicingOffsets, slicingSizes,
|
||||
slicingStrides);
|
||||
conv.getLoc(), convAdaptor.getW(), slicingOffsets, slicingSizes, slicingStrides);
|
||||
|
||||
// Add a note to the extractSliceOp, with the filterX and filterY.
|
||||
weightsGroups[inputTile][outputTile].push_back(extractSliceOp);
|
||||
@@ -150,8 +151,7 @@ struct ExperimentalONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
|
||||
// -------------------------------- //
|
||||
|
||||
// Get the next group of weights for the compute unit.
|
||||
SmallVector<TaggedWeights> weightsGroups =
|
||||
weightSubdivider.popGroups(crossbarCountInCore.getValue());
|
||||
SmallVector<TaggedWeights> weightsGroups = weightSubdivider.popGroups(crossbarCountInCore.getValue());
|
||||
|
||||
SmallVector<Value> computeWeights;
|
||||
SmallVector<Value> computeOperands;
|
||||
@@ -168,15 +168,13 @@ struct ExperimentalONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
|
||||
// Iterate over all weights groups for this compute unit.
|
||||
map<long, Value> localSlices; // WRT the current compute unit.
|
||||
for (auto group : weightsGroups) {
|
||||
for (Value weight : group.weights) {
|
||||
for (Value weight : group.weights)
|
||||
computeWeights.push_back(weight);
|
||||
}
|
||||
|
||||
// There might be multiple weight groups for the same input tile, so if
|
||||
// we've already added the input tile, skip it.
|
||||
if (localSlices.find(group.inputTile) != localSlices.end()) {
|
||||
if (localSlices.find(group.inputTile) != localSlices.end())
|
||||
continue;
|
||||
}
|
||||
|
||||
// We might have already sliced the input tensor for some other compute
|
||||
// unit, so if we have, reuse the slicing operation without creating a
|
||||
@@ -188,26 +186,21 @@ struct ExperimentalONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
|
||||
}
|
||||
|
||||
// Create the input tensor slicing offsets.
|
||||
SmallVector<OpFoldResult> slicingOffsets{
|
||||
/* 0 */ rewriter.getIndexAttr(0), // No offset along the batch axis.
|
||||
/* 1 */ rewriter.getIndexAttr(group.inputTile * crossbarSize),
|
||||
/* 2 */ rewriter.getIndexAttr(0),
|
||||
/* 3 */ rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> slicingOffsets {/* 0 */ rewriter.getIndexAttr(0), // No offset along the batch axis.
|
||||
/* 1 */ rewriter.getIndexAttr(group.inputTile * crossbarSize),
|
||||
/* 2 */ rewriter.getIndexAttr(0),
|
||||
/* 3 */ rewriter.getIndexAttr(0)};
|
||||
|
||||
// Create the input tensor slicing sizes.
|
||||
size_t tilingSize = group.inputTile == inputTileCount.quot
|
||||
? inputTileCount.rem
|
||||
: crossbarSize;
|
||||
SmallVector<OpFoldResult> slicingSizes{
|
||||
/* 0 */ rewriter.getIndexAttr(1), // Batch size is always 1.
|
||||
/* 1 */ rewriter.getIndexAttr(tilingSize),
|
||||
/* 2 */ rewriter.getIndexAttr(GET_IMAGE_WIDTH(inputType)),
|
||||
/* 3 */ rewriter.getIndexAttr(GET_IMAGE_HEIGHT(inputType))};
|
||||
size_t tilingSize = group.inputTile == inputTileCount.quot ? inputTileCount.rem : crossbarSize;
|
||||
SmallVector<OpFoldResult> slicingSizes {/* 0 */ rewriter.getIndexAttr(1), // Batch size is always 1.
|
||||
/* 1 */ rewriter.getIndexAttr(tilingSize),
|
||||
/* 2 */ rewriter.getIndexAttr(GET_IMAGE_WIDTH(inputType)),
|
||||
/* 3 */ rewriter.getIndexAttr(GET_IMAGE_HEIGHT(inputType))};
|
||||
|
||||
// Create the slice extraction operation.
|
||||
auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
|
||||
conv.getLoc(), convAdaptor.getX(), slicingOffsets, slicingSizes,
|
||||
slicingStrides);
|
||||
conv.getLoc(), convAdaptor.getX(), slicingOffsets, slicingSizes, slicingStrides);
|
||||
|
||||
computeOperands.push_back(extractSliceOp);
|
||||
|
||||
@@ -229,36 +222,28 @@ struct ExperimentalONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
|
||||
|
||||
// There might be multiple weight groups for the same output tile, so if
|
||||
// we've already added the output tile, skip it.
|
||||
if (outputTileIndices.find(group.outputTile) !=
|
||||
outputTileIndices.end()) {
|
||||
if (outputTileIndices.find(group.outputTile) != outputTileIndices.end())
|
||||
continue;
|
||||
}
|
||||
|
||||
// Additionally, after adding the input slices as operands, also add any
|
||||
// compatible partial results from previous compute units.
|
||||
if (globalPartialResults.find(group.outputTile) !=
|
||||
globalPartialResults.end()) {
|
||||
if (globalPartialResults.find(group.outputTile) != globalPartialResults.end()) {
|
||||
computeOperands.push_back(globalPartialResults[group.outputTile]);
|
||||
reductionTileIndices[group.outputTile] = computeOperands.size() - 1;
|
||||
}
|
||||
|
||||
// Define the output shape for this group.
|
||||
long outputTileSize = group.outputTile == outputTileCount.quot
|
||||
? outputTileCount.rem
|
||||
: crossbarSize;
|
||||
long outputTileSize = group.outputTile == outputTileCount.quot ? outputTileCount.rem : crossbarSize;
|
||||
|
||||
// TODO: Address non-same padding.
|
||||
SmallVector<int64_t> outputShapeArray{
|
||||
/* 0 */ 1, // Batch size is always 1.
|
||||
/* 1 */ outputTileSize,
|
||||
/* 2 */ GET_IMAGE_WIDTH(outputType), // Same padding assumed.
|
||||
/* 3 */ GET_IMAGE_HEIGHT(outputType)};
|
||||
SmallVector<int64_t> outputShapeArray {/* 0 */ 1, // Batch size is always 1.
|
||||
/* 1 */ outputTileSize,
|
||||
/* 2 */ GET_IMAGE_WIDTH(outputType), // Same padding assumed.
|
||||
/* 3 */ GET_IMAGE_HEIGHT(outputType)};
|
||||
|
||||
auto elementType =
|
||||
dyn_cast<RankedTensorType>(conv.getY().getType()).getElementType();
|
||||
auto elementType = dyn_cast<RankedTensorType>(conv.getY().getType()).getElementType();
|
||||
|
||||
computeOutputType.push_back(
|
||||
RankedTensorType::get(outputShapeArray, elementType));
|
||||
computeOutputType.push_back(RankedTensorType::get(outputShapeArray, elementType));
|
||||
|
||||
outputTileIndices[group.outputTile] = computeOutputType.size() - 1;
|
||||
}
|
||||
@@ -268,43 +253,36 @@ struct ExperimentalONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
|
||||
// ----------------------------- //
|
||||
|
||||
// Create the compute unit.
|
||||
spatial::SpatWeightedCompute currentCompute =
|
||||
rewriter.create<spatial::SpatWeightedCompute>(conv.getLoc(),
|
||||
computeOutputType, computeWeights, computeOperands);
|
||||
spatial::SpatWeightedCompute currentCompute = rewriter.create<spatial::SpatWeightedCompute>(
|
||||
conv.getLoc(), computeOutputType, computeWeights, computeOperands);
|
||||
|
||||
// Create a new block for the compute unit and add the operands.
|
||||
Block *block = rewriter.createBlock(¤tCompute.getRegion());
|
||||
Block* block = rewriter.createBlock(¤tCompute.getRegion());
|
||||
rewriter.setInsertionPointToStart(block);
|
||||
for (Value operand : computeOperands) {
|
||||
for (Value operand : computeOperands)
|
||||
block->addArgument(operand.getType(), conv->getLoc());
|
||||
}
|
||||
|
||||
// Initialize a map of local partial results.
|
||||
map<long, Value> localPartialResults; // WRT the current compute unit.
|
||||
|
||||
// If we have any reduction tiles, add them to the local partial results.
|
||||
for (auto reductionTileIndex : reductionTileIndices) {
|
||||
localPartialResults[reductionTileIndex.first] =
|
||||
block->getArgument(reductionTileIndex.second);
|
||||
}
|
||||
for (auto reductionTileIndex : reductionTileIndices)
|
||||
localPartialResults[reductionTileIndex.first] = block->getArgument(reductionTileIndex.second);
|
||||
|
||||
// Add all the applyFilters operations to the block.
|
||||
for (TaggedWeights group : weightsGroups) {
|
||||
|
||||
// Get the outputType for this group.
|
||||
Type outputType =
|
||||
computeOutputType[outputTileIndices[group.outputTile]];
|
||||
Type outputType = computeOutputType[outputTileIndices[group.outputTile]];
|
||||
|
||||
// Create an apply filters operation.
|
||||
BlockArgument blockArgument =
|
||||
block->getArgument(inputTileIndices[group.inputTile]);
|
||||
BlockArgument blockArgument = block->getArgument(inputTileIndices[group.inputTile]);
|
||||
|
||||
// The list of weight indices is group.startingCrossbarIndex + 0, 1, 2,
|
||||
// ... As many weights as the size of group.weights.
|
||||
SmallVector<long> weightIndices;
|
||||
for (size_t i = 0; i < group.weights.size(); ++i) {
|
||||
for (size_t i = 0; i < group.weights.size(); ++i)
|
||||
weightIndices.push_back(group.startingCrossbarIndex + i);
|
||||
}
|
||||
|
||||
SmallVector<int64_t> xKerPos;
|
||||
SmallVector<int64_t> yKerPos;
|
||||
@@ -323,16 +301,14 @@ struct ExperimentalONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
|
||||
ArrayAttr xKerPosAttr = rewriter.getI64ArrayAttr(xKerPos);
|
||||
ArrayAttr yKerPosAttr = rewriter.getI64ArrayAttr(yKerPos);
|
||||
|
||||
Value result =
|
||||
rewriter.create<spatial::SpatApplyFiltersOp>(conv.getLoc(), outputType,
|
||||
weightIndicesAttr, xKerPosAttr, yKerPosAttr, blockArgument);
|
||||
Value result = rewriter.create<spatial::SpatApplyFiltersOp>(
|
||||
conv.getLoc(), outputType, weightIndicesAttr, xKerPosAttr, yKerPosAttr, blockArgument);
|
||||
|
||||
// Perform local reduction if necessary.
|
||||
if (localPartialResults.find(group.outputTile) !=
|
||||
localPartialResults.end()) {
|
||||
if (localPartialResults.find(group.outputTile) != localPartialResults.end()) {
|
||||
|
||||
result = rewriter.create<spatial::SpatVAddOp>(conv.getLoc(),
|
||||
result.getType(), localPartialResults[group.outputTile], result);
|
||||
result = rewriter.create<spatial::SpatVAddOp>(
|
||||
conv.getLoc(), result.getType(), localPartialResults[group.outputTile], result);
|
||||
}
|
||||
|
||||
// Update the partial results map.
|
||||
@@ -385,22 +361,18 @@ struct ExperimentalONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
|
||||
|
||||
// Turn the values into a SmallVector.
|
||||
SmallVector<Value> outputValues;
|
||||
for (long i = 0; i < outputTileCount.quot + (outputTileCount.rem > 0);
|
||||
++i) {
|
||||
for (long i = 0; i < outputTileCount.quot + (outputTileCount.rem > 0); ++i)
|
||||
outputValues.push_back(globalPartialResults[i]);
|
||||
}
|
||||
|
||||
// Assert that the number of output values is correct.
|
||||
assert(outputValues.size() > 0 &&
|
||||
"No output values were generated for the convolution.");
|
||||
assert(outputValues.size() > 0 && "No output values were generated for the convolution.");
|
||||
|
||||
// If the conv's user is a ReLU...
|
||||
if (conv->hasOneUse()) {
|
||||
Operation *user = *conv->getUsers().begin();
|
||||
Operation* user = *conv->getUsers().begin();
|
||||
if (auto relu = dyn_cast<ONNXReluOp>(user)) {
|
||||
// ...then we can just replace the ReLU with the concatenation.
|
||||
rewriter.replaceOp(relu,
|
||||
rewriter.create<tensor::ConcatOp>(conv.getLoc(), 1, outputValues));
|
||||
rewriter.replaceOp(relu, rewriter.create<tensor::ConcatOp>(conv.getLoc(), 1, outputValues));
|
||||
|
||||
// And erase the convolution.
|
||||
rewriter.eraseOp(conv);
|
||||
@@ -409,8 +381,7 @@ struct ExperimentalONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
|
||||
}
|
||||
|
||||
// Return the final output.
|
||||
rewriter.replaceOp(conv,
|
||||
rewriter.create<tensor::ConcatOp>(conv.getLoc(), 1, outputValues));
|
||||
rewriter.replaceOp(conv, rewriter.create<tensor::ConcatOp>(conv.getLoc(), 1, outputValues));
|
||||
|
||||
return success();
|
||||
}
|
||||
@@ -422,9 +393,8 @@ struct ExperimentalONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
|
||||
* @param patterns The pattern set to populate.
|
||||
* @param ctx The MLIR context.
|
||||
*/
|
||||
void populateExperimentalTilingConvOpPattern(
|
||||
RewritePatternSet &patterns, MLIRContext *ctx) {
|
||||
void populateExperimentalTilingConvOpPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.insert<ExperimentalONNXConvOpTile>(ctx);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -1,24 +1,25 @@
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "Compiler/PimCompilerOptions.hpp"
|
||||
#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
||||
#include "Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp"
|
||||
#include "Conversion/ONNXToSpatial/Utils/WeightSubdivider.hpp"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
#include <cstdlib>
|
||||
|
||||
using namespace mlir;
|
||||
using namespace std;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
struct ExperimentalGemmConversionPattern
|
||||
: public OpConversionPattern<ONNXGemmOp> {
|
||||
ExperimentalGemmConversionPattern(MLIRContext *ctx)
|
||||
: OpConversionPattern(ctx) {}
|
||||
struct ExperimentalGemmConversionPattern : public OpConversionPattern<ONNXGemmOp> {
|
||||
ExperimentalGemmConversionPattern(MLIRContext* ctx)
|
||||
: OpConversionPattern(ctx) {}
|
||||
|
||||
LogicalResult matchAndRewrite(ONNXGemmOp gemmOp, ONNXGemmOpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
LogicalResult
|
||||
matchAndRewrite(ONNXGemmOp gemmOp, ONNXGemmOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final {
|
||||
|
||||
// --------------------------------- //
|
||||
// --- READ OPERATION PARAMETERS --- //
|
||||
@@ -34,17 +35,14 @@ struct ExperimentalGemmConversionPattern
|
||||
ShapedType matrixType = cast<ShapedType>(adaptor.getB().getType());
|
||||
|
||||
// TODO: Address bigger batches.
|
||||
assert(inputType.getShape()[0] == 1 &&
|
||||
"Only batch size of 1 is supported for GEMM.");
|
||||
assert(inputType.getShape()[0] == 1 && "Only batch size of 1 is supported for GEMM.");
|
||||
|
||||
// TODO: Address replication.
|
||||
assert(coresCount.getValue() == -1 &&
|
||||
"Replication is not yet supported for GEMM.");
|
||||
assert(coresCount.getValue() == -1 && "Replication is not yet supported for GEMM.");
|
||||
|
||||
// TODO: Address bias addition.
|
||||
|
||||
assert(inputType.getShape()[1] == matrixType.getShape()[0] &&
|
||||
"Input tile size must match the matrix's row size.");
|
||||
assert(inputType.getShape()[1] == matrixType.getShape()[0] && "Input tile size must match the matrix's row size.");
|
||||
|
||||
ldiv_t inputTileCount = div(inputType.getShape()[1], crossbarSize);
|
||||
ldiv_t outputTileCount = div(outputType.getShape()[1], crossbarSize);
|
||||
@@ -88,11 +86,10 @@ struct ExperimentalGemmConversionPattern
|
||||
long inputTile = it;
|
||||
|
||||
// Create the slicing sizes.
|
||||
SmallVector<OpFoldResult> slicingSizes{
|
||||
/* 0 */ rewriter.getIndexAttr(crossbarHeight),
|
||||
/* 1 */ rewriter.getIndexAttr(crossbarWidth),
|
||||
/* 2 */ /* rewriter.getIndexAttr(1), */
|
||||
/* 3 */ /* rewriter.getIndexAttr(1) */};
|
||||
SmallVector<OpFoldResult> slicingSizes {/* 0 */ rewriter.getIndexAttr(crossbarHeight),
|
||||
/* 1 */ rewriter.getIndexAttr(crossbarWidth),
|
||||
/* 2 */ /* rewriter.getIndexAttr(1), */
|
||||
/* 3 */ /* rewriter.getIndexAttr(1) */};
|
||||
|
||||
// - Slicing along the filter x position.
|
||||
// - Slicing along the filter y position.
|
||||
@@ -100,16 +97,14 @@ struct ExperimentalGemmConversionPattern
|
||||
for (size_t filterY = 0; filterY < kernelHeight; ++filterY) {
|
||||
|
||||
// Create the slicing offsets.
|
||||
SmallVector<OpFoldResult> slicingOffsets{
|
||||
/* 0 */ rewriter.getIndexAttr(outputTile * crossbarSize),
|
||||
/* 1 */ rewriter.getIndexAttr(inputTile * crossbarSize),
|
||||
/* 2 */ /* rewriter.getIndexAttr(filterX), */
|
||||
/* 3 */ /* rewriter.getIndexAttr(filterY) */};
|
||||
SmallVector<OpFoldResult> slicingOffsets {/* 0 */ rewriter.getIndexAttr(outputTile * crossbarSize),
|
||||
/* 1 */ rewriter.getIndexAttr(inputTile * crossbarSize),
|
||||
/* 2 */ /* rewriter.getIndexAttr(filterX), */
|
||||
/* 3 */ /* rewriter.getIndexAttr(filterY) */};
|
||||
|
||||
// Create the slice extraction operation.
|
||||
auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
|
||||
gemmOp.getLoc(), adaptor.getB(), slicingOffsets, slicingSizes,
|
||||
slicingStrides);
|
||||
gemmOp.getLoc(), adaptor.getB(), slicingOffsets, slicingSizes, slicingStrides);
|
||||
|
||||
// Add a note to the extractSliceOp, with the filterX and filterY.
|
||||
weightsGroups[inputTile][outputTile].push_back(extractSliceOp);
|
||||
@@ -141,8 +136,7 @@ struct ExperimentalGemmConversionPattern
|
||||
// -------------------------------- //
|
||||
|
||||
// Get the next group of weights for the compute unit.
|
||||
SmallVector<TaggedWeights> weightsGroups =
|
||||
weightSubdivider.popGroups(crossbarCountInCore.getValue());
|
||||
SmallVector<TaggedWeights> weightsGroups = weightSubdivider.popGroups(crossbarCountInCore.getValue());
|
||||
|
||||
SmallVector<Value> computeWeights;
|
||||
SmallVector<Value> computeOperands;
|
||||
@@ -159,15 +153,13 @@ struct ExperimentalGemmConversionPattern
|
||||
// Iterate over all weights groups for this compute unit.
|
||||
map<long, Value> localSlices; // WRT the current compute unit.
|
||||
for (auto group : weightsGroups) {
|
||||
for (Value weight : group.weights) {
|
||||
for (Value weight : group.weights)
|
||||
computeWeights.push_back(weight);
|
||||
}
|
||||
|
||||
// There might be multiple weight groups for the same input tile, so if
|
||||
// we've already added the input tile, skip it.
|
||||
if (localSlices.find(group.inputTile) != localSlices.end()) {
|
||||
if (localSlices.find(group.inputTile) != localSlices.end())
|
||||
continue;
|
||||
}
|
||||
|
||||
// We might have already sliced the input tensor for some other compute
|
||||
// unit, so if we have, reuse the slicing operation without creating a
|
||||
@@ -179,26 +171,21 @@ struct ExperimentalGemmConversionPattern
|
||||
}
|
||||
|
||||
// Create the input tensor slicing offsets.
|
||||
SmallVector<OpFoldResult> slicingOffsets{
|
||||
/* 0 */ rewriter.getIndexAttr(0), // No offset along the batch axis.
|
||||
/* 1 */ rewriter.getIndexAttr(group.inputTile * crossbarSize),
|
||||
/* 2 */ /* rewriter.getIndexAttr(0), */
|
||||
/* 3 */ /* rewriter.getIndexAttr(0) */};
|
||||
SmallVector<OpFoldResult> slicingOffsets {/* 0 */ rewriter.getIndexAttr(0), // No offset along the batch axis.
|
||||
/* 1 */ rewriter.getIndexAttr(group.inputTile * crossbarSize),
|
||||
/* 2 */ /* rewriter.getIndexAttr(0), */
|
||||
/* 3 */ /* rewriter.getIndexAttr(0) */};
|
||||
|
||||
// Create the input tensor slicing sizes.
|
||||
size_t tilingSize = group.inputTile == inputTileCount.quot
|
||||
? inputTileCount.rem
|
||||
: crossbarSize;
|
||||
SmallVector<OpFoldResult> slicingSizes{
|
||||
/* 0 */ rewriter.getIndexAttr(1), // Batch size is always 1.
|
||||
/* 1 */ rewriter.getIndexAttr(tilingSize),
|
||||
/* 2 */ /* rewriter.getIndexAttr(GET_IMAGE_WIDTH(inputType)), */
|
||||
/* 3 */ /* rewriter.getIndexAttr(GET_IMAGE_HEIGHT(inputType)) */};
|
||||
size_t tilingSize = group.inputTile == inputTileCount.quot ? inputTileCount.rem : crossbarSize;
|
||||
SmallVector<OpFoldResult> slicingSizes {/* 0 */ rewriter.getIndexAttr(1), // Batch size is always 1.
|
||||
/* 1 */ rewriter.getIndexAttr(tilingSize),
|
||||
/* 2 */ /* rewriter.getIndexAttr(GET_IMAGE_WIDTH(inputType)), */
|
||||
/* 3 */ /* rewriter.getIndexAttr(GET_IMAGE_HEIGHT(inputType)) */};
|
||||
|
||||
// Create the slice extraction operation.
|
||||
auto extractSliceOp =
|
||||
rewriter.create<tensor::ExtractSliceOp>(gemmOp.getLoc(),
|
||||
adaptor.getA(), slicingOffsets, slicingSizes, slicingStrides);
|
||||
auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
|
||||
gemmOp.getLoc(), adaptor.getA(), slicingOffsets, slicingSizes, slicingStrides);
|
||||
|
||||
computeOperands.push_back(extractSliceOp);
|
||||
|
||||
@@ -220,36 +207,28 @@ struct ExperimentalGemmConversionPattern
|
||||
|
||||
// There might be multiple weight groups for the same output tile, so if
|
||||
// we've already added the output tile, skip it.
|
||||
if (outputTileIndices.find(group.outputTile) !=
|
||||
outputTileIndices.end()) {
|
||||
if (outputTileIndices.find(group.outputTile) != outputTileIndices.end())
|
||||
continue;
|
||||
}
|
||||
|
||||
// Additionally, after adding the input slices as operands, also add any
|
||||
// compatible partial results from previous compute units.
|
||||
if (globalPartialResults.find(group.outputTile) !=
|
||||
globalPartialResults.end()) {
|
||||
if (globalPartialResults.find(group.outputTile) != globalPartialResults.end()) {
|
||||
computeOperands.push_back(globalPartialResults[group.outputTile]);
|
||||
reductionTileIndices[group.outputTile] = computeOperands.size() - 1;
|
||||
}
|
||||
|
||||
// Define the output shape for this group.
|
||||
long outputTileSize = group.outputTile == outputTileCount.quot
|
||||
? outputTileCount.rem
|
||||
: crossbarSize;
|
||||
long outputTileSize = group.outputTile == outputTileCount.quot ? outputTileCount.rem : crossbarSize;
|
||||
|
||||
// TODO: Address non-same padding.
|
||||
SmallVector<int64_t> outputShapeArray{
|
||||
/* 0 */ 1, // Batch size is always 1.
|
||||
/* 1 */ outputTileSize,
|
||||
/* 2 */ /* GET_IMAGE_WIDTH(outputType), */ // Same padding assumed.
|
||||
/* 3 */ /* GET_IMAGE_HEIGHT(outputType) */};
|
||||
SmallVector<int64_t> outputShapeArray {/* 0 */ 1, // Batch size is always 1.
|
||||
/* 1 */ outputTileSize,
|
||||
/* 2 */ /* GET_IMAGE_WIDTH(outputType), */ // Same padding assumed.
|
||||
/* 3 */ /* GET_IMAGE_HEIGHT(outputType) */};
|
||||
|
||||
auto elementType = dyn_cast<RankedTensorType>(gemmOp.getY().getType())
|
||||
.getElementType();
|
||||
auto elementType = dyn_cast<RankedTensorType>(gemmOp.getY().getType()).getElementType();
|
||||
|
||||
computeOutputType.push_back(
|
||||
RankedTensorType::get(outputShapeArray, elementType));
|
||||
computeOutputType.push_back(RankedTensorType::get(outputShapeArray, elementType));
|
||||
|
||||
outputTileIndices[group.outputTile] = computeOutputType.size() - 1;
|
||||
}
|
||||
@@ -259,43 +238,36 @@ struct ExperimentalGemmConversionPattern
|
||||
// ----------------------------- //
|
||||
|
||||
// Create the compute unit.
|
||||
spatial::SpatWeightedCompute currentCompute =
|
||||
rewriter.create<spatial::SpatWeightedCompute>(gemmOp.getLoc(),
|
||||
computeOutputType, computeWeights, computeOperands);
|
||||
spatial::SpatWeightedCompute currentCompute = rewriter.create<spatial::SpatWeightedCompute>(
|
||||
gemmOp.getLoc(), computeOutputType, computeWeights, computeOperands);
|
||||
|
||||
// Create a new block for the compute unit and add the operands.
|
||||
Block *block = rewriter.createBlock(¤tCompute.getRegion());
|
||||
Block* block = rewriter.createBlock(¤tCompute.getRegion());
|
||||
rewriter.setInsertionPointToStart(block);
|
||||
for (Value operand : computeOperands) {
|
||||
for (Value operand : computeOperands)
|
||||
block->addArgument(operand.getType(), gemmOp->getLoc());
|
||||
}
|
||||
|
||||
// Initialize a map of local partial results.
|
||||
map<long, Value> localPartialResults; // WRT the current compute unit.
|
||||
|
||||
// If we have any reduction tiles, add them to the local partial results.
|
||||
for (auto reductionTileIndex : reductionTileIndices) {
|
||||
localPartialResults[reductionTileIndex.first] =
|
||||
block->getArgument(reductionTileIndex.second);
|
||||
}
|
||||
for (auto reductionTileIndex : reductionTileIndices)
|
||||
localPartialResults[reductionTileIndex.first] = block->getArgument(reductionTileIndex.second);
|
||||
|
||||
// Add all the applyFilters operations to the block.
|
||||
for (TaggedWeights group : weightsGroups) {
|
||||
|
||||
// Get the outputType for this group.
|
||||
Type outputType =
|
||||
computeOutputType[outputTileIndices[group.outputTile]];
|
||||
Type outputType = computeOutputType[outputTileIndices[group.outputTile]];
|
||||
|
||||
// Create an apply filters operation.
|
||||
BlockArgument blockArgument =
|
||||
block->getArgument(inputTileIndices[group.inputTile]);
|
||||
BlockArgument blockArgument = block->getArgument(inputTileIndices[group.inputTile]);
|
||||
|
||||
// The list of weight indices is group.startingCrossbarIndex + 0, 1, 2,
|
||||
// ... As many weights as the size of group.weights.
|
||||
SmallVector<long> weightIndices;
|
||||
for (size_t i = 0; i < group.weights.size(); ++i) {
|
||||
for (size_t i = 0; i < group.weights.size(); ++i)
|
||||
weightIndices.push_back(group.startingCrossbarIndex + i);
|
||||
}
|
||||
|
||||
SmallVector<int64_t> xKerPos;
|
||||
SmallVector<int64_t> yKerPos;
|
||||
@@ -313,16 +285,14 @@ struct ExperimentalGemmConversionPattern
|
||||
ArrayAttr xKerPosAttr = rewriter.getI64ArrayAttr(xKerPos);
|
||||
ArrayAttr yKerPosAttr = rewriter.getI64ArrayAttr(yKerPos);
|
||||
|
||||
Value result = rewriter.create<spatial::SpatApplyFiltersOp>(gemmOp.getLoc(),
|
||||
outputType, weightIndicesAttr, xKerPosAttr, yKerPosAttr,
|
||||
blockArgument);
|
||||
Value result = rewriter.create<spatial::SpatApplyFiltersOp>(
|
||||
gemmOp.getLoc(), outputType, weightIndicesAttr, xKerPosAttr, yKerPosAttr, blockArgument);
|
||||
|
||||
// Perform local reduction if necessary.
|
||||
if (localPartialResults.find(group.outputTile) !=
|
||||
localPartialResults.end()) {
|
||||
if (localPartialResults.find(group.outputTile) != localPartialResults.end()) {
|
||||
|
||||
result = rewriter.create<spatial::SpatVAddOp>(gemmOp.getLoc(),
|
||||
result.getType(), localPartialResults[group.outputTile], result);
|
||||
result = rewriter.create<spatial::SpatVAddOp>(
|
||||
gemmOp.getLoc(), result.getType(), localPartialResults[group.outputTile], result);
|
||||
}
|
||||
|
||||
// Update the partial results map.
|
||||
@@ -375,26 +345,21 @@ struct ExperimentalGemmConversionPattern
|
||||
|
||||
// Turn the values into a SmallVector.
|
||||
SmallVector<Value> outputValues;
|
||||
for (long i = 0; i < outputTileCount.quot + (outputTileCount.rem > 0);
|
||||
++i) {
|
||||
for (long i = 0; i < outputTileCount.quot + (outputTileCount.rem > 0); ++i)
|
||||
outputValues.push_back(globalPartialResults[i]);
|
||||
}
|
||||
|
||||
// Assert that the number of output values is correct.
|
||||
assert(outputValues.size() > 0 &&
|
||||
"No output values were generated for the GEMM operation.");
|
||||
assert(outputValues.size() > 0 && "No output values were generated for the GEMM operation.");
|
||||
|
||||
// Return the final output.
|
||||
rewriter.replaceOp(gemmOp,
|
||||
rewriter.create<tensor::ConcatOp>(gemmOp.getLoc(), 1, outputValues));
|
||||
rewriter.replaceOp(gemmOp, rewriter.create<tensor::ConcatOp>(gemmOp.getLoc(), 1, outputValues));
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populateGemmToConvConversionPattern(
|
||||
RewritePatternSet &patterns, MLIRContext *ctx) {
|
||||
void populateGemmToConvConversionPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.insert<ExperimentalGemmConversionPattern>(ctx);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -6,20 +6,22 @@
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/IR/ValueRange.h"
|
||||
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
#include <cstddef>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
@@ -35,71 +37,68 @@ bool hasPostProcessExperimentalPoolingWindow<ONNXAveragePoolOp>() {
|
||||
}
|
||||
|
||||
template <typename PoolOp>
|
||||
Value postProcessExperimentalPoolingWindow(ConversionPatternRewriter &rewriter,
|
||||
Location loc, PoolOp poolOp, Value valueToDivide, size_t krn_size,
|
||||
size_t tilesSkippedByPadding) {
|
||||
Value postProcessExperimentalPoolingWindow(ConversionPatternRewriter& rewriter,
|
||||
Location loc,
|
||||
PoolOp poolOp,
|
||||
Value valueToDivide,
|
||||
size_t krn_size,
|
||||
size_t tilesSkippedByPadding) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <>
|
||||
Value postProcessExperimentalPoolingWindow<ONNXAveragePoolOp>(
|
||||
ConversionPatternRewriter &rewriter, Location loc, ONNXAveragePoolOp poolOp,
|
||||
Value valueToDivide, size_t krn_size, size_t tilesSkippedByPadding) {
|
||||
Value postProcessExperimentalPoolingWindow<ONNXAveragePoolOp>(ConversionPatternRewriter& rewriter,
|
||||
Location loc,
|
||||
ONNXAveragePoolOp poolOp,
|
||||
Value valueToDivide,
|
||||
size_t krn_size,
|
||||
size_t tilesSkippedByPadding) {
|
||||
bool countIncludePad = poolOp.getCountIncludePad() == 1;
|
||||
|
||||
size_t divisorNumber =
|
||||
countIncludePad ? krn_size : krn_size - tilesSkippedByPadding;
|
||||
size_t divisorNumber = countIncludePad ? krn_size : krn_size - tilesSkippedByPadding;
|
||||
|
||||
RankedTensorType scalarTensor =
|
||||
RankedTensorType::get({1}, rewriter.getF32Type());
|
||||
RankedTensorType scalarTensor = RankedTensorType::get({1}, rewriter.getF32Type());
|
||||
|
||||
// Put a spat.const before the computeOp, and use its value. We do this to be
|
||||
// compatible with the current code generation, which assumes constant to be
|
||||
// loaded in global memory, which is allocated by adding a spat.const OP
|
||||
// directly under func.func (i.e. alongside ComputeOps)
|
||||
auto computeOp = cast<spatial::SpatWeightedCompute>(
|
||||
valueToDivide.getDefiningOp()->getParentOp());
|
||||
auto computeOp = cast<spatial::SpatWeightedCompute>(valueToDivide.getDefiningOp()->getParentOp());
|
||||
rewriter.setInsertionPoint(computeOp);
|
||||
auto divisorValue = rewriter.create<spatial::SpatConstantOp>(loc, scalarTensor,
|
||||
rewriter.getI64IntegerAttr(divisorNumber),
|
||||
/* should_allocate = */ rewriter.getBoolAttr(true));
|
||||
auto divisorValue = rewriter.create<spatial::SpatConstantOp>(loc,
|
||||
scalarTensor,
|
||||
rewriter.getI64IntegerAttr(divisorNumber),
|
||||
/* should_allocate = */ rewriter.getBoolAttr(true));
|
||||
|
||||
rewriter.setInsertionPointAfterValue(valueToDivide);
|
||||
return rewriter.create<spatial::SpatVSDivOp>(
|
||||
loc, valueToDivide.getType(), valueToDivide, divisorValue);
|
||||
return rewriter.create<spatial::SpatVSDivOp>(loc, valueToDivide.getType(), valueToDivide, divisorValue);
|
||||
}
|
||||
|
||||
template <typename ReductionOp>
|
||||
Value reduceInputTiles(
|
||||
SmallVector<Value> &inputTiles, ConversionPatternRewriter &rewriter) {
|
||||
if (inputTiles.size() == 1) {
|
||||
Value reduceInputTiles(SmallVector<Value>& inputTiles, ConversionPatternRewriter& rewriter) {
|
||||
if (inputTiles.size() == 1)
|
||||
return inputTiles[0];
|
||||
}
|
||||
|
||||
if (inputTiles.size() == 2) {
|
||||
return rewriter.create<spatial::SpatVMaxOp>(inputTiles[0].getLoc(),
|
||||
inputTiles[0].getType(), inputTiles[0], inputTiles[1]);
|
||||
return rewriter.create<spatial::SpatVMaxOp>(
|
||||
inputTiles[0].getLoc(), inputTiles[0].getType(), inputTiles[0], inputTiles[1]);
|
||||
}
|
||||
|
||||
SmallVector<Value> left(
|
||||
inputTiles.begin(), inputTiles.begin() + inputTiles.size() / 2);
|
||||
SmallVector<Value> right(
|
||||
inputTiles.begin() + inputTiles.size() / 2, inputTiles.end());
|
||||
SmallVector<Value> left(inputTiles.begin(), inputTiles.begin() + inputTiles.size() / 2);
|
||||
SmallVector<Value> right(inputTiles.begin() + inputTiles.size() / 2, inputTiles.end());
|
||||
|
||||
Value leftReduced = reduceInputTiles<ReductionOp>(left, rewriter);
|
||||
Value rightReduced = reduceInputTiles<ReductionOp>(right, rewriter);
|
||||
|
||||
return rewriter.create<ReductionOp>(
|
||||
inputTiles[0].getLoc(), leftReduced.getType(), leftReduced, rightReduced);
|
||||
return rewriter.create<ReductionOp>(inputTiles[0].getLoc(), leftReduced.getType(), leftReduced, rightReduced);
|
||||
}
|
||||
|
||||
template <typename PoolOp, typename PoolOpAdaptor, typename ReduceOp>
|
||||
struct ExperimentalPoolingBaseConverter : public OpConversionPattern<PoolOp> {
|
||||
ExperimentalPoolingBaseConverter(MLIRContext *ctx)
|
||||
: OpConversionPattern<PoolOp>(ctx) {}
|
||||
ExperimentalPoolingBaseConverter(MLIRContext* ctx)
|
||||
: OpConversionPattern<PoolOp>(ctx) {}
|
||||
|
||||
LogicalResult matchAndRewrite(PoolOp poolOp, PoolOpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
LogicalResult matchAndRewrite(PoolOp poolOp, PoolOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final {
|
||||
Value X = adaptor.getX();
|
||||
ShapedType xShape = mlir::cast<ShapedType>(X.getType());
|
||||
Value Y = poolOp.getResult();
|
||||
@@ -110,17 +109,13 @@ struct ExperimentalPoolingBaseConverter : public OpConversionPattern<PoolOp> {
|
||||
unpackOptionalPairVector(adaptor.getDilations(), dilation_x, dilation_y);
|
||||
unpackOptionalPairVector(adaptor.getKernelShape(), krn_w, krn_h);
|
||||
|
||||
if (adaptor.getAutoPad() != "NOTSET") {
|
||||
return rewriter.notifyMatchFailure(
|
||||
poolOp, "auto_pad != NOTSET is deprecated.");
|
||||
}
|
||||
if (adaptor.getAutoPad() != "NOTSET")
|
||||
return rewriter.notifyMatchFailure(poolOp, "auto_pad != NOTSET is deprecated.");
|
||||
|
||||
size_t pad_x, pad_y;
|
||||
auto padUnpackError =
|
||||
unpackOptionalPadsVector(adaptor.getPads(), pad_x, pad_y);
|
||||
if (padUnpackError.has_value()) {
|
||||
auto padUnpackError = unpackOptionalPadsVector(adaptor.getPads(), pad_x, pad_y);
|
||||
if (padUnpackError.has_value())
|
||||
return rewriter.notifyMatchFailure(poolOp, padUnpackError.value());
|
||||
}
|
||||
|
||||
Location loc = poolOp.getLoc();
|
||||
|
||||
@@ -133,10 +128,8 @@ struct ExperimentalPoolingBaseConverter : public OpConversionPattern<PoolOp> {
|
||||
|
||||
// Assert that the input is a tensor.ConcatOp.
|
||||
auto concat = X.getDefiningOp<tensor::ConcatOp>();
|
||||
if (!concat) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
poolOp, "Expected input to be a tensor.ConcatOp");
|
||||
}
|
||||
if (!concat)
|
||||
return rewriter.notifyMatchFailure(poolOp, "Expected input to be a tensor.ConcatOp");
|
||||
|
||||
// Create a [channel_tile][x][y] array to store the input tiles.
|
||||
std::map<long, std::map<long, std::map<long, Value>>> inputTiles;
|
||||
@@ -145,24 +138,21 @@ struct ExperimentalPoolingBaseConverter : public OpConversionPattern<PoolOp> {
|
||||
for (size_t y = 0; y < input_h; ++y) {
|
||||
for (size_t x = 0; x < input_w; ++x) {
|
||||
for (long it = 0; it < tileCount.quot + (tileCount.rem > 0); ++it) {
|
||||
size_t tilingSize =
|
||||
it == tileCount.quot ? tileCount.rem : crossbarSize;
|
||||
size_t tilingSize = it == tileCount.quot ? tileCount.rem : crossbarSize;
|
||||
|
||||
SmallVector<OpFoldResult> strides(4, rewriter.getIndexAttr(1));
|
||||
SmallVector<OpFoldResult> offsets = {/* 0 */ rewriter.getIndexAttr(0),
|
||||
/* 1 */ rewriter.getIndexAttr(0),
|
||||
/* 2 */ rewriter.getIndexAttr(x),
|
||||
/* 3 */ rewriter.getIndexAttr(y)};
|
||||
SmallVector<OpFoldResult> sizes = {
|
||||
/* 0 */ rewriter.getIndexAttr(1), // Batch size is always 1.
|
||||
/* 1 */ rewriter.getIndexAttr(tilingSize),
|
||||
/* 2 */ rewriter.getIndexAttr(1),
|
||||
/* 3 */ rewriter.getIndexAttr(1)};
|
||||
/* 1 */ rewriter.getIndexAttr(0),
|
||||
/* 2 */ rewriter.getIndexAttr(x),
|
||||
/* 3 */ rewriter.getIndexAttr(y)};
|
||||
SmallVector<OpFoldResult> sizes = {/* 0 */ rewriter.getIndexAttr(1), // Batch size is always 1.
|
||||
/* 1 */ rewriter.getIndexAttr(tilingSize),
|
||||
/* 2 */ rewriter.getIndexAttr(1),
|
||||
/* 3 */ rewriter.getIndexAttr(1)};
|
||||
|
||||
// Get the concat's operand that we want to slice.
|
||||
Value concatInput = concat.getOperand(it);
|
||||
Value slicedTile = rewriter.create<tensor::ExtractSliceOp>(
|
||||
loc, concatInput, offsets, sizes, strides);
|
||||
Value slicedTile = rewriter.create<tensor::ExtractSliceOp>(loc, concatInput, offsets, sizes, strides);
|
||||
|
||||
inputTiles[it][x][y] = slicedTile;
|
||||
}
|
||||
@@ -175,19 +165,15 @@ struct ExperimentalPoolingBaseConverter : public OpConversionPattern<PoolOp> {
|
||||
for (size_t y = 0; y < output_h; ++y) {
|
||||
for (size_t x = 0; x < output_w; ++x) {
|
||||
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) {
|
||||
SmallVector<int64_t> outputShapeArray{
|
||||
/* 0 */ 1, // Batch size is always 1.
|
||||
/* 1 */
|
||||
cast<RankedTensorType>(inputTiles[it][0][0].getType())
|
||||
.getShape()[1],
|
||||
/* 2 */ 1,
|
||||
/* 3 */ 1};
|
||||
SmallVector<int64_t> outputShapeArray {/* 0 */ 1, // Batch size is always 1.
|
||||
/* 1 */
|
||||
cast<RankedTensorType>(inputTiles[it][0][0].getType()).getShape()[1],
|
||||
/* 2 */ 1,
|
||||
/* 3 */ 1};
|
||||
|
||||
auto elementType =
|
||||
dyn_cast<RankedTensorType>(xShape).getElementType();
|
||||
auto elementType = dyn_cast<RankedTensorType>(xShape).getElementType();
|
||||
|
||||
outputTileTypes.push_back(
|
||||
RankedTensorType::get(outputShapeArray, elementType));
|
||||
outputTileTypes.push_back(RankedTensorType::get(outputShapeArray, elementType));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -195,29 +181,25 @@ struct ExperimentalPoolingBaseConverter : public OpConversionPattern<PoolOp> {
|
||||
// Create a plain value list of the input tiles.
|
||||
SmallVector<Value> inputTilesList;
|
||||
for (size_t y = 0; y < input_h; ++y) {
|
||||
for (size_t x = 0; x < input_w; ++x) {
|
||||
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) {
|
||||
for (size_t x = 0; x < input_w; ++x)
|
||||
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it)
|
||||
inputTilesList.push_back(inputTiles[it][y][x]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create a single compute to calculate the output.
|
||||
auto computeOp = rewriter.create<spatial::SpatWeightedCompute>(
|
||||
loc, outputTileTypes, SmallVector<Value>(), inputTilesList);
|
||||
auto computeOp =
|
||||
rewriter.create<spatial::SpatWeightedCompute>(loc, outputTileTypes, SmallVector<Value>(), inputTilesList);
|
||||
|
||||
// Create a new block for the compute unit and add the operands.
|
||||
Block *block = rewriter.createBlock(&computeOp.getRegion());
|
||||
Block* block = rewriter.createBlock(&computeOp.getRegion());
|
||||
|
||||
// Fill the block arguments and keep a reference to them.
|
||||
std::map<size_t, std::map<size_t, std::map<size_t, Value>>> inputTilesArgs;
|
||||
for (size_t y = 0; y < input_h; ++y) {
|
||||
for (size_t x = 0; x < input_w; ++x) {
|
||||
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) {
|
||||
auto tileIndex = y * input_w * (itc.quot + (itc.rem > 0)) +
|
||||
x * (itc.quot + (itc.rem > 0)) + it;
|
||||
inputTilesArgs[it][y][x] = block->addArgument(
|
||||
computeOp->getOperand(tileIndex).getType(), loc);
|
||||
auto tileIndex = y * input_w * (itc.quot + (itc.rem > 0)) + x * (itc.quot + (itc.rem > 0)) + it;
|
||||
inputTilesArgs[it][y][x] = block->addArgument(computeOp->getOperand(tileIndex).getType(), loc);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -236,28 +218,26 @@ struct ExperimentalPoolingBaseConverter : public OpConversionPattern<PoolOp> {
|
||||
size_t end_y = std::min(start_y + krn_h, input_h);
|
||||
|
||||
SmallVector<Value> inputTilesToReduce;
|
||||
for (size_t ky = start_y; ky < end_y; ++ky) {
|
||||
for (size_t kx = start_x; kx < end_x; ++kx) {
|
||||
for (size_t ky = start_y; ky < end_y; ++ky)
|
||||
for (size_t kx = start_x; kx < end_x; ++kx)
|
||||
inputTilesToReduce.push_back(inputTilesArgs[it][ky][kx]);
|
||||
}
|
||||
}
|
||||
|
||||
auto reduceResult =
|
||||
reduceInputTiles<ReduceOp>(inputTilesToReduce, rewriter);
|
||||
auto reduceResult = reduceInputTiles<ReduceOp>(inputTilesToReduce, rewriter);
|
||||
|
||||
// If the reduce op is add, we need to divide the result by the
|
||||
// number of elements in the pooling window.
|
||||
if (hasPostProcessExperimentalPoolingWindow<PoolOp>()) {
|
||||
// Add a spat.const before the computeOp.
|
||||
rewriter.setInsertionPoint(computeOp);
|
||||
auto divisorValue = rewriter.create<spatial::SpatConstantOp>(loc,
|
||||
RankedTensorType::get({1}, rewriter.getF32Type()),
|
||||
rewriter.getI64IntegerAttr(krn_w * krn_h),
|
||||
rewriter.getBoolAttr(true));
|
||||
auto divisorValue =
|
||||
rewriter.create<spatial::SpatConstantOp>(loc,
|
||||
RankedTensorType::get({1}, rewriter.getF32Type()),
|
||||
rewriter.getI64IntegerAttr(krn_w * krn_h),
|
||||
rewriter.getBoolAttr(true));
|
||||
|
||||
rewriter.setInsertionPointAfter(reduceResult.getDefiningOp());
|
||||
reduceResult = rewriter.create<spatial::SpatVSDivOp>(
|
||||
loc, reduceResult.getType(), reduceResult, divisorValue);
|
||||
reduceResult =
|
||||
rewriter.create<spatial::SpatVSDivOp>(loc, reduceResult.getType(), reduceResult, divisorValue);
|
||||
}
|
||||
outputTiles.push_back(reduceResult);
|
||||
}
|
||||
@@ -274,8 +254,7 @@ struct ExperimentalPoolingBaseConverter : public OpConversionPattern<PoolOp> {
|
||||
for (size_t y = 0; y < output_h; ++y) {
|
||||
for (size_t x = 0; x < output_w; ++x) {
|
||||
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) {
|
||||
auto tileIndex = y * output_w * (itc.quot + (itc.rem > 0)) +
|
||||
x * (itc.quot + (itc.rem > 0)) + it;
|
||||
auto tileIndex = y * output_w * (itc.quot + (itc.rem > 0)) + x * (itc.quot + (itc.rem > 0)) + it;
|
||||
computeOutput[it][y][x] = computeOp.getResult(tileIndex);
|
||||
}
|
||||
}
|
||||
@@ -285,30 +264,25 @@ struct ExperimentalPoolingBaseConverter : public OpConversionPattern<PoolOp> {
|
||||
SmallVector<Value> outputTilesList;
|
||||
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) {
|
||||
SmallVector<Value> imgConcatTiles;
|
||||
for (size_t y = 0; y < output_h; ++y) {
|
||||
for (size_t x = 0; x < output_w; ++x) {
|
||||
for (size_t y = 0; y < output_h; ++y)
|
||||
for (size_t x = 0; x < output_w; ++x)
|
||||
imgConcatTiles.push_back(computeOutput[it][y][x]);
|
||||
}
|
||||
}
|
||||
|
||||
size_t tilingSize = it == tileCount.quot ? tileCount.rem : crossbarSize;
|
||||
|
||||
SmallVector<int64_t> outputShapeArray{
|
||||
/* 0 */ 1, // Batch size is always 1.
|
||||
/* 1 */ (long)tilingSize,
|
||||
/* 2 */ (long)output_w,
|
||||
/* 3 */ (long)output_h};
|
||||
SmallVector<int64_t> outputShapeArray {/* 0 */ 1, // Batch size is always 1.
|
||||
/* 1 */ (long) tilingSize,
|
||||
/* 2 */ (long) output_w,
|
||||
/* 3 */ (long) output_h};
|
||||
|
||||
auto elementType = dyn_cast<RankedTensorType>(xShape).getElementType();
|
||||
|
||||
outputTilesList.push_back(rewriter.create<spatial::SpatImgConcatOp>(loc,
|
||||
RankedTensorType::get(outputShapeArray, elementType),
|
||||
imgConcatTiles));
|
||||
outputTilesList.push_back(rewriter.create<spatial::SpatImgConcatOp>(
|
||||
loc, RankedTensorType::get(outputShapeArray, elementType), imgConcatTiles));
|
||||
}
|
||||
|
||||
// Create a new tensor.ConcatOp to concatenate the output tiles.
|
||||
Value outputTensor =
|
||||
rewriter.create<tensor::ConcatOp>(loc, 1, outputTilesList);
|
||||
Value outputTensor = rewriter.create<tensor::ConcatOp>(loc, 1, outputTilesList);
|
||||
|
||||
rewriter.replaceOp(poolOp, outputTensor);
|
||||
|
||||
@@ -316,12 +290,11 @@ struct ExperimentalPoolingBaseConverter : public OpConversionPattern<PoolOp> {
|
||||
}
|
||||
};
|
||||
|
||||
void populateExperimentalPoolingTilingPattern(
|
||||
RewritePatternSet &patterns, MLIRContext *ctx) {
|
||||
patterns.insert<ExperimentalPoolingBaseConverter<ONNXMaxPoolSingleOutOp,
|
||||
ONNXMaxPoolSingleOutOpAdaptor, spatial::SpatVMaxOp>>(ctx);
|
||||
patterns.insert<ExperimentalPoolingBaseConverter<ONNXAveragePoolOp,
|
||||
ONNXAveragePoolOpAdaptor, spatial::SpatVAddOp>>(ctx);
|
||||
void populateExperimentalPoolingTilingPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.insert<
|
||||
ExperimentalPoolingBaseConverter<ONNXMaxPoolSingleOutOp, ONNXMaxPoolSingleOutOpAdaptor, spatial::SpatVMaxOp>>(ctx);
|
||||
patterns.insert<ExperimentalPoolingBaseConverter<ONNXAveragePoolOp, ONNXAveragePoolOpAdaptor, spatial::SpatVAddOp>>(
|
||||
ctx);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -6,38 +6,39 @@
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/IR/ValueRange.h"
|
||||
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
#include <cstddef>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
llvm::SmallPtrSet<Operation *, 16> oldComputeOpsReplaced;
|
||||
llvm::SmallPtrSet<Operation*, 16> oldComputeOpsReplaced;
|
||||
|
||||
Value applyReducePatternNew(SmallVector<Value> &valuesToReduce,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
std::function<Value(const Value &, const Value &)> reduce,
|
||||
std::function<Value(const Value &)> preprocess,
|
||||
std::function<Value(const Value &)> postprocess) {
|
||||
Value applyReducePatternNew(SmallVector<Value>& valuesToReduce,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
std::function<Value(const Value&, const Value&)> reduce,
|
||||
std::function<Value(const Value&)> preprocess,
|
||||
std::function<Value(const Value&)> postprocess) {
|
||||
// Simple case: if we have only one input, just return it
|
||||
if (valuesToReduce.size() == 1) {
|
||||
if (valuesToReduce.size() == 1)
|
||||
return valuesToReduce[0];
|
||||
}
|
||||
|
||||
if (preprocess) {
|
||||
for (auto &valToReduce : valuesToReduce) {
|
||||
for (auto& valToReduce : valuesToReduce) {
|
||||
rewriter.setInsertionPointAfterValue(valToReduce);
|
||||
valToReduce = preprocess(valToReduce);
|
||||
}
|
||||
@@ -47,9 +48,9 @@ Value applyReducePatternNew(SmallVector<Value> &valuesToReduce,
|
||||
// computeOp. In this case, we need to apply the reduction within-computef
|
||||
|
||||
// Keep a map between a computeOp and the last Value for this reduction
|
||||
std::unordered_map<Operation *, Value> lastValueForCompute;
|
||||
for (auto &valToReduce : valuesToReduce) {
|
||||
Operation *computeOp = valToReduce.getParentBlock()->getParentOp();
|
||||
std::unordered_map<Operation*, Value> lastValueForCompute;
|
||||
for (auto& valToReduce : valuesToReduce) {
|
||||
Operation* computeOp = valToReduce.getParentBlock()->getParentOp();
|
||||
// if (valToReduce.getDefiningOp()) {
|
||||
// // If the value is defined by an operation, we take the parent
|
||||
// operation computeOp = valToReduce.getDefiningOp()->getParentOp();
|
||||
@@ -67,12 +68,10 @@ Value applyReducePatternNew(SmallVector<Value> &valuesToReduce,
|
||||
// within-compute
|
||||
Value lastWithinComputeValue = it->second;
|
||||
|
||||
if (valToReduce.getDefiningOp()->isBeforeInBlock(
|
||||
lastWithinComputeValue.getDefiningOp())) {
|
||||
if (valToReduce.getDefiningOp()->isBeforeInBlock(lastWithinComputeValue.getDefiningOp()))
|
||||
rewriter.setInsertionPointAfterValue(lastWithinComputeValue);
|
||||
} else {
|
||||
else
|
||||
rewriter.setInsertionPointAfterValue(valToReduce);
|
||||
}
|
||||
valToReduce = reduce(lastWithinComputeValue, valToReduce);
|
||||
lastValueForCompute[computeOp] = valToReduce;
|
||||
}
|
||||
@@ -83,9 +82,8 @@ Value applyReducePatternNew(SmallVector<Value> &valuesToReduce,
|
||||
// Now, reconstruct from the map the valuesToReduce list
|
||||
valuesToReduce.clear();
|
||||
valuesToReduce.reserve(lastValueForCompute.size());
|
||||
for (auto &entry : lastValueForCompute) {
|
||||
for (auto& entry : lastValueForCompute)
|
||||
valuesToReduce.push_back(entry.second);
|
||||
}
|
||||
|
||||
Location loc = valuesToReduce[0].getLoc();
|
||||
auto channelType = spatial::SpatChannelType::get(rewriter.getContext());
|
||||
@@ -123,8 +121,7 @@ Value applyReducePatternNew(SmallVector<Value> &valuesToReduce,
|
||||
|
||||
// 3. Add a receiveOp after the second value
|
||||
rewriter.setInsertionPointAfterValue(secondValue);
|
||||
auto receivedValue = rewriter.create<spatial::SpatChannelReceiveOp>(
|
||||
loc, secondValue.getType(), channel);
|
||||
auto receivedValue = rewriter.create<spatial::SpatChannelReceiveOp>(loc, secondValue.getType(), channel);
|
||||
|
||||
// 4. Apply reduction between second value and received value
|
||||
rewriter.setInsertionPointAfterValue(receivedValue);
|
||||
@@ -135,17 +132,14 @@ Value applyReducePatternNew(SmallVector<Value> &valuesToReduce,
|
||||
|
||||
// If we have an odd number of inputs, we need to add the last one to the
|
||||
// newInputs list.
|
||||
if (valuesToReduceRef.size() % 2 == 1) {
|
||||
if (valuesToReduceRef.size() % 2 == 1)
|
||||
nextValuesToReduce.push_back(valuesToReduceRef.back());
|
||||
}
|
||||
|
||||
// Replace the inputOps list with the new one.
|
||||
valuesToReduceRef =
|
||||
llvm::OwningArrayRef<Value>(std::move(nextValuesToReduce));
|
||||
valuesToReduceRef = llvm::OwningArrayRef<Value>(std::move(nextValuesToReduce));
|
||||
}
|
||||
|
||||
assert(valuesToReduceRef.size() == 1 &&
|
||||
"Internal error: expected a single input at this point.");
|
||||
assert(valuesToReduceRef.size() == 1 && "Internal error: expected a single input at this point.");
|
||||
|
||||
auto finalValue = valuesToReduceRef[0];
|
||||
|
||||
@@ -168,46 +162,49 @@ bool hasPostProcessPoolingWindow<ONNXAveragePoolOp>() {
|
||||
}
|
||||
|
||||
template <typename PoolOp>
|
||||
Value postProcessPoolingWindow(ConversionPatternRewriter &rewriter,
|
||||
Location loc, PoolOp poolOp, Value valueToDivide, size_t krn_size,
|
||||
size_t tilesSkippedByPadding) {
|
||||
Value postProcessPoolingWindow(ConversionPatternRewriter& rewriter,
|
||||
Location loc,
|
||||
PoolOp poolOp,
|
||||
Value valueToDivide,
|
||||
size_t krn_size,
|
||||
size_t tilesSkippedByPadding) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <>
|
||||
Value postProcessPoolingWindow<ONNXAveragePoolOp>(
|
||||
ConversionPatternRewriter &rewriter, Location loc, ONNXAveragePoolOp poolOp,
|
||||
Value valueToDivide, size_t krn_size, size_t tilesSkippedByPadding) {
|
||||
Value postProcessPoolingWindow<ONNXAveragePoolOp>(ConversionPatternRewriter& rewriter,
|
||||
Location loc,
|
||||
ONNXAveragePoolOp poolOp,
|
||||
Value valueToDivide,
|
||||
size_t krn_size,
|
||||
size_t tilesSkippedByPadding) {
|
||||
bool countIncludePad = poolOp.getCountIncludePad() == 1;
|
||||
|
||||
size_t divisorNumber =
|
||||
countIncludePad ? krn_size : krn_size - tilesSkippedByPadding;
|
||||
size_t divisorNumber = countIncludePad ? krn_size : krn_size - tilesSkippedByPadding;
|
||||
|
||||
RankedTensorType scalarTensor =
|
||||
RankedTensorType::get({1}, rewriter.getF32Type());
|
||||
RankedTensorType scalarTensor = RankedTensorType::get({1}, rewriter.getF32Type());
|
||||
|
||||
// Put a spat.const before the computeOp, and use its value. We do this to be
|
||||
// compatible with the current code generation, which assumes constant to be
|
||||
// loaded in global memory, which is allocated by adding a spat.const OP
|
||||
// directly under func.func (i.e. alongside ComputeOps)
|
||||
auto computeOp = cast<spatial::SpatWeightedCompute>(
|
||||
valueToDivide.getDefiningOp()->getParentOp());
|
||||
auto computeOp = cast<spatial::SpatWeightedCompute>(valueToDivide.getDefiningOp()->getParentOp());
|
||||
rewriter.setInsertionPoint(computeOp);
|
||||
auto divisorValue = rewriter.create<spatial::SpatConstantOp>(loc, scalarTensor,
|
||||
rewriter.getI64IntegerAttr(divisorNumber),
|
||||
/* should_allocate = */ rewriter.getBoolAttr(true));
|
||||
auto divisorValue = rewriter.create<spatial::SpatConstantOp>(loc,
|
||||
scalarTensor,
|
||||
rewriter.getI64IntegerAttr(divisorNumber),
|
||||
/* should_allocate = */ rewriter.getBoolAttr(true));
|
||||
|
||||
rewriter.setInsertionPointAfterValue(valueToDivide);
|
||||
return rewriter.create<spatial::SpatVSDivOp>(
|
||||
loc, valueToDivide.getType(), valueToDivide, divisorValue);
|
||||
return rewriter.create<spatial::SpatVSDivOp>(loc, valueToDivide.getType(), valueToDivide, divisorValue);
|
||||
}
|
||||
|
||||
template <typename PoolOp, typename PoolOpAdaptor, typename ReduceOp>
|
||||
struct PoolingBaseConverter : public OpConversionPattern<PoolOp> {
|
||||
PoolingBaseConverter(MLIRContext *ctx) : OpConversionPattern<PoolOp>(ctx) {}
|
||||
PoolingBaseConverter(MLIRContext* ctx)
|
||||
: OpConversionPattern<PoolOp>(ctx) {}
|
||||
|
||||
LogicalResult matchAndRewrite(PoolOp poolOp, PoolOpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
LogicalResult matchAndRewrite(PoolOp poolOp, PoolOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final {
|
||||
Value X = adaptor.getX();
|
||||
ShapedType xShape = mlir::cast<ShapedType>(X.getType());
|
||||
Value Y = poolOp.getResult();
|
||||
@@ -218,17 +215,13 @@ struct PoolingBaseConverter : public OpConversionPattern<PoolOp> {
|
||||
unpackOptionalPairVector(adaptor.getDilations(), dilation_x, dilation_y);
|
||||
unpackOptionalPairVector(adaptor.getKernelShape(), krn_w, krn_h);
|
||||
|
||||
if (adaptor.getAutoPad() != "NOTSET") {
|
||||
return rewriter.notifyMatchFailure(
|
||||
poolOp, "auto_pad != NOTSET is deprecated.");
|
||||
}
|
||||
if (adaptor.getAutoPad() != "NOTSET")
|
||||
return rewriter.notifyMatchFailure(poolOp, "auto_pad != NOTSET is deprecated.");
|
||||
|
||||
size_t pad_x, pad_y;
|
||||
auto padUnpackError =
|
||||
unpackOptionalPadsVector(adaptor.getPads(), pad_x, pad_y);
|
||||
if (padUnpackError.has_value()) {
|
||||
auto padUnpackError = unpackOptionalPadsVector(adaptor.getPads(), pad_x, pad_y);
|
||||
if (padUnpackError.has_value())
|
||||
return rewriter.notifyMatchFailure(poolOp, padUnpackError.value());
|
||||
}
|
||||
|
||||
Location loc = poolOp.getLoc();
|
||||
|
||||
@@ -236,8 +229,7 @@ struct PoolingBaseConverter : public OpConversionPattern<PoolOp> {
|
||||
size_t input_w = GET_IMAGE_WIDTH(xShape);
|
||||
size_t output_h = GET_IMAGE_HEIGHT(yShape);
|
||||
size_t output_w = GET_IMAGE_WIDTH(yShape);
|
||||
size_t channelTileCount =
|
||||
ceilIntegerDivide(GET_IMAGE_CHANNEL(xShape), crossbarSize.getValue());
|
||||
size_t channelTileCount = ceilIntegerDivide(GET_IMAGE_CHANNEL(xShape), crossbarSize.getValue());
|
||||
size_t channelTileRest = GET_IMAGE_CHANNEL(xShape) % crossbarSize;
|
||||
|
||||
// 1: Tile the input tensor
|
||||
@@ -249,14 +241,13 @@ struct PoolingBaseConverter : public OpConversionPattern<PoolOp> {
|
||||
// Example complete input tensor: tensor<1x3x12x12xf32> (NxCxWxH)
|
||||
// Suppose that the input tensor is produced by concatenating the results of
|
||||
// many ComputeOps. Get the result tiles from these ComputeOps.
|
||||
SmallVector<SmallVector<SmallVector<Value>>> inputTiles(channelTileCount,
|
||||
SmallVector<SmallVector<Value>>(input_w, SmallVector<Value>(input_h)));
|
||||
SmallVector<SmallVector<SmallVector<Value>>> inputTiles(
|
||||
channelTileCount, SmallVector<SmallVector<Value>>(input_w, SmallVector<Value>(input_h)));
|
||||
|
||||
auto resolveErrorOpt = resolveImgInputTiles(X, inputTiles, channelTileCount,
|
||||
channelTileRest, input_w, input_h, rewriter);
|
||||
if (resolveErrorOpt.has_value()) {
|
||||
auto resolveErrorOpt =
|
||||
resolveImgInputTiles(X, inputTiles, channelTileCount, channelTileRest, input_w, input_h, rewriter);
|
||||
if (resolveErrorOpt.has_value())
|
||||
return rewriter.notifyMatchFailure(poolOp, *resolveErrorOpt);
|
||||
}
|
||||
|
||||
// TODO: This requires a core for each input tile, which is not ideal. We
|
||||
// can do better.
|
||||
@@ -265,18 +256,17 @@ struct PoolingBaseConverter : public OpConversionPattern<PoolOp> {
|
||||
for (size_t t = 0; t < channelTileCount; t++) {
|
||||
for (size_t x = 0; x < input_w; x++) {
|
||||
for (size_t y = 0; y < input_h; y++) {
|
||||
if (auto extractSliceOp =
|
||||
inputTiles[t][x][y].getDefiningOp<tensor::ExtractSliceOp>()) {
|
||||
if (auto extractSliceOp = inputTiles[t][x][y].getDefiningOp<tensor::ExtractSliceOp>()) {
|
||||
Location tileLoc = extractSliceOp.getLoc();
|
||||
|
||||
auto tempComputeOp = rewriter.create<spatial::SpatWeightedCompute>(
|
||||
tileLoc, extractSliceOp.getResultType(),
|
||||
/* xbarWeights =*/ValueRange(), extractSliceOp.getResult());
|
||||
auto tempComputeOp = rewriter.create<spatial::SpatWeightedCompute>(tileLoc,
|
||||
extractSliceOp.getResultType(),
|
||||
/* xbarWeights =*/ValueRange(),
|
||||
extractSliceOp.getResult());
|
||||
|
||||
Block *tempComputeOpBlock = new Block();
|
||||
Block* tempComputeOpBlock = new Block();
|
||||
tempComputeOp.getBody().push_back(tempComputeOpBlock);
|
||||
auto tempComputeOpBlockArg = tempComputeOpBlock->addArgument(
|
||||
extractSliceOp.getType(), tileLoc);
|
||||
auto tempComputeOpBlockArg = tempComputeOpBlock->addArgument(extractSliceOp.getType(), tileLoc);
|
||||
|
||||
rewriter.setInsertionPointToStart(tempComputeOpBlock);
|
||||
rewriter.create<spatial::SpatYieldOp>(tileLoc, tempComputeOpBlockArg);
|
||||
@@ -295,8 +285,7 @@ struct PoolingBaseConverter : public OpConversionPattern<PoolOp> {
|
||||
// For example: outputTiles[channelTile][x][y]
|
||||
// Example complete output tensor: tensor<1x3x6x6xf32> (NxCxWxH)
|
||||
SmallVector<SmallVector<SmallVector<Value>>> outputTiles(
|
||||
channelTileCount, SmallVector<SmallVector<Value>>(
|
||||
output_w, SmallVector<Value>(output_h, nullptr)));
|
||||
channelTileCount, SmallVector<SmallVector<Value>>(output_w, SmallVector<Value>(output_h, nullptr)));
|
||||
|
||||
// List of values to pool for each output pixel
|
||||
SmallVector<Value> valuesToPool;
|
||||
@@ -312,15 +301,12 @@ struct PoolingBaseConverter : public OpConversionPattern<PoolOp> {
|
||||
valuesToPool.clear();
|
||||
size_t tilesSkippedByPadding = 0;
|
||||
|
||||
auto [start_x, end_x] = kernel_get_start_and_end(
|
||||
outX, input_w, krn_w, stride_x, dilation_x, pad_x);
|
||||
auto [start_y, end_y] = kernel_get_start_and_end(
|
||||
outY, input_h, krn_h, stride_y, dilation_y, pad_y);
|
||||
auto [start_x, end_x] = kernel_get_start_and_end(outX, input_w, krn_w, stride_x, dilation_x, pad_x);
|
||||
auto [start_y, end_y] = kernel_get_start_and_end(outY, input_h, krn_h, stride_y, dilation_y, pad_y);
|
||||
|
||||
for (size_t inX = start_x; inX < end_x; inX += dilation_x) {
|
||||
for (size_t inY = start_y; inY < end_y; inY += dilation_y) {
|
||||
if (failed(verifyWithinBoundsAndPaddings(
|
||||
input_w, input_h, inX, inY, pad_x, pad_y))) {
|
||||
if (failed(verifyWithinBoundsAndPaddings(input_w, input_h, inX, inY, pad_x, pad_y))) {
|
||||
tilesSkippedByPadding++;
|
||||
continue;
|
||||
}
|
||||
@@ -328,78 +314,73 @@ struct PoolingBaseConverter : public OpConversionPattern<PoolOp> {
|
||||
Value inputTile = inputTiles[outTile][inX][inY];
|
||||
|
||||
Value valueToPool;
|
||||
if (auto computeProducer =
|
||||
inputTile.getDefiningOp<spatial::SpatWeightedCompute>()) {
|
||||
if (auto computeProducer = inputTile.getDefiningOp<spatial::SpatWeightedCompute>()) {
|
||||
|
||||
int resultNumber = getResultIndex(computeProducer, inputTile);
|
||||
|
||||
auto yieldInComputeOp = cast<spatial::SpatYieldOp>(
|
||||
computeProducer.getBody().front().getTerminator());
|
||||
auto yieldInComputeOp = cast<spatial::SpatYieldOp>(computeProducer.getBody().front().getTerminator());
|
||||
valueToPool = yieldInComputeOp.getOperand(resultNumber);
|
||||
} else if (auto receiveProducer =
|
||||
inputTile
|
||||
.getDefiningOp<spatial::SpatChannelReceiveOp>()) {
|
||||
auto sendOpOpt =
|
||||
getOtherEndOfChannel(receiveProducer, true, rewriter);
|
||||
}
|
||||
else if (auto receiveProducer = inputTile.getDefiningOp<spatial::SpatChannelReceiveOp>()) {
|
||||
auto sendOpOpt = getOtherEndOfChannel(receiveProducer, true, rewriter);
|
||||
if (failed(sendOpOpt)) {
|
||||
return rewriter.notifyMatchFailure(poolOp,
|
||||
"ChannelReceiveOp does not have a matching "
|
||||
"ChannelSendOp.");
|
||||
"ChannelReceiveOp does not have a matching "
|
||||
"ChannelSendOp.");
|
||||
}
|
||||
auto sendOp = cast<spatial::SpatChannelSendOp>(*sendOpOpt);
|
||||
|
||||
valueToPool = sendOp.getData();
|
||||
} else {
|
||||
}
|
||||
else {
|
||||
return rewriter.notifyMatchFailure(poolOp,
|
||||
"Input tile for Pooling is not produced by a "
|
||||
"WeightedComputeOp nor a receiveOp");
|
||||
"Input tile for Pooling is not produced by a "
|
||||
"WeightedComputeOp nor a receiveOp");
|
||||
}
|
||||
|
||||
valuesToPool.push_back(valueToPool);
|
||||
}
|
||||
}
|
||||
|
||||
assert(valuesToPool.size() != 0 &&
|
||||
"Pooling computed on zero tiles make no sense.");
|
||||
assert(valuesToPool.size() != 0 && "Pooling computed on zero tiles make no sense.");
|
||||
// assert(computeOpsForPooling.size() != 1 &&
|
||||
// "Pooling computed on one tiles make no sense??? Or maybe
|
||||
// this " "should have been simplified earlier???");
|
||||
|
||||
std::function<Value(const Value &)> postProcessFn = nullptr;
|
||||
std::function<Value(const Value&)> postProcessFn = nullptr;
|
||||
if (hasPostProcessPoolingWindow<PoolOp>()) {
|
||||
postProcessFn = [&](const Value prevFinalRes) {
|
||||
return postProcessPoolingWindow(rewriter, loc, poolOp,
|
||||
prevFinalRes, krn_h * krn_w, tilesSkippedByPadding);
|
||||
return postProcessPoolingWindow(
|
||||
rewriter, loc, poolOp, prevFinalRes, krn_h * krn_w, tilesSkippedByPadding);
|
||||
};
|
||||
}
|
||||
|
||||
Value reducedWithinCompute = applyReducePatternNew(
|
||||
valuesToPool, rewriter,
|
||||
[&](const Value lhs, const Value rhs) {
|
||||
return rewriter.create<ReduceOp>(loc, lhs.getType(), lhs, rhs);
|
||||
},
|
||||
nullptr, postProcessFn);
|
||||
valuesToPool,
|
||||
rewriter,
|
||||
[&](const Value lhs, const Value rhs) { return rewriter.create<ReduceOp>(loc, lhs.getType(), lhs, rhs); },
|
||||
nullptr,
|
||||
postProcessFn);
|
||||
|
||||
// Send this value through a channel, and receive it in the
|
||||
// `func.func`. During lowering, we will need to "move it" into the
|
||||
// users computeOps
|
||||
auto computeOpOfReduced = cast<spatial::SpatWeightedCompute>(
|
||||
reducedWithinCompute.getDefiningOp()->getParentOp());
|
||||
auto computeOpOfReduced =
|
||||
cast<spatial::SpatWeightedCompute>(reducedWithinCompute.getDefiningOp()->getParentOp());
|
||||
|
||||
// Create a new channel before the computeOp
|
||||
rewriter.setInsertionPoint(computeOpOfReduced);
|
||||
auto reduceChannel = rewriter.create<spatial::SpatChannelNewOp>(
|
||||
loc, spatial::SpatChannelType::get(rewriter.getContext()));
|
||||
auto reduceChannel =
|
||||
rewriter.create<spatial::SpatChannelNewOp>(loc, spatial::SpatChannelType::get(rewriter.getContext()));
|
||||
|
||||
// Send value through the channel
|
||||
rewriter.setInsertionPointAfterValue(reducedWithinCompute);
|
||||
rewriter.create<spatial::SpatChannelSendOp>(
|
||||
loc, reduceChannel, reducedWithinCompute);
|
||||
rewriter.create<spatial::SpatChannelSendOp>(loc, reduceChannel, reducedWithinCompute);
|
||||
|
||||
// Receive after the computeOp
|
||||
rewriter.setInsertionPointAfter(computeOpOfReduced);
|
||||
auto receivedValue = rewriter.create<spatial::SpatChannelReceiveOp>(
|
||||
loc, reducedWithinCompute.getType(), reduceChannel);
|
||||
auto receivedValue =
|
||||
rewriter.create<spatial::SpatChannelReceiveOp>(loc, reducedWithinCompute.getType(), reduceChannel);
|
||||
|
||||
outputTiles[outTile][outX][outY] = receivedValue;
|
||||
}
|
||||
@@ -409,9 +390,7 @@ struct PoolingBaseConverter : public OpConversionPattern<PoolOp> {
|
||||
// TODO: outputTiles are not the results of the computeOps! We need to add
|
||||
// them!
|
||||
|
||||
std::unordered_map<Operation *,
|
||||
SmallVector<std::tuple<size_t, size_t, size_t, Value>>>
|
||||
computeOpNeedingResults;
|
||||
std::unordered_map<Operation*, SmallVector<std::tuple<size_t, size_t, size_t, Value>>> computeOpNeedingResults;
|
||||
|
||||
// Iterate each output tile
|
||||
for (size_t outTile = 0; outTile < channelTileCount; outTile++) {
|
||||
@@ -422,18 +401,16 @@ struct PoolingBaseConverter : public OpConversionPattern<PoolOp> {
|
||||
auto outputTileProducer = outputTile.getDefiningOp()->getParentOp();
|
||||
if (!outputTileProducer) {
|
||||
return rewriter.notifyMatchFailure(poolOp,
|
||||
"Output tile for Pooling is not produced by a "
|
||||
"WeightedComputeOp.");
|
||||
"Output tile for Pooling is not produced by a "
|
||||
"WeightedComputeOp.");
|
||||
}
|
||||
|
||||
computeOpNeedingResults[outputTileProducer].push_back(
|
||||
std::make_tuple(outTile, outX, outY, outputTile));
|
||||
computeOpNeedingResults[outputTileProducer].push_back(std::make_tuple(outTile, outX, outY, outputTile));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Value outputImage =
|
||||
createImgConcatOp(outputTiles, rewriter, loc, poolOp.getType());
|
||||
Value outputImage = createImgConcatOp(outputTiles, rewriter, loc, poolOp.getType());
|
||||
|
||||
rewriter.replaceOp(poolOp, outputImage);
|
||||
|
||||
@@ -441,12 +418,10 @@ struct PoolingBaseConverter : public OpConversionPattern<PoolOp> {
|
||||
}
|
||||
};
|
||||
|
||||
void populatePoolingTilingPattern(
|
||||
RewritePatternSet &patterns, MLIRContext *ctx) {
|
||||
patterns.insert<PoolingBaseConverter<ONNXMaxPoolSingleOutOp,
|
||||
ONNXMaxPoolSingleOutOpAdaptor, spatial::SpatVMaxOp>>(ctx);
|
||||
patterns.insert<PoolingBaseConverter<ONNXAveragePoolOp,
|
||||
ONNXAveragePoolOpAdaptor, spatial::SpatVAddOp>>(ctx);
|
||||
void populatePoolingTilingPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.insert<PoolingBaseConverter<ONNXMaxPoolSingleOutOp, ONNXMaxPoolSingleOutOpAdaptor, spatial::SpatVMaxOp>>(
|
||||
ctx);
|
||||
patterns.insert<PoolingBaseConverter<ONNXAveragePoolOp, ONNXAveragePoolOpAdaptor, spatial::SpatVAddOp>>(ctx);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -1,20 +1,19 @@
|
||||
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
namespace onnx_mlir {
|
||||
|
||||
struct ReduceMeanConversionPattern
|
||||
: public OpConversionPattern<ONNXReduceMeanV13Op> {
|
||||
struct ReduceMeanConversionPattern : public OpConversionPattern<ONNXReduceMeanV13Op> {
|
||||
|
||||
ReduceMeanConversionPattern(MLIRContext *ctx) : OpConversionPattern(ctx) {}
|
||||
ReduceMeanConversionPattern(MLIRContext* ctx)
|
||||
: OpConversionPattern(ctx) {}
|
||||
|
||||
LogicalResult matchAndRewrite(ONNXReduceMeanV13Op reduceMean,
|
||||
ONNXReduceMeanV13OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
ONNXReduceMeanV13OpAdaptor adaptor,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
|
||||
// Get the input tensor.
|
||||
Value inputTensor = adaptor.getData();
|
||||
@@ -38,42 +37,42 @@ struct ReduceMeanConversionPattern
|
||||
SmallVector<int64_t> padsVals = {0, 0, 0, 0};
|
||||
|
||||
// Create the ArrayAttrs
|
||||
auto kernelShape = mlir::ArrayAttr::get(rewriter.getContext(),
|
||||
llvm::to_vector(
|
||||
llvm::map_range(kernelShapeVals, [&](int64_t v) -> mlir::Attribute {
|
||||
return rewriter.getI64IntegerAttr(v);
|
||||
})));
|
||||
auto kernelShape = mlir::ArrayAttr::get(
|
||||
rewriter.getContext(), llvm::to_vector(llvm::map_range(kernelShapeVals, [&](int64_t v) -> mlir::Attribute {
|
||||
return rewriter.getI64IntegerAttr(v);
|
||||
})));
|
||||
|
||||
auto strides = mlir::ArrayAttr::get(rewriter.getContext(),
|
||||
llvm::to_vector(
|
||||
llvm::map_range(stridesVals, [&](int64_t v) -> mlir::Attribute {
|
||||
return rewriter.getI64IntegerAttr(v);
|
||||
})));
|
||||
llvm::to_vector(llvm::map_range(stridesVals, [&](int64_t v) -> mlir::Attribute {
|
||||
return rewriter.getI64IntegerAttr(v);
|
||||
})));
|
||||
|
||||
auto dilations = mlir::ArrayAttr::get(rewriter.getContext(),
|
||||
llvm::to_vector(
|
||||
llvm::map_range(dilationsVals, [&](int64_t v) -> mlir::Attribute {
|
||||
return rewriter.getI64IntegerAttr(v);
|
||||
})));
|
||||
auto dilations = mlir::ArrayAttr::get(
|
||||
rewriter.getContext(), llvm::to_vector(llvm::map_range(dilationsVals, [&](int64_t v) -> mlir::Attribute {
|
||||
return rewriter.getI64IntegerAttr(v);
|
||||
})));
|
||||
|
||||
auto pads = mlir::ArrayAttr::get(rewriter.getContext(),
|
||||
llvm::to_vector(
|
||||
llvm::map_range(padsVals, [&](int64_t v) -> mlir::Attribute {
|
||||
return rewriter.getI64IntegerAttr(v);
|
||||
})));
|
||||
llvm::to_vector(llvm::map_range(padsVals, [&](int64_t v) -> mlir::Attribute {
|
||||
return rewriter.getI64IntegerAttr(v);
|
||||
})));
|
||||
|
||||
// Create the resulting tensor type.
|
||||
auto resultType = RankedTensorType::get(
|
||||
/*shape=*/{inputTensorType.getShape()[0], inputTensorType.getShape()[1],
|
||||
1, 1},
|
||||
/*elementType=*/inputTensorType.getElementType());
|
||||
/*shape=*/ {inputTensorType.getShape()[0], inputTensorType.getShape()[1], 1, 1},
|
||||
/*elementType=*/inputTensorType.getElementType());
|
||||
|
||||
// Create the ONNXAveragePoolOp.
|
||||
auto averagePool = rewriter.create<ONNXAveragePoolOp>(reduceMean.getLoc(),
|
||||
resultType, inputTensor, /*auto_pad=*/"NOTSET",
|
||||
/*ceil_mode=*/0, /*count_include_pad=*/1, dilations,
|
||||
/*kernel_shape=*/kernelShape,
|
||||
/*pads=*/pads, /*strides=*/strides);
|
||||
resultType,
|
||||
inputTensor,
|
||||
/*auto_pad=*/"NOTSET",
|
||||
/*ceil_mode=*/0,
|
||||
/*count_include_pad=*/1,
|
||||
dilations,
|
||||
/*kernel_shape=*/kernelShape,
|
||||
/*pads=*/pads,
|
||||
/*strides=*/strides);
|
||||
|
||||
// Replace the ONNXReduceMeanV13Op with the ONNXAveragePoolOp.
|
||||
rewriter.replaceOp(reduceMean, averagePool.getResult());
|
||||
@@ -82,9 +81,8 @@ struct ReduceMeanConversionPattern
|
||||
}
|
||||
};
|
||||
|
||||
void populateReduceMeanConversionPattern(
|
||||
RewritePatternSet &patterns, MLIRContext *ctx) {
|
||||
void populateReduceMeanConversionPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.insert<ReduceMeanConversionPattern>(ctx);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -6,11 +6,12 @@
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
#include "llvm/Support/LogicalResult.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
#define DEFINE_MAP_OP(opname) opname,
|
||||
|
||||
#define GET_IMAGE_WIDTH(shapedType) shapedType.getDimSize(2)
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/raw_os_ostream.h"
|
||||
|
||||
#include <filesystem>
|
||||
#include <fstream>
|
||||
|
||||
#include "Common/PIMCommon.hpp"
|
||||
@@ -16,7 +15,6 @@
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
|
||||
#include "src/Compiler/CompilerOptions.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
@@ -39,7 +37,7 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
mergeActivationPatterns.add<matMulToGemmPattern>(ctx);
|
||||
mergeActivationPatterns.add<removeFlattenSameShapePattern>(ctx);
|
||||
|
||||
if (failed(applyPatternsAndFoldGreedily(moduleOp, std::move(mergeActivationPatterns))))
|
||||
if (failed(applyPatternsGreedily(moduleOp, std::move(mergeActivationPatterns))))
|
||||
llvm::dbgs() << "Failed to merge activation patterns, continuing...\n";
|
||||
|
||||
IRRewriter rewriter(moduleOp);
|
||||
@@ -88,7 +86,7 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
if (coresCount != -1) {
|
||||
int computeOpsCount = 0;
|
||||
for (auto& op : funcOp.getFunctionBody().front().getOperations())
|
||||
if (isa<spatial::SpatWeightedCompute>(op))
|
||||
if (isa<SpatWeightedCompute>(op))
|
||||
computeOpsCount++;
|
||||
|
||||
if (computeOpsCount > coresCount) {
|
||||
|
||||
@@ -3,38 +3,26 @@
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
void populateLoweringONNXMatMulOpToSpatialPattern(
|
||||
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx);
|
||||
void populateLoweringONNXMatMulOpToSpatialPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateTilingGemmOpPattern(
|
||||
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx);
|
||||
void populateTilingConvOpPattern(
|
||||
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx);
|
||||
void populateTilingGemmOpPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
void populateTilingConvOpPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populatePoolingTilingPattern(
|
||||
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx);
|
||||
void populatePoolingTilingPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateDistributeReducePattern(
|
||||
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx);
|
||||
void populateDistributeReducePattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateFoldComputePattern(
|
||||
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx);
|
||||
void populateFoldComputePattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateONNXConcatToTensorConcatPattern(
|
||||
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx);
|
||||
void populateONNXConcatToTensorConcatPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateRemoveUnusedHelperOpsPatterns(
|
||||
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx);
|
||||
void populateRemoveUnusedHelperOpsPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateReduceMeanConversionPattern(
|
||||
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx);
|
||||
void populateReduceMeanConversionPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
// Experimental patterns.
|
||||
void populateExperimentalTilingConvOpPattern(
|
||||
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx);
|
||||
void populateGemmToConvConversionPattern(
|
||||
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx);
|
||||
void populateExperimentalPoolingTilingPattern(
|
||||
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx);
|
||||
void populateExperimentalTilingConvOpPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
void populateGemmToConvConversionPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
void populateExperimentalPoolingTilingPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -1,19 +1,20 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
struct ONNXConcatToTensorConcat : public OpConversionPattern<ONNXConcatOp> {
|
||||
ONNXConcatToTensorConcat(MLIRContext *ctx) : OpConversionPattern(ctx) {}
|
||||
ONNXConcatToTensorConcat(MLIRContext* ctx)
|
||||
: OpConversionPattern(ctx) {}
|
||||
|
||||
LogicalResult matchAndRewrite(ONNXConcatOp maxpoolOp,
|
||||
ONNXConcatOpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
ONNXConcatOpAdaptor adaptor,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
auto inputs = adaptor.getInputs();
|
||||
int64_t axis = adaptor.getAxis();
|
||||
|
||||
@@ -23,9 +24,8 @@ struct ONNXConcatToTensorConcat : public OpConversionPattern<ONNXConcatOp> {
|
||||
}
|
||||
};
|
||||
|
||||
void populateONNXConcatToTensorConcatPattern(
|
||||
RewritePatternSet &patterns, MLIRContext *ctx) {
|
||||
void populateONNXConcatToTensorConcatPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.insert<ONNXConcatToTensorConcat>(ctx);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
#include <queue>
|
||||
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/AnnotateReplication.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
#include <queue>
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
mlir::LogicalResult annotateReplication(
|
||||
mlir::func::FuncOp funcOp, mlir::IRRewriter &rewriter);
|
||||
mlir::LogicalResult annotateReplication(mlir::func::FuncOp funcOp, mlir::IRRewriter& rewriter);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -1,32 +1,31 @@
|
||||
|
||||
#include "SpatialReducer.hpp"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
|
||||
#include "SpatialReducer.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
#define GET_COMP(computeOpAndResNum) std::get<0>(computeOpAndResNum)
|
||||
#define GET_RES_NUM(computeOpAndResNum) std::get<1>(computeOpAndResNum)
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
llvm::SmallPtrSet<Operation *, 16>
|
||||
onnx_mlir::SpatialReducer::oldComputeOpsReplaced;
|
||||
llvm::SmallPtrSet<Operation*, 16> onnx_mlir::SpatialReducer::oldComputeOpsReplaced;
|
||||
|
||||
ResNum SpatialReducer::applyResultProcessing(
|
||||
ComputeAndResNum computeOpAndResNum,
|
||||
std::function<Value(const Value &)> processFun,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
ResNum SpatialReducer::applyResultProcessing(ComputeAndResNum computeOpAndResNum,
|
||||
std::function<Value(const Value&)> processFun,
|
||||
ConversionPatternRewriter& rewriter) {
|
||||
assert(processFun);
|
||||
|
||||
auto computeOp = GET_COMP(computeOpAndResNum);
|
||||
auto resultNum = GET_RES_NUM(computeOpAndResNum);
|
||||
|
||||
spatial::SpatYieldOp yieldOp =
|
||||
cast<spatial::SpatYieldOp>(computeOp.getBody().front().getTerminator());
|
||||
spatial::SpatYieldOp yieldOp = cast<spatial::SpatYieldOp>(computeOp.getBody().front().getTerminator());
|
||||
|
||||
Value result = yieldOp->getOperand(resultNum);
|
||||
rewriter.setInsertionPointAfterValue(result);
|
||||
@@ -43,30 +42,24 @@ ResNum SpatialReducer::applyResultProcessing(
|
||||
return yieldOp.getNumOperands() - 1;
|
||||
}
|
||||
|
||||
OpAndResNum SpatialReducer::applyReducePattern(
|
||||
SmallVector<ComputeAndResNum> &computeOpsAndResNum,
|
||||
std::function<Value(const Value &, const Value &)> reduce,
|
||||
std::function<Value(const Value &)> preprocess,
|
||||
std::function<Value(const Value &)> postprocess) {
|
||||
OpAndResNum SpatialReducer::applyReducePattern(SmallVector<ComputeAndResNum>& computeOpsAndResNum,
|
||||
std::function<Value(const Value&, const Value&)> reduce,
|
||||
std::function<Value(const Value&)> preprocess,
|
||||
std::function<Value(const Value&)> postprocess) {
|
||||
|
||||
if (preprocess) {
|
||||
for (auto &computeOpAndResNum : computeOpsAndResNum) {
|
||||
GET_RES_NUM(computeOpAndResNum) =
|
||||
applyResultProcessing(computeOpAndResNum, preprocess, rewriter);
|
||||
}
|
||||
}
|
||||
if (preprocess)
|
||||
for (auto& computeOpAndResNum : computeOpsAndResNum)
|
||||
GET_RES_NUM(computeOpAndResNum) = applyResultProcessing(computeOpAndResNum, preprocess, rewriter);
|
||||
|
||||
// It is possible that `computeOpsAndResNum` contains two entries for the same
|
||||
// computeOp. In this case, we need to apply the reduction within-computef
|
||||
|
||||
// Keep a map between a computeOp and the last Value for this reduction
|
||||
std::unordered_map<Operation *, Value> lastValueForCompute;
|
||||
for (auto &computeOpAndResNum : computeOpsAndResNum) {
|
||||
std::unordered_map<Operation*, Value> lastValueForCompute;
|
||||
for (auto& computeOpAndResNum : computeOpsAndResNum) {
|
||||
auto computeOp = GET_COMP(computeOpAndResNum);
|
||||
auto yieldOp =
|
||||
cast<spatial::SpatYieldOp>(computeOp.getBody().front().getTerminator());
|
||||
Value valueWithinCompute =
|
||||
yieldOp->getOperand(GET_RES_NUM(computeOpAndResNum));
|
||||
auto yieldOp = cast<spatial::SpatYieldOp>(computeOp.getBody().front().getTerminator());
|
||||
Value valueWithinCompute = yieldOp->getOperand(GET_RES_NUM(computeOpAndResNum));
|
||||
|
||||
auto it = lastValueForCompute.find(computeOp.getOperation());
|
||||
|
||||
@@ -75,15 +68,12 @@ OpAndResNum SpatialReducer::applyReducePattern(
|
||||
// within-compute
|
||||
Value lastWithinComputeValue = it->second;
|
||||
|
||||
assert(valueWithinCompute.getDefiningOp() &&
|
||||
lastWithinComputeValue.getDefiningOp());
|
||||
assert(valueWithinCompute.getDefiningOp() && lastWithinComputeValue.getDefiningOp());
|
||||
|
||||
if (valueWithinCompute.getDefiningOp()->isBeforeInBlock(
|
||||
lastWithinComputeValue.getDefiningOp())) {
|
||||
if (valueWithinCompute.getDefiningOp()->isBeforeInBlock(lastWithinComputeValue.getDefiningOp()))
|
||||
rewriter.setInsertionPointAfterValue(lastWithinComputeValue);
|
||||
} else {
|
||||
else
|
||||
rewriter.setInsertionPointAfterValue(valueWithinCompute);
|
||||
}
|
||||
valueWithinCompute = reduce(lastWithinComputeValue, valueWithinCompute);
|
||||
lastValueForCompute[computeOp.getOperation()] = valueWithinCompute;
|
||||
}
|
||||
@@ -94,16 +84,15 @@ OpAndResNum SpatialReducer::applyReducePattern(
|
||||
// Now, reconstruct from the map the computeOpsAndResNum list
|
||||
computeOpsAndResNum.clear();
|
||||
computeOpsAndResNum.reserve(lastValueForCompute.size());
|
||||
for (auto &entry : lastValueForCompute) {
|
||||
for (auto& entry : lastValueForCompute) {
|
||||
auto computeOp = cast<spatial::SpatWeightedCompute>(entry.first);
|
||||
auto valueWithinCompute = entry.second;
|
||||
|
||||
// We check if `valueWithinCompute` is already used by the yieldOp, in that
|
||||
// case no need to add it
|
||||
auto yieldOp =
|
||||
cast<spatial::SpatYieldOp>(computeOp.getBody().front().getTerminator());
|
||||
auto yieldOp = cast<spatial::SpatYieldOp>(computeOp.getBody().front().getTerminator());
|
||||
bool yieldOpUseFound = false;
|
||||
for (auto &use : valueWithinCompute.getUses()) {
|
||||
for (auto& use : valueWithinCompute.getUses()) {
|
||||
if (use.getOwner() == yieldOp.getOperation()) {
|
||||
// If the value is already used by the yieldOp, we can just use it
|
||||
computeOpsAndResNum.push_back({computeOp, use.getOperandNumber()});
|
||||
@@ -111,9 +100,8 @@ OpAndResNum SpatialReducer::applyReducePattern(
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (yieldOpUseFound) {
|
||||
if (yieldOpUseFound)
|
||||
continue;
|
||||
}
|
||||
|
||||
// If this result is not used within a yieldOp, then add it
|
||||
auto resultNum = yieldOp->getNumOperands();
|
||||
@@ -147,23 +135,18 @@ OpAndResNum SpatialReducer::applyReducePattern(
|
||||
// the number of results)
|
||||
// See below `reducerChanges.push_back` and `finalizeReduceUpdates`
|
||||
|
||||
auto yieldOpFirstCompute = cast<spatial::SpatYieldOp>(
|
||||
firstCompute.getBody().front().getTerminator());
|
||||
auto yieldOpFirstCompute = cast<spatial::SpatYieldOp>(firstCompute.getBody().front().getTerminator());
|
||||
|
||||
// Add a new operand to the block of the second computeOp
|
||||
Block &secondBlock = secondCompute.getBody().front();
|
||||
Value formerRes1 = secondBlock.addArgument(
|
||||
yieldOpFirstCompute->getOperand(firstResultNum).getType(), loc);
|
||||
Block& secondBlock = secondCompute.getBody().front();
|
||||
Value formerRes1 = secondBlock.addArgument(yieldOpFirstCompute->getOperand(firstResultNum).getType(), loc);
|
||||
|
||||
auto secondComputeWeightsNum =
|
||||
secondCompute->getAttrOfType<DenseI32ArrayAttr>(
|
||||
secondCompute.getOperandSegmentSizesAttrName())[0];
|
||||
auto secondComputeOperandNum =
|
||||
secondComputeWeightsNum + secondBlock.getNumArguments() - 1;
|
||||
secondCompute->getAttrOfType<DenseI32ArrayAttr>(secondCompute.getOperandSegmentSizesAttrName())[0];
|
||||
auto secondComputeOperandNum = secondComputeWeightsNum + secondBlock.getNumArguments() - 1;
|
||||
|
||||
// Take the "former-result" from the second computeOp
|
||||
spatial::SpatYieldOp secondYield =
|
||||
cast<spatial::SpatYieldOp>(secondBlock.getTerminator());
|
||||
spatial::SpatYieldOp secondYield = cast<spatial::SpatYieldOp>(secondBlock.getTerminator());
|
||||
Value formerRes2 = secondYield.getOperand(secondResultNum);
|
||||
|
||||
// Apply reduction operation
|
||||
@@ -184,37 +167,31 @@ OpAndResNum SpatialReducer::applyReducePattern(
|
||||
// We should also add an entry for updating the results of the last
|
||||
// operation (the one which never becomes a `firstCompute`): because it is
|
||||
// not tracked by reducerChanges as `fromOp`
|
||||
reducerChanges.push_back({firstCompute.getOperation(), firstResultNum,
|
||||
secondCompute.getOperation(), secondComputeOperandNum});
|
||||
reducerChanges.push_back(
|
||||
{firstCompute.getOperation(), firstResultNum, secondCompute.getOperation(), secondComputeOperandNum});
|
||||
nextComputeOps.push_back(std::make_pair(secondCompute, secondResultNum));
|
||||
}
|
||||
|
||||
// If we have an odd number of inputs, we need to add the last one to the
|
||||
// newInputs list.
|
||||
if (computeOpsRef.size() % 2 == 1) {
|
||||
if (computeOpsRef.size() % 2 == 1)
|
||||
nextComputeOps.push_back(computeOpsRef.back());
|
||||
}
|
||||
|
||||
// Replace the inputOps list with the new one.
|
||||
computeOpsRef =
|
||||
llvm::OwningArrayRef<ComputeAndResNum>(std::move(nextComputeOps));
|
||||
computeOpsRef = llvm::OwningArrayRef<ComputeAndResNum>(std::move(nextComputeOps));
|
||||
}
|
||||
|
||||
assert(computeOpsRef.size() == 1 &&
|
||||
"Internal error: expected a single input at this point.");
|
||||
assert(computeOpsRef.size() == 1 && "Internal error: expected a single input at this point.");
|
||||
|
||||
auto finalComputeAndResNum = computeOpsRef[0];
|
||||
|
||||
// Force the update of the results of this computeOp, when finalizing
|
||||
computeOpNeedingResUpdate.push_back(GET_COMP(finalComputeAndResNum));
|
||||
|
||||
if (postprocess) {
|
||||
GET_RES_NUM(finalComputeAndResNum) =
|
||||
applyResultProcessing(finalComputeAndResNum, postprocess, rewriter);
|
||||
}
|
||||
if (postprocess)
|
||||
GET_RES_NUM(finalComputeAndResNum) = applyResultProcessing(finalComputeAndResNum, postprocess, rewriter);
|
||||
|
||||
return std::make_pair(GET_COMP(finalComputeAndResNum).getOperation(),
|
||||
GET_RES_NUM(finalComputeAndResNum));
|
||||
return std::make_pair(GET_COMP(finalComputeAndResNum).getOperation(), GET_RES_NUM(finalComputeAndResNum));
|
||||
}
|
||||
|
||||
void SpatialReducer::finalizeReduceUpdates() {
|
||||
@@ -223,15 +200,13 @@ void SpatialReducer::finalizeReduceUpdates() {
|
||||
reducesFinalized = true;
|
||||
|
||||
// First, add the results to the computeOps
|
||||
for (auto &reduceChange : reducerChanges) {
|
||||
for (auto& reduceChange : reducerChanges)
|
||||
updateResultsOfCompute(reduceChange.fromOp);
|
||||
}
|
||||
|
||||
for (auto &c : computeOpNeedingResUpdate) {
|
||||
for (auto& c : computeOpNeedingResUpdate)
|
||||
updateResultsOfCompute(c.getOperation());
|
||||
}
|
||||
|
||||
for (auto &reducerChange : this->reducerChanges) {
|
||||
for (auto& reducerChange : this->reducerChanges) {
|
||||
auto fromOp = reducerChange.fromOp;
|
||||
auto toOp = reducerChange.toOp;
|
||||
auto fromOpResNum = reducerChange.fromOpResNum;
|
||||
@@ -243,16 +218,14 @@ void SpatialReducer::finalizeReduceUpdates() {
|
||||
// toComputeOp could be the existing pointer, or we have to remap it with
|
||||
// `opToReplacedCompute`
|
||||
auto toComputeOp = opToReplacedCompute[toOp];
|
||||
if (!toComputeOp) {
|
||||
if (!toComputeOp)
|
||||
toComputeOp = cast<spatial::SpatWeightedCompute>(toOp);
|
||||
}
|
||||
|
||||
assert(toComputeOp != fromComputeOp &&
|
||||
"Oops should have caught this earlier!");
|
||||
assert(toComputeOp != fromComputeOp && "Oops should have caught this earlier!");
|
||||
|
||||
assert(toComputeOp->getNumOperands() == toOpOperandNum &&
|
||||
"toOpOperandNum should be the last operand of toComputeOp, are the "
|
||||
"operations in the right order?");
|
||||
assert(toComputeOp->getNumOperands() == toOpOperandNum
|
||||
&& "toOpOperandNum should be the last operand of toComputeOp, are the "
|
||||
"operations in the right order?");
|
||||
|
||||
// Add the new operand to `toComputeOp`
|
||||
auto fromResult = fromComputeOp.getResult(fromOpResNum);
|
||||
@@ -261,24 +234,22 @@ void SpatialReducer::finalizeReduceUpdates() {
|
||||
}
|
||||
}
|
||||
|
||||
Value SpatialReducer::resolveValueFromOpAndResNum(OpAndResNum &opAndResNum) {
|
||||
assert(reducesFinalized &&
|
||||
"Cannot create resolve values before finalizing the reduce updates.");
|
||||
Value SpatialReducer::resolveValueFromOpAndResNum(OpAndResNum& opAndResNum) {
|
||||
assert(reducesFinalized && "Cannot create resolve values before finalizing the reduce updates.");
|
||||
|
||||
Operation *opToCast;
|
||||
Operation* opToCast;
|
||||
auto it = opToReplacedCompute.find(opAndResNum.first);
|
||||
if (it != opToReplacedCompute.end()) {
|
||||
if (it != opToReplacedCompute.end())
|
||||
opToCast = it->second;
|
||||
} else {
|
||||
else
|
||||
opToCast = opAndResNum.first;
|
||||
}
|
||||
|
||||
auto computeOp = cast<spatial::SpatWeightedCompute>(opToCast);
|
||||
|
||||
return computeOp.getResult(opAndResNum.second);
|
||||
}
|
||||
|
||||
void SpatialReducer::updateResultsOfCompute(Operation *computeOp) {
|
||||
void SpatialReducer::updateResultsOfCompute(Operation* computeOp) {
|
||||
if (opToReplacedCompute.find(computeOp) != opToReplacedCompute.end()) {
|
||||
// If we have already replaced the fromOp, we do not need to do it again
|
||||
return;
|
||||
@@ -287,8 +258,7 @@ void SpatialReducer::updateResultsOfCompute(Operation *computeOp) {
|
||||
|
||||
auto oldComputeOpNum = oldComputeOp->getNumOperands();
|
||||
|
||||
auto yieldOp =
|
||||
cast<spatial::SpatYieldOp>(oldComputeOp.getBody().front().getTerminator());
|
||||
auto yieldOp = cast<spatial::SpatYieldOp>(oldComputeOp.getBody().front().getTerminator());
|
||||
|
||||
if (yieldOp.getNumOperands() == oldComputeOp->getNumResults()) {
|
||||
// No result was added, just add itself to the map
|
||||
@@ -301,9 +271,8 @@ void SpatialReducer::updateResultsOfCompute(Operation *computeOp) {
|
||||
|
||||
// Create a new ComputeOp with the new result type, but same operands
|
||||
rewriter.setInsertionPoint(oldComputeOp);
|
||||
auto newComputeOp =
|
||||
rewriter.create<spatial::SpatWeightedCompute>(oldComputeOp->getLoc(),
|
||||
newResultTypes, oldComputeOp.getWeights(), oldComputeOp.getInputs());
|
||||
auto newComputeOp = rewriter.create<spatial::SpatWeightedCompute>(
|
||||
oldComputeOp->getLoc(), newResultTypes, oldComputeOp.getWeights(), oldComputeOp.getInputs());
|
||||
|
||||
newComputeOp.getBody().takeBody(oldComputeOp.getBody());
|
||||
|
||||
@@ -329,54 +298,49 @@ void SpatialReducer::updateResultsOfCompute(Operation *computeOp) {
|
||||
rewriter.eraseOp(oldComputeOp);
|
||||
}
|
||||
|
||||
Value SpatialReducer::createImgConcatOp(
|
||||
SmallVector<SmallVector<SmallVector<OpAndResNum>>> &outputTiles,
|
||||
Location &loc, Type outputType) {
|
||||
Value SpatialReducer::createImgConcatOp(SmallVector<SmallVector<SmallVector<OpAndResNum>>>& outputTiles,
|
||||
Location& loc,
|
||||
Type outputType) {
|
||||
|
||||
assert(reducesFinalized &&
|
||||
"Cannot create ImgConcatOp before finalizing the reduce updates.");
|
||||
assert(reducesFinalized && "Cannot create ImgConcatOp before finalizing the reduce updates.");
|
||||
|
||||
// outputTiles are indexed like this: [channelTile][x][y]
|
||||
auto tilesCount = outputTiles.size();
|
||||
auto width = outputTiles[0].size();
|
||||
auto height = outputTiles[0][0].size();
|
||||
|
||||
SmallVector<SmallVector<SmallVector<Value>>> remappedOutputTiles(tilesCount,
|
||||
SmallVector<SmallVector<Value>>(width, SmallVector<Value>(height)));
|
||||
SmallVector<SmallVector<SmallVector<Value>>> remappedOutputTiles(
|
||||
tilesCount, SmallVector<SmallVector<Value>>(width, SmallVector<Value>(height)));
|
||||
|
||||
for (size_t t = 0; t < tilesCount; t++)
|
||||
for (size_t x = 0; x < width; x++)
|
||||
for (size_t y = 0; y < height; y++)
|
||||
remappedOutputTiles[t][x][y] =
|
||||
resolveValueFromOpAndResNum(outputTiles[t][x][y]);
|
||||
remappedOutputTiles[t][x][y] = resolveValueFromOpAndResNum(outputTiles[t][x][y]);
|
||||
|
||||
return ::onnx_mlir::createImgConcatOp(
|
||||
remappedOutputTiles, rewriter, loc, outputType);
|
||||
return ::onnx_mlir::createImgConcatOp(remappedOutputTiles, rewriter, loc, outputType);
|
||||
}
|
||||
|
||||
OpAndResNum SpatialReducer::applyAddMapReduction(
|
||||
SmallVector<ComputeAndResNum> &computeOps,
|
||||
ConversionPatternRewriter &rewriter, Value biasTile, MapOperations mapOp) {
|
||||
OpAndResNum SpatialReducer::applyAddMapReduction(SmallVector<ComputeAndResNum>& computeOps,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Value biasTile,
|
||||
MapOperations mapOp) {
|
||||
|
||||
std::function<Value(const Value &)> postprocessing = nullptr;
|
||||
std::function<Value(const Value&)> postprocessing = nullptr;
|
||||
|
||||
if (mapOp != MapOperations::None) {
|
||||
postprocessing = [&](const Value a) {
|
||||
Value mapOperand = a;
|
||||
if (biasTile) {
|
||||
mapOperand = rewriter.create<spatial::SpatVAddOp>(
|
||||
a.getLoc(), a.getType(), a, biasTile);
|
||||
}
|
||||
if (biasTile)
|
||||
mapOperand = rewriter.create<spatial::SpatVAddOp>(a.getLoc(), a.getType(), a, biasTile);
|
||||
return createMapOperation(rewriter, mapOp, mapOperand);
|
||||
};
|
||||
}
|
||||
|
||||
return this->applyReducePattern(
|
||||
computeOps,
|
||||
[&](Value a, Value b) {
|
||||
return rewriter.create<spatial::SpatVAddOp>(a.getLoc(), a.getType(), a, b);
|
||||
},
|
||||
/* preprocess = */ nullptr, postprocessing);
|
||||
computeOps,
|
||||
[&](Value a, Value b) { return rewriter.create<spatial::SpatVAddOp>(a.getLoc(), a.getType(), a, b); },
|
||||
/* preprocess = */ nullptr,
|
||||
postprocessing);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
#pragma once
|
||||
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
@@ -12,48 +13,48 @@ using ResNum = unsigned int;
|
||||
using ComputeAndResNum = std::pair<spatial::SpatWeightedCompute, ResNum>;
|
||||
|
||||
struct SpatialReducerChange {
|
||||
Operation *fromOp;
|
||||
Operation* fromOp;
|
||||
unsigned int fromOpResNum;
|
||||
Operation *toOp;
|
||||
Operation* toOp;
|
||||
unsigned int toOpOperandNum;
|
||||
};
|
||||
|
||||
using OpAndResNum = std::pair<Operation *, ResNum>;
|
||||
using OpAndResNum = std::pair<Operation*, ResNum>;
|
||||
|
||||
class SpatialReducer {
|
||||
|
||||
public:
|
||||
SpatialReducer(ConversionPatternRewriter &rewriter) : rewriter(rewriter) {}
|
||||
SpatialReducer(ConversionPatternRewriter& rewriter)
|
||||
: rewriter(rewriter) {}
|
||||
|
||||
OpAndResNum applyReducePattern(
|
||||
SmallVector<ComputeAndResNum> &computeOpsAndResNum,
|
||||
std::function<Value(const Value &, const Value &)> reduce,
|
||||
std::function<Value(const Value &)> preprocess,
|
||||
std::function<Value(const Value &)> postprocess);
|
||||
OpAndResNum applyReducePattern(SmallVector<ComputeAndResNum>& computeOpsAndResNum,
|
||||
std::function<Value(const Value&, const Value&)> reduce,
|
||||
std::function<Value(const Value&)> preprocess,
|
||||
std::function<Value(const Value&)> postprocess);
|
||||
|
||||
OpAndResNum applyAddMapReduction(SmallVector<ComputeAndResNum> &computeOps,
|
||||
ConversionPatternRewriter &rewriter, Value biasTile, MapOperations mapOp);
|
||||
OpAndResNum applyAddMapReduction(SmallVector<ComputeAndResNum>& computeOps,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Value biasTile,
|
||||
MapOperations mapOp);
|
||||
|
||||
void finalizeReduceUpdates();
|
||||
|
||||
~SpatialReducer() {
|
||||
if (!reducesFinalized) {
|
||||
if (!reducesFinalized)
|
||||
finalizeReduceUpdates();
|
||||
}
|
||||
}
|
||||
|
||||
Value createImgConcatOp(
|
||||
llvm::SmallVector<llvm::SmallVector<llvm::SmallVector<OpAndResNum>>>
|
||||
&outputTiles,
|
||||
Location &loc, Type outputType);
|
||||
Value createImgConcatOp(llvm::SmallVector<llvm::SmallVector<llvm::SmallVector<OpAndResNum>>>& outputTiles,
|
||||
Location& loc,
|
||||
Type outputType);
|
||||
|
||||
Value resolveValueFromOpAndResNum(OpAndResNum &opAndResNum);
|
||||
Value resolveValueFromOpAndResNum(OpAndResNum& opAndResNum);
|
||||
|
||||
private:
|
||||
[[nodiscard("computeOp result number gets updated")]] ResNum
|
||||
applyResultProcessing(ComputeAndResNum computeOpAndResNum,
|
||||
std::function<Value(const Value &)> processFun,
|
||||
ConversionPatternRewriter &rewriter);
|
||||
std::function<Value(const Value&)> processFun,
|
||||
ConversionPatternRewriter& rewriter);
|
||||
|
||||
/**
|
||||
* @brief Update the results of a ComputeOp.
|
||||
@@ -65,9 +66,9 @@ private:
|
||||
*
|
||||
* @param computeOp The ComputeOp to update the results of.
|
||||
*/
|
||||
void updateResultsOfCompute(Operation *computeOp);
|
||||
void updateResultsOfCompute(Operation* computeOp);
|
||||
|
||||
ConversionPatternRewriter &rewriter;
|
||||
ConversionPatternRewriter& rewriter;
|
||||
bool reducesFinalized = false;
|
||||
|
||||
// List of changes to be applied after the reduction is finalized
|
||||
@@ -75,9 +76,9 @@ private:
|
||||
// List of computeOps that need to be replaced with new results
|
||||
SmallVector<spatial::SpatWeightedCompute> computeOpNeedingResUpdate;
|
||||
|
||||
std::unordered_map<Operation *, spatial::SpatWeightedCompute> opToReplacedCompute;
|
||||
std::unordered_map<Operation*, spatial::SpatWeightedCompute> opToReplacedCompute;
|
||||
|
||||
static llvm::SmallPtrSet<Operation *, 16> oldComputeOpsReplaced;
|
||||
static llvm::SmallPtrSet<Operation*, 16> oldComputeOpsReplaced;
|
||||
};
|
||||
|
||||
} // namespace onnx_mlir
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/WeightSubdivider.hpp"
|
||||
#include <cassert>
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/WeightSubdivider.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
WeightSubdivider::WeightSubdivider(
|
||||
map<long, map<long, SmallVector<Value>>> weights)
|
||||
: weights(std::move(weights)) {}
|
||||
WeightSubdivider::WeightSubdivider(map<long, map<long, SmallVector<Value>>> weights)
|
||||
: weights(std::move(weights)) {}
|
||||
|
||||
bool WeightSubdivider::isEmpty() const { return weights.empty(); }
|
||||
|
||||
@@ -13,7 +13,7 @@ TaggedWeights WeightSubdivider::popGroup(size_t amount) {
|
||||
assert(!weights.empty() && "No weights to extract.");
|
||||
|
||||
auto it = weights.begin();
|
||||
SmallVector<Value> &values = it->second.begin()->second;
|
||||
SmallVector<Value>& values = it->second.begin()->second;
|
||||
|
||||
long inputTile = it->first;
|
||||
long outputTile = it->second.begin()->first;
|
||||
@@ -26,11 +26,11 @@ TaggedWeights WeightSubdivider::popGroup(size_t amount) {
|
||||
|
||||
if (n < values.size()) {
|
||||
values.erase(values.begin(), values.begin() + n);
|
||||
} else {
|
||||
}
|
||||
else {
|
||||
it->second.erase(outputTile);
|
||||
if (it->second.empty()) {
|
||||
if (it->second.empty())
|
||||
weights.erase(inputTile);
|
||||
}
|
||||
}
|
||||
|
||||
return {inputTile, outputTile, crossbarsUsed - n, result};
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <map>
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
@@ -1,23 +1,21 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
#include "mlir/IR/Block.h"
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/Format.h"
|
||||
|
||||
#define FORMAT_OPERATION(op) \
|
||||
'x' << llvm::format_hex_no_prefix(reinterpret_cast<size_t>(op), 0)
|
||||
#define FORMAT_ARGUMENT(computeOpPointer, argumentNum) \
|
||||
llvm::format("Arg_%p_%u", computeOpPointer, argumentNum)
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
#define FORMAT_OPERATION(op) 'x' << llvm::format_hex_no_prefix(reinterpret_cast<size_t>(op), 0)
|
||||
#define FORMAT_ARGUMENT(computeOpPointer, argumentNum) llvm::format("Arg_%p_%u", computeOpPointer, argumentNum)
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
@@ -25,26 +23,22 @@ namespace onnx_mlir {
|
||||
|
||||
namespace {
|
||||
|
||||
struct SpatialToGraphvizPass
|
||||
: public PassWrapper<SpatialToGraphvizPass, OperationPass<ModuleOp>> {
|
||||
struct SpatialToGraphvizPass : public PassWrapper<SpatialToGraphvizPass, OperationPass<ModuleOp>> {
|
||||
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialToGraphvizPass)
|
||||
|
||||
StringRef getArgument() const override {
|
||||
return "convert-spatial-to-graphviz";
|
||||
}
|
||||
StringRef getArgument() const override { return "convert-spatial-to-graphviz"; }
|
||||
|
||||
StringRef getDescription() const override {
|
||||
return "Lower ONNX ops to Spatial ops.";
|
||||
}
|
||||
StringRef getDescription() const override { return "Lower ONNX ops to Spatial ops."; }
|
||||
|
||||
SpatialToGraphvizPass(raw_ostream &os = llvm::errs()) : os(os) {}
|
||||
SpatialToGraphvizPass(const SpatialToGraphvizPass &pass)
|
||||
: SpatialToGraphvizPass(pass.os) {}
|
||||
SpatialToGraphvizPass(raw_ostream& os = llvm::errs())
|
||||
: os(os) {}
|
||||
SpatialToGraphvizPass(const SpatialToGraphvizPass& pass)
|
||||
: SpatialToGraphvizPass(pass.os) {}
|
||||
void runOnOperation() final;
|
||||
|
||||
private:
|
||||
raw_ostream &os;
|
||||
raw_ostream& os;
|
||||
|
||||
/**
|
||||
* Draws the subgraph for a given spatial::SpatWeightedCompute, including:
|
||||
@@ -56,31 +50,27 @@ private:
|
||||
* @param computeNum The number of the compute operation.
|
||||
*/
|
||||
void drawComputeOpSubgraph(spatial::SpatWeightedCompute op, size_t computeNum) {
|
||||
os << "\tsubgraph cluster" << computeNum << " {\n\t\tlabel=\"Compute"
|
||||
<< computeNum << "\";\n"
|
||||
os << "\tsubgraph cluster" << computeNum << " {\n\t\tlabel=\"Compute" << computeNum << "\";\n"
|
||||
<< "\t\tstyle=filled;\n"
|
||||
<< "\t\tcolor=lightblue;\n";
|
||||
|
||||
Block &block = op.getBody().front();
|
||||
Block& block = op.getBody().front();
|
||||
|
||||
// Inputs
|
||||
size_t inputNum = 0;
|
||||
for (BlockArgument &input : block.getArguments()) {
|
||||
for (BlockArgument& input : block.getArguments()) {
|
||||
|
||||
auto fromOp = FORMAT_ARGUMENT(op.getOperation(), inputNum);
|
||||
|
||||
os << "\t\t" << fromOp << " [label=\"Arg" << inputNum
|
||||
<< "\",shape=box];\n";
|
||||
for (auto userOp : input.getUsers()) {
|
||||
os << "\t\t" << fromOp << " [label=\"Arg" << inputNum << "\",shape=box];\n";
|
||||
for (auto userOp : input.getUsers())
|
||||
os << "\t\t" << fromOp << " -> " << FORMAT_OPERATION(userOp) << ";\n";
|
||||
}
|
||||
inputNum++;
|
||||
}
|
||||
|
||||
// Iterate operations
|
||||
for (auto &childOp : block.getOperations()) {
|
||||
os << "\t\t" << FORMAT_OPERATION(&childOp) << " [label=\""
|
||||
<< childOp.getName() << "\"];\n";
|
||||
for (auto& childOp : block.getOperations()) {
|
||||
os << "\t\t" << FORMAT_OPERATION(&childOp) << " [label=\"" << childOp.getName() << "\"];\n";
|
||||
|
||||
drawEdgesFromOpToItsUsers(&childOp);
|
||||
}
|
||||
@@ -88,7 +78,7 @@ private:
|
||||
os << "\t}\n";
|
||||
|
||||
// Draw edges from the yield to the users of this computeOp
|
||||
Operation *yieldOp = block.getTerminator();
|
||||
Operation* yieldOp = block.getTerminator();
|
||||
if (!isa<spatial::SpatYieldOp>(yieldOp)) {
|
||||
yieldOp->emitError("Terminator of block must be YieldOp ???");
|
||||
signalPassFailure();
|
||||
@@ -96,9 +86,8 @@ private:
|
||||
}
|
||||
|
||||
for (auto computeOpResult : op->getResults()) {
|
||||
for (auto &computeOpUse : computeOpResult.getUses()) {
|
||||
auto toOp = FORMAT_ARGUMENT(
|
||||
computeOpUse.getOwner(), computeOpUse.getOperandNumber());
|
||||
for (auto& computeOpUse : computeOpResult.getUses()) {
|
||||
auto toOp = FORMAT_ARGUMENT(computeOpUse.getOwner(), computeOpUse.getOperandNumber());
|
||||
os << "\t" << FORMAT_OPERATION(yieldOp) << " -> " << toOp << ";\n";
|
||||
}
|
||||
}
|
||||
@@ -114,9 +103,8 @@ private:
|
||||
* @param concatOp The concatOp for which the subgraph is drawn.
|
||||
* @param concatOpNum The number of the concatOp.
|
||||
*/
|
||||
void drawConcatOpSubgraph(Operation *concatOp, size_t concatOpNum) {
|
||||
os << "\tsubgraph clusterconcat" << concatOpNum
|
||||
<< " {\n\t\tlabel=\"ConcatOp" << concatOpNum << "\";\n"
|
||||
void drawConcatOpSubgraph(Operation* concatOp, size_t concatOpNum) {
|
||||
os << "\tsubgraph clusterconcat" << concatOpNum << " {\n\t\tlabel=\"ConcatOp" << concatOpNum << "\";\n"
|
||||
<< "\t\tstyle=filled;\n"
|
||||
<< "\t\tcolor=orange;\n";
|
||||
|
||||
@@ -126,9 +114,8 @@ private:
|
||||
auto fromOp = FORMAT_ARGUMENT(concatOp, inputNum);
|
||||
|
||||
os << "\t\t" << fromOp << " [label=\"Input" << inputNum << "\"];\n";
|
||||
for (auto userOp : input.getUsers()) {
|
||||
for (auto userOp : input.getUsers())
|
||||
os << "\t\t" << fromOp << " -> " << FORMAT_OPERATION(userOp) << ";\n";
|
||||
}
|
||||
inputNum++;
|
||||
}
|
||||
|
||||
@@ -139,11 +126,9 @@ private:
|
||||
|
||||
// Edges from output to users
|
||||
|
||||
for (auto &computeOpUse : concatOp->getResult(0).getUses()) {
|
||||
for (auto& computeOpUse : concatOp->getResult(0).getUses()) {
|
||||
os << "\t" << FORMAT_OPERATION(concatOp) << " -> "
|
||||
<< FORMAT_ARGUMENT(
|
||||
computeOpUse.getOwner(), computeOpUse.getOperandNumber())
|
||||
<< ";\n";
|
||||
<< FORMAT_ARGUMENT(computeOpUse.getOwner(), computeOpUse.getOperandNumber()) << ";\n";
|
||||
}
|
||||
}
|
||||
|
||||
@@ -164,10 +149,8 @@ private:
|
||||
sliceOp.getStaticOffsetsAttr().print(os);
|
||||
os << "\",color=lawngreen];\n";
|
||||
|
||||
for (auto &computeOpUse : sliceOp.getResult().getUses()) {
|
||||
os << "\t" << nodeId << " -> "
|
||||
<< FORMAT_ARGUMENT(
|
||||
computeOpUse.getOwner(), computeOpUse.getOperandNumber())
|
||||
for (auto& computeOpUse : sliceOp.getResult().getUses()) {
|
||||
os << "\t" << nodeId << " -> " << FORMAT_ARGUMENT(computeOpUse.getOwner(), computeOpUse.getOperandNumber())
|
||||
<< ";\n";
|
||||
}
|
||||
}
|
||||
@@ -178,9 +161,8 @@ private:
|
||||
sliceOp.getStaticOffsetsAttr().print(os);
|
||||
os << "\",color=lightpink];\n";
|
||||
|
||||
for (auto user : sliceOp.getResult().getUsers()) {
|
||||
for (auto user : sliceOp.getResult().getUsers())
|
||||
os << "\t" << nodeId << " -> " << FORMAT_OPERATION(user) << ";\n";
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -188,13 +170,10 @@ private:
|
||||
*
|
||||
* @param fromOp The operation from which the edges are drawn.
|
||||
*/
|
||||
void drawEdgesFromOpToItsUsers(mlir::Operation *fromOp) {
|
||||
for (auto result : fromOp->getResults()) {
|
||||
for (auto userOp : result.getUsers()) {
|
||||
os << "\t\t" << FORMAT_OPERATION(fromOp) << " -> "
|
||||
<< FORMAT_OPERATION(userOp) << ";\n";
|
||||
}
|
||||
}
|
||||
void drawEdgesFromOpToItsUsers(mlir::Operation* fromOp) {
|
||||
for (auto result : fromOp->getResults())
|
||||
for (auto userOp : result.getUsers())
|
||||
os << "\t\t" << FORMAT_OPERATION(fromOp) << " -> " << FORMAT_OPERATION(userOp) << ";\n";
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -202,16 +181,15 @@ private:
|
||||
*
|
||||
* @param funcOp The `funcOp` for which to draw input nodes and edges.
|
||||
*/
|
||||
void drawInputNodesAndEdges(func::FuncOp &funcOp) {
|
||||
void drawInputNodesAndEdges(func::FuncOp& funcOp) {
|
||||
os << "\tinput [label=\"Module Input\",color=green];\n";
|
||||
|
||||
size_t funcOpArgNum = 0;
|
||||
for (BlockArgument &arg : funcOp.getArguments()) {
|
||||
for (BlockArgument& arg : funcOp.getArguments()) {
|
||||
|
||||
for (auto &useOp : arg.getUses()) {
|
||||
os << "\tinput -> "
|
||||
<< FORMAT_ARGUMENT(useOp.getOwner(), useOp.getOperandNumber())
|
||||
<< "[label=" << funcOpArgNum << "];\n";
|
||||
for (auto& useOp : arg.getUses()) {
|
||||
os << "\tinput -> " << FORMAT_ARGUMENT(useOp.getOwner(), useOp.getOperandNumber()) << "[label=" << funcOpArgNum
|
||||
<< "];\n";
|
||||
}
|
||||
funcOpArgNum++;
|
||||
}
|
||||
@@ -237,20 +215,22 @@ void SpatialToGraphvizPass::runOnOperation() {
|
||||
// Iterate over the ComputeOps within FuncOp:
|
||||
// 1. Print their subgraph
|
||||
// 2. Print the edges from its inputs to its outputs
|
||||
for (Operation &op : func.getOps()) {
|
||||
for (Operation& op : func.getOps()) {
|
||||
if (auto computeOp = dyn_cast<spatial::SpatWeightedCompute>(op)) {
|
||||
drawComputeOpSubgraph(computeOp, computeNum++);
|
||||
} else if (auto concatOp = dyn_cast<tensor::ConcatOp>(op)) {
|
||||
}
|
||||
else if (auto concatOp = dyn_cast<tensor::ConcatOp>(op)) {
|
||||
drawConcatOpSubgraph(concatOp, concatNum++);
|
||||
} else if (auto imgConcatOp = dyn_cast<spatial::SpatImgConcatOp>(op)) {
|
||||
}
|
||||
else if (auto imgConcatOp = dyn_cast<spatial::SpatImgConcatOp>(op)) {
|
||||
drawConcatOpSubgraph(imgConcatOp, concatNum++);
|
||||
} else if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op)) {
|
||||
}
|
||||
else if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op)) {
|
||||
auto producerOp = extractSliceOp->getOperand(0).getDefiningOp();
|
||||
if (producerOp) {
|
||||
// Skip extractSliceOp if producer is constant weights (ONNXConstantOp)
|
||||
if (llvm::isa<ONNXConstantOp>(producerOp)) {
|
||||
if (llvm::isa<ONNXConstantOp>(producerOp))
|
||||
continue;
|
||||
}
|
||||
// If produced by tosa::ReshapeOp (i.e. it is a bias tile) connect
|
||||
// directly to its user, which is not a ComputeOp argument.
|
||||
if (llvm::isa<tosa::ReshapeOp>(producerOp)) {
|
||||
@@ -268,16 +248,13 @@ void SpatialToGraphvizPass::runOnOperation() {
|
||||
|
||||
// Draw output node (use the return Operation - argument number=0 - as nodeId)
|
||||
auto returnOp = func.getBody().front().getTerminator();
|
||||
os << '\t' << FORMAT_ARGUMENT(returnOp, 0)
|
||||
<< " [label=\"Module Output\",color=green];\n";
|
||||
os << '\t' << FORMAT_ARGUMENT(returnOp, 0) << " [label=\"Module Output\",color=green];\n";
|
||||
|
||||
os << "}\n";
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> createSpatialToGraphvizPass() {
|
||||
return std::make_unique<SpatialToGraphvizPass>();
|
||||
}
|
||||
std::unique_ptr<Pass> createSpatialToGraphvizPass() { return std::make_unique<SpatialToGraphvizPass>(); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -148,7 +148,8 @@ void SpatialToPIMPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IR
|
||||
size_t resultIndexInConcat = resultUses.begin()->getOperandNumber();
|
||||
size_t offset = 0;
|
||||
for (auto operand : concatOp->getOperands().take_front(resultIndexInConcat))
|
||||
offset += cast<ShapedType>(operand.getType()).getNumElements() * cast<ShapedType>(operand.getType()).getElementTypeBitWidth() / 8;
|
||||
offset += cast<ShapedType>(operand.getType()).getNumElements()
|
||||
* cast<ShapedType>(operand.getType()).getElementTypeBitWidth() / 8;
|
||||
|
||||
size_t elementSize = yieldType.getElementTypeBitWidth() / 8;
|
||||
|
||||
|
||||
@@ -9,4 +9,4 @@ namespace spatial {
|
||||
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
|
||||
#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h"
|
||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
@@ -11,6 +8,9 @@
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
|
||||
/// Include the auto-generated header files containing the declarations
|
||||
#include "src/Accelerators/PIM/Dialect/PIM/PimDialect.hpp.inc"
|
||||
|
||||
|
||||
@@ -15,7 +15,10 @@ namespace pim {
|
||||
|
||||
struct MemCopyHostToDevOpInterface
|
||||
: DstBufferizableOpInterfaceExternalModel<MemCopyHostToDevOpInterface, PimMemCopyHostToDevOp> {
|
||||
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const {
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
auto memCopyHostToDevOp = cast<PimMemCopyHostToDevOp>(op);
|
||||
auto deviceDst = memCopyHostToDevOp.getDeviceDst();
|
||||
auto hostSrc = memCopyHostToDevOp.getHostSrc();
|
||||
@@ -44,7 +47,10 @@ struct MemCopyHostToDevOpInterface
|
||||
|
||||
struct MemCopyDevToHostOpInterface
|
||||
: DstBufferizableOpInterfaceExternalModel<MemCopyDevToHostOpInterface, PimMemCopyDevToHostOp> {
|
||||
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const {
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
auto memCopyDevToHostOp = cast<PimMemCopyDevToHostOp>(op);
|
||||
|
||||
auto globalDst = memCopyDevToHostOp.getHostDst();
|
||||
@@ -88,7 +94,10 @@ struct VMMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<VMMOpBu
|
||||
return false;
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const {
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
auto vmmOp = cast<PimVMMOp>(op);
|
||||
|
||||
auto vectorInputOpt = getBuffer(rewriter, vmmOp.getVectorInput(), options, state);
|
||||
@@ -110,7 +119,10 @@ struct MVMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<MVMOpBu
|
||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const {
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
auto mvmOp = cast<PimMVMOp>(op);
|
||||
|
||||
auto vectorInputOpt = getBuffer(rewriter, mvmOp.getVectorInput(), options, state);
|
||||
@@ -138,7 +150,10 @@ struct VAddOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<VAddOp
|
||||
return true;
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const {
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
auto vaddOp = cast<PimVAddOp>(op);
|
||||
|
||||
auto aOpt = getBuffer(rewriter, vaddOp.getA(), options, state);
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/DialectRegistry.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
@@ -8,7 +9,7 @@ using namespace mlir;
|
||||
namespace onnx_mlir {
|
||||
namespace pim {
|
||||
|
||||
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry);
|
||||
void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry);
|
||||
|
||||
} // namespace pim
|
||||
} // namespace onnx_mlir
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
@@ -10,6 +7,9 @@
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
|
||||
/// Include the auto-generated header files containing the declarations
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialDialect.hpp.inc"
|
||||
|
||||
|
||||
@@ -75,7 +75,10 @@ struct WComputeOpInterface : BufferizableOpInterface::ExternalModel<WComputeOpIn
|
||||
return {};
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const {
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
// Bufferize its block
|
||||
|
||||
auto& block = op->getRegion(0).front();
|
||||
@@ -104,7 +107,10 @@ struct VariadicArgumentElementWiseOpInterface : BufferizableOpInterface::Externa
|
||||
}
|
||||
|
||||
// Cast tensor values into memref values
|
||||
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const {
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
|
||||
// Turn Tensor Operands into Memref Operands
|
||||
SmallVector<Value> memrefOperands;
|
||||
@@ -151,7 +157,10 @@ struct WeightedMultiplicationsOpInterface : BufferizableOpInterface::ExternalMod
|
||||
}
|
||||
|
||||
// Cast tensor value into memref value
|
||||
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const {
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
auto memrefOperandOpt = getBuffer(rewriter, op->getOperand(0), options, state);
|
||||
if (failed(memrefOperandOpt))
|
||||
return failure();
|
||||
@@ -190,7 +199,10 @@ struct ChannelReceiveOpInterface
|
||||
/*
|
||||
* Turn the channel receive to pim.recv
|
||||
*/
|
||||
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const {
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
|
||||
auto outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
|
||||
|
||||
@@ -235,7 +247,10 @@ struct ChannelSendOpInterface : BufferizableOpInterface::ExternalModel<ChannelSe
|
||||
/*
|
||||
* Turn the channel send to pim.send
|
||||
*/
|
||||
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const {
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
auto srcTensor = op->getOperand(1);
|
||||
|
||||
auto srcTensorOpt = getBuffer(rewriter, srcTensor, options, state);
|
||||
@@ -278,7 +293,10 @@ struct ChannelBroadcastReceiveOpInterface
|
||||
/*
|
||||
* Turn the channel receive to pim.load using by creating a new global buffer
|
||||
*/
|
||||
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const {
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
|
||||
auto outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
|
||||
|
||||
@@ -340,7 +358,10 @@ struct ChannelBroadcastSendOpInterface
|
||||
/*
|
||||
* Turn the channel send to pim.send
|
||||
*/
|
||||
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const {
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
auto srcTensor = op->getOperand(1);
|
||||
|
||||
auto srcTensorOpt = getBuffer(rewriter, srcTensor, options, state);
|
||||
@@ -414,7 +435,10 @@ struct ApplyFiltersOpInterface : BufferizableOpInterface::ExternalModel<ApplyFil
|
||||
}
|
||||
|
||||
// Bufferize the operation.
|
||||
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const {
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
|
||||
// Get the input tensor buffer.
|
||||
auto inputBuffer = getBuffer(rewriter, op->getOperand(0), options, state);
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/DialectRegistry.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Compiler/CompilerUtils.hpp"
|
||||
@@ -10,22 +11,19 @@ namespace onnx_mlir {
|
||||
|
||||
namespace {
|
||||
|
||||
struct CountInstructionPass
|
||||
: public PassWrapper<CountInstructionPass, OperationPass<ModuleOp>> {
|
||||
struct CountInstructionPass : public PassWrapper<CountInstructionPass, OperationPass<ModuleOp>> {
|
||||
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CountInstructionPass)
|
||||
|
||||
StringRef getArgument() const override { return "count-instruction-pass"; }
|
||||
|
||||
StringRef getDescription() const override {
|
||||
return "Count instructions for each core/compute in the module";
|
||||
}
|
||||
StringRef getDescription() const override { return "Count instructions for each core/compute in the module"; }
|
||||
|
||||
// Make sure that we have a valid default constructor and copy
|
||||
// constructor to make sure that the options are initialized properly.
|
||||
CountInstructionPass() {}
|
||||
CountInstructionPass(const CountInstructionPass &pass)
|
||||
: PassWrapper<CountInstructionPass, OperationPass<ModuleOp>>() {}
|
||||
CountInstructionPass(const CountInstructionPass& pass)
|
||||
: PassWrapper<CountInstructionPass, OperationPass<ModuleOp>>() {}
|
||||
void runOnOperation() final {
|
||||
ModuleOp module = getOperation();
|
||||
|
||||
@@ -37,8 +35,7 @@ struct CountInstructionPass
|
||||
for (auto computeOp : func.getOps<spatial::SpatWeightedCompute>()) {
|
||||
unsigned instructionCount = 0;
|
||||
instructionCount += computeOp.getBody().front().getOperations().size();
|
||||
llvm::outs() << "Compute " << computeId << ": " << instructionCount
|
||||
<< " instructions\n";
|
||||
llvm::outs() << "Compute " << computeId << ": " << instructionCount << " instructions\n";
|
||||
totalInstructionCount += instructionCount;
|
||||
computeId++;
|
||||
}
|
||||
@@ -47,21 +44,17 @@ struct CountInstructionPass
|
||||
for (auto coreOp : func.getOps<pim::PimCoreOp>()) {
|
||||
unsigned instructionCount = 0;
|
||||
instructionCount += coreOp.getBody().front().getOperations().size();
|
||||
llvm::outs() << "Core " << coreId << ": " << instructionCount
|
||||
<< " instructions\n";
|
||||
llvm::outs() << "Core " << coreId << ": " << instructionCount << " instructions\n";
|
||||
totalInstructionCount += instructionCount;
|
||||
coreId++;
|
||||
}
|
||||
|
||||
llvm::outs() << "Total instruction count: " << totalInstructionCount
|
||||
<< "\n";
|
||||
llvm::outs() << "Total instruction count: " << totalInstructionCount << "\n";
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> createCountInstructionPass() {
|
||||
return std::make_unique<CountInstructionPass>();
|
||||
}
|
||||
std::unique_ptr<Pass> createCountInstructionPass() { return std::make_unique<CountInstructionPass>(); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
using namespace mlir;
|
||||
@@ -21,4 +22,4 @@ std::unique_ptr<Pass> createMessagePass(std::string message);
|
||||
|
||||
std::unique_ptr<Pass> createCountInstructionPass();
|
||||
|
||||
} // namespace onnx_mlir
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
|
||||
#include "src/Accelerators/Accelerator.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
@@ -9,38 +10,37 @@ namespace accel {
|
||||
/// Singleton class to construct PIM accelerator.
|
||||
class PimAccelerator final : public Accelerator {
|
||||
private:
|
||||
static PimAccelerator *instance;
|
||||
static PimAccelerator* instance;
|
||||
PimAccelerator();
|
||||
|
||||
public:
|
||||
/// Singleton should not be clonable or assignable.
|
||||
PimAccelerator(PimAccelerator &) = delete;
|
||||
void operator=(const PimAccelerator &) = delete;
|
||||
PimAccelerator(PimAccelerator&) = delete;
|
||||
void operator=(const PimAccelerator&) = delete;
|
||||
|
||||
~PimAccelerator();
|
||||
|
||||
/// Creates an instance on the first invocation. Subsequent invocations
|
||||
/// return the existing instance.
|
||||
static PimAccelerator *getInstance();
|
||||
static PimAccelerator* getInstance();
|
||||
|
||||
/// Define classof to be able to use isa<>, cast<>, dyn_cast<>, etc.
|
||||
static bool classof(const Accelerator *accel) {
|
||||
return accel->getKind() == Accelerator::Kind::PIM;
|
||||
}
|
||||
static bool classof(const PimAccelerator *) { return true; }
|
||||
static bool classof(const Accelerator* accel) { return accel->getKind() == Accelerator::Kind::PIM; }
|
||||
static bool classof(const PimAccelerator*) { return true; }
|
||||
|
||||
uint64_t getVersionNumber() const final;
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Hooks for onnx-mlir-opt driver
|
||||
//===--------------------------------------------------------------------===//
|
||||
virtual void addPasses(mlir::OwningOpRef<mlir::ModuleOp> &module,
|
||||
mlir::PassManager &pm, onnx_mlir::EmissionTargetType &emissionTarget,
|
||||
std::string outputNameNoExt) const final;
|
||||
virtual void addPasses(mlir::OwningOpRef<mlir::ModuleOp>& module,
|
||||
mlir::PassManager& pm,
|
||||
onnx_mlir::EmissionTargetType& emissionTarget,
|
||||
std::string outputNameNoExt) const final;
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Hooks for onnx-mlir-opt driver
|
||||
//===--------------------------------------------------------------------===//
|
||||
virtual void registerDialects(mlir::DialectRegistry ®istry) const final;
|
||||
virtual void registerDialects(mlir::DialectRegistry& registry) const final;
|
||||
virtual void registerPasses(int optLevel) const final;
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Hooks for both onnx-mlir and onnx-mlir-opt drivers
|
||||
@@ -49,21 +49,19 @@ public:
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Hooks for onnx-to-krnl pass
|
||||
//===--------------------------------------------------------------------===//
|
||||
virtual mlir::MemRefType convertTensorTypeToMemRefType(
|
||||
const mlir::TensorType tensorType) const final;
|
||||
virtual void conversionTargetONNXToKrnl(
|
||||
mlir::ConversionTarget &target) const final;
|
||||
virtual void rewritePatternONNXToKrnl(mlir::RewritePatternSet &patterns,
|
||||
mlir::TypeConverter &typeConverter, mlir::MLIRContext *ctx) const final;
|
||||
virtual mlir::MemRefType convertTensorTypeToMemRefType(const mlir::TensorType tensorType) const final;
|
||||
virtual void conversionTargetONNXToKrnl(mlir::ConversionTarget& target) const final;
|
||||
virtual void rewritePatternONNXToKrnl(mlir::RewritePatternSet& patterns,
|
||||
mlir::TypeConverter& typeConverter,
|
||||
mlir::MLIRContext* ctx) const final;
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Hooks for krnl-to-llvm pass
|
||||
//===--------------------------------------------------------------------===//
|
||||
virtual void conversionTargetKrnlToLLVM(
|
||||
mlir::ConversionTarget &target) const final;
|
||||
virtual void rewritePatternKrnlToLLVM(mlir::RewritePatternSet &patterns,
|
||||
mlir::LLVMTypeConverter &typeConverter,
|
||||
mlir::MLIRContext *ctx) const final;
|
||||
virtual void conversionTargetKrnlToLLVM(mlir::ConversionTarget& target) const final;
|
||||
virtual void rewritePatternKrnlToLLVM(mlir::RewritePatternSet& patterns,
|
||||
mlir::LLVMTypeConverter& typeConverter,
|
||||
mlir::MLIRContext* ctx) const final;
|
||||
};
|
||||
|
||||
} // namespace accel
|
||||
|
||||
@@ -63,7 +63,8 @@ void PimBufferizationPass::runOnOperation() {
|
||||
void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const {
|
||||
MLIRContext* ctx = funcOp.getContext();
|
||||
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
|
||||
bool isAlwaysWeight = llvm::all_of(getGlobalOp->getUsers(), [](auto user) -> bool { return isa<PimCoreOp>(user); }) && !getGlobalOp->getUsers().empty() ;
|
||||
bool isAlwaysWeight = llvm::all_of(getGlobalOp->getUsers(), [](auto user) -> bool { return isa<PimCoreOp>(user); })
|
||||
&& !getGlobalOp->getUsers().empty();
|
||||
if (isAlwaysWeight) {
|
||||
auto globalMemrefOp = moduleOp.lookupSymbol<memref::GlobalOp>(getGlobalOp.getName());
|
||||
assert("Weights must be constants" && globalMemrefOp.getConstant());
|
||||
|
||||
Reference in New Issue
Block a user