add .clang-format

reformat all src
This commit is contained in:
NiccoloN
2026-02-26 19:16:42 +01:00
parent a2c31836ae
commit 810e5e75f9
32 changed files with 902 additions and 953 deletions

143
.clang-format Normal file
View File

@@ -0,0 +1,143 @@
---
AccessModifierOffset: -2
AlignAfterOpenBracket: Align
AlignArrayOfStructures: Left
AlignConsecutiveShortCaseStatements:
Enabled: true
AlignConsecutiveAssignments:
Enabled: false
AlignConsecutiveBitFields:
Enabled: false
AlignConsecutiveDeclarations:
Enabled: false
AlignConsecutiveMacros:
Enabled: false
AlignEscapedNewlines: Left
AlignOperands: AlignAfterOperator
AlignTrailingComments:
Kind: Always
OverEmptyLines: 4
AllowAllArgumentsOnNextLine: true
AllowAllParametersOfDeclarationOnNextLine: true
AllowShortBlocksOnASingleLine: Empty
AllowShortCaseExpressionOnASingleLine: true
AllowShortCaseLabelsOnASingleLine: true
AllowShortCompoundRequirementOnASingleLine: true
AllowShortEnumsOnASingleLine: false
AllowShortFunctionsOnASingleLine: All
AllowShortIfStatementsOnASingleLine: Never
AllowShortLambdasOnASingleLine: All
AllowShortLoopsOnASingleLine: false
AlwaysBreakBeforeMultilineStrings: false
BinPackArguments: false
BinPackParameters: false
BitFieldColonSpacing: Both
BraceWrapping:
BeforeElse: true
BeforeCatch: true
BeforeWhile: true
BreakBeforeBraces: Custom
BracedInitializerIndentWidth: 2
BreakAdjacentStringLiterals: true
BreakAfterAttributes: Never
BreakAfterJavaFieldAnnotations: false
BreakArrays: false
BreakBeforeBinaryOperators: NonAssignment
BreakBeforeConceptDeclarations: Always
BreakBeforeInlineASMColon: OnlyMultiline
BreakBeforeTernaryOperators: true
BreakConstructorInitializers: BeforeColon
BreakFunctionDefinitionParameters: false
BreakInheritanceList: AfterComma
BreakStringLiterals: true
BreakTemplateDeclarations: Yes
ColumnLimit: 120
CompactNamespaces: false
ConstructorInitializerIndentWidth: 0
ContinuationIndentWidth: 2
Cpp11BracedListStyle: true
DerivePointerAlignment: false
DisableFormat: false
EmptyLineAfterAccessModifier: Never
EmptyLineBeforeAccessModifier: Always
FixNamespaceComments: true
IncludeBlocks: Regroup
IncludeCategories:
- Regex: (<|")mlir(.*)(>|")
Priority: 2
- Regex: (<|")llvm(.*)(>|")
Priority: 3
- Regex: <.*>
Priority: 4
IncludeIsMainRegex: (Test)?$
IncludeIsMainSourceRegex: ""
IndentAccessModifiers: false
IndentCaseBlocks: false
IndentCaseLabels: false
IndentExternBlock: Indent
IndentGotoLabels: false
IndentPPDirectives: None
IndentRequiresClause: true
IndentWidth: 2
IndentWrappedFunctionNames: false
InsertBraces: false
InsertNewlineAtEOF: true
KeepEmptyLines:
AtEndOfFile: false
AtStartOfBlock: true
AtStartOfFile: false
LambdaBodyIndentation: Signature
Language: Cpp
LineEnding: LF
MainIncludeChar: AngleBracket
MaxEmptyLinesToKeep: 1
NamespaceIndentation: None
PackConstructorInitializers: NextLineOnly
PointerAlignment: Left
ReferenceAlignment: Pointer
RemoveBracesLLVM: true
RemoveParentheses: ReturnStatement
RemoveSemicolon: true
RequiresClausePosition: OwnLine
RequiresExpressionIndentation: OuterScope
SeparateDefinitionBlocks: Leave
ShortNamespaceLines: 4
SortIncludes: CaseSensitive
SpaceAfterCStyleCast: true
SpaceAfterLogicalNot: false
SpaceAfterTemplateKeyword: true
SpaceAroundPointerQualifiers: Before
SpaceBeforeAssignmentOperators: true
SpaceBeforeCaseColon: false
SpaceBeforeCpp11BracedList: true
SpaceBeforeCtorInitializerColon: true
SpaceBeforeInheritanceColon: true
SpaceBeforeJsonColon: false
SpaceBeforeParensOptions:
AfterControlStatements: true
AfterForeachMacros: true
AfterFunctionDeclarationName: false
AfterFunctionDefinitionName: false
AfterIfMacros: true
AfterOverloadedOperator: false
AfterPlacementOperator: false
AfterRequiresInClause: false
AfterRequiresInExpression: false
BeforeNonEmptyParentheses: false
SpaceBeforeRangeBasedForLoopColon: true
SpaceBeforeSquareBrackets: false
SpaceInEmptyBlock: false
SpacesBeforeTrailingComments: 1
SpacesInAngles: Never
SpacesInContainerLiterals: false
SpacesInLineCommentPrefix:
Minimum: 1
Maximum: 1
SpacesInParens: Never
SpacesInSquareBrackets: false
Standard: Latest
TabWidth: 4
UseTab: Never
VerilogBreakBetweenInstancePorts: true
Macros:
- LLVM_DEBUG(X)=X\n

View File

@@ -14,9 +14,9 @@ namespace onnx_mlir {
std::string getOutputDir();
void createDirectory(const std::string &directory);
void createDirectory(const std::string& directory);
void dumpModule(mlir::ModuleOp moduleOp, const std::string &name);
void dumpModule(mlir::ModuleOp moduleOp, const std::string& name);
llvm::FailureOr<mlir::Operation*>
getOtherEndOfChannel(mlir::Operation* op, bool opIsReceive, mlir::RewriterBase& rewriter);

View File

@@ -11,20 +11,21 @@
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h"
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/LogicalResult.h"
#include <cstddef>
#include <cstddef>
#include <memory>
#include <unordered_map>
#include <vector>
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
using namespace std;
@@ -40,8 +41,8 @@ namespace onnx_mlir {
*/
class Core {
public:
Core(const size_t coreId, ConversionPatternRewriter &rewriter)
: coreId(coreId), rewriter(rewriter) {}
Core(const size_t coreId, ConversionPatternRewriter& rewriter)
: coreId(coreId), rewriter(rewriter) {}
/**
* @brief Add a MVM operation to the core.
@@ -52,8 +53,7 @@ public:
* @param mvmOutType The result's shape.
* @return Value The result of the MVM operation.
*/
Value addMVM(
Value inputTile, size_t xbarIndex, size_t outputTileId, Type mvmOutType) {
Value addMVM(Value inputTile, size_t xbarIndex, size_t outputTileId, Type mvmOutType) {
// Use the inputTile as the reference location for the MVM operation.
Location loc = inputTile.getLoc();
@@ -72,8 +72,7 @@ public:
// is correct.
// Construct the MVM operation
Value result = rewriter.create<spatial::SpatWeightedMVMOp>(
loc, mvmOutType, xbarIndex, operand);
Value result = rewriter.create<spatial::SpatWeightedMVMOp>(loc, mvmOutType, xbarIndex, operand);
// Since we are within the same core and no computation can happen in
// paralllel, we can just apply a linear reduction in case we have multiple
@@ -84,8 +83,7 @@ public:
if (lastMVM != outputTileToMVM.end()) {
// MVM results should have the same type for reduction.
assert(lastMVM->second.getType() == result.getType());
result = rewriter.create<spatial::SpatVAddOp>(
loc, mvmOutType, lastMVM->second, result);
result = rewriter.create<spatial::SpatVAddOp>(loc, mvmOutType, lastMVM->second, result);
}
outputTileToMVM[outputTileId] = result;
@@ -139,16 +137,14 @@ public:
spatial::SpatWeightedCompute createWComputeOp(Location loc) {
// Get the shape of the results.
SmallVector<Type> resultTypes;
for (const auto &value : results) {
for (const auto& value : results)
resultTypes.push_back(value.getType());
}
// Create the WComputeOp, with non-remappable operands only.
wcomputeOp = rewriter.create<spatial::SpatWeightedCompute>(
loc, resultTypes, xbarWeights, operands);
wcomputeOp = rewriter.create<spatial::SpatWeightedCompute>(loc, resultTypes, xbarWeights, operands);
// Add the body to the WComputeOp.
Block *releasedBlock = block.release();
Block* releasedBlock = block.release();
wcomputeOp.getBody().push_back(releasedBlock);
// Add the `yieldOp` at the end, with the results.
@@ -164,21 +160,18 @@ public:
void remapResults() {
// Remap all the results to the WComputeOp results.
assert(resultsToRemap.size() == wcomputeOp->getNumResults());
for (size_t i = 0; i < resultsToRemap.size(); i++) {
for (size_t i = 0; i < resultsToRemap.size(); i++)
*resultsToRemap[i] = wcomputeOp.getResult(i);
}
}
void addRemappedOperands() {
// Insert the remappableOperands (which were remapped in
// `addRemappableOperand` of another Core)
for (auto remappedValue : remappableOperands) {
for (auto remappedValue : remappableOperands)
wcomputeOp->insertOperands(wcomputeOp->getNumOperands(), *remappedValue);
}
// Update the wcomputeOp operandSegmentSize
incrementWeightedComputeInputsSegmentSize(
wcomputeOp, static_cast<int>(remappableOperands.size()));
incrementWeightedComputeInputsSegmentSize(wcomputeOp, static_cast<int>(remappableOperands.size()));
}
size_t addXbarWeight(Value weight) {
@@ -199,31 +192,27 @@ public:
llvm::outs() << "Core " << coreId << ":\n";
// Print the weights
llvm::outs() << "Xbar Weights:\n";
for (auto weight : xbarWeights) {
for (auto weight : xbarWeights)
weight.dump();
}
// Print the operands
llvm::outs() << "Operands:\n";
for (auto operand : operands) {
for (auto operand : operands)
llvm::outs() << operand << "\n";
}
// Dump the body block
for (auto &op : block->getOperations()) {
for (auto& op : block->getOperations())
op.dump();
}
// Print the results
llvm::outs() << "Results:\n";
for (auto result : results) {
for (auto result : results)
llvm::outs() << result << "\n";
}
}
const size_t coreId;
private:
ConversionPatternRewriter &rewriter;
ConversionPatternRewriter& rewriter;
// Should these be set<Value> instead? But I need to keep the order
vector<Value> operands;
@@ -246,15 +235,16 @@ private:
};
struct ONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
ONNXConvOpTile(MLIRContext *ctx) : OpConversionPattern(ctx) {}
ONNXConvOpTile(MLIRContext* ctx)
: OpConversionPattern(ctx) {}
struct Producer_t {
Value value;
shared_ptr<Core> core;
};
LogicalResult matchAndRewrite(ONNXConvOp conv, ONNXConvOpAdaptor convAdaptor,
ConversionPatternRewriter &rewriter) const final {
LogicalResult
matchAndRewrite(ONNXConvOp conv, ONNXConvOpAdaptor convAdaptor, ConversionPatternRewriter& rewriter) const final {
ShapedType xShape = mlir::cast<ShapedType>(convAdaptor.getX().getType());
ShapedType wShape = mlir::cast<ShapedType>(convAdaptor.getW().getType());
ShapedType bShape = mlir::cast<ShapedType>(convAdaptor.getB().getType());
@@ -264,11 +254,9 @@ struct ONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
unpackOptionalPairVector(conv.getStrides(), stride_x, stride_y);
unpackOptionalPairVector(conv.getDilations(), dilation_x, dilation_y);
auto padUnpackError =
unpackOptionalPadsVector(convAdaptor.getPads(), pad_x, pad_y);
if (padUnpackError.has_value()) {
auto padUnpackError = unpackOptionalPadsVector(convAdaptor.getPads(), pad_x, pad_y);
if (padUnpackError.has_value())
return rewriter.notifyMatchFailure(conv, padUnpackError.value());
}
// TODO: Pad value at beginning and end of each dimension could be
// different. We should handle this case.
@@ -296,11 +284,9 @@ struct ONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
Location loc = conv.getLoc();
size_t inputTileCount =
ceilIntegerDivide(GET_IMAGE_CHANNEL(xShape), crossbarSize.getValue());
size_t inputTileCount = ceilIntegerDivide(GET_IMAGE_CHANNEL(xShape), crossbarSize.getValue());
size_t inputTileRemainder = GET_IMAGE_CHANNEL(xShape) % crossbarSize;
size_t outputTileCount =
ceilIntegerDivide(GET_IMAGE_CHANNEL(yShape), crossbarSize.getValue());
size_t outputTileCount = ceilIntegerDivide(GET_IMAGE_CHANNEL(yShape), crossbarSize.getValue());
size_t outputTileRemainder = GET_IMAGE_CHANNEL(yShape) % crossbarSize;
// Tile the input tensor
@@ -310,22 +296,20 @@ struct ONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
// c. Pixel `y` position
// For example: inputTiles[channelTile][x][y]
// Example complete input tensor: tensor<1x3x6x6xf32> (NxCxWxH)
SmallVector<SmallVector<SmallVector<Value>>> inputTiles(inputTileCount,
SmallVector<SmallVector<Value>>(input_w, SmallVector<Value>(input_h)));
SmallVector<SmallVector<SmallVector<Value>>> inputTiles(
inputTileCount, SmallVector<SmallVector<Value>>(input_w, SmallVector<Value>(input_h)));
auto resolveErrorOpt = resolveImgInputTiles(convAdaptor.getX(), inputTiles,
inputTileCount, inputTileRemainder, input_h, input_h, rewriter);
if (resolveErrorOpt.has_value()) {
auto resolveErrorOpt = resolveImgInputTiles(
convAdaptor.getX(), inputTiles, inputTileCount, inputTileRemainder, input_h, input_h, rewriter);
if (resolveErrorOpt.has_value())
return rewriter.notifyMatchFailure(conv, *resolveErrorOpt);
}
SmallVector<OpFoldResult> strides =
SmallVector<OpFoldResult>(4, rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> offsets =
SmallVector<OpFoldResult>(4, rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> sizes = SmallVector<OpFoldResult>{
rewriter.getIndexAttr(1), rewriter.getIndexAttr(crossbarSize),
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> strides = SmallVector<OpFoldResult>(4, rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> offsets = SmallVector<OpFoldResult>(4, rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> sizes = SmallVector<OpFoldResult> {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(crossbarSize),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1)};
// Tile the weight tensor
// Weight tiles need to be indexed by:
@@ -336,31 +320,30 @@ struct ONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
// For example: weightTiles[filterTile][channelTile][x][y]
// Example complete weight tensor: tensor<32x3x3x3xf32> (FxCxWxH)
SmallVector<SmallVector<SmallVector<SmallVector<Value>>>> weightTiles(
outputTileCount,
SmallVector<SmallVector<SmallVector<Value>>>(inputTileCount,
SmallVector<SmallVector<Value>>(krn_w, SmallVector<Value>(krn_h))));
outputTileCount,
SmallVector<SmallVector<SmallVector<Value>>>(inputTileCount,
SmallVector<SmallVector<Value>>(krn_w, SmallVector<Value>(krn_h))));
strides = SmallVector<OpFoldResult>(4, rewriter.getIndexAttr(1));
offsets = SmallVector<OpFoldResult>(4, rewriter.getIndexAttr(0));
sizes = {rewriter.getIndexAttr(crossbarSize),
rewriter.getIndexAttr(crossbarSize), rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1)};
rewriter.getIndexAttr(crossbarSize),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1)};
for (size_t i = 0; i < outputTileCount; i++) {
if (i == outputTileCount - 1 && outputTileRemainder != 0) {
if (i == outputTileCount - 1 && outputTileRemainder != 0)
sizes[0] = rewriter.getIndexAttr(outputTileRemainder);
}
sizes[1] = rewriter.getIndexAttr(crossbarSize);
offsets[0] = rewriter.getIndexAttr(i * crossbarSize);
for (size_t j = 0; j < inputTileCount; j++) {
if (j == inputTileCount - 1 && inputTileRemainder != 0) {
if (j == inputTileCount - 1 && inputTileRemainder != 0)
sizes[1] = rewriter.getIndexAttr(inputTileRemainder);
}
for (size_t x = 0; x < krn_w; x++) {
for (size_t y = 0; y < krn_h; y++) {
offsets[1] = rewriter.getIndexAttr(j * crossbarSize);
offsets[2] = rewriter.getIndexAttr(x);
offsets[3] = rewriter.getIndexAttr(y);
weightTiles[i][j][x][y] = rewriter.create<tensor::ExtractSliceOp>(
loc, convAdaptor.getW(), offsets, sizes, strides);
weightTiles[i][j][x][y] =
rewriter.create<tensor::ExtractSliceOp>(loc, convAdaptor.getW(), offsets, sizes, strides);
}
}
}
@@ -379,56 +362,45 @@ struct ONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
// For example: outputTiles[filterTile][x][y]
// Example complete output tensor: tensor<1x32x3x3xf32> (NxFxWxH)
SmallVector<SmallVector<SmallVector<shared_ptr<Value>>>> outputTiles(
outputTileCount,
SmallVector<SmallVector<shared_ptr<Value>>>(
output_w, SmallVector<shared_ptr<Value>>(output_h, nullptr)));
outputTileCount,
SmallVector<SmallVector<shared_ptr<Value>>>(output_w, SmallVector<shared_ptr<Value>>(output_h, nullptr)));
size_t replicationFactor;
if (!conv->hasAttr(REPLICATION_ATTR_NAME)) {
if (!conv->hasAttr(REPLICATION_ATTR_NAME))
replicationFactor = 1;
} else {
replicationFactor =
conv->getAttrOfType<IntegerAttr>(REPLICATION_ATTR_NAME).getInt();
}
else
replicationFactor = conv->getAttrOfType<IntegerAttr>(REPLICATION_ATTR_NAME).getInt();
// producers[outTile][out_x][out_y][producerIndex]
vector<vector<vector<vector<Producer_t>>>> producers =
vector<vector<vector<vector<Producer_t>>>>(outputTileCount,
vector<vector<vector<Producer_t>>>(output_w,
vector<vector<Producer_t>>(output_h, vector<Producer_t>())));
vector<vector<vector<vector<Producer_t>>>> producers = vector<vector<vector<vector<Producer_t>>>>(
outputTileCount,
vector<vector<vector<Producer_t>>>(output_w, vector<vector<Producer_t>>(output_h, vector<Producer_t>())));
// Schedule in cores
size_t coreId = 0;
vector<shared_ptr<Core>> curCores(replicationFactor);
for (size_t i = 0; i < replicationFactor; i++) {
for (size_t i = 0; i < replicationFactor; i++)
curCores[i] = make_shared<Core>(coreId++, rewriter);
}
vector<shared_ptr<Core>> cores;
const size_t replicationSliceSize =
ceilIntegerDivide(input_w, replicationFactor);
const size_t replicationSliceSize = ceilIntegerDivide(input_w, replicationFactor);
for (size_t krn_x = 0; krn_x < krn_h; krn_x++) {
for (size_t krn_y = 0; krn_y < krn_w; krn_y++) {
RankedTensorType mvmOutType =
RankedTensorType::get({1, static_cast<long>(crossbarSize), 1, 1},
bShape.getElementType());
RankedTensorType::get({1, static_cast<long>(crossbarSize), 1, 1}, bShape.getElementType());
for (size_t outTile = 0; outTile < outputTileCount; outTile++) {
if (outTile == outputTileCount - 1 && outputTileRemainder != 0) {
mvmOutType = mvmOutType.clone(
{1, static_cast<long>(outputTileRemainder), 1, 1});
}
if (outTile == outputTileCount - 1 && outputTileRemainder != 0)
mvmOutType = mvmOutType.clone({1, static_cast<long>(outputTileRemainder), 1, 1});
for (size_t inTile = 0; inTile < inputTileCount; inTile++) {
vector<size_t> xbarIndexes(replicationFactor);
for (size_t i = 0; i < replicationFactor; i++) {
xbarIndexes[i] = curCores[i]->addXbarWeight(
weightTiles[outTile][inTile][krn_x][krn_y]);
}
for (size_t i = 0; i < replicationFactor; i++)
xbarIndexes[i] = curCores[i]->addXbarWeight(weightTiles[outTile][inTile][krn_x][krn_y]);
size_t out_x = 0;
for (size_t in_x = 0; in_x < input_w; in_x += stride_x) {
@@ -443,25 +415,20 @@ struct ONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
for (size_t in_y = 0; in_y < input_h; in_y += stride_y) {
// Adjust the input based on the kernel
int actual_in_x = in_x - ((int)krn_w / 2) + krn_x * dilation_x;
int actual_in_y = in_y - ((int)krn_h / 2) + krn_y * dilation_y;
int actual_in_x = in_x - ((int) krn_w / 2) + krn_x * dilation_x;
int actual_in_y = in_y - ((int) krn_h / 2) + krn_y * dilation_y;
// Check if we are within the input image
if (verifyWithinBoundsAndPaddings(input_w, input_h, actual_in_x,
actual_in_y, pad_x, pad_y)
.failed()) {
if (verifyWithinBoundsAndPaddings(input_w, input_h, actual_in_x, actual_in_y, pad_x, pad_y).failed()) {
out_y++;
continue;
}
size_t outTileId =
outTile * output_w * output_h + out_x * output_h + out_y;
size_t outTileId = outTile * output_w * output_h + out_x * output_h + out_y;
auto mvm = curCores[coreIndex]->addMVM(
inputTiles[inTile][actual_in_x][actual_in_y],
xbarIndexes[coreIndex], outTileId, mvmOutType);
inputTiles[inTile][actual_in_x][actual_in_y], xbarIndexes[coreIndex], outTileId, mvmOutType);
producers[outTile][out_x][out_y].push_back(
{mvm, curCores[coreIndex]});
producers[outTile][out_x][out_y].push_back({mvm, curCores[coreIndex]});
out_y++;
}
@@ -481,11 +448,9 @@ struct ONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
}
}
for (auto &curCore : curCores) {
if (curCore->isCoreEmpty() == false) {
for (auto& curCore : curCores)
if (curCore->isCoreEmpty() == false)
cores.emplace_back(std::move(curCore));
}
}
curCores.clear();
// Now, do the reduction of each output pixel tile
for (size_t outTile = 0; outTile < outputTileCount; outTile++) {
@@ -497,9 +462,8 @@ struct ONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
// core.
std::unordered_map<size_t, Producer_t> withinCoreReducedProducers;
for (auto producer : producers[outTile][out_x][out_y]) {
for (auto producer : producers[outTile][out_x][out_y])
withinCoreReducedProducers[producer.core->coreId] = producer;
}
// Now, we need to apply inter-core reduction
@@ -509,8 +473,7 @@ struct ONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
auto singleProducer = withinCoreReducedProducers.begin()->second;
// Use last producer as the final result
auto reducedValue =
singleProducer.core->makeResultRemappable(singleProducer.value);
auto reducedValue = singleProducer.core->makeResultRemappable(singleProducer.value);
outputTiles[outTile][out_x][out_y] = reducedValue;
continue;
}
@@ -535,9 +498,9 @@ struct ONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
auto lastProducerCoreId = lastProducer.core->coreId;
auto curProducerCoreId = curProducer.core->coreId;
assert(lastProducerCoreId != curProducerCoreId &&
"We should have already applied within-core reduction, how "
"could we have same cores here?");
assert(lastProducerCoreId != curProducerCoreId
&& "We should have already applied within-core reduction, how "
"could we have same cores here?");
// Sort the cores by coreId
if (curProducerCoreId < lastProducerCoreId) {
@@ -545,7 +508,8 @@ struct ONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
core1Value = curProducer.value;
core2 = lastProducer.core;
core2Value = lastProducer.value;
} else {
}
else {
core1 = lastProducer.core;
core1Value = lastProducer.value;
core2 = curProducer.core;
@@ -556,9 +520,8 @@ struct ONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
auto secondCoreBlockArg = core2->addRemappableOperand(newCoreRes);
rewriter.setInsertionPointAfterValue(core2Value);
Value vaddRes =
rewriter.create<spatial::SpatVAddOp>(core2Value.getLoc(),
core2Value.getType(), core2Value, secondCoreBlockArg);
Value vaddRes = rewriter.create<spatial::SpatVAddOp>(
core2Value.getLoc(), core2Value.getType(), core2Value, secondCoreBlockArg);
lastProducer = {vaddRes, core2};
@@ -568,8 +531,7 @@ struct ONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
// TODO: Add the bias and apply mapping (if present)
// Use last producer as the final result
auto reducedValue =
lastProducer.core->makeResultRemappable(lastProducer.value);
auto reducedValue = lastProducer.core->makeResultRemappable(lastProducer.value);
outputTiles[outTile][out_x][out_y] = reducedValue;
}
}
@@ -578,15 +540,14 @@ struct ONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
// Now, we need to turn the cores into a spatial::SpatWeightedCompute.
rewriter.setInsertionPointAfter(conv);
spatial::SpatWeightedCompute lastWComputeOp;
for (auto &core : cores) {
for (auto& core : cores) {
lastWComputeOp = core->createWComputeOp(loc);
core->remapResults();
rewriter.setInsertionPointAfter(lastWComputeOp);
}
for (auto &core : cores) {
for (auto& core : cores)
core->addRemappedOperands();
}
// Set the insertion point after the last WComputeOp.
rewriter.setInsertionPointAfter(lastWComputeOp);
@@ -597,8 +558,7 @@ struct ONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
for (size_t outTile = 0; outTile < outputTileCount; outTile++)
tilesToConcat.push_back(*outputTiles[outTile][outX][outY]);
Value outputImage = rewriter.create<spatial::SpatImgConcatOp>(
loc, conv.getY().getType(), tilesToConcat);
Value outputImage = rewriter.create<spatial::SpatImgConcatOp>(loc, conv.getY().getType(), tilesToConcat);
// Value outputImage =
// createImgConcatOp(outputTiles, rewriter, loc, Y.getType());
@@ -616,9 +576,8 @@ struct ONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
}
};
void populateTilingConvOpPattern(
RewritePatternSet &patterns, MLIRContext *ctx) {
void populateTilingConvOpPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<ONNXConvOpTile>(ctx);
}
} // namespace onnx_mlir
} // namespace onnx_mlir

View File

@@ -1,6 +1,3 @@
#include "Compiler/PimCompilerOptions.hpp"
#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "Dialect/Spatial/SpatialOps.hpp"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
@@ -8,13 +5,19 @@
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/WeightSubdivider.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "llvm/ADT/SmallVector.h"
#include <algorithm>
#include <cstddef>
#include <unistd.h>
#include "Compiler/PimCompilerOptions.hpp"
#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/WeightSubdivider.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
using namespace std;
@@ -27,10 +30,11 @@ namespace onnx_mlir {
* output tensor.
*/
struct ExperimentalONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
ExperimentalONNXConvOpTile(MLIRContext *ctx) : OpConversionPattern(ctx) {}
ExperimentalONNXConvOpTile(MLIRContext* ctx)
: OpConversionPattern(ctx) {}
LogicalResult matchAndRewrite(ONNXConvOp conv, ONNXConvOpAdaptor convAdaptor,
ConversionPatternRewriter &rewriter) const final {
LogicalResult
matchAndRewrite(ONNXConvOp conv, ONNXConvOpAdaptor convAdaptor, ConversionPatternRewriter& rewriter) const final {
// --------------------------------- //
// --- READ OPERATION PARAMETERS --- //
@@ -46,12 +50,12 @@ struct ExperimentalONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
ShapedType weightsType = cast<ShapedType>(convAdaptor.getW().getType());
// TODO: Address bigger batches.
assert(GET_IMAGE_N(inputType) == 1 && "Batch size must be 1"
"for convolution.");
assert(GET_IMAGE_N(inputType) == 1
&& "Batch size must be 1"
"for convolution.");
// TODO: Address replication.
assert(coresCount.getValue() == -1 &&
"Replication is not yet supported for convolution.");
assert(coresCount.getValue() == -1 && "Replication is not yet supported for convolution.");
// TODO: Address bias addition.
@@ -97,11 +101,10 @@ struct ExperimentalONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
long inputTile = it;
// Create the slicing sizes.
SmallVector<OpFoldResult> slicingSizes{
/* 0 */ rewriter.getIndexAttr(crossbarHeight),
/* 1 */ rewriter.getIndexAttr(crossbarWidth),
/* 2 */ rewriter.getIndexAttr(1),
/* 3 */ rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> slicingSizes {/* 0 */ rewriter.getIndexAttr(crossbarHeight),
/* 1 */ rewriter.getIndexAttr(crossbarWidth),
/* 2 */ rewriter.getIndexAttr(1),
/* 3 */ rewriter.getIndexAttr(1)};
// - Slicing along the filter x position.
// - Slicing along the filter y position.
@@ -109,16 +112,14 @@ struct ExperimentalONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
for (size_t filterY = 0; filterY < kernelHeight; ++filterY) {
// Create the slicing offsets.
SmallVector<OpFoldResult> slicingOffsets{
/* 0 */ rewriter.getIndexAttr(outputTile * crossbarSize),
/* 1 */ rewriter.getIndexAttr(inputTile * crossbarSize),
/* 2 */ rewriter.getIndexAttr(filterX),
/* 3 */ rewriter.getIndexAttr(filterY)};
SmallVector<OpFoldResult> slicingOffsets {/* 0 */ rewriter.getIndexAttr(outputTile * crossbarSize),
/* 1 */ rewriter.getIndexAttr(inputTile * crossbarSize),
/* 2 */ rewriter.getIndexAttr(filterX),
/* 3 */ rewriter.getIndexAttr(filterY)};
// Create the slice extraction operation.
auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
conv.getLoc(), convAdaptor.getW(), slicingOffsets, slicingSizes,
slicingStrides);
conv.getLoc(), convAdaptor.getW(), slicingOffsets, slicingSizes, slicingStrides);
// Add a note to the extractSliceOp, with the filterX and filterY.
weightsGroups[inputTile][outputTile].push_back(extractSliceOp);
@@ -150,8 +151,7 @@ struct ExperimentalONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
// -------------------------------- //
// Get the next group of weights for the compute unit.
SmallVector<TaggedWeights> weightsGroups =
weightSubdivider.popGroups(crossbarCountInCore.getValue());
SmallVector<TaggedWeights> weightsGroups = weightSubdivider.popGroups(crossbarCountInCore.getValue());
SmallVector<Value> computeWeights;
SmallVector<Value> computeOperands;
@@ -168,15 +168,13 @@ struct ExperimentalONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
// Iterate over all weights groups for this compute unit.
map<long, Value> localSlices; // WRT the current compute unit.
for (auto group : weightsGroups) {
for (Value weight : group.weights) {
for (Value weight : group.weights)
computeWeights.push_back(weight);
}
// There might be multiple weight groups for the same input tile, so if
// we've already added the input tile, skip it.
if (localSlices.find(group.inputTile) != localSlices.end()) {
if (localSlices.find(group.inputTile) != localSlices.end())
continue;
}
// We might have already sliced the input tensor for some other compute
// unit, so if we have, reuse the slicing operation without creating a
@@ -188,26 +186,21 @@ struct ExperimentalONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
}
// Create the input tensor slicing offsets.
SmallVector<OpFoldResult> slicingOffsets{
/* 0 */ rewriter.getIndexAttr(0), // No offset along the batch axis.
/* 1 */ rewriter.getIndexAttr(group.inputTile * crossbarSize),
/* 2 */ rewriter.getIndexAttr(0),
/* 3 */ rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> slicingOffsets {/* 0 */ rewriter.getIndexAttr(0), // No offset along the batch axis.
/* 1 */ rewriter.getIndexAttr(group.inputTile * crossbarSize),
/* 2 */ rewriter.getIndexAttr(0),
/* 3 */ rewriter.getIndexAttr(0)};
// Create the input tensor slicing sizes.
size_t tilingSize = group.inputTile == inputTileCount.quot
? inputTileCount.rem
: crossbarSize;
SmallVector<OpFoldResult> slicingSizes{
/* 0 */ rewriter.getIndexAttr(1), // Batch size is always 1.
/* 1 */ rewriter.getIndexAttr(tilingSize),
/* 2 */ rewriter.getIndexAttr(GET_IMAGE_WIDTH(inputType)),
/* 3 */ rewriter.getIndexAttr(GET_IMAGE_HEIGHT(inputType))};
size_t tilingSize = group.inputTile == inputTileCount.quot ? inputTileCount.rem : crossbarSize;
SmallVector<OpFoldResult> slicingSizes {/* 0 */ rewriter.getIndexAttr(1), // Batch size is always 1.
/* 1 */ rewriter.getIndexAttr(tilingSize),
/* 2 */ rewriter.getIndexAttr(GET_IMAGE_WIDTH(inputType)),
/* 3 */ rewriter.getIndexAttr(GET_IMAGE_HEIGHT(inputType))};
// Create the slice extraction operation.
auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
conv.getLoc(), convAdaptor.getX(), slicingOffsets, slicingSizes,
slicingStrides);
conv.getLoc(), convAdaptor.getX(), slicingOffsets, slicingSizes, slicingStrides);
computeOperands.push_back(extractSliceOp);
@@ -229,36 +222,28 @@ struct ExperimentalONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
// There might be multiple weight groups for the same output tile, so if
// we've already added the output tile, skip it.
if (outputTileIndices.find(group.outputTile) !=
outputTileIndices.end()) {
if (outputTileIndices.find(group.outputTile) != outputTileIndices.end())
continue;
}
// Additionally, after adding the input slices as operands, also add any
// compatible partial results from previous compute units.
if (globalPartialResults.find(group.outputTile) !=
globalPartialResults.end()) {
if (globalPartialResults.find(group.outputTile) != globalPartialResults.end()) {
computeOperands.push_back(globalPartialResults[group.outputTile]);
reductionTileIndices[group.outputTile] = computeOperands.size() - 1;
}
// Define the output shape for this group.
long outputTileSize = group.outputTile == outputTileCount.quot
? outputTileCount.rem
: crossbarSize;
long outputTileSize = group.outputTile == outputTileCount.quot ? outputTileCount.rem : crossbarSize;
// TODO: Address non-same padding.
SmallVector<int64_t> outputShapeArray{
/* 0 */ 1, // Batch size is always 1.
/* 1 */ outputTileSize,
/* 2 */ GET_IMAGE_WIDTH(outputType), // Same padding assumed.
/* 3 */ GET_IMAGE_HEIGHT(outputType)};
SmallVector<int64_t> outputShapeArray {/* 0 */ 1, // Batch size is always 1.
/* 1 */ outputTileSize,
/* 2 */ GET_IMAGE_WIDTH(outputType), // Same padding assumed.
/* 3 */ GET_IMAGE_HEIGHT(outputType)};
auto elementType =
dyn_cast<RankedTensorType>(conv.getY().getType()).getElementType();
auto elementType = dyn_cast<RankedTensorType>(conv.getY().getType()).getElementType();
computeOutputType.push_back(
RankedTensorType::get(outputShapeArray, elementType));
computeOutputType.push_back(RankedTensorType::get(outputShapeArray, elementType));
outputTileIndices[group.outputTile] = computeOutputType.size() - 1;
}
@@ -268,43 +253,36 @@ struct ExperimentalONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
// ----------------------------- //
// Create the compute unit.
spatial::SpatWeightedCompute currentCompute =
rewriter.create<spatial::SpatWeightedCompute>(conv.getLoc(),
computeOutputType, computeWeights, computeOperands);
spatial::SpatWeightedCompute currentCompute = rewriter.create<spatial::SpatWeightedCompute>(
conv.getLoc(), computeOutputType, computeWeights, computeOperands);
// Create a new block for the compute unit and add the operands.
Block *block = rewriter.createBlock(&currentCompute.getRegion());
Block* block = rewriter.createBlock(&currentCompute.getRegion());
rewriter.setInsertionPointToStart(block);
for (Value operand : computeOperands) {
for (Value operand : computeOperands)
block->addArgument(operand.getType(), conv->getLoc());
}
// Initialize a map of local partial results.
map<long, Value> localPartialResults; // WRT the current compute unit.
// If we have any reduction tiles, add them to the local partial results.
for (auto reductionTileIndex : reductionTileIndices) {
localPartialResults[reductionTileIndex.first] =
block->getArgument(reductionTileIndex.second);
}
for (auto reductionTileIndex : reductionTileIndices)
localPartialResults[reductionTileIndex.first] = block->getArgument(reductionTileIndex.second);
// Add all the applyFilters operations to the block.
for (TaggedWeights group : weightsGroups) {
// Get the outputType for this group.
Type outputType =
computeOutputType[outputTileIndices[group.outputTile]];
Type outputType = computeOutputType[outputTileIndices[group.outputTile]];
// Create an apply filters operation.
BlockArgument blockArgument =
block->getArgument(inputTileIndices[group.inputTile]);
BlockArgument blockArgument = block->getArgument(inputTileIndices[group.inputTile]);
// The list of weight indices is group.startingCrossbarIndex + 0, 1, 2,
// ... As many weights as the size of group.weights.
SmallVector<long> weightIndices;
for (size_t i = 0; i < group.weights.size(); ++i) {
for (size_t i = 0; i < group.weights.size(); ++i)
weightIndices.push_back(group.startingCrossbarIndex + i);
}
SmallVector<int64_t> xKerPos;
SmallVector<int64_t> yKerPos;
@@ -323,16 +301,14 @@ struct ExperimentalONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
ArrayAttr xKerPosAttr = rewriter.getI64ArrayAttr(xKerPos);
ArrayAttr yKerPosAttr = rewriter.getI64ArrayAttr(yKerPos);
Value result =
rewriter.create<spatial::SpatApplyFiltersOp>(conv.getLoc(), outputType,
weightIndicesAttr, xKerPosAttr, yKerPosAttr, blockArgument);
Value result = rewriter.create<spatial::SpatApplyFiltersOp>(
conv.getLoc(), outputType, weightIndicesAttr, xKerPosAttr, yKerPosAttr, blockArgument);
// Perform local reduction if necessary.
if (localPartialResults.find(group.outputTile) !=
localPartialResults.end()) {
if (localPartialResults.find(group.outputTile) != localPartialResults.end()) {
result = rewriter.create<spatial::SpatVAddOp>(conv.getLoc(),
result.getType(), localPartialResults[group.outputTile], result);
result = rewriter.create<spatial::SpatVAddOp>(
conv.getLoc(), result.getType(), localPartialResults[group.outputTile], result);
}
// Update the partial results map.
@@ -385,22 +361,18 @@ struct ExperimentalONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
// Turn the values into a SmallVector.
SmallVector<Value> outputValues;
for (long i = 0; i < outputTileCount.quot + (outputTileCount.rem > 0);
++i) {
for (long i = 0; i < outputTileCount.quot + (outputTileCount.rem > 0); ++i)
outputValues.push_back(globalPartialResults[i]);
}
// Assert that the number of output values is correct.
assert(outputValues.size() > 0 &&
"No output values were generated for the convolution.");
assert(outputValues.size() > 0 && "No output values were generated for the convolution.");
// If the conv's user is a ReLU...
if (conv->hasOneUse()) {
Operation *user = *conv->getUsers().begin();
Operation* user = *conv->getUsers().begin();
if (auto relu = dyn_cast<ONNXReluOp>(user)) {
// ...then we can just replace the ReLU with the concatenation.
rewriter.replaceOp(relu,
rewriter.create<tensor::ConcatOp>(conv.getLoc(), 1, outputValues));
rewriter.replaceOp(relu, rewriter.create<tensor::ConcatOp>(conv.getLoc(), 1, outputValues));
// And erase the convolution.
rewriter.eraseOp(conv);
@@ -409,8 +381,7 @@ struct ExperimentalONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
}
// Return the final output.
rewriter.replaceOp(conv,
rewriter.create<tensor::ConcatOp>(conv.getLoc(), 1, outputValues));
rewriter.replaceOp(conv, rewriter.create<tensor::ConcatOp>(conv.getLoc(), 1, outputValues));
return success();
}
@@ -422,9 +393,8 @@ struct ExperimentalONNXConvOpTile : public OpConversionPattern<ONNXConvOp> {
* @param patterns The pattern set to populate.
* @param ctx The MLIR context.
*/
void populateExperimentalTilingConvOpPattern(
RewritePatternSet &patterns, MLIRContext *ctx) {
void populateExperimentalTilingConvOpPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<ExperimentalONNXConvOpTile>(ctx);
}
} // namespace onnx_mlir
} // namespace onnx_mlir

View File

@@ -1,24 +1,25 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Transforms/DialectConversion.h"
#include <cstdlib>
#include "Compiler/PimCompilerOptions.hpp"
#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp"
#include "Conversion/ONNXToSpatial/Utils/WeightSubdivider.hpp"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Transforms/DialectConversion.h"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include <cstdlib>
using namespace mlir;
using namespace std;
namespace onnx_mlir {
struct ExperimentalGemmConversionPattern
: public OpConversionPattern<ONNXGemmOp> {
ExperimentalGemmConversionPattern(MLIRContext *ctx)
: OpConversionPattern(ctx) {}
struct ExperimentalGemmConversionPattern : public OpConversionPattern<ONNXGemmOp> {
ExperimentalGemmConversionPattern(MLIRContext* ctx)
: OpConversionPattern(ctx) {}
LogicalResult matchAndRewrite(ONNXGemmOp gemmOp, ONNXGemmOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
LogicalResult
matchAndRewrite(ONNXGemmOp gemmOp, ONNXGemmOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final {
// --------------------------------- //
// --- READ OPERATION PARAMETERS --- //
@@ -34,17 +35,14 @@ struct ExperimentalGemmConversionPattern
ShapedType matrixType = cast<ShapedType>(adaptor.getB().getType());
// TODO: Address bigger batches.
assert(inputType.getShape()[0] == 1 &&
"Only batch size of 1 is supported for GEMM.");
assert(inputType.getShape()[0] == 1 && "Only batch size of 1 is supported for GEMM.");
// TODO: Address replication.
assert(coresCount.getValue() == -1 &&
"Replication is not yet supported for GEMM.");
assert(coresCount.getValue() == -1 && "Replication is not yet supported for GEMM.");
// TODO: Address bias addition.
assert(inputType.getShape()[1] == matrixType.getShape()[0] &&
"Input tile size must match the matrix's row size.");
assert(inputType.getShape()[1] == matrixType.getShape()[0] && "Input tile size must match the matrix's row size.");
ldiv_t inputTileCount = div(inputType.getShape()[1], crossbarSize);
ldiv_t outputTileCount = div(outputType.getShape()[1], crossbarSize);
@@ -88,11 +86,10 @@ struct ExperimentalGemmConversionPattern
long inputTile = it;
// Create the slicing sizes.
SmallVector<OpFoldResult> slicingSizes{
/* 0 */ rewriter.getIndexAttr(crossbarHeight),
/* 1 */ rewriter.getIndexAttr(crossbarWidth),
/* 2 */ /* rewriter.getIndexAttr(1), */
/* 3 */ /* rewriter.getIndexAttr(1) */};
SmallVector<OpFoldResult> slicingSizes {/* 0 */ rewriter.getIndexAttr(crossbarHeight),
/* 1 */ rewriter.getIndexAttr(crossbarWidth),
/* 2 */ /* rewriter.getIndexAttr(1), */
/* 3 */ /* rewriter.getIndexAttr(1) */};
// - Slicing along the filter x position.
// - Slicing along the filter y position.
@@ -100,16 +97,14 @@ struct ExperimentalGemmConversionPattern
for (size_t filterY = 0; filterY < kernelHeight; ++filterY) {
// Create the slicing offsets.
SmallVector<OpFoldResult> slicingOffsets{
/* 0 */ rewriter.getIndexAttr(outputTile * crossbarSize),
/* 1 */ rewriter.getIndexAttr(inputTile * crossbarSize),
/* 2 */ /* rewriter.getIndexAttr(filterX), */
/* 3 */ /* rewriter.getIndexAttr(filterY) */};
SmallVector<OpFoldResult> slicingOffsets {/* 0 */ rewriter.getIndexAttr(outputTile * crossbarSize),
/* 1 */ rewriter.getIndexAttr(inputTile * crossbarSize),
/* 2 */ /* rewriter.getIndexAttr(filterX), */
/* 3 */ /* rewriter.getIndexAttr(filterY) */};
// Create the slice extraction operation.
auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
gemmOp.getLoc(), adaptor.getB(), slicingOffsets, slicingSizes,
slicingStrides);
gemmOp.getLoc(), adaptor.getB(), slicingOffsets, slicingSizes, slicingStrides);
// Add a note to the extractSliceOp, with the filterX and filterY.
weightsGroups[inputTile][outputTile].push_back(extractSliceOp);
@@ -141,8 +136,7 @@ struct ExperimentalGemmConversionPattern
// -------------------------------- //
// Get the next group of weights for the compute unit.
SmallVector<TaggedWeights> weightsGroups =
weightSubdivider.popGroups(crossbarCountInCore.getValue());
SmallVector<TaggedWeights> weightsGroups = weightSubdivider.popGroups(crossbarCountInCore.getValue());
SmallVector<Value> computeWeights;
SmallVector<Value> computeOperands;
@@ -159,15 +153,13 @@ struct ExperimentalGemmConversionPattern
// Iterate over all weights groups for this compute unit.
map<long, Value> localSlices; // WRT the current compute unit.
for (auto group : weightsGroups) {
for (Value weight : group.weights) {
for (Value weight : group.weights)
computeWeights.push_back(weight);
}
// There might be multiple weight groups for the same input tile, so if
// we've already added the input tile, skip it.
if (localSlices.find(group.inputTile) != localSlices.end()) {
if (localSlices.find(group.inputTile) != localSlices.end())
continue;
}
// We might have already sliced the input tensor for some other compute
// unit, so if we have, reuse the slicing operation without creating a
@@ -179,26 +171,21 @@ struct ExperimentalGemmConversionPattern
}
// Create the input tensor slicing offsets.
SmallVector<OpFoldResult> slicingOffsets{
/* 0 */ rewriter.getIndexAttr(0), // No offset along the batch axis.
/* 1 */ rewriter.getIndexAttr(group.inputTile * crossbarSize),
/* 2 */ /* rewriter.getIndexAttr(0), */
/* 3 */ /* rewriter.getIndexAttr(0) */};
SmallVector<OpFoldResult> slicingOffsets {/* 0 */ rewriter.getIndexAttr(0), // No offset along the batch axis.
/* 1 */ rewriter.getIndexAttr(group.inputTile * crossbarSize),
/* 2 */ /* rewriter.getIndexAttr(0), */
/* 3 */ /* rewriter.getIndexAttr(0) */};
// Create the input tensor slicing sizes.
size_t tilingSize = group.inputTile == inputTileCount.quot
? inputTileCount.rem
: crossbarSize;
SmallVector<OpFoldResult> slicingSizes{
/* 0 */ rewriter.getIndexAttr(1), // Batch size is always 1.
/* 1 */ rewriter.getIndexAttr(tilingSize),
/* 2 */ /* rewriter.getIndexAttr(GET_IMAGE_WIDTH(inputType)), */
/* 3 */ /* rewriter.getIndexAttr(GET_IMAGE_HEIGHT(inputType)) */};
size_t tilingSize = group.inputTile == inputTileCount.quot ? inputTileCount.rem : crossbarSize;
SmallVector<OpFoldResult> slicingSizes {/* 0 */ rewriter.getIndexAttr(1), // Batch size is always 1.
/* 1 */ rewriter.getIndexAttr(tilingSize),
/* 2 */ /* rewriter.getIndexAttr(GET_IMAGE_WIDTH(inputType)), */
/* 3 */ /* rewriter.getIndexAttr(GET_IMAGE_HEIGHT(inputType)) */};
// Create the slice extraction operation.
auto extractSliceOp =
rewriter.create<tensor::ExtractSliceOp>(gemmOp.getLoc(),
adaptor.getA(), slicingOffsets, slicingSizes, slicingStrides);
auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
gemmOp.getLoc(), adaptor.getA(), slicingOffsets, slicingSizes, slicingStrides);
computeOperands.push_back(extractSliceOp);
@@ -220,36 +207,28 @@ struct ExperimentalGemmConversionPattern
// There might be multiple weight groups for the same output tile, so if
// we've already added the output tile, skip it.
if (outputTileIndices.find(group.outputTile) !=
outputTileIndices.end()) {
if (outputTileIndices.find(group.outputTile) != outputTileIndices.end())
continue;
}
// Additionally, after adding the input slices as operands, also add any
// compatible partial results from previous compute units.
if (globalPartialResults.find(group.outputTile) !=
globalPartialResults.end()) {
if (globalPartialResults.find(group.outputTile) != globalPartialResults.end()) {
computeOperands.push_back(globalPartialResults[group.outputTile]);
reductionTileIndices[group.outputTile] = computeOperands.size() - 1;
}
// Define the output shape for this group.
long outputTileSize = group.outputTile == outputTileCount.quot
? outputTileCount.rem
: crossbarSize;
long outputTileSize = group.outputTile == outputTileCount.quot ? outputTileCount.rem : crossbarSize;
// TODO: Address non-same padding.
SmallVector<int64_t> outputShapeArray{
/* 0 */ 1, // Batch size is always 1.
/* 1 */ outputTileSize,
/* 2 */ /* GET_IMAGE_WIDTH(outputType), */ // Same padding assumed.
/* 3 */ /* GET_IMAGE_HEIGHT(outputType) */};
SmallVector<int64_t> outputShapeArray {/* 0 */ 1, // Batch size is always 1.
/* 1 */ outputTileSize,
/* 2 */ /* GET_IMAGE_WIDTH(outputType), */ // Same padding assumed.
/* 3 */ /* GET_IMAGE_HEIGHT(outputType) */};
auto elementType = dyn_cast<RankedTensorType>(gemmOp.getY().getType())
.getElementType();
auto elementType = dyn_cast<RankedTensorType>(gemmOp.getY().getType()).getElementType();
computeOutputType.push_back(
RankedTensorType::get(outputShapeArray, elementType));
computeOutputType.push_back(RankedTensorType::get(outputShapeArray, elementType));
outputTileIndices[group.outputTile] = computeOutputType.size() - 1;
}
@@ -259,43 +238,36 @@ struct ExperimentalGemmConversionPattern
// ----------------------------- //
// Create the compute unit.
spatial::SpatWeightedCompute currentCompute =
rewriter.create<spatial::SpatWeightedCompute>(gemmOp.getLoc(),
computeOutputType, computeWeights, computeOperands);
spatial::SpatWeightedCompute currentCompute = rewriter.create<spatial::SpatWeightedCompute>(
gemmOp.getLoc(), computeOutputType, computeWeights, computeOperands);
// Create a new block for the compute unit and add the operands.
Block *block = rewriter.createBlock(&currentCompute.getRegion());
Block* block = rewriter.createBlock(&currentCompute.getRegion());
rewriter.setInsertionPointToStart(block);
for (Value operand : computeOperands) {
for (Value operand : computeOperands)
block->addArgument(operand.getType(), gemmOp->getLoc());
}
// Initialize a map of local partial results.
map<long, Value> localPartialResults; // WRT the current compute unit.
// If we have any reduction tiles, add them to the local partial results.
for (auto reductionTileIndex : reductionTileIndices) {
localPartialResults[reductionTileIndex.first] =
block->getArgument(reductionTileIndex.second);
}
for (auto reductionTileIndex : reductionTileIndices)
localPartialResults[reductionTileIndex.first] = block->getArgument(reductionTileIndex.second);
// Add all the applyFilters operations to the block.
for (TaggedWeights group : weightsGroups) {
// Get the outputType for this group.
Type outputType =
computeOutputType[outputTileIndices[group.outputTile]];
Type outputType = computeOutputType[outputTileIndices[group.outputTile]];
// Create an apply filters operation.
BlockArgument blockArgument =
block->getArgument(inputTileIndices[group.inputTile]);
BlockArgument blockArgument = block->getArgument(inputTileIndices[group.inputTile]);
// The list of weight indices is group.startingCrossbarIndex + 0, 1, 2,
// ... As many weights as the size of group.weights.
SmallVector<long> weightIndices;
for (size_t i = 0; i < group.weights.size(); ++i) {
for (size_t i = 0; i < group.weights.size(); ++i)
weightIndices.push_back(group.startingCrossbarIndex + i);
}
SmallVector<int64_t> xKerPos;
SmallVector<int64_t> yKerPos;
@@ -313,16 +285,14 @@ struct ExperimentalGemmConversionPattern
ArrayAttr xKerPosAttr = rewriter.getI64ArrayAttr(xKerPos);
ArrayAttr yKerPosAttr = rewriter.getI64ArrayAttr(yKerPos);
Value result = rewriter.create<spatial::SpatApplyFiltersOp>(gemmOp.getLoc(),
outputType, weightIndicesAttr, xKerPosAttr, yKerPosAttr,
blockArgument);
Value result = rewriter.create<spatial::SpatApplyFiltersOp>(
gemmOp.getLoc(), outputType, weightIndicesAttr, xKerPosAttr, yKerPosAttr, blockArgument);
// Perform local reduction if necessary.
if (localPartialResults.find(group.outputTile) !=
localPartialResults.end()) {
if (localPartialResults.find(group.outputTile) != localPartialResults.end()) {
result = rewriter.create<spatial::SpatVAddOp>(gemmOp.getLoc(),
result.getType(), localPartialResults[group.outputTile], result);
result = rewriter.create<spatial::SpatVAddOp>(
gemmOp.getLoc(), result.getType(), localPartialResults[group.outputTile], result);
}
// Update the partial results map.
@@ -375,26 +345,21 @@ struct ExperimentalGemmConversionPattern
// Turn the values into a SmallVector.
SmallVector<Value> outputValues;
for (long i = 0; i < outputTileCount.quot + (outputTileCount.rem > 0);
++i) {
for (long i = 0; i < outputTileCount.quot + (outputTileCount.rem > 0); ++i)
outputValues.push_back(globalPartialResults[i]);
}
// Assert that the number of output values is correct.
assert(outputValues.size() > 0 &&
"No output values were generated for the GEMM operation.");
assert(outputValues.size() > 0 && "No output values were generated for the GEMM operation.");
// Return the final output.
rewriter.replaceOp(gemmOp,
rewriter.create<tensor::ConcatOp>(gemmOp.getLoc(), 1, outputValues));
rewriter.replaceOp(gemmOp, rewriter.create<tensor::ConcatOp>(gemmOp.getLoc(), 1, outputValues));
return success();
}
};
void populateGemmToConvConversionPattern(
RewritePatternSet &patterns, MLIRContext *ctx) {
void populateGemmToConvConversionPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<ExperimentalGemmConversionPattern>(ctx);
}
} // namespace onnx_mlir
} // namespace onnx_mlir

View File

@@ -6,20 +6,22 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include <cassert>
#include <cmath>
#include <cstddef>
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
@@ -35,71 +37,68 @@ bool hasPostProcessExperimentalPoolingWindow<ONNXAveragePoolOp>() {
}
template <typename PoolOp>
Value postProcessExperimentalPoolingWindow(ConversionPatternRewriter &rewriter,
Location loc, PoolOp poolOp, Value valueToDivide, size_t krn_size,
size_t tilesSkippedByPadding) {
Value postProcessExperimentalPoolingWindow(ConversionPatternRewriter& rewriter,
Location loc,
PoolOp poolOp,
Value valueToDivide,
size_t krn_size,
size_t tilesSkippedByPadding) {
return nullptr;
}
template <>
Value postProcessExperimentalPoolingWindow<ONNXAveragePoolOp>(
ConversionPatternRewriter &rewriter, Location loc, ONNXAveragePoolOp poolOp,
Value valueToDivide, size_t krn_size, size_t tilesSkippedByPadding) {
Value postProcessExperimentalPoolingWindow<ONNXAveragePoolOp>(ConversionPatternRewriter& rewriter,
Location loc,
ONNXAveragePoolOp poolOp,
Value valueToDivide,
size_t krn_size,
size_t tilesSkippedByPadding) {
bool countIncludePad = poolOp.getCountIncludePad() == 1;
size_t divisorNumber =
countIncludePad ? krn_size : krn_size - tilesSkippedByPadding;
size_t divisorNumber = countIncludePad ? krn_size : krn_size - tilesSkippedByPadding;
RankedTensorType scalarTensor =
RankedTensorType::get({1}, rewriter.getF32Type());
RankedTensorType scalarTensor = RankedTensorType::get({1}, rewriter.getF32Type());
// Put a spat.const before the computeOp, and use its value. We do this to be
// compatible with the current code generation, which assumes constant to be
// loaded in global memory, which is allocated by adding a spat.const OP
// directly under func.func (i.e. alongside ComputeOps)
auto computeOp = cast<spatial::SpatWeightedCompute>(
valueToDivide.getDefiningOp()->getParentOp());
auto computeOp = cast<spatial::SpatWeightedCompute>(valueToDivide.getDefiningOp()->getParentOp());
rewriter.setInsertionPoint(computeOp);
auto divisorValue = rewriter.create<spatial::SpatConstantOp>(loc, scalarTensor,
rewriter.getI64IntegerAttr(divisorNumber),
/* should_allocate = */ rewriter.getBoolAttr(true));
auto divisorValue = rewriter.create<spatial::SpatConstantOp>(loc,
scalarTensor,
rewriter.getI64IntegerAttr(divisorNumber),
/* should_allocate = */ rewriter.getBoolAttr(true));
rewriter.setInsertionPointAfterValue(valueToDivide);
return rewriter.create<spatial::SpatVSDivOp>(
loc, valueToDivide.getType(), valueToDivide, divisorValue);
return rewriter.create<spatial::SpatVSDivOp>(loc, valueToDivide.getType(), valueToDivide, divisorValue);
}
template <typename ReductionOp>
Value reduceInputTiles(
SmallVector<Value> &inputTiles, ConversionPatternRewriter &rewriter) {
if (inputTiles.size() == 1) {
Value reduceInputTiles(SmallVector<Value>& inputTiles, ConversionPatternRewriter& rewriter) {
if (inputTiles.size() == 1)
return inputTiles[0];
}
if (inputTiles.size() == 2) {
return rewriter.create<spatial::SpatVMaxOp>(inputTiles[0].getLoc(),
inputTiles[0].getType(), inputTiles[0], inputTiles[1]);
return rewriter.create<spatial::SpatVMaxOp>(
inputTiles[0].getLoc(), inputTiles[0].getType(), inputTiles[0], inputTiles[1]);
}
SmallVector<Value> left(
inputTiles.begin(), inputTiles.begin() + inputTiles.size() / 2);
SmallVector<Value> right(
inputTiles.begin() + inputTiles.size() / 2, inputTiles.end());
SmallVector<Value> left(inputTiles.begin(), inputTiles.begin() + inputTiles.size() / 2);
SmallVector<Value> right(inputTiles.begin() + inputTiles.size() / 2, inputTiles.end());
Value leftReduced = reduceInputTiles<ReductionOp>(left, rewriter);
Value rightReduced = reduceInputTiles<ReductionOp>(right, rewriter);
return rewriter.create<ReductionOp>(
inputTiles[0].getLoc(), leftReduced.getType(), leftReduced, rightReduced);
return rewriter.create<ReductionOp>(inputTiles[0].getLoc(), leftReduced.getType(), leftReduced, rightReduced);
}
template <typename PoolOp, typename PoolOpAdaptor, typename ReduceOp>
struct ExperimentalPoolingBaseConverter : public OpConversionPattern<PoolOp> {
ExperimentalPoolingBaseConverter(MLIRContext *ctx)
: OpConversionPattern<PoolOp>(ctx) {}
ExperimentalPoolingBaseConverter(MLIRContext* ctx)
: OpConversionPattern<PoolOp>(ctx) {}
LogicalResult matchAndRewrite(PoolOp poolOp, PoolOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
LogicalResult matchAndRewrite(PoolOp poolOp, PoolOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final {
Value X = adaptor.getX();
ShapedType xShape = mlir::cast<ShapedType>(X.getType());
Value Y = poolOp.getResult();
@@ -110,17 +109,13 @@ struct ExperimentalPoolingBaseConverter : public OpConversionPattern<PoolOp> {
unpackOptionalPairVector(adaptor.getDilations(), dilation_x, dilation_y);
unpackOptionalPairVector(adaptor.getKernelShape(), krn_w, krn_h);
if (adaptor.getAutoPad() != "NOTSET") {
return rewriter.notifyMatchFailure(
poolOp, "auto_pad != NOTSET is deprecated.");
}
if (adaptor.getAutoPad() != "NOTSET")
return rewriter.notifyMatchFailure(poolOp, "auto_pad != NOTSET is deprecated.");
size_t pad_x, pad_y;
auto padUnpackError =
unpackOptionalPadsVector(adaptor.getPads(), pad_x, pad_y);
if (padUnpackError.has_value()) {
auto padUnpackError = unpackOptionalPadsVector(adaptor.getPads(), pad_x, pad_y);
if (padUnpackError.has_value())
return rewriter.notifyMatchFailure(poolOp, padUnpackError.value());
}
Location loc = poolOp.getLoc();
@@ -133,10 +128,8 @@ struct ExperimentalPoolingBaseConverter : public OpConversionPattern<PoolOp> {
// Assert that the input is a tensor.ConcatOp.
auto concat = X.getDefiningOp<tensor::ConcatOp>();
if (!concat) {
return rewriter.notifyMatchFailure(
poolOp, "Expected input to be a tensor.ConcatOp");
}
if (!concat)
return rewriter.notifyMatchFailure(poolOp, "Expected input to be a tensor.ConcatOp");
// Create a [channel_tile][x][y] array to store the input tiles.
std::map<long, std::map<long, std::map<long, Value>>> inputTiles;
@@ -145,24 +138,21 @@ struct ExperimentalPoolingBaseConverter : public OpConversionPattern<PoolOp> {
for (size_t y = 0; y < input_h; ++y) {
for (size_t x = 0; x < input_w; ++x) {
for (long it = 0; it < tileCount.quot + (tileCount.rem > 0); ++it) {
size_t tilingSize =
it == tileCount.quot ? tileCount.rem : crossbarSize;
size_t tilingSize = it == tileCount.quot ? tileCount.rem : crossbarSize;
SmallVector<OpFoldResult> strides(4, rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> offsets = {/* 0 */ rewriter.getIndexAttr(0),
/* 1 */ rewriter.getIndexAttr(0),
/* 2 */ rewriter.getIndexAttr(x),
/* 3 */ rewriter.getIndexAttr(y)};
SmallVector<OpFoldResult> sizes = {
/* 0 */ rewriter.getIndexAttr(1), // Batch size is always 1.
/* 1 */ rewriter.getIndexAttr(tilingSize),
/* 2 */ rewriter.getIndexAttr(1),
/* 3 */ rewriter.getIndexAttr(1)};
/* 1 */ rewriter.getIndexAttr(0),
/* 2 */ rewriter.getIndexAttr(x),
/* 3 */ rewriter.getIndexAttr(y)};
SmallVector<OpFoldResult> sizes = {/* 0 */ rewriter.getIndexAttr(1), // Batch size is always 1.
/* 1 */ rewriter.getIndexAttr(tilingSize),
/* 2 */ rewriter.getIndexAttr(1),
/* 3 */ rewriter.getIndexAttr(1)};
// Get the concat's operand that we want to slice.
Value concatInput = concat.getOperand(it);
Value slicedTile = rewriter.create<tensor::ExtractSliceOp>(
loc, concatInput, offsets, sizes, strides);
Value slicedTile = rewriter.create<tensor::ExtractSliceOp>(loc, concatInput, offsets, sizes, strides);
inputTiles[it][x][y] = slicedTile;
}
@@ -175,19 +165,15 @@ struct ExperimentalPoolingBaseConverter : public OpConversionPattern<PoolOp> {
for (size_t y = 0; y < output_h; ++y) {
for (size_t x = 0; x < output_w; ++x) {
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) {
SmallVector<int64_t> outputShapeArray{
/* 0 */ 1, // Batch size is always 1.
/* 1 */
cast<RankedTensorType>(inputTiles[it][0][0].getType())
.getShape()[1],
/* 2 */ 1,
/* 3 */ 1};
SmallVector<int64_t> outputShapeArray {/* 0 */ 1, // Batch size is always 1.
/* 1 */
cast<RankedTensorType>(inputTiles[it][0][0].getType()).getShape()[1],
/* 2 */ 1,
/* 3 */ 1};
auto elementType =
dyn_cast<RankedTensorType>(xShape).getElementType();
auto elementType = dyn_cast<RankedTensorType>(xShape).getElementType();
outputTileTypes.push_back(
RankedTensorType::get(outputShapeArray, elementType));
outputTileTypes.push_back(RankedTensorType::get(outputShapeArray, elementType));
}
}
}
@@ -195,29 +181,25 @@ struct ExperimentalPoolingBaseConverter : public OpConversionPattern<PoolOp> {
// Create a plain value list of the input tiles.
SmallVector<Value> inputTilesList;
for (size_t y = 0; y < input_h; ++y) {
for (size_t x = 0; x < input_w; ++x) {
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) {
for (size_t x = 0; x < input_w; ++x)
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it)
inputTilesList.push_back(inputTiles[it][y][x]);
}
}
}
// Create a single compute to calculate the output.
auto computeOp = rewriter.create<spatial::SpatWeightedCompute>(
loc, outputTileTypes, SmallVector<Value>(), inputTilesList);
auto computeOp =
rewriter.create<spatial::SpatWeightedCompute>(loc, outputTileTypes, SmallVector<Value>(), inputTilesList);
// Create a new block for the compute unit and add the operands.
Block *block = rewriter.createBlock(&computeOp.getRegion());
Block* block = rewriter.createBlock(&computeOp.getRegion());
// Fill the block arguments and keep a reference to them.
std::map<size_t, std::map<size_t, std::map<size_t, Value>>> inputTilesArgs;
for (size_t y = 0; y < input_h; ++y) {
for (size_t x = 0; x < input_w; ++x) {
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) {
auto tileIndex = y * input_w * (itc.quot + (itc.rem > 0)) +
x * (itc.quot + (itc.rem > 0)) + it;
inputTilesArgs[it][y][x] = block->addArgument(
computeOp->getOperand(tileIndex).getType(), loc);
auto tileIndex = y * input_w * (itc.quot + (itc.rem > 0)) + x * (itc.quot + (itc.rem > 0)) + it;
inputTilesArgs[it][y][x] = block->addArgument(computeOp->getOperand(tileIndex).getType(), loc);
}
}
}
@@ -236,28 +218,26 @@ struct ExperimentalPoolingBaseConverter : public OpConversionPattern<PoolOp> {
size_t end_y = std::min(start_y + krn_h, input_h);
SmallVector<Value> inputTilesToReduce;
for (size_t ky = start_y; ky < end_y; ++ky) {
for (size_t kx = start_x; kx < end_x; ++kx) {
for (size_t ky = start_y; ky < end_y; ++ky)
for (size_t kx = start_x; kx < end_x; ++kx)
inputTilesToReduce.push_back(inputTilesArgs[it][ky][kx]);
}
}
auto reduceResult =
reduceInputTiles<ReduceOp>(inputTilesToReduce, rewriter);
auto reduceResult = reduceInputTiles<ReduceOp>(inputTilesToReduce, rewriter);
// If the reduce op is add, we need to divide the result by the
// number of elements in the pooling window.
if (hasPostProcessExperimentalPoolingWindow<PoolOp>()) {
// Add a spat.const before the computeOp.
rewriter.setInsertionPoint(computeOp);
auto divisorValue = rewriter.create<spatial::SpatConstantOp>(loc,
RankedTensorType::get({1}, rewriter.getF32Type()),
rewriter.getI64IntegerAttr(krn_w * krn_h),
rewriter.getBoolAttr(true));
auto divisorValue =
rewriter.create<spatial::SpatConstantOp>(loc,
RankedTensorType::get({1}, rewriter.getF32Type()),
rewriter.getI64IntegerAttr(krn_w * krn_h),
rewriter.getBoolAttr(true));
rewriter.setInsertionPointAfter(reduceResult.getDefiningOp());
reduceResult = rewriter.create<spatial::SpatVSDivOp>(
loc, reduceResult.getType(), reduceResult, divisorValue);
reduceResult =
rewriter.create<spatial::SpatVSDivOp>(loc, reduceResult.getType(), reduceResult, divisorValue);
}
outputTiles.push_back(reduceResult);
}
@@ -274,8 +254,7 @@ struct ExperimentalPoolingBaseConverter : public OpConversionPattern<PoolOp> {
for (size_t y = 0; y < output_h; ++y) {
for (size_t x = 0; x < output_w; ++x) {
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) {
auto tileIndex = y * output_w * (itc.quot + (itc.rem > 0)) +
x * (itc.quot + (itc.rem > 0)) + it;
auto tileIndex = y * output_w * (itc.quot + (itc.rem > 0)) + x * (itc.quot + (itc.rem > 0)) + it;
computeOutput[it][y][x] = computeOp.getResult(tileIndex);
}
}
@@ -285,30 +264,25 @@ struct ExperimentalPoolingBaseConverter : public OpConversionPattern<PoolOp> {
SmallVector<Value> outputTilesList;
for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) {
SmallVector<Value> imgConcatTiles;
for (size_t y = 0; y < output_h; ++y) {
for (size_t x = 0; x < output_w; ++x) {
for (size_t y = 0; y < output_h; ++y)
for (size_t x = 0; x < output_w; ++x)
imgConcatTiles.push_back(computeOutput[it][y][x]);
}
}
size_t tilingSize = it == tileCount.quot ? tileCount.rem : crossbarSize;
SmallVector<int64_t> outputShapeArray{
/* 0 */ 1, // Batch size is always 1.
/* 1 */ (long)tilingSize,
/* 2 */ (long)output_w,
/* 3 */ (long)output_h};
SmallVector<int64_t> outputShapeArray {/* 0 */ 1, // Batch size is always 1.
/* 1 */ (long) tilingSize,
/* 2 */ (long) output_w,
/* 3 */ (long) output_h};
auto elementType = dyn_cast<RankedTensorType>(xShape).getElementType();
outputTilesList.push_back(rewriter.create<spatial::SpatImgConcatOp>(loc,
RankedTensorType::get(outputShapeArray, elementType),
imgConcatTiles));
outputTilesList.push_back(rewriter.create<spatial::SpatImgConcatOp>(
loc, RankedTensorType::get(outputShapeArray, elementType), imgConcatTiles));
}
// Create a new tensor.ConcatOp to concatenate the output tiles.
Value outputTensor =
rewriter.create<tensor::ConcatOp>(loc, 1, outputTilesList);
Value outputTensor = rewriter.create<tensor::ConcatOp>(loc, 1, outputTilesList);
rewriter.replaceOp(poolOp, outputTensor);
@@ -316,12 +290,11 @@ struct ExperimentalPoolingBaseConverter : public OpConversionPattern<PoolOp> {
}
};
void populateExperimentalPoolingTilingPattern(
RewritePatternSet &patterns, MLIRContext *ctx) {
patterns.insert<ExperimentalPoolingBaseConverter<ONNXMaxPoolSingleOutOp,
ONNXMaxPoolSingleOutOpAdaptor, spatial::SpatVMaxOp>>(ctx);
patterns.insert<ExperimentalPoolingBaseConverter<ONNXAveragePoolOp,
ONNXAveragePoolOpAdaptor, spatial::SpatVAddOp>>(ctx);
void populateExperimentalPoolingTilingPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<
ExperimentalPoolingBaseConverter<ONNXMaxPoolSingleOutOp, ONNXMaxPoolSingleOutOpAdaptor, spatial::SpatVMaxOp>>(ctx);
patterns.insert<ExperimentalPoolingBaseConverter<ONNXAveragePoolOp, ONNXAveragePoolOpAdaptor, spatial::SpatVAddOp>>(
ctx);
}
} // namespace onnx_mlir
} // namespace onnx_mlir

View File

@@ -6,38 +6,39 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include <cassert>
#include <cmath>
#include <cstddef>
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
llvm::SmallPtrSet<Operation *, 16> oldComputeOpsReplaced;
llvm::SmallPtrSet<Operation*, 16> oldComputeOpsReplaced;
Value applyReducePatternNew(SmallVector<Value> &valuesToReduce,
ConversionPatternRewriter &rewriter,
std::function<Value(const Value &, const Value &)> reduce,
std::function<Value(const Value &)> preprocess,
std::function<Value(const Value &)> postprocess) {
Value applyReducePatternNew(SmallVector<Value>& valuesToReduce,
ConversionPatternRewriter& rewriter,
std::function<Value(const Value&, const Value&)> reduce,
std::function<Value(const Value&)> preprocess,
std::function<Value(const Value&)> postprocess) {
// Simple case: if we have only one input, just return it
if (valuesToReduce.size() == 1) {
if (valuesToReduce.size() == 1)
return valuesToReduce[0];
}
if (preprocess) {
for (auto &valToReduce : valuesToReduce) {
for (auto& valToReduce : valuesToReduce) {
rewriter.setInsertionPointAfterValue(valToReduce);
valToReduce = preprocess(valToReduce);
}
@@ -47,9 +48,9 @@ Value applyReducePatternNew(SmallVector<Value> &valuesToReduce,
// computeOp. In this case, we need to apply the reduction within-computef
// Keep a map between a computeOp and the last Value for this reduction
std::unordered_map<Operation *, Value> lastValueForCompute;
for (auto &valToReduce : valuesToReduce) {
Operation *computeOp = valToReduce.getParentBlock()->getParentOp();
std::unordered_map<Operation*, Value> lastValueForCompute;
for (auto& valToReduce : valuesToReduce) {
Operation* computeOp = valToReduce.getParentBlock()->getParentOp();
// if (valToReduce.getDefiningOp()) {
// // If the value is defined by an operation, we take the parent
// operation computeOp = valToReduce.getDefiningOp()->getParentOp();
@@ -67,12 +68,10 @@ Value applyReducePatternNew(SmallVector<Value> &valuesToReduce,
// within-compute
Value lastWithinComputeValue = it->second;
if (valToReduce.getDefiningOp()->isBeforeInBlock(
lastWithinComputeValue.getDefiningOp())) {
if (valToReduce.getDefiningOp()->isBeforeInBlock(lastWithinComputeValue.getDefiningOp()))
rewriter.setInsertionPointAfterValue(lastWithinComputeValue);
} else {
else
rewriter.setInsertionPointAfterValue(valToReduce);
}
valToReduce = reduce(lastWithinComputeValue, valToReduce);
lastValueForCompute[computeOp] = valToReduce;
}
@@ -83,9 +82,8 @@ Value applyReducePatternNew(SmallVector<Value> &valuesToReduce,
// Now, reconstruct from the map the valuesToReduce list
valuesToReduce.clear();
valuesToReduce.reserve(lastValueForCompute.size());
for (auto &entry : lastValueForCompute) {
for (auto& entry : lastValueForCompute)
valuesToReduce.push_back(entry.second);
}
Location loc = valuesToReduce[0].getLoc();
auto channelType = spatial::SpatChannelType::get(rewriter.getContext());
@@ -123,8 +121,7 @@ Value applyReducePatternNew(SmallVector<Value> &valuesToReduce,
// 3. Add a receiveOp after the second value
rewriter.setInsertionPointAfterValue(secondValue);
auto receivedValue = rewriter.create<spatial::SpatChannelReceiveOp>(
loc, secondValue.getType(), channel);
auto receivedValue = rewriter.create<spatial::SpatChannelReceiveOp>(loc, secondValue.getType(), channel);
// 4. Apply reduction between second value and received value
rewriter.setInsertionPointAfterValue(receivedValue);
@@ -135,17 +132,14 @@ Value applyReducePatternNew(SmallVector<Value> &valuesToReduce,
// If we have an odd number of inputs, we need to add the last one to the
// newInputs list.
if (valuesToReduceRef.size() % 2 == 1) {
if (valuesToReduceRef.size() % 2 == 1)
nextValuesToReduce.push_back(valuesToReduceRef.back());
}
// Replace the inputOps list with the new one.
valuesToReduceRef =
llvm::OwningArrayRef<Value>(std::move(nextValuesToReduce));
valuesToReduceRef = llvm::OwningArrayRef<Value>(std::move(nextValuesToReduce));
}
assert(valuesToReduceRef.size() == 1 &&
"Internal error: expected a single input at this point.");
assert(valuesToReduceRef.size() == 1 && "Internal error: expected a single input at this point.");
auto finalValue = valuesToReduceRef[0];
@@ -168,46 +162,49 @@ bool hasPostProcessPoolingWindow<ONNXAveragePoolOp>() {
}
template <typename PoolOp>
Value postProcessPoolingWindow(ConversionPatternRewriter &rewriter,
Location loc, PoolOp poolOp, Value valueToDivide, size_t krn_size,
size_t tilesSkippedByPadding) {
Value postProcessPoolingWindow(ConversionPatternRewriter& rewriter,
Location loc,
PoolOp poolOp,
Value valueToDivide,
size_t krn_size,
size_t tilesSkippedByPadding) {
return nullptr;
}
template <>
Value postProcessPoolingWindow<ONNXAveragePoolOp>(
ConversionPatternRewriter &rewriter, Location loc, ONNXAveragePoolOp poolOp,
Value valueToDivide, size_t krn_size, size_t tilesSkippedByPadding) {
Value postProcessPoolingWindow<ONNXAveragePoolOp>(ConversionPatternRewriter& rewriter,
Location loc,
ONNXAveragePoolOp poolOp,
Value valueToDivide,
size_t krn_size,
size_t tilesSkippedByPadding) {
bool countIncludePad = poolOp.getCountIncludePad() == 1;
size_t divisorNumber =
countIncludePad ? krn_size : krn_size - tilesSkippedByPadding;
size_t divisorNumber = countIncludePad ? krn_size : krn_size - tilesSkippedByPadding;
RankedTensorType scalarTensor =
RankedTensorType::get({1}, rewriter.getF32Type());
RankedTensorType scalarTensor = RankedTensorType::get({1}, rewriter.getF32Type());
// Put a spat.const before the computeOp, and use its value. We do this to be
// compatible with the current code generation, which assumes constant to be
// loaded in global memory, which is allocated by adding a spat.const OP
// directly under func.func (i.e. alongside ComputeOps)
auto computeOp = cast<spatial::SpatWeightedCompute>(
valueToDivide.getDefiningOp()->getParentOp());
auto computeOp = cast<spatial::SpatWeightedCompute>(valueToDivide.getDefiningOp()->getParentOp());
rewriter.setInsertionPoint(computeOp);
auto divisorValue = rewriter.create<spatial::SpatConstantOp>(loc, scalarTensor,
rewriter.getI64IntegerAttr(divisorNumber),
/* should_allocate = */ rewriter.getBoolAttr(true));
auto divisorValue = rewriter.create<spatial::SpatConstantOp>(loc,
scalarTensor,
rewriter.getI64IntegerAttr(divisorNumber),
/* should_allocate = */ rewriter.getBoolAttr(true));
rewriter.setInsertionPointAfterValue(valueToDivide);
return rewriter.create<spatial::SpatVSDivOp>(
loc, valueToDivide.getType(), valueToDivide, divisorValue);
return rewriter.create<spatial::SpatVSDivOp>(loc, valueToDivide.getType(), valueToDivide, divisorValue);
}
template <typename PoolOp, typename PoolOpAdaptor, typename ReduceOp>
struct PoolingBaseConverter : public OpConversionPattern<PoolOp> {
PoolingBaseConverter(MLIRContext *ctx) : OpConversionPattern<PoolOp>(ctx) {}
PoolingBaseConverter(MLIRContext* ctx)
: OpConversionPattern<PoolOp>(ctx) {}
LogicalResult matchAndRewrite(PoolOp poolOp, PoolOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
LogicalResult matchAndRewrite(PoolOp poolOp, PoolOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final {
Value X = adaptor.getX();
ShapedType xShape = mlir::cast<ShapedType>(X.getType());
Value Y = poolOp.getResult();
@@ -218,17 +215,13 @@ struct PoolingBaseConverter : public OpConversionPattern<PoolOp> {
unpackOptionalPairVector(adaptor.getDilations(), dilation_x, dilation_y);
unpackOptionalPairVector(adaptor.getKernelShape(), krn_w, krn_h);
if (adaptor.getAutoPad() != "NOTSET") {
return rewriter.notifyMatchFailure(
poolOp, "auto_pad != NOTSET is deprecated.");
}
if (adaptor.getAutoPad() != "NOTSET")
return rewriter.notifyMatchFailure(poolOp, "auto_pad != NOTSET is deprecated.");
size_t pad_x, pad_y;
auto padUnpackError =
unpackOptionalPadsVector(adaptor.getPads(), pad_x, pad_y);
if (padUnpackError.has_value()) {
auto padUnpackError = unpackOptionalPadsVector(adaptor.getPads(), pad_x, pad_y);
if (padUnpackError.has_value())
return rewriter.notifyMatchFailure(poolOp, padUnpackError.value());
}
Location loc = poolOp.getLoc();
@@ -236,8 +229,7 @@ struct PoolingBaseConverter : public OpConversionPattern<PoolOp> {
size_t input_w = GET_IMAGE_WIDTH(xShape);
size_t output_h = GET_IMAGE_HEIGHT(yShape);
size_t output_w = GET_IMAGE_WIDTH(yShape);
size_t channelTileCount =
ceilIntegerDivide(GET_IMAGE_CHANNEL(xShape), crossbarSize.getValue());
size_t channelTileCount = ceilIntegerDivide(GET_IMAGE_CHANNEL(xShape), crossbarSize.getValue());
size_t channelTileRest = GET_IMAGE_CHANNEL(xShape) % crossbarSize;
// 1: Tile the input tensor
@@ -249,14 +241,13 @@ struct PoolingBaseConverter : public OpConversionPattern<PoolOp> {
// Example complete input tensor: tensor<1x3x12x12xf32> (NxCxWxH)
// Suppose that the input tensor is produced by concatenating the results of
// many ComputeOps. Get the result tiles from these ComputeOps.
SmallVector<SmallVector<SmallVector<Value>>> inputTiles(channelTileCount,
SmallVector<SmallVector<Value>>(input_w, SmallVector<Value>(input_h)));
SmallVector<SmallVector<SmallVector<Value>>> inputTiles(
channelTileCount, SmallVector<SmallVector<Value>>(input_w, SmallVector<Value>(input_h)));
auto resolveErrorOpt = resolveImgInputTiles(X, inputTiles, channelTileCount,
channelTileRest, input_w, input_h, rewriter);
if (resolveErrorOpt.has_value()) {
auto resolveErrorOpt =
resolveImgInputTiles(X, inputTiles, channelTileCount, channelTileRest, input_w, input_h, rewriter);
if (resolveErrorOpt.has_value())
return rewriter.notifyMatchFailure(poolOp, *resolveErrorOpt);
}
// TODO: This requires a core for each input tile, which is not ideal. We
// can do better.
@@ -265,18 +256,17 @@ struct PoolingBaseConverter : public OpConversionPattern<PoolOp> {
for (size_t t = 0; t < channelTileCount; t++) {
for (size_t x = 0; x < input_w; x++) {
for (size_t y = 0; y < input_h; y++) {
if (auto extractSliceOp =
inputTiles[t][x][y].getDefiningOp<tensor::ExtractSliceOp>()) {
if (auto extractSliceOp = inputTiles[t][x][y].getDefiningOp<tensor::ExtractSliceOp>()) {
Location tileLoc = extractSliceOp.getLoc();
auto tempComputeOp = rewriter.create<spatial::SpatWeightedCompute>(
tileLoc, extractSliceOp.getResultType(),
/* xbarWeights =*/ValueRange(), extractSliceOp.getResult());
auto tempComputeOp = rewriter.create<spatial::SpatWeightedCompute>(tileLoc,
extractSliceOp.getResultType(),
/* xbarWeights =*/ValueRange(),
extractSliceOp.getResult());
Block *tempComputeOpBlock = new Block();
Block* tempComputeOpBlock = new Block();
tempComputeOp.getBody().push_back(tempComputeOpBlock);
auto tempComputeOpBlockArg = tempComputeOpBlock->addArgument(
extractSliceOp.getType(), tileLoc);
auto tempComputeOpBlockArg = tempComputeOpBlock->addArgument(extractSliceOp.getType(), tileLoc);
rewriter.setInsertionPointToStart(tempComputeOpBlock);
rewriter.create<spatial::SpatYieldOp>(tileLoc, tempComputeOpBlockArg);
@@ -295,8 +285,7 @@ struct PoolingBaseConverter : public OpConversionPattern<PoolOp> {
// For example: outputTiles[channelTile][x][y]
// Example complete output tensor: tensor<1x3x6x6xf32> (NxCxWxH)
SmallVector<SmallVector<SmallVector<Value>>> outputTiles(
channelTileCount, SmallVector<SmallVector<Value>>(
output_w, SmallVector<Value>(output_h, nullptr)));
channelTileCount, SmallVector<SmallVector<Value>>(output_w, SmallVector<Value>(output_h, nullptr)));
// List of values to pool for each output pixel
SmallVector<Value> valuesToPool;
@@ -312,15 +301,12 @@ struct PoolingBaseConverter : public OpConversionPattern<PoolOp> {
valuesToPool.clear();
size_t tilesSkippedByPadding = 0;
auto [start_x, end_x] = kernel_get_start_and_end(
outX, input_w, krn_w, stride_x, dilation_x, pad_x);
auto [start_y, end_y] = kernel_get_start_and_end(
outY, input_h, krn_h, stride_y, dilation_y, pad_y);
auto [start_x, end_x] = kernel_get_start_and_end(outX, input_w, krn_w, stride_x, dilation_x, pad_x);
auto [start_y, end_y] = kernel_get_start_and_end(outY, input_h, krn_h, stride_y, dilation_y, pad_y);
for (size_t inX = start_x; inX < end_x; inX += dilation_x) {
for (size_t inY = start_y; inY < end_y; inY += dilation_y) {
if (failed(verifyWithinBoundsAndPaddings(
input_w, input_h, inX, inY, pad_x, pad_y))) {
if (failed(verifyWithinBoundsAndPaddings(input_w, input_h, inX, inY, pad_x, pad_y))) {
tilesSkippedByPadding++;
continue;
}
@@ -328,78 +314,73 @@ struct PoolingBaseConverter : public OpConversionPattern<PoolOp> {
Value inputTile = inputTiles[outTile][inX][inY];
Value valueToPool;
if (auto computeProducer =
inputTile.getDefiningOp<spatial::SpatWeightedCompute>()) {
if (auto computeProducer = inputTile.getDefiningOp<spatial::SpatWeightedCompute>()) {
int resultNumber = getResultIndex(computeProducer, inputTile);
auto yieldInComputeOp = cast<spatial::SpatYieldOp>(
computeProducer.getBody().front().getTerminator());
auto yieldInComputeOp = cast<spatial::SpatYieldOp>(computeProducer.getBody().front().getTerminator());
valueToPool = yieldInComputeOp.getOperand(resultNumber);
} else if (auto receiveProducer =
inputTile
.getDefiningOp<spatial::SpatChannelReceiveOp>()) {
auto sendOpOpt =
getOtherEndOfChannel(receiveProducer, true, rewriter);
}
else if (auto receiveProducer = inputTile.getDefiningOp<spatial::SpatChannelReceiveOp>()) {
auto sendOpOpt = getOtherEndOfChannel(receiveProducer, true, rewriter);
if (failed(sendOpOpt)) {
return rewriter.notifyMatchFailure(poolOp,
"ChannelReceiveOp does not have a matching "
"ChannelSendOp.");
"ChannelReceiveOp does not have a matching "
"ChannelSendOp.");
}
auto sendOp = cast<spatial::SpatChannelSendOp>(*sendOpOpt);
valueToPool = sendOp.getData();
} else {
}
else {
return rewriter.notifyMatchFailure(poolOp,
"Input tile for Pooling is not produced by a "
"WeightedComputeOp nor a receiveOp");
"Input tile for Pooling is not produced by a "
"WeightedComputeOp nor a receiveOp");
}
valuesToPool.push_back(valueToPool);
}
}
assert(valuesToPool.size() != 0 &&
"Pooling computed on zero tiles make no sense.");
assert(valuesToPool.size() != 0 && "Pooling computed on zero tiles make no sense.");
// assert(computeOpsForPooling.size() != 1 &&
// "Pooling computed on one tiles make no sense??? Or maybe
// this " "should have been simplified earlier???");
std::function<Value(const Value &)> postProcessFn = nullptr;
std::function<Value(const Value&)> postProcessFn = nullptr;
if (hasPostProcessPoolingWindow<PoolOp>()) {
postProcessFn = [&](const Value prevFinalRes) {
return postProcessPoolingWindow(rewriter, loc, poolOp,
prevFinalRes, krn_h * krn_w, tilesSkippedByPadding);
return postProcessPoolingWindow(
rewriter, loc, poolOp, prevFinalRes, krn_h * krn_w, tilesSkippedByPadding);
};
}
Value reducedWithinCompute = applyReducePatternNew(
valuesToPool, rewriter,
[&](const Value lhs, const Value rhs) {
return rewriter.create<ReduceOp>(loc, lhs.getType(), lhs, rhs);
},
nullptr, postProcessFn);
valuesToPool,
rewriter,
[&](const Value lhs, const Value rhs) { return rewriter.create<ReduceOp>(loc, lhs.getType(), lhs, rhs); },
nullptr,
postProcessFn);
// Send this value through a channel, and receive it in the
// `func.func`. During lowering, we will need to "move it" into the
// users computeOps
auto computeOpOfReduced = cast<spatial::SpatWeightedCompute>(
reducedWithinCompute.getDefiningOp()->getParentOp());
auto computeOpOfReduced =
cast<spatial::SpatWeightedCompute>(reducedWithinCompute.getDefiningOp()->getParentOp());
// Create a new channel before the computeOp
rewriter.setInsertionPoint(computeOpOfReduced);
auto reduceChannel = rewriter.create<spatial::SpatChannelNewOp>(
loc, spatial::SpatChannelType::get(rewriter.getContext()));
auto reduceChannel =
rewriter.create<spatial::SpatChannelNewOp>(loc, spatial::SpatChannelType::get(rewriter.getContext()));
// Send value through the channel
rewriter.setInsertionPointAfterValue(reducedWithinCompute);
rewriter.create<spatial::SpatChannelSendOp>(
loc, reduceChannel, reducedWithinCompute);
rewriter.create<spatial::SpatChannelSendOp>(loc, reduceChannel, reducedWithinCompute);
// Receive after the computeOp
rewriter.setInsertionPointAfter(computeOpOfReduced);
auto receivedValue = rewriter.create<spatial::SpatChannelReceiveOp>(
loc, reducedWithinCompute.getType(), reduceChannel);
auto receivedValue =
rewriter.create<spatial::SpatChannelReceiveOp>(loc, reducedWithinCompute.getType(), reduceChannel);
outputTiles[outTile][outX][outY] = receivedValue;
}
@@ -409,9 +390,7 @@ struct PoolingBaseConverter : public OpConversionPattern<PoolOp> {
// TODO: outputTiles are not the results of the computeOps! We need to add
// them!
std::unordered_map<Operation *,
SmallVector<std::tuple<size_t, size_t, size_t, Value>>>
computeOpNeedingResults;
std::unordered_map<Operation*, SmallVector<std::tuple<size_t, size_t, size_t, Value>>> computeOpNeedingResults;
// Iterate each output tile
for (size_t outTile = 0; outTile < channelTileCount; outTile++) {
@@ -422,18 +401,16 @@ struct PoolingBaseConverter : public OpConversionPattern<PoolOp> {
auto outputTileProducer = outputTile.getDefiningOp()->getParentOp();
if (!outputTileProducer) {
return rewriter.notifyMatchFailure(poolOp,
"Output tile for Pooling is not produced by a "
"WeightedComputeOp.");
"Output tile for Pooling is not produced by a "
"WeightedComputeOp.");
}
computeOpNeedingResults[outputTileProducer].push_back(
std::make_tuple(outTile, outX, outY, outputTile));
computeOpNeedingResults[outputTileProducer].push_back(std::make_tuple(outTile, outX, outY, outputTile));
}
}
}
Value outputImage =
createImgConcatOp(outputTiles, rewriter, loc, poolOp.getType());
Value outputImage = createImgConcatOp(outputTiles, rewriter, loc, poolOp.getType());
rewriter.replaceOp(poolOp, outputImage);
@@ -441,12 +418,10 @@ struct PoolingBaseConverter : public OpConversionPattern<PoolOp> {
}
};
void populatePoolingTilingPattern(
RewritePatternSet &patterns, MLIRContext *ctx) {
patterns.insert<PoolingBaseConverter<ONNXMaxPoolSingleOutOp,
ONNXMaxPoolSingleOutOpAdaptor, spatial::SpatVMaxOp>>(ctx);
patterns.insert<PoolingBaseConverter<ONNXAveragePoolOp,
ONNXAveragePoolOpAdaptor, spatial::SpatVAddOp>>(ctx);
void populatePoolingTilingPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<PoolingBaseConverter<ONNXMaxPoolSingleOutOp, ONNXMaxPoolSingleOutOpAdaptor, spatial::SpatVMaxOp>>(
ctx);
patterns.insert<PoolingBaseConverter<ONNXAveragePoolOp, ONNXAveragePoolOpAdaptor, spatial::SpatVAddOp>>(ctx);
}
} // namespace onnx_mlir
} // namespace onnx_mlir

View File

@@ -1,20 +1,19 @@
#include "mlir/Transforms/DialectConversion.h"
#include "Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp"
#include "mlir/Transforms/DialectConversion.h"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
struct ReduceMeanConversionPattern
: public OpConversionPattern<ONNXReduceMeanV13Op> {
struct ReduceMeanConversionPattern : public OpConversionPattern<ONNXReduceMeanV13Op> {
ReduceMeanConversionPattern(MLIRContext *ctx) : OpConversionPattern(ctx) {}
ReduceMeanConversionPattern(MLIRContext* ctx)
: OpConversionPattern(ctx) {}
LogicalResult matchAndRewrite(ONNXReduceMeanV13Op reduceMean,
ONNXReduceMeanV13OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
ONNXReduceMeanV13OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const final {
// Get the input tensor.
Value inputTensor = adaptor.getData();
@@ -38,42 +37,42 @@ struct ReduceMeanConversionPattern
SmallVector<int64_t> padsVals = {0, 0, 0, 0};
// Create the ArrayAttrs
auto kernelShape = mlir::ArrayAttr::get(rewriter.getContext(),
llvm::to_vector(
llvm::map_range(kernelShapeVals, [&](int64_t v) -> mlir::Attribute {
return rewriter.getI64IntegerAttr(v);
})));
auto kernelShape = mlir::ArrayAttr::get(
rewriter.getContext(), llvm::to_vector(llvm::map_range(kernelShapeVals, [&](int64_t v) -> mlir::Attribute {
return rewriter.getI64IntegerAttr(v);
})));
auto strides = mlir::ArrayAttr::get(rewriter.getContext(),
llvm::to_vector(
llvm::map_range(stridesVals, [&](int64_t v) -> mlir::Attribute {
return rewriter.getI64IntegerAttr(v);
})));
llvm::to_vector(llvm::map_range(stridesVals, [&](int64_t v) -> mlir::Attribute {
return rewriter.getI64IntegerAttr(v);
})));
auto dilations = mlir::ArrayAttr::get(rewriter.getContext(),
llvm::to_vector(
llvm::map_range(dilationsVals, [&](int64_t v) -> mlir::Attribute {
return rewriter.getI64IntegerAttr(v);
})));
auto dilations = mlir::ArrayAttr::get(
rewriter.getContext(), llvm::to_vector(llvm::map_range(dilationsVals, [&](int64_t v) -> mlir::Attribute {
return rewriter.getI64IntegerAttr(v);
})));
auto pads = mlir::ArrayAttr::get(rewriter.getContext(),
llvm::to_vector(
llvm::map_range(padsVals, [&](int64_t v) -> mlir::Attribute {
return rewriter.getI64IntegerAttr(v);
})));
llvm::to_vector(llvm::map_range(padsVals, [&](int64_t v) -> mlir::Attribute {
return rewriter.getI64IntegerAttr(v);
})));
// Create the resulting tensor type.
auto resultType = RankedTensorType::get(
/*shape=*/{inputTensorType.getShape()[0], inputTensorType.getShape()[1],
1, 1},
/*elementType=*/inputTensorType.getElementType());
/*shape=*/ {inputTensorType.getShape()[0], inputTensorType.getShape()[1], 1, 1},
/*elementType=*/inputTensorType.getElementType());
// Create the ONNXAveragePoolOp.
auto averagePool = rewriter.create<ONNXAveragePoolOp>(reduceMean.getLoc(),
resultType, inputTensor, /*auto_pad=*/"NOTSET",
/*ceil_mode=*/0, /*count_include_pad=*/1, dilations,
/*kernel_shape=*/kernelShape,
/*pads=*/pads, /*strides=*/strides);
resultType,
inputTensor,
/*auto_pad=*/"NOTSET",
/*ceil_mode=*/0,
/*count_include_pad=*/1,
dilations,
/*kernel_shape=*/kernelShape,
/*pads=*/pads,
/*strides=*/strides);
// Replace the ONNXReduceMeanV13Op with the ONNXAveragePoolOp.
rewriter.replaceOp(reduceMean, averagePool.getResult());
@@ -82,9 +81,8 @@ struct ReduceMeanConversionPattern
}
};
void populateReduceMeanConversionPattern(
RewritePatternSet &patterns, MLIRContext *ctx) {
void populateReduceMeanConversionPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<ReduceMeanConversionPattern>(ctx);
}
} // namespace onnx_mlir
} // namespace onnx_mlir

View File

@@ -6,11 +6,12 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "llvm/Support/LogicalResult.h"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#define DEFINE_MAP_OP(opname) opname,
#define GET_IMAGE_WIDTH(shapedType) shapedType.getDimSize(2)

View File

@@ -6,7 +6,6 @@
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_os_ostream.h"
#include <filesystem>
#include <fstream>
#include "Common/PIMCommon.hpp"
@@ -16,7 +15,6 @@
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
#include "src/Compiler/CompilerOptions.hpp"
using namespace mlir;
@@ -39,7 +37,7 @@ void ONNXToSpatialPass::runOnOperation() {
mergeActivationPatterns.add<matMulToGemmPattern>(ctx);
mergeActivationPatterns.add<removeFlattenSameShapePattern>(ctx);
if (failed(applyPatternsAndFoldGreedily(moduleOp, std::move(mergeActivationPatterns))))
if (failed(applyPatternsGreedily(moduleOp, std::move(mergeActivationPatterns))))
llvm::dbgs() << "Failed to merge activation patterns, continuing...\n";
IRRewriter rewriter(moduleOp);
@@ -88,7 +86,7 @@ void ONNXToSpatialPass::runOnOperation() {
if (coresCount != -1) {
int computeOpsCount = 0;
for (auto& op : funcOp.getFunctionBody().front().getOperations())
if (isa<spatial::SpatWeightedCompute>(op))
if (isa<SpatWeightedCompute>(op))
computeOpsCount++;
if (computeOpsCount > coresCount) {

View File

@@ -3,38 +3,26 @@
namespace onnx_mlir {
void populateLoweringONNXMatMulOpToSpatialPattern(
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx);
void populateLoweringONNXMatMulOpToSpatialPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateTilingGemmOpPattern(
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx);
void populateTilingConvOpPattern(
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx);
void populateTilingGemmOpPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateTilingConvOpPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populatePoolingTilingPattern(
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx);
void populatePoolingTilingPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateDistributeReducePattern(
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx);
void populateDistributeReducePattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateFoldComputePattern(
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx);
void populateFoldComputePattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateONNXConcatToTensorConcatPattern(
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx);
void populateONNXConcatToTensorConcatPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateRemoveUnusedHelperOpsPatterns(
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx);
void populateRemoveUnusedHelperOpsPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateReduceMeanConversionPattern(
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx);
void populateReduceMeanConversionPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
// Experimental patterns.
void populateExperimentalTilingConvOpPattern(
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx);
void populateGemmToConvConversionPattern(
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx);
void populateExperimentalPoolingTilingPattern(
mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx);
void populateExperimentalTilingConvOpPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateGemmToConvConversionPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateExperimentalPoolingTilingPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
} // namespace onnx_mlir
} // namespace onnx_mlir

View File

@@ -1,19 +1,20 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/PatternMatch.h"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
struct ONNXConcatToTensorConcat : public OpConversionPattern<ONNXConcatOp> {
ONNXConcatToTensorConcat(MLIRContext *ctx) : OpConversionPattern(ctx) {}
ONNXConcatToTensorConcat(MLIRContext* ctx)
: OpConversionPattern(ctx) {}
LogicalResult matchAndRewrite(ONNXConcatOp maxpoolOp,
ONNXConcatOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
ONNXConcatOpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const final {
auto inputs = adaptor.getInputs();
int64_t axis = adaptor.getAxis();
@@ -23,9 +24,8 @@ struct ONNXConcatToTensorConcat : public OpConversionPattern<ONNXConcatOp> {
}
};
void populateONNXConcatToTensorConcatPattern(
RewritePatternSet &patterns, MLIRContext *ctx) {
void populateONNXConcatToTensorConcatPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<ONNXConcatToTensorConcat>(ctx);
}
} // namespace onnx_mlir
} // namespace onnx_mlir

View File

@@ -1,5 +1,6 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/PatternMatch.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"

View File

@@ -1,10 +1,10 @@
#include <queue>
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/AnnotateReplication.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include <queue>
using namespace mlir;
namespace onnx_mlir {

View File

@@ -5,7 +5,6 @@
namespace onnx_mlir {
mlir::LogicalResult annotateReplication(
mlir::func::FuncOp funcOp, mlir::IRRewriter &rewriter);
mlir::LogicalResult annotateReplication(mlir::func::FuncOp funcOp, mlir::IRRewriter& rewriter);
} // namespace onnx_mlir
} // namespace onnx_mlir

View File

@@ -1,32 +1,31 @@
#include "SpatialReducer.hpp"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Value.h"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "llvm/Support/raw_ostream.h"
#include <cassert>
#include <unordered_map>
#include <utility>
#include "SpatialReducer.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#define GET_COMP(computeOpAndResNum) std::get<0>(computeOpAndResNum)
#define GET_RES_NUM(computeOpAndResNum) std::get<1>(computeOpAndResNum)
namespace onnx_mlir {
llvm::SmallPtrSet<Operation *, 16>
onnx_mlir::SpatialReducer::oldComputeOpsReplaced;
llvm::SmallPtrSet<Operation*, 16> onnx_mlir::SpatialReducer::oldComputeOpsReplaced;
ResNum SpatialReducer::applyResultProcessing(
ComputeAndResNum computeOpAndResNum,
std::function<Value(const Value &)> processFun,
ConversionPatternRewriter &rewriter) {
ResNum SpatialReducer::applyResultProcessing(ComputeAndResNum computeOpAndResNum,
std::function<Value(const Value&)> processFun,
ConversionPatternRewriter& rewriter) {
assert(processFun);
auto computeOp = GET_COMP(computeOpAndResNum);
auto resultNum = GET_RES_NUM(computeOpAndResNum);
spatial::SpatYieldOp yieldOp =
cast<spatial::SpatYieldOp>(computeOp.getBody().front().getTerminator());
spatial::SpatYieldOp yieldOp = cast<spatial::SpatYieldOp>(computeOp.getBody().front().getTerminator());
Value result = yieldOp->getOperand(resultNum);
rewriter.setInsertionPointAfterValue(result);
@@ -43,30 +42,24 @@ ResNum SpatialReducer::applyResultProcessing(
return yieldOp.getNumOperands() - 1;
}
OpAndResNum SpatialReducer::applyReducePattern(
SmallVector<ComputeAndResNum> &computeOpsAndResNum,
std::function<Value(const Value &, const Value &)> reduce,
std::function<Value(const Value &)> preprocess,
std::function<Value(const Value &)> postprocess) {
OpAndResNum SpatialReducer::applyReducePattern(SmallVector<ComputeAndResNum>& computeOpsAndResNum,
std::function<Value(const Value&, const Value&)> reduce,
std::function<Value(const Value&)> preprocess,
std::function<Value(const Value&)> postprocess) {
if (preprocess) {
for (auto &computeOpAndResNum : computeOpsAndResNum) {
GET_RES_NUM(computeOpAndResNum) =
applyResultProcessing(computeOpAndResNum, preprocess, rewriter);
}
}
if (preprocess)
for (auto& computeOpAndResNum : computeOpsAndResNum)
GET_RES_NUM(computeOpAndResNum) = applyResultProcessing(computeOpAndResNum, preprocess, rewriter);
// It is possible that `computeOpsAndResNum` contains two entries for the same
// computeOp. In this case, we need to apply the reduction within-computef
// Keep a map between a computeOp and the last Value for this reduction
std::unordered_map<Operation *, Value> lastValueForCompute;
for (auto &computeOpAndResNum : computeOpsAndResNum) {
std::unordered_map<Operation*, Value> lastValueForCompute;
for (auto& computeOpAndResNum : computeOpsAndResNum) {
auto computeOp = GET_COMP(computeOpAndResNum);
auto yieldOp =
cast<spatial::SpatYieldOp>(computeOp.getBody().front().getTerminator());
Value valueWithinCompute =
yieldOp->getOperand(GET_RES_NUM(computeOpAndResNum));
auto yieldOp = cast<spatial::SpatYieldOp>(computeOp.getBody().front().getTerminator());
Value valueWithinCompute = yieldOp->getOperand(GET_RES_NUM(computeOpAndResNum));
auto it = lastValueForCompute.find(computeOp.getOperation());
@@ -75,15 +68,12 @@ OpAndResNum SpatialReducer::applyReducePattern(
// within-compute
Value lastWithinComputeValue = it->second;
assert(valueWithinCompute.getDefiningOp() &&
lastWithinComputeValue.getDefiningOp());
assert(valueWithinCompute.getDefiningOp() && lastWithinComputeValue.getDefiningOp());
if (valueWithinCompute.getDefiningOp()->isBeforeInBlock(
lastWithinComputeValue.getDefiningOp())) {
if (valueWithinCompute.getDefiningOp()->isBeforeInBlock(lastWithinComputeValue.getDefiningOp()))
rewriter.setInsertionPointAfterValue(lastWithinComputeValue);
} else {
else
rewriter.setInsertionPointAfterValue(valueWithinCompute);
}
valueWithinCompute = reduce(lastWithinComputeValue, valueWithinCompute);
lastValueForCompute[computeOp.getOperation()] = valueWithinCompute;
}
@@ -94,16 +84,15 @@ OpAndResNum SpatialReducer::applyReducePattern(
// Now, reconstruct from the map the computeOpsAndResNum list
computeOpsAndResNum.clear();
computeOpsAndResNum.reserve(lastValueForCompute.size());
for (auto &entry : lastValueForCompute) {
for (auto& entry : lastValueForCompute) {
auto computeOp = cast<spatial::SpatWeightedCompute>(entry.first);
auto valueWithinCompute = entry.second;
// We check if `valueWithinCompute` is already used by the yieldOp, in that
// case no need to add it
auto yieldOp =
cast<spatial::SpatYieldOp>(computeOp.getBody().front().getTerminator());
auto yieldOp = cast<spatial::SpatYieldOp>(computeOp.getBody().front().getTerminator());
bool yieldOpUseFound = false;
for (auto &use : valueWithinCompute.getUses()) {
for (auto& use : valueWithinCompute.getUses()) {
if (use.getOwner() == yieldOp.getOperation()) {
// If the value is already used by the yieldOp, we can just use it
computeOpsAndResNum.push_back({computeOp, use.getOperandNumber()});
@@ -111,9 +100,8 @@ OpAndResNum SpatialReducer::applyReducePattern(
break;
}
}
if (yieldOpUseFound) {
if (yieldOpUseFound)
continue;
}
// If this result is not used within a yieldOp, then add it
auto resultNum = yieldOp->getNumOperands();
@@ -147,23 +135,18 @@ OpAndResNum SpatialReducer::applyReducePattern(
// the number of results)
// See below `reducerChanges.push_back` and `finalizeReduceUpdates`
auto yieldOpFirstCompute = cast<spatial::SpatYieldOp>(
firstCompute.getBody().front().getTerminator());
auto yieldOpFirstCompute = cast<spatial::SpatYieldOp>(firstCompute.getBody().front().getTerminator());
// Add a new operand to the block of the second computeOp
Block &secondBlock = secondCompute.getBody().front();
Value formerRes1 = secondBlock.addArgument(
yieldOpFirstCompute->getOperand(firstResultNum).getType(), loc);
Block& secondBlock = secondCompute.getBody().front();
Value formerRes1 = secondBlock.addArgument(yieldOpFirstCompute->getOperand(firstResultNum).getType(), loc);
auto secondComputeWeightsNum =
secondCompute->getAttrOfType<DenseI32ArrayAttr>(
secondCompute.getOperandSegmentSizesAttrName())[0];
auto secondComputeOperandNum =
secondComputeWeightsNum + secondBlock.getNumArguments() - 1;
secondCompute->getAttrOfType<DenseI32ArrayAttr>(secondCompute.getOperandSegmentSizesAttrName())[0];
auto secondComputeOperandNum = secondComputeWeightsNum + secondBlock.getNumArguments() - 1;
// Take the "former-result" from the second computeOp
spatial::SpatYieldOp secondYield =
cast<spatial::SpatYieldOp>(secondBlock.getTerminator());
spatial::SpatYieldOp secondYield = cast<spatial::SpatYieldOp>(secondBlock.getTerminator());
Value formerRes2 = secondYield.getOperand(secondResultNum);
// Apply reduction operation
@@ -184,37 +167,31 @@ OpAndResNum SpatialReducer::applyReducePattern(
// We should also add an entry for updating the results of the last
// operation (the one which never becomes a `firstCompute`): because it is
// not tracked by reducerChanges as `fromOp`
reducerChanges.push_back({firstCompute.getOperation(), firstResultNum,
secondCompute.getOperation(), secondComputeOperandNum});
reducerChanges.push_back(
{firstCompute.getOperation(), firstResultNum, secondCompute.getOperation(), secondComputeOperandNum});
nextComputeOps.push_back(std::make_pair(secondCompute, secondResultNum));
}
// If we have an odd number of inputs, we need to add the last one to the
// newInputs list.
if (computeOpsRef.size() % 2 == 1) {
if (computeOpsRef.size() % 2 == 1)
nextComputeOps.push_back(computeOpsRef.back());
}
// Replace the inputOps list with the new one.
computeOpsRef =
llvm::OwningArrayRef<ComputeAndResNum>(std::move(nextComputeOps));
computeOpsRef = llvm::OwningArrayRef<ComputeAndResNum>(std::move(nextComputeOps));
}
assert(computeOpsRef.size() == 1 &&
"Internal error: expected a single input at this point.");
assert(computeOpsRef.size() == 1 && "Internal error: expected a single input at this point.");
auto finalComputeAndResNum = computeOpsRef[0];
// Force the update of the results of this computeOp, when finalizing
computeOpNeedingResUpdate.push_back(GET_COMP(finalComputeAndResNum));
if (postprocess) {
GET_RES_NUM(finalComputeAndResNum) =
applyResultProcessing(finalComputeAndResNum, postprocess, rewriter);
}
if (postprocess)
GET_RES_NUM(finalComputeAndResNum) = applyResultProcessing(finalComputeAndResNum, postprocess, rewriter);
return std::make_pair(GET_COMP(finalComputeAndResNum).getOperation(),
GET_RES_NUM(finalComputeAndResNum));
return std::make_pair(GET_COMP(finalComputeAndResNum).getOperation(), GET_RES_NUM(finalComputeAndResNum));
}
void SpatialReducer::finalizeReduceUpdates() {
@@ -223,15 +200,13 @@ void SpatialReducer::finalizeReduceUpdates() {
reducesFinalized = true;
// First, add the results to the computeOps
for (auto &reduceChange : reducerChanges) {
for (auto& reduceChange : reducerChanges)
updateResultsOfCompute(reduceChange.fromOp);
}
for (auto &c : computeOpNeedingResUpdate) {
for (auto& c : computeOpNeedingResUpdate)
updateResultsOfCompute(c.getOperation());
}
for (auto &reducerChange : this->reducerChanges) {
for (auto& reducerChange : this->reducerChanges) {
auto fromOp = reducerChange.fromOp;
auto toOp = reducerChange.toOp;
auto fromOpResNum = reducerChange.fromOpResNum;
@@ -243,16 +218,14 @@ void SpatialReducer::finalizeReduceUpdates() {
// toComputeOp could be the existing pointer, or we have to remap it with
// `opToReplacedCompute`
auto toComputeOp = opToReplacedCompute[toOp];
if (!toComputeOp) {
if (!toComputeOp)
toComputeOp = cast<spatial::SpatWeightedCompute>(toOp);
}
assert(toComputeOp != fromComputeOp &&
"Oops should have caught this earlier!");
assert(toComputeOp != fromComputeOp && "Oops should have caught this earlier!");
assert(toComputeOp->getNumOperands() == toOpOperandNum &&
"toOpOperandNum should be the last operand of toComputeOp, are the "
"operations in the right order?");
assert(toComputeOp->getNumOperands() == toOpOperandNum
&& "toOpOperandNum should be the last operand of toComputeOp, are the "
"operations in the right order?");
// Add the new operand to `toComputeOp`
auto fromResult = fromComputeOp.getResult(fromOpResNum);
@@ -261,24 +234,22 @@ void SpatialReducer::finalizeReduceUpdates() {
}
}
Value SpatialReducer::resolveValueFromOpAndResNum(OpAndResNum &opAndResNum) {
assert(reducesFinalized &&
"Cannot create resolve values before finalizing the reduce updates.");
Value SpatialReducer::resolveValueFromOpAndResNum(OpAndResNum& opAndResNum) {
assert(reducesFinalized && "Cannot create resolve values before finalizing the reduce updates.");
Operation *opToCast;
Operation* opToCast;
auto it = opToReplacedCompute.find(opAndResNum.first);
if (it != opToReplacedCompute.end()) {
if (it != opToReplacedCompute.end())
opToCast = it->second;
} else {
else
opToCast = opAndResNum.first;
}
auto computeOp = cast<spatial::SpatWeightedCompute>(opToCast);
return computeOp.getResult(opAndResNum.second);
}
void SpatialReducer::updateResultsOfCompute(Operation *computeOp) {
void SpatialReducer::updateResultsOfCompute(Operation* computeOp) {
if (opToReplacedCompute.find(computeOp) != opToReplacedCompute.end()) {
// If we have already replaced the fromOp, we do not need to do it again
return;
@@ -287,8 +258,7 @@ void SpatialReducer::updateResultsOfCompute(Operation *computeOp) {
auto oldComputeOpNum = oldComputeOp->getNumOperands();
auto yieldOp =
cast<spatial::SpatYieldOp>(oldComputeOp.getBody().front().getTerminator());
auto yieldOp = cast<spatial::SpatYieldOp>(oldComputeOp.getBody().front().getTerminator());
if (yieldOp.getNumOperands() == oldComputeOp->getNumResults()) {
// No result was added, just add itself to the map
@@ -301,9 +271,8 @@ void SpatialReducer::updateResultsOfCompute(Operation *computeOp) {
// Create a new ComputeOp with the new result type, but same operands
rewriter.setInsertionPoint(oldComputeOp);
auto newComputeOp =
rewriter.create<spatial::SpatWeightedCompute>(oldComputeOp->getLoc(),
newResultTypes, oldComputeOp.getWeights(), oldComputeOp.getInputs());
auto newComputeOp = rewriter.create<spatial::SpatWeightedCompute>(
oldComputeOp->getLoc(), newResultTypes, oldComputeOp.getWeights(), oldComputeOp.getInputs());
newComputeOp.getBody().takeBody(oldComputeOp.getBody());
@@ -329,54 +298,49 @@ void SpatialReducer::updateResultsOfCompute(Operation *computeOp) {
rewriter.eraseOp(oldComputeOp);
}
Value SpatialReducer::createImgConcatOp(
SmallVector<SmallVector<SmallVector<OpAndResNum>>> &outputTiles,
Location &loc, Type outputType) {
Value SpatialReducer::createImgConcatOp(SmallVector<SmallVector<SmallVector<OpAndResNum>>>& outputTiles,
Location& loc,
Type outputType) {
assert(reducesFinalized &&
"Cannot create ImgConcatOp before finalizing the reduce updates.");
assert(reducesFinalized && "Cannot create ImgConcatOp before finalizing the reduce updates.");
// outputTiles are indexed like this: [channelTile][x][y]
auto tilesCount = outputTiles.size();
auto width = outputTiles[0].size();
auto height = outputTiles[0][0].size();
SmallVector<SmallVector<SmallVector<Value>>> remappedOutputTiles(tilesCount,
SmallVector<SmallVector<Value>>(width, SmallVector<Value>(height)));
SmallVector<SmallVector<SmallVector<Value>>> remappedOutputTiles(
tilesCount, SmallVector<SmallVector<Value>>(width, SmallVector<Value>(height)));
for (size_t t = 0; t < tilesCount; t++)
for (size_t x = 0; x < width; x++)
for (size_t y = 0; y < height; y++)
remappedOutputTiles[t][x][y] =
resolveValueFromOpAndResNum(outputTiles[t][x][y]);
remappedOutputTiles[t][x][y] = resolveValueFromOpAndResNum(outputTiles[t][x][y]);
return ::onnx_mlir::createImgConcatOp(
remappedOutputTiles, rewriter, loc, outputType);
return ::onnx_mlir::createImgConcatOp(remappedOutputTiles, rewriter, loc, outputType);
}
OpAndResNum SpatialReducer::applyAddMapReduction(
SmallVector<ComputeAndResNum> &computeOps,
ConversionPatternRewriter &rewriter, Value biasTile, MapOperations mapOp) {
OpAndResNum SpatialReducer::applyAddMapReduction(SmallVector<ComputeAndResNum>& computeOps,
ConversionPatternRewriter& rewriter,
Value biasTile,
MapOperations mapOp) {
std::function<Value(const Value &)> postprocessing = nullptr;
std::function<Value(const Value&)> postprocessing = nullptr;
if (mapOp != MapOperations::None) {
postprocessing = [&](const Value a) {
Value mapOperand = a;
if (biasTile) {
mapOperand = rewriter.create<spatial::SpatVAddOp>(
a.getLoc(), a.getType(), a, biasTile);
}
if (biasTile)
mapOperand = rewriter.create<spatial::SpatVAddOp>(a.getLoc(), a.getType(), a, biasTile);
return createMapOperation(rewriter, mapOp, mapOperand);
};
}
return this->applyReducePattern(
computeOps,
[&](Value a, Value b) {
return rewriter.create<spatial::SpatVAddOp>(a.getLoc(), a.getType(), a, b);
},
/* preprocess = */ nullptr, postprocessing);
computeOps,
[&](Value a, Value b) { return rewriter.create<spatial::SpatVAddOp>(a.getLoc(), a.getType(), a, b); },
/* preprocess = */ nullptr,
postprocessing);
}
} // namespace onnx_mlir
} // namespace onnx_mlir

View File

@@ -1,9 +1,10 @@
#pragma once
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/Casting.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/Casting.h"
namespace onnx_mlir {
@@ -12,48 +13,48 @@ using ResNum = unsigned int;
using ComputeAndResNum = std::pair<spatial::SpatWeightedCompute, ResNum>;
struct SpatialReducerChange {
Operation *fromOp;
Operation* fromOp;
unsigned int fromOpResNum;
Operation *toOp;
Operation* toOp;
unsigned int toOpOperandNum;
};
using OpAndResNum = std::pair<Operation *, ResNum>;
using OpAndResNum = std::pair<Operation*, ResNum>;
class SpatialReducer {
public:
SpatialReducer(ConversionPatternRewriter &rewriter) : rewriter(rewriter) {}
SpatialReducer(ConversionPatternRewriter& rewriter)
: rewriter(rewriter) {}
OpAndResNum applyReducePattern(
SmallVector<ComputeAndResNum> &computeOpsAndResNum,
std::function<Value(const Value &, const Value &)> reduce,
std::function<Value(const Value &)> preprocess,
std::function<Value(const Value &)> postprocess);
OpAndResNum applyReducePattern(SmallVector<ComputeAndResNum>& computeOpsAndResNum,
std::function<Value(const Value&, const Value&)> reduce,
std::function<Value(const Value&)> preprocess,
std::function<Value(const Value&)> postprocess);
OpAndResNum applyAddMapReduction(SmallVector<ComputeAndResNum> &computeOps,
ConversionPatternRewriter &rewriter, Value biasTile, MapOperations mapOp);
OpAndResNum applyAddMapReduction(SmallVector<ComputeAndResNum>& computeOps,
ConversionPatternRewriter& rewriter,
Value biasTile,
MapOperations mapOp);
void finalizeReduceUpdates();
~SpatialReducer() {
if (!reducesFinalized) {
if (!reducesFinalized)
finalizeReduceUpdates();
}
}
Value createImgConcatOp(
llvm::SmallVector<llvm::SmallVector<llvm::SmallVector<OpAndResNum>>>
&outputTiles,
Location &loc, Type outputType);
Value createImgConcatOp(llvm::SmallVector<llvm::SmallVector<llvm::SmallVector<OpAndResNum>>>& outputTiles,
Location& loc,
Type outputType);
Value resolveValueFromOpAndResNum(OpAndResNum &opAndResNum);
Value resolveValueFromOpAndResNum(OpAndResNum& opAndResNum);
private:
[[nodiscard("computeOp result number gets updated")]] ResNum
applyResultProcessing(ComputeAndResNum computeOpAndResNum,
std::function<Value(const Value &)> processFun,
ConversionPatternRewriter &rewriter);
std::function<Value(const Value&)> processFun,
ConversionPatternRewriter& rewriter);
/**
* @brief Update the results of a ComputeOp.
@@ -65,9 +66,9 @@ private:
*
* @param computeOp The ComputeOp to update the results of.
*/
void updateResultsOfCompute(Operation *computeOp);
void updateResultsOfCompute(Operation* computeOp);
ConversionPatternRewriter &rewriter;
ConversionPatternRewriter& rewriter;
bool reducesFinalized = false;
// List of changes to be applied after the reduction is finalized
@@ -75,9 +76,9 @@ private:
// List of computeOps that need to be replaced with new results
SmallVector<spatial::SpatWeightedCompute> computeOpNeedingResUpdate;
std::unordered_map<Operation *, spatial::SpatWeightedCompute> opToReplacedCompute;
std::unordered_map<Operation*, spatial::SpatWeightedCompute> opToReplacedCompute;
static llvm::SmallPtrSet<Operation *, 16> oldComputeOpsReplaced;
static llvm::SmallPtrSet<Operation*, 16> oldComputeOpsReplaced;
};
} // namespace onnx_mlir
} // namespace onnx_mlir

View File

@@ -1,11 +1,11 @@
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/WeightSubdivider.hpp"
#include <cassert>
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/WeightSubdivider.hpp"
namespace onnx_mlir {
WeightSubdivider::WeightSubdivider(
map<long, map<long, SmallVector<Value>>> weights)
: weights(std::move(weights)) {}
WeightSubdivider::WeightSubdivider(map<long, map<long, SmallVector<Value>>> weights)
: weights(std::move(weights)) {}
bool WeightSubdivider::isEmpty() const { return weights.empty(); }
@@ -13,7 +13,7 @@ TaggedWeights WeightSubdivider::popGroup(size_t amount) {
assert(!weights.empty() && "No weights to extract.");
auto it = weights.begin();
SmallVector<Value> &values = it->second.begin()->second;
SmallVector<Value>& values = it->second.begin()->second;
long inputTile = it->first;
long outputTile = it->second.begin()->first;
@@ -26,11 +26,11 @@ TaggedWeights WeightSubdivider::popGroup(size_t amount) {
if (n < values.size()) {
values.erase(values.begin(), values.begin() + n);
} else {
}
else {
it->second.erase(outputTile);
if (it->second.empty()) {
if (it->second.empty())
weights.erase(inputTile);
}
}
return {inputTile, outputTile, crossbarsUsed - n, result};

View File

@@ -1,7 +1,9 @@
#pragma once
#include "mlir/IR/Value.h"
#include "llvm/ADT/SmallVector.h"
#include <map>
using namespace mlir;

View File

@@ -1,23 +1,21 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Format.h"
#define FORMAT_OPERATION(op) \
'x' << llvm::format_hex_no_prefix(reinterpret_cast<size_t>(op), 0)
#define FORMAT_ARGUMENT(computeOpPointer, argumentNum) \
llvm::format("Arg_%p_%u", computeOpPointer, argumentNum)
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#define FORMAT_OPERATION(op) 'x' << llvm::format_hex_no_prefix(reinterpret_cast<size_t>(op), 0)
#define FORMAT_ARGUMENT(computeOpPointer, argumentNum) llvm::format("Arg_%p_%u", computeOpPointer, argumentNum)
using namespace mlir;
@@ -25,26 +23,22 @@ namespace onnx_mlir {
namespace {
struct SpatialToGraphvizPass
: public PassWrapper<SpatialToGraphvizPass, OperationPass<ModuleOp>> {
struct SpatialToGraphvizPass : public PassWrapper<SpatialToGraphvizPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialToGraphvizPass)
StringRef getArgument() const override {
return "convert-spatial-to-graphviz";
}
StringRef getArgument() const override { return "convert-spatial-to-graphviz"; }
StringRef getDescription() const override {
return "Lower ONNX ops to Spatial ops.";
}
StringRef getDescription() const override { return "Lower ONNX ops to Spatial ops."; }
SpatialToGraphvizPass(raw_ostream &os = llvm::errs()) : os(os) {}
SpatialToGraphvizPass(const SpatialToGraphvizPass &pass)
: SpatialToGraphvizPass(pass.os) {}
SpatialToGraphvizPass(raw_ostream& os = llvm::errs())
: os(os) {}
SpatialToGraphvizPass(const SpatialToGraphvizPass& pass)
: SpatialToGraphvizPass(pass.os) {}
void runOnOperation() final;
private:
raw_ostream &os;
raw_ostream& os;
/**
* Draws the subgraph for a given spatial::SpatWeightedCompute, including:
@@ -56,31 +50,27 @@ private:
* @param computeNum The number of the compute operation.
*/
void drawComputeOpSubgraph(spatial::SpatWeightedCompute op, size_t computeNum) {
os << "\tsubgraph cluster" << computeNum << " {\n\t\tlabel=\"Compute"
<< computeNum << "\";\n"
os << "\tsubgraph cluster" << computeNum << " {\n\t\tlabel=\"Compute" << computeNum << "\";\n"
<< "\t\tstyle=filled;\n"
<< "\t\tcolor=lightblue;\n";
Block &block = op.getBody().front();
Block& block = op.getBody().front();
// Inputs
size_t inputNum = 0;
for (BlockArgument &input : block.getArguments()) {
for (BlockArgument& input : block.getArguments()) {
auto fromOp = FORMAT_ARGUMENT(op.getOperation(), inputNum);
os << "\t\t" << fromOp << " [label=\"Arg" << inputNum
<< "\",shape=box];\n";
for (auto userOp : input.getUsers()) {
os << "\t\t" << fromOp << " [label=\"Arg" << inputNum << "\",shape=box];\n";
for (auto userOp : input.getUsers())
os << "\t\t" << fromOp << " -> " << FORMAT_OPERATION(userOp) << ";\n";
}
inputNum++;
}
// Iterate operations
for (auto &childOp : block.getOperations()) {
os << "\t\t" << FORMAT_OPERATION(&childOp) << " [label=\""
<< childOp.getName() << "\"];\n";
for (auto& childOp : block.getOperations()) {
os << "\t\t" << FORMAT_OPERATION(&childOp) << " [label=\"" << childOp.getName() << "\"];\n";
drawEdgesFromOpToItsUsers(&childOp);
}
@@ -88,7 +78,7 @@ private:
os << "\t}\n";
// Draw edges from the yield to the users of this computeOp
Operation *yieldOp = block.getTerminator();
Operation* yieldOp = block.getTerminator();
if (!isa<spatial::SpatYieldOp>(yieldOp)) {
yieldOp->emitError("Terminator of block must be YieldOp ???");
signalPassFailure();
@@ -96,9 +86,8 @@ private:
}
for (auto computeOpResult : op->getResults()) {
for (auto &computeOpUse : computeOpResult.getUses()) {
auto toOp = FORMAT_ARGUMENT(
computeOpUse.getOwner(), computeOpUse.getOperandNumber());
for (auto& computeOpUse : computeOpResult.getUses()) {
auto toOp = FORMAT_ARGUMENT(computeOpUse.getOwner(), computeOpUse.getOperandNumber());
os << "\t" << FORMAT_OPERATION(yieldOp) << " -> " << toOp << ";\n";
}
}
@@ -114,9 +103,8 @@ private:
* @param concatOp The concatOp for which the subgraph is drawn.
* @param concatOpNum The number of the concatOp.
*/
void drawConcatOpSubgraph(Operation *concatOp, size_t concatOpNum) {
os << "\tsubgraph clusterconcat" << concatOpNum
<< " {\n\t\tlabel=\"ConcatOp" << concatOpNum << "\";\n"
void drawConcatOpSubgraph(Operation* concatOp, size_t concatOpNum) {
os << "\tsubgraph clusterconcat" << concatOpNum << " {\n\t\tlabel=\"ConcatOp" << concatOpNum << "\";\n"
<< "\t\tstyle=filled;\n"
<< "\t\tcolor=orange;\n";
@@ -126,9 +114,8 @@ private:
auto fromOp = FORMAT_ARGUMENT(concatOp, inputNum);
os << "\t\t" << fromOp << " [label=\"Input" << inputNum << "\"];\n";
for (auto userOp : input.getUsers()) {
for (auto userOp : input.getUsers())
os << "\t\t" << fromOp << " -> " << FORMAT_OPERATION(userOp) << ";\n";
}
inputNum++;
}
@@ -139,11 +126,9 @@ private:
// Edges from output to users
for (auto &computeOpUse : concatOp->getResult(0).getUses()) {
for (auto& computeOpUse : concatOp->getResult(0).getUses()) {
os << "\t" << FORMAT_OPERATION(concatOp) << " -> "
<< FORMAT_ARGUMENT(
computeOpUse.getOwner(), computeOpUse.getOperandNumber())
<< ";\n";
<< FORMAT_ARGUMENT(computeOpUse.getOwner(), computeOpUse.getOperandNumber()) << ";\n";
}
}
@@ -164,10 +149,8 @@ private:
sliceOp.getStaticOffsetsAttr().print(os);
os << "\",color=lawngreen];\n";
for (auto &computeOpUse : sliceOp.getResult().getUses()) {
os << "\t" << nodeId << " -> "
<< FORMAT_ARGUMENT(
computeOpUse.getOwner(), computeOpUse.getOperandNumber())
for (auto& computeOpUse : sliceOp.getResult().getUses()) {
os << "\t" << nodeId << " -> " << FORMAT_ARGUMENT(computeOpUse.getOwner(), computeOpUse.getOperandNumber())
<< ";\n";
}
}
@@ -178,9 +161,8 @@ private:
sliceOp.getStaticOffsetsAttr().print(os);
os << "\",color=lightpink];\n";
for (auto user : sliceOp.getResult().getUsers()) {
for (auto user : sliceOp.getResult().getUsers())
os << "\t" << nodeId << " -> " << FORMAT_OPERATION(user) << ";\n";
}
}
/**
@@ -188,13 +170,10 @@ private:
*
* @param fromOp The operation from which the edges are drawn.
*/
void drawEdgesFromOpToItsUsers(mlir::Operation *fromOp) {
for (auto result : fromOp->getResults()) {
for (auto userOp : result.getUsers()) {
os << "\t\t" << FORMAT_OPERATION(fromOp) << " -> "
<< FORMAT_OPERATION(userOp) << ";\n";
}
}
void drawEdgesFromOpToItsUsers(mlir::Operation* fromOp) {
for (auto result : fromOp->getResults())
for (auto userOp : result.getUsers())
os << "\t\t" << FORMAT_OPERATION(fromOp) << " -> " << FORMAT_OPERATION(userOp) << ";\n";
}
/**
@@ -202,16 +181,15 @@ private:
*
* @param funcOp The `funcOp` for which to draw input nodes and edges.
*/
void drawInputNodesAndEdges(func::FuncOp &funcOp) {
void drawInputNodesAndEdges(func::FuncOp& funcOp) {
os << "\tinput [label=\"Module Input\",color=green];\n";
size_t funcOpArgNum = 0;
for (BlockArgument &arg : funcOp.getArguments()) {
for (BlockArgument& arg : funcOp.getArguments()) {
for (auto &useOp : arg.getUses()) {
os << "\tinput -> "
<< FORMAT_ARGUMENT(useOp.getOwner(), useOp.getOperandNumber())
<< "[label=" << funcOpArgNum << "];\n";
for (auto& useOp : arg.getUses()) {
os << "\tinput -> " << FORMAT_ARGUMENT(useOp.getOwner(), useOp.getOperandNumber()) << "[label=" << funcOpArgNum
<< "];\n";
}
funcOpArgNum++;
}
@@ -237,20 +215,22 @@ void SpatialToGraphvizPass::runOnOperation() {
// Iterate over the ComputeOps within FuncOp:
// 1. Print their subgraph
// 2. Print the edges from its inputs to its outputs
for (Operation &op : func.getOps()) {
for (Operation& op : func.getOps()) {
if (auto computeOp = dyn_cast<spatial::SpatWeightedCompute>(op)) {
drawComputeOpSubgraph(computeOp, computeNum++);
} else if (auto concatOp = dyn_cast<tensor::ConcatOp>(op)) {
}
else if (auto concatOp = dyn_cast<tensor::ConcatOp>(op)) {
drawConcatOpSubgraph(concatOp, concatNum++);
} else if (auto imgConcatOp = dyn_cast<spatial::SpatImgConcatOp>(op)) {
}
else if (auto imgConcatOp = dyn_cast<spatial::SpatImgConcatOp>(op)) {
drawConcatOpSubgraph(imgConcatOp, concatNum++);
} else if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op)) {
}
else if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op)) {
auto producerOp = extractSliceOp->getOperand(0).getDefiningOp();
if (producerOp) {
// Skip extractSliceOp if producer is constant weights (ONNXConstantOp)
if (llvm::isa<ONNXConstantOp>(producerOp)) {
if (llvm::isa<ONNXConstantOp>(producerOp))
continue;
}
// If produced by tosa::ReshapeOp (i.e. it is a bias tile) connect
// directly to its user, which is not a ComputeOp argument.
if (llvm::isa<tosa::ReshapeOp>(producerOp)) {
@@ -268,16 +248,13 @@ void SpatialToGraphvizPass::runOnOperation() {
// Draw output node (use the return Operation - argument number=0 - as nodeId)
auto returnOp = func.getBody().front().getTerminator();
os << '\t' << FORMAT_ARGUMENT(returnOp, 0)
<< " [label=\"Module Output\",color=green];\n";
os << '\t' << FORMAT_ARGUMENT(returnOp, 0) << " [label=\"Module Output\",color=green];\n";
os << "}\n";
}
} // namespace
std::unique_ptr<Pass> createSpatialToGraphvizPass() {
return std::make_unique<SpatialToGraphvizPass>();
}
std::unique_ptr<Pass> createSpatialToGraphvizPass() { return std::make_unique<SpatialToGraphvizPass>(); }
} // namespace onnx_mlir
} // namespace onnx_mlir

View File

@@ -148,7 +148,8 @@ void SpatialToPIMPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IR
size_t resultIndexInConcat = resultUses.begin()->getOperandNumber();
size_t offset = 0;
for (auto operand : concatOp->getOperands().take_front(resultIndexInConcat))
offset += cast<ShapedType>(operand.getType()).getNumElements() * cast<ShapedType>(operand.getType()).getElementTypeBitWidth() / 8;
offset += cast<ShapedType>(operand.getType()).getNumElements()
* cast<ShapedType>(operand.getType()).getElementTypeBitWidth() / 8;
size_t elementSize = yieldType.getElementTypeBitWidth() / 8;

View File

@@ -9,4 +9,4 @@ namespace spatial {
}
} // namespace onnx_mlir
} // namespace onnx_mlir

View File

@@ -1,8 +1,5 @@
#pragma once
#include <map>
#include <string>
#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -11,6 +8,9 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include <map>
#include <string>
/// Include the auto-generated header files containing the declarations
#include "src/Accelerators/PIM/Dialect/PIM/PimDialect.hpp.inc"

View File

@@ -15,7 +15,10 @@ namespace pim {
struct MemCopyHostToDevOpInterface
: DstBufferizableOpInterfaceExternalModel<MemCopyHostToDevOpInterface, PimMemCopyHostToDevOp> {
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const {
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto memCopyHostToDevOp = cast<PimMemCopyHostToDevOp>(op);
auto deviceDst = memCopyHostToDevOp.getDeviceDst();
auto hostSrc = memCopyHostToDevOp.getHostSrc();
@@ -44,7 +47,10 @@ struct MemCopyHostToDevOpInterface
struct MemCopyDevToHostOpInterface
: DstBufferizableOpInterfaceExternalModel<MemCopyDevToHostOpInterface, PimMemCopyDevToHostOp> {
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const {
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto memCopyDevToHostOp = cast<PimMemCopyDevToHostOp>(op);
auto globalDst = memCopyDevToHostOp.getHostDst();
@@ -88,7 +94,10 @@ struct VMMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<VMMOpBu
return false;
}
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const {
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto vmmOp = cast<PimVMMOp>(op);
auto vectorInputOpt = getBuffer(rewriter, vmmOp.getVectorInput(), options, state);
@@ -110,7 +119,10 @@ struct MVMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<MVMOpBu
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
}
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const {
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto mvmOp = cast<PimMVMOp>(op);
auto vectorInputOpt = getBuffer(rewriter, mvmOp.getVectorInput(), options, state);
@@ -138,7 +150,10 @@ struct VAddOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel<VAddOp
return true;
}
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const {
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto vaddOp = cast<PimVAddOp>(op);
auto aOpt = getBuffer(rewriter, vaddOp.getA(), options, state);

View File

@@ -1,6 +1,7 @@
#pragma once
#include "mlir/IR/DialectRegistry.h"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
using namespace mlir;
@@ -8,7 +9,7 @@ using namespace mlir;
namespace onnx_mlir {
namespace pim {
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);
void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry);
} // namespace pim
} // namespace onnx_mlir
} // namespace onnx_mlir

View File

@@ -1,8 +1,5 @@
#pragma once
#include <map>
#include <string>
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
@@ -10,6 +7,9 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/Types.h"
#include <map>
#include <string>
/// Include the auto-generated header files containing the declarations
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialDialect.hpp.inc"

View File

@@ -75,7 +75,10 @@ struct WComputeOpInterface : BufferizableOpInterface::ExternalModel<WComputeOpIn
return {};
}
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const {
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
// Bufferize its block
auto& block = op->getRegion(0).front();
@@ -104,7 +107,10 @@ struct VariadicArgumentElementWiseOpInterface : BufferizableOpInterface::Externa
}
// Cast tensor values into memref values
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const {
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
// Turn Tensor Operands into Memref Operands
SmallVector<Value> memrefOperands;
@@ -151,7 +157,10 @@ struct WeightedMultiplicationsOpInterface : BufferizableOpInterface::ExternalMod
}
// Cast tensor value into memref value
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const {
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto memrefOperandOpt = getBuffer(rewriter, op->getOperand(0), options, state);
if (failed(memrefOperandOpt))
return failure();
@@ -190,7 +199,10 @@ struct ChannelReceiveOpInterface
/*
* Turn the channel receive to pim.recv
*/
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const {
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
@@ -235,7 +247,10 @@ struct ChannelSendOpInterface : BufferizableOpInterface::ExternalModel<ChannelSe
/*
* Turn the channel send to pim.send
*/
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const {
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto srcTensor = op->getOperand(1);
auto srcTensorOpt = getBuffer(rewriter, srcTensor, options, state);
@@ -278,7 +293,10 @@ struct ChannelBroadcastReceiveOpInterface
/*
* Turn the channel receive to pim.load using by creating a new global buffer
*/
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const {
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
@@ -340,7 +358,10 @@ struct ChannelBroadcastSendOpInterface
/*
* Turn the channel send to pim.send
*/
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const {
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
auto srcTensor = op->getOperand(1);
auto srcTensorOpt = getBuffer(rewriter, srcTensor, options, state);
@@ -414,7 +435,10 @@ struct ApplyFiltersOpInterface : BufferizableOpInterface::ExternalModel<ApplyFil
}
// Bufferize the operation.
LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const {
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
BufferizationState& state) const {
// Get the input tensor buffer.
auto inputBuffer = getBuffer(rewriter, op->getOperand(0), options, state);

View File

@@ -1,6 +1,7 @@
#pragma once
#include "mlir/IR/DialectRegistry.h"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;

View File

@@ -1,5 +1,6 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
#include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Compiler/CompilerUtils.hpp"
@@ -10,22 +11,19 @@ namespace onnx_mlir {
namespace {
struct CountInstructionPass
: public PassWrapper<CountInstructionPass, OperationPass<ModuleOp>> {
struct CountInstructionPass : public PassWrapper<CountInstructionPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CountInstructionPass)
StringRef getArgument() const override { return "count-instruction-pass"; }
StringRef getDescription() const override {
return "Count instructions for each core/compute in the module";
}
StringRef getDescription() const override { return "Count instructions for each core/compute in the module"; }
// Make sure that we have a valid default constructor and copy
// constructor to make sure that the options are initialized properly.
CountInstructionPass() {}
CountInstructionPass(const CountInstructionPass &pass)
: PassWrapper<CountInstructionPass, OperationPass<ModuleOp>>() {}
CountInstructionPass(const CountInstructionPass& pass)
: PassWrapper<CountInstructionPass, OperationPass<ModuleOp>>() {}
void runOnOperation() final {
ModuleOp module = getOperation();
@@ -37,8 +35,7 @@ struct CountInstructionPass
for (auto computeOp : func.getOps<spatial::SpatWeightedCompute>()) {
unsigned instructionCount = 0;
instructionCount += computeOp.getBody().front().getOperations().size();
llvm::outs() << "Compute " << computeId << ": " << instructionCount
<< " instructions\n";
llvm::outs() << "Compute " << computeId << ": " << instructionCount << " instructions\n";
totalInstructionCount += instructionCount;
computeId++;
}
@@ -47,21 +44,17 @@ struct CountInstructionPass
for (auto coreOp : func.getOps<pim::PimCoreOp>()) {
unsigned instructionCount = 0;
instructionCount += coreOp.getBody().front().getOperations().size();
llvm::outs() << "Core " << coreId << ": " << instructionCount
<< " instructions\n";
llvm::outs() << "Core " << coreId << ": " << instructionCount << " instructions\n";
totalInstructionCount += instructionCount;
coreId++;
}
llvm::outs() << "Total instruction count: " << totalInstructionCount
<< "\n";
llvm::outs() << "Total instruction count: " << totalInstructionCount << "\n";
}
};
} // namespace
std::unique_ptr<Pass> createCountInstructionPass() {
return std::make_unique<CountInstructionPass>();
}
std::unique_ptr<Pass> createCountInstructionPass() { return std::make_unique<CountInstructionPass>(); }
} // namespace onnx_mlir
} // namespace onnx_mlir

View File

@@ -1,6 +1,7 @@
#pragma once
#include "mlir/Pass/Pass.h"
#include <memory>
using namespace mlir;
@@ -21,4 +22,4 @@ std::unique_ptr<Pass> createMessagePass(std::string message);
std::unique_ptr<Pass> createCountInstructionPass();
} // namespace onnx_mlir
} // namespace onnx_mlir

View File

@@ -1,6 +1,7 @@
#pragma once
#include "mlir/IR/BuiltinTypes.h"
#include "src/Accelerators/Accelerator.hpp"
namespace onnx_mlir {
@@ -9,38 +10,37 @@ namespace accel {
/// Singleton class to construct PIM accelerator.
class PimAccelerator final : public Accelerator {
private:
static PimAccelerator *instance;
static PimAccelerator* instance;
PimAccelerator();
public:
/// Singleton should not be clonable or assignable.
PimAccelerator(PimAccelerator &) = delete;
void operator=(const PimAccelerator &) = delete;
PimAccelerator(PimAccelerator&) = delete;
void operator=(const PimAccelerator&) = delete;
~PimAccelerator();
/// Creates an instance on the first invocation. Subsequent invocations
/// return the existing instance.
static PimAccelerator *getInstance();
static PimAccelerator* getInstance();
/// Define classof to be able to use isa<>, cast<>, dyn_cast<>, etc.
static bool classof(const Accelerator *accel) {
return accel->getKind() == Accelerator::Kind::PIM;
}
static bool classof(const PimAccelerator *) { return true; }
static bool classof(const Accelerator* accel) { return accel->getKind() == Accelerator::Kind::PIM; }
static bool classof(const PimAccelerator*) { return true; }
uint64_t getVersionNumber() const final;
//===--------------------------------------------------------------------===//
// Hooks for onnx-mlir-opt driver
//===--------------------------------------------------------------------===//
virtual void addPasses(mlir::OwningOpRef<mlir::ModuleOp> &module,
mlir::PassManager &pm, onnx_mlir::EmissionTargetType &emissionTarget,
std::string outputNameNoExt) const final;
virtual void addPasses(mlir::OwningOpRef<mlir::ModuleOp>& module,
mlir::PassManager& pm,
onnx_mlir::EmissionTargetType& emissionTarget,
std::string outputNameNoExt) const final;
//===--------------------------------------------------------------------===//
// Hooks for onnx-mlir-opt driver
//===--------------------------------------------------------------------===//
virtual void registerDialects(mlir::DialectRegistry &registry) const final;
virtual void registerDialects(mlir::DialectRegistry& registry) const final;
virtual void registerPasses(int optLevel) const final;
//===--------------------------------------------------------------------===//
// Hooks for both onnx-mlir and onnx-mlir-opt drivers
@@ -49,21 +49,19 @@ public:
//===--------------------------------------------------------------------===//
// Hooks for onnx-to-krnl pass
//===--------------------------------------------------------------------===//
virtual mlir::MemRefType convertTensorTypeToMemRefType(
const mlir::TensorType tensorType) const final;
virtual void conversionTargetONNXToKrnl(
mlir::ConversionTarget &target) const final;
virtual void rewritePatternONNXToKrnl(mlir::RewritePatternSet &patterns,
mlir::TypeConverter &typeConverter, mlir::MLIRContext *ctx) const final;
virtual mlir::MemRefType convertTensorTypeToMemRefType(const mlir::TensorType tensorType) const final;
virtual void conversionTargetONNXToKrnl(mlir::ConversionTarget& target) const final;
virtual void rewritePatternONNXToKrnl(mlir::RewritePatternSet& patterns,
mlir::TypeConverter& typeConverter,
mlir::MLIRContext* ctx) const final;
//===--------------------------------------------------------------------===//
// Hooks for krnl-to-llvm pass
//===--------------------------------------------------------------------===//
virtual void conversionTargetKrnlToLLVM(
mlir::ConversionTarget &target) const final;
virtual void rewritePatternKrnlToLLVM(mlir::RewritePatternSet &patterns,
mlir::LLVMTypeConverter &typeConverter,
mlir::MLIRContext *ctx) const final;
virtual void conversionTargetKrnlToLLVM(mlir::ConversionTarget& target) const final;
virtual void rewritePatternKrnlToLLVM(mlir::RewritePatternSet& patterns,
mlir::LLVMTypeConverter& typeConverter,
mlir::MLIRContext* ctx) const final;
};
} // namespace accel

View File

@@ -63,7 +63,8 @@ void PimBufferizationPass::runOnOperation() {
void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const {
MLIRContext* ctx = funcOp.getContext();
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
bool isAlwaysWeight = llvm::all_of(getGlobalOp->getUsers(), [](auto user) -> bool { return isa<PimCoreOp>(user); }) && !getGlobalOp->getUsers().empty() ;
bool isAlwaysWeight = llvm::all_of(getGlobalOp->getUsers(), [](auto user) -> bool { return isa<PimCoreOp>(user); })
&& !getGlobalOp->getUsers().empty();
if (isAlwaysWeight) {
auto globalMemrefOp = moduleOp.lookupSymbol<memref::GlobalOp>(getGlobalOp.getName());
assert("Weights must be constants" && globalMemrefOp.getConstant());