diff --git a/.clang-format b/.clang-format new file mode 100644 index 0000000..c9f05fd --- /dev/null +++ b/.clang-format @@ -0,0 +1,143 @@ +--- +AccessModifierOffset: -2 +AlignAfterOpenBracket: Align +AlignArrayOfStructures: Left +AlignConsecutiveShortCaseStatements: + Enabled: true +AlignConsecutiveAssignments: + Enabled: false +AlignConsecutiveBitFields: + Enabled: false +AlignConsecutiveDeclarations: + Enabled: false +AlignConsecutiveMacros: + Enabled: false +AlignEscapedNewlines: Left +AlignOperands: AlignAfterOperator +AlignTrailingComments: + Kind: Always + OverEmptyLines: 4 +AllowAllArgumentsOnNextLine: true +AllowAllParametersOfDeclarationOnNextLine: true +AllowShortBlocksOnASingleLine: Empty +AllowShortCaseExpressionOnASingleLine: true +AllowShortCaseLabelsOnASingleLine: true +AllowShortCompoundRequirementOnASingleLine: true +AllowShortEnumsOnASingleLine: false +AllowShortFunctionsOnASingleLine: All +AllowShortIfStatementsOnASingleLine: Never +AllowShortLambdasOnASingleLine: All +AllowShortLoopsOnASingleLine: false +AlwaysBreakBeforeMultilineStrings: false +BinPackArguments: false +BinPackParameters: false +BitFieldColonSpacing: Both +BraceWrapping: + BeforeElse: true + BeforeCatch: true + BeforeWhile: true +BreakBeforeBraces: Custom +BracedInitializerIndentWidth: 2 +BreakAdjacentStringLiterals: true +BreakAfterAttributes: Never +BreakAfterJavaFieldAnnotations: false +BreakArrays: false +BreakBeforeBinaryOperators: NonAssignment +BreakBeforeConceptDeclarations: Always +BreakBeforeInlineASMColon: OnlyMultiline +BreakBeforeTernaryOperators: true +BreakConstructorInitializers: BeforeColon +BreakFunctionDefinitionParameters: false +BreakInheritanceList: AfterComma +BreakStringLiterals: true +BreakTemplateDeclarations: Yes +ColumnLimit: 120 +CompactNamespaces: false +ConstructorInitializerIndentWidth: 0 +ContinuationIndentWidth: 2 +Cpp11BracedListStyle: true +DerivePointerAlignment: false +DisableFormat: false +EmptyLineAfterAccessModifier: Never +EmptyLineBeforeAccessModifier: Always +FixNamespaceComments: true +IncludeBlocks: Regroup +IncludeCategories: + - Regex: (<|")mlir(.*)(>|") + Priority: 2 + - Regex: (<|")llvm(.*)(>|") + Priority: 3 + - Regex: <.*> + Priority: 4 +IncludeIsMainRegex: (Test)?$ +IncludeIsMainSourceRegex: "" +IndentAccessModifiers: false +IndentCaseBlocks: false +IndentCaseLabels: false +IndentExternBlock: Indent +IndentGotoLabels: false +IndentPPDirectives: None +IndentRequiresClause: true +IndentWidth: 2 +IndentWrappedFunctionNames: false +InsertBraces: false +InsertNewlineAtEOF: true +KeepEmptyLines: + AtEndOfFile: false + AtStartOfBlock: true + AtStartOfFile: false +LambdaBodyIndentation: Signature +Language: Cpp +LineEnding: LF +MainIncludeChar: AngleBracket +MaxEmptyLinesToKeep: 1 +NamespaceIndentation: None +PackConstructorInitializers: NextLineOnly +PointerAlignment: Left +ReferenceAlignment: Pointer +RemoveBracesLLVM: true +RemoveParentheses: ReturnStatement +RemoveSemicolon: true +RequiresClausePosition: OwnLine +RequiresExpressionIndentation: OuterScope +SeparateDefinitionBlocks: Leave +ShortNamespaceLines: 4 +SortIncludes: CaseSensitive +SpaceAfterCStyleCast: true +SpaceAfterLogicalNot: false +SpaceAfterTemplateKeyword: true +SpaceAroundPointerQualifiers: Before +SpaceBeforeAssignmentOperators: true +SpaceBeforeCaseColon: false +SpaceBeforeCpp11BracedList: true +SpaceBeforeCtorInitializerColon: true +SpaceBeforeInheritanceColon: true +SpaceBeforeJsonColon: false +SpaceBeforeParensOptions: + AfterControlStatements: true + AfterForeachMacros: true + AfterFunctionDeclarationName: false + AfterFunctionDefinitionName: false + AfterIfMacros: true + AfterOverloadedOperator: false + AfterPlacementOperator: false + AfterRequiresInClause: false + AfterRequiresInExpression: false + BeforeNonEmptyParentheses: false +SpaceBeforeRangeBasedForLoopColon: true +SpaceBeforeSquareBrackets: false +SpaceInEmptyBlock: false +SpacesBeforeTrailingComments: 1 +SpacesInAngles: Never +SpacesInContainerLiterals: false +SpacesInLineCommentPrefix: + Minimum: 1 + Maximum: 1 +SpacesInParens: Never +SpacesInSquareBrackets: false +Standard: Latest +TabWidth: 4 +UseTab: Never +VerilogBreakBetweenInstancePorts: true +Macros: + - LLVM_DEBUG(X)=X\n diff --git a/src/PIM/Common/PIMCommon.hpp b/src/PIM/Common/PIMCommon.hpp index fac20a6..530b973 100644 --- a/src/PIM/Common/PIMCommon.hpp +++ b/src/PIM/Common/PIMCommon.hpp @@ -14,9 +14,9 @@ namespace onnx_mlir { std::string getOutputDir(); -void createDirectory(const std::string &directory); +void createDirectory(const std::string& directory); -void dumpModule(mlir::ModuleOp moduleOp, const std::string &name); +void dumpModule(mlir::ModuleOp moduleOp, const std::string& name); llvm::FailureOr getOtherEndOfChannel(mlir::Operation* op, bool opIsReceive, mlir::RewriterBase& rewriter); diff --git a/src/PIM/Conversion/ONNXToSpatial/Math/Conv.cpp b/src/PIM/Conversion/ONNXToSpatial/Math/Conv.cpp index 255ca55..08e6837 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Math/Conv.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Math/Conv.cpp @@ -11,20 +11,21 @@ #include "mlir/IR/Value.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/DialectConversion.h" -#include "src/Accelerators/PIM/Common/PIMCommon.hpp" -#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" -#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" -#include "src/Dialect/ONNX/ONNXOps.hpp" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/LogicalResult.h" -#include +#include #include #include #include +#include "src/Accelerators/PIM/Common/PIMCommon.hpp" +#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" + using namespace mlir; using namespace std; @@ -40,8 +41,8 @@ namespace onnx_mlir { */ class Core { public: - Core(const size_t coreId, ConversionPatternRewriter &rewriter) - : coreId(coreId), rewriter(rewriter) {} + Core(const size_t coreId, ConversionPatternRewriter& rewriter) + : coreId(coreId), rewriter(rewriter) {} /** * @brief Add a MVM operation to the core. @@ -52,8 +53,7 @@ public: * @param mvmOutType The result's shape. * @return Value The result of the MVM operation. */ - Value addMVM( - Value inputTile, size_t xbarIndex, size_t outputTileId, Type mvmOutType) { + Value addMVM(Value inputTile, size_t xbarIndex, size_t outputTileId, Type mvmOutType) { // Use the inputTile as the reference location for the MVM operation. Location loc = inputTile.getLoc(); @@ -72,8 +72,7 @@ public: // is correct. // Construct the MVM operation - Value result = rewriter.create( - loc, mvmOutType, xbarIndex, operand); + Value result = rewriter.create(loc, mvmOutType, xbarIndex, operand); // Since we are within the same core and no computation can happen in // paralllel, we can just apply a linear reduction in case we have multiple @@ -84,8 +83,7 @@ public: if (lastMVM != outputTileToMVM.end()) { // MVM results should have the same type for reduction. assert(lastMVM->second.getType() == result.getType()); - result = rewriter.create( - loc, mvmOutType, lastMVM->second, result); + result = rewriter.create(loc, mvmOutType, lastMVM->second, result); } outputTileToMVM[outputTileId] = result; @@ -139,16 +137,14 @@ public: spatial::SpatWeightedCompute createWComputeOp(Location loc) { // Get the shape of the results. SmallVector resultTypes; - for (const auto &value : results) { + for (const auto& value : results) resultTypes.push_back(value.getType()); - } // Create the WComputeOp, with non-remappable operands only. - wcomputeOp = rewriter.create( - loc, resultTypes, xbarWeights, operands); + wcomputeOp = rewriter.create(loc, resultTypes, xbarWeights, operands); // Add the body to the WComputeOp. - Block *releasedBlock = block.release(); + Block* releasedBlock = block.release(); wcomputeOp.getBody().push_back(releasedBlock); // Add the `yieldOp` at the end, with the results. @@ -164,21 +160,18 @@ public: void remapResults() { // Remap all the results to the WComputeOp results. assert(resultsToRemap.size() == wcomputeOp->getNumResults()); - for (size_t i = 0; i < resultsToRemap.size(); i++) { + for (size_t i = 0; i < resultsToRemap.size(); i++) *resultsToRemap[i] = wcomputeOp.getResult(i); - } } void addRemappedOperands() { // Insert the remappableOperands (which were remapped in // `addRemappableOperand` of another Core) - for (auto remappedValue : remappableOperands) { + for (auto remappedValue : remappableOperands) wcomputeOp->insertOperands(wcomputeOp->getNumOperands(), *remappedValue); - } // Update the wcomputeOp operandSegmentSize - incrementWeightedComputeInputsSegmentSize( - wcomputeOp, static_cast(remappableOperands.size())); + incrementWeightedComputeInputsSegmentSize(wcomputeOp, static_cast(remappableOperands.size())); } size_t addXbarWeight(Value weight) { @@ -199,31 +192,27 @@ public: llvm::outs() << "Core " << coreId << ":\n"; // Print the weights llvm::outs() << "Xbar Weights:\n"; - for (auto weight : xbarWeights) { + for (auto weight : xbarWeights) weight.dump(); - } // Print the operands llvm::outs() << "Operands:\n"; - for (auto operand : operands) { + for (auto operand : operands) llvm::outs() << operand << "\n"; - } // Dump the body block - for (auto &op : block->getOperations()) { + for (auto& op : block->getOperations()) op.dump(); - } // Print the results llvm::outs() << "Results:\n"; - for (auto result : results) { + for (auto result : results) llvm::outs() << result << "\n"; - } } const size_t coreId; private: - ConversionPatternRewriter &rewriter; + ConversionPatternRewriter& rewriter; // Should these be set instead? But I need to keep the order vector operands; @@ -246,15 +235,16 @@ private: }; struct ONNXConvOpTile : public OpConversionPattern { - ONNXConvOpTile(MLIRContext *ctx) : OpConversionPattern(ctx) {} + ONNXConvOpTile(MLIRContext* ctx) + : OpConversionPattern(ctx) {} struct Producer_t { Value value; shared_ptr core; }; - LogicalResult matchAndRewrite(ONNXConvOp conv, ONNXConvOpAdaptor convAdaptor, - ConversionPatternRewriter &rewriter) const final { + LogicalResult + matchAndRewrite(ONNXConvOp conv, ONNXConvOpAdaptor convAdaptor, ConversionPatternRewriter& rewriter) const final { ShapedType xShape = mlir::cast(convAdaptor.getX().getType()); ShapedType wShape = mlir::cast(convAdaptor.getW().getType()); ShapedType bShape = mlir::cast(convAdaptor.getB().getType()); @@ -264,11 +254,9 @@ struct ONNXConvOpTile : public OpConversionPattern { unpackOptionalPairVector(conv.getStrides(), stride_x, stride_y); unpackOptionalPairVector(conv.getDilations(), dilation_x, dilation_y); - auto padUnpackError = - unpackOptionalPadsVector(convAdaptor.getPads(), pad_x, pad_y); - if (padUnpackError.has_value()) { + auto padUnpackError = unpackOptionalPadsVector(convAdaptor.getPads(), pad_x, pad_y); + if (padUnpackError.has_value()) return rewriter.notifyMatchFailure(conv, padUnpackError.value()); - } // TODO: Pad value at beginning and end of each dimension could be // different. We should handle this case. @@ -296,11 +284,9 @@ struct ONNXConvOpTile : public OpConversionPattern { Location loc = conv.getLoc(); - size_t inputTileCount = - ceilIntegerDivide(GET_IMAGE_CHANNEL(xShape), crossbarSize.getValue()); + size_t inputTileCount = ceilIntegerDivide(GET_IMAGE_CHANNEL(xShape), crossbarSize.getValue()); size_t inputTileRemainder = GET_IMAGE_CHANNEL(xShape) % crossbarSize; - size_t outputTileCount = - ceilIntegerDivide(GET_IMAGE_CHANNEL(yShape), crossbarSize.getValue()); + size_t outputTileCount = ceilIntegerDivide(GET_IMAGE_CHANNEL(yShape), crossbarSize.getValue()); size_t outputTileRemainder = GET_IMAGE_CHANNEL(yShape) % crossbarSize; // Tile the input tensor @@ -310,22 +296,20 @@ struct ONNXConvOpTile : public OpConversionPattern { // c. Pixel `y` position // For example: inputTiles[channelTile][x][y] // Example complete input tensor: tensor<1x3x6x6xf32> (NxCxWxH) - SmallVector>> inputTiles(inputTileCount, - SmallVector>(input_w, SmallVector(input_h))); + SmallVector>> inputTiles( + inputTileCount, SmallVector>(input_w, SmallVector(input_h))); - auto resolveErrorOpt = resolveImgInputTiles(convAdaptor.getX(), inputTiles, - inputTileCount, inputTileRemainder, input_h, input_h, rewriter); - if (resolveErrorOpt.has_value()) { + auto resolveErrorOpt = resolveImgInputTiles( + convAdaptor.getX(), inputTiles, inputTileCount, inputTileRemainder, input_h, input_h, rewriter); + if (resolveErrorOpt.has_value()) return rewriter.notifyMatchFailure(conv, *resolveErrorOpt); - } - SmallVector strides = - SmallVector(4, rewriter.getIndexAttr(1)); - SmallVector offsets = - SmallVector(4, rewriter.getIndexAttr(0)); - SmallVector sizes = SmallVector{ - rewriter.getIndexAttr(1), rewriter.getIndexAttr(crossbarSize), - rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; + SmallVector strides = SmallVector(4, rewriter.getIndexAttr(1)); + SmallVector offsets = SmallVector(4, rewriter.getIndexAttr(0)); + SmallVector sizes = SmallVector {rewriter.getIndexAttr(1), + rewriter.getIndexAttr(crossbarSize), + rewriter.getIndexAttr(1), + rewriter.getIndexAttr(1)}; // Tile the weight tensor // Weight tiles need to be indexed by: @@ -336,31 +320,30 @@ struct ONNXConvOpTile : public OpConversionPattern { // For example: weightTiles[filterTile][channelTile][x][y] // Example complete weight tensor: tensor<32x3x3x3xf32> (FxCxWxH) SmallVector>>> weightTiles( - outputTileCount, - SmallVector>>(inputTileCount, - SmallVector>(krn_w, SmallVector(krn_h)))); + outputTileCount, + SmallVector>>(inputTileCount, + SmallVector>(krn_w, SmallVector(krn_h)))); strides = SmallVector(4, rewriter.getIndexAttr(1)); offsets = SmallVector(4, rewriter.getIndexAttr(0)); sizes = {rewriter.getIndexAttr(crossbarSize), - rewriter.getIndexAttr(crossbarSize), rewriter.getIndexAttr(1), - rewriter.getIndexAttr(1)}; + rewriter.getIndexAttr(crossbarSize), + rewriter.getIndexAttr(1), + rewriter.getIndexAttr(1)}; for (size_t i = 0; i < outputTileCount; i++) { - if (i == outputTileCount - 1 && outputTileRemainder != 0) { + if (i == outputTileCount - 1 && outputTileRemainder != 0) sizes[0] = rewriter.getIndexAttr(outputTileRemainder); - } sizes[1] = rewriter.getIndexAttr(crossbarSize); offsets[0] = rewriter.getIndexAttr(i * crossbarSize); for (size_t j = 0; j < inputTileCount; j++) { - if (j == inputTileCount - 1 && inputTileRemainder != 0) { + if (j == inputTileCount - 1 && inputTileRemainder != 0) sizes[1] = rewriter.getIndexAttr(inputTileRemainder); - } for (size_t x = 0; x < krn_w; x++) { for (size_t y = 0; y < krn_h; y++) { offsets[1] = rewriter.getIndexAttr(j * crossbarSize); offsets[2] = rewriter.getIndexAttr(x); offsets[3] = rewriter.getIndexAttr(y); - weightTiles[i][j][x][y] = rewriter.create( - loc, convAdaptor.getW(), offsets, sizes, strides); + weightTiles[i][j][x][y] = + rewriter.create(loc, convAdaptor.getW(), offsets, sizes, strides); } } } @@ -379,56 +362,45 @@ struct ONNXConvOpTile : public OpConversionPattern { // For example: outputTiles[filterTile][x][y] // Example complete output tensor: tensor<1x32x3x3xf32> (NxFxWxH) SmallVector>>> outputTiles( - outputTileCount, - SmallVector>>( - output_w, SmallVector>(output_h, nullptr))); + outputTileCount, + SmallVector>>(output_w, SmallVector>(output_h, nullptr))); size_t replicationFactor; - if (!conv->hasAttr(REPLICATION_ATTR_NAME)) { + if (!conv->hasAttr(REPLICATION_ATTR_NAME)) replicationFactor = 1; - } else { - replicationFactor = - conv->getAttrOfType(REPLICATION_ATTR_NAME).getInt(); - } + else + replicationFactor = conv->getAttrOfType(REPLICATION_ATTR_NAME).getInt(); // producers[outTile][out_x][out_y][producerIndex] - vector>>> producers = - vector>>>(outputTileCount, - vector>>(output_w, - vector>(output_h, vector()))); + vector>>> producers = vector>>>( + outputTileCount, + vector>>(output_w, vector>(output_h, vector()))); // Schedule in cores size_t coreId = 0; vector> curCores(replicationFactor); - for (size_t i = 0; i < replicationFactor; i++) { + for (size_t i = 0; i < replicationFactor; i++) curCores[i] = make_shared(coreId++, rewriter); - } vector> cores; - const size_t replicationSliceSize = - ceilIntegerDivide(input_w, replicationFactor); + const size_t replicationSliceSize = ceilIntegerDivide(input_w, replicationFactor); for (size_t krn_x = 0; krn_x < krn_h; krn_x++) { for (size_t krn_y = 0; krn_y < krn_w; krn_y++) { RankedTensorType mvmOutType = - RankedTensorType::get({1, static_cast(crossbarSize), 1, 1}, - bShape.getElementType()); + RankedTensorType::get({1, static_cast(crossbarSize), 1, 1}, bShape.getElementType()); for (size_t outTile = 0; outTile < outputTileCount; outTile++) { - if (outTile == outputTileCount - 1 && outputTileRemainder != 0) { - mvmOutType = mvmOutType.clone( - {1, static_cast(outputTileRemainder), 1, 1}); - } + if (outTile == outputTileCount - 1 && outputTileRemainder != 0) + mvmOutType = mvmOutType.clone({1, static_cast(outputTileRemainder), 1, 1}); for (size_t inTile = 0; inTile < inputTileCount; inTile++) { vector xbarIndexes(replicationFactor); - for (size_t i = 0; i < replicationFactor; i++) { - xbarIndexes[i] = curCores[i]->addXbarWeight( - weightTiles[outTile][inTile][krn_x][krn_y]); - } + for (size_t i = 0; i < replicationFactor; i++) + xbarIndexes[i] = curCores[i]->addXbarWeight(weightTiles[outTile][inTile][krn_x][krn_y]); size_t out_x = 0; for (size_t in_x = 0; in_x < input_w; in_x += stride_x) { @@ -443,25 +415,20 @@ struct ONNXConvOpTile : public OpConversionPattern { for (size_t in_y = 0; in_y < input_h; in_y += stride_y) { // Adjust the input based on the kernel - int actual_in_x = in_x - ((int)krn_w / 2) + krn_x * dilation_x; - int actual_in_y = in_y - ((int)krn_h / 2) + krn_y * dilation_y; + int actual_in_x = in_x - ((int) krn_w / 2) + krn_x * dilation_x; + int actual_in_y = in_y - ((int) krn_h / 2) + krn_y * dilation_y; // Check if we are within the input image - if (verifyWithinBoundsAndPaddings(input_w, input_h, actual_in_x, - actual_in_y, pad_x, pad_y) - .failed()) { + if (verifyWithinBoundsAndPaddings(input_w, input_h, actual_in_x, actual_in_y, pad_x, pad_y).failed()) { out_y++; continue; } - size_t outTileId = - outTile * output_w * output_h + out_x * output_h + out_y; + size_t outTileId = outTile * output_w * output_h + out_x * output_h + out_y; auto mvm = curCores[coreIndex]->addMVM( - inputTiles[inTile][actual_in_x][actual_in_y], - xbarIndexes[coreIndex], outTileId, mvmOutType); + inputTiles[inTile][actual_in_x][actual_in_y], xbarIndexes[coreIndex], outTileId, mvmOutType); - producers[outTile][out_x][out_y].push_back( - {mvm, curCores[coreIndex]}); + producers[outTile][out_x][out_y].push_back({mvm, curCores[coreIndex]}); out_y++; } @@ -481,11 +448,9 @@ struct ONNXConvOpTile : public OpConversionPattern { } } - for (auto &curCore : curCores) { - if (curCore->isCoreEmpty() == false) { + for (auto& curCore : curCores) + if (curCore->isCoreEmpty() == false) cores.emplace_back(std::move(curCore)); - } - } curCores.clear(); // Now, do the reduction of each output pixel tile for (size_t outTile = 0; outTile < outputTileCount; outTile++) { @@ -497,9 +462,8 @@ struct ONNXConvOpTile : public OpConversionPattern { // core. std::unordered_map withinCoreReducedProducers; - for (auto producer : producers[outTile][out_x][out_y]) { + for (auto producer : producers[outTile][out_x][out_y]) withinCoreReducedProducers[producer.core->coreId] = producer; - } // Now, we need to apply inter-core reduction @@ -509,8 +473,7 @@ struct ONNXConvOpTile : public OpConversionPattern { auto singleProducer = withinCoreReducedProducers.begin()->second; // Use last producer as the final result - auto reducedValue = - singleProducer.core->makeResultRemappable(singleProducer.value); + auto reducedValue = singleProducer.core->makeResultRemappable(singleProducer.value); outputTiles[outTile][out_x][out_y] = reducedValue; continue; } @@ -535,9 +498,9 @@ struct ONNXConvOpTile : public OpConversionPattern { auto lastProducerCoreId = lastProducer.core->coreId; auto curProducerCoreId = curProducer.core->coreId; - assert(lastProducerCoreId != curProducerCoreId && - "We should have already applied within-core reduction, how " - "could we have same cores here?"); + assert(lastProducerCoreId != curProducerCoreId + && "We should have already applied within-core reduction, how " + "could we have same cores here?"); // Sort the cores by coreId if (curProducerCoreId < lastProducerCoreId) { @@ -545,7 +508,8 @@ struct ONNXConvOpTile : public OpConversionPattern { core1Value = curProducer.value; core2 = lastProducer.core; core2Value = lastProducer.value; - } else { + } + else { core1 = lastProducer.core; core1Value = lastProducer.value; core2 = curProducer.core; @@ -556,9 +520,8 @@ struct ONNXConvOpTile : public OpConversionPattern { auto secondCoreBlockArg = core2->addRemappableOperand(newCoreRes); rewriter.setInsertionPointAfterValue(core2Value); - Value vaddRes = - rewriter.create(core2Value.getLoc(), - core2Value.getType(), core2Value, secondCoreBlockArg); + Value vaddRes = rewriter.create( + core2Value.getLoc(), core2Value.getType(), core2Value, secondCoreBlockArg); lastProducer = {vaddRes, core2}; @@ -568,8 +531,7 @@ struct ONNXConvOpTile : public OpConversionPattern { // TODO: Add the bias and apply mapping (if present) // Use last producer as the final result - auto reducedValue = - lastProducer.core->makeResultRemappable(lastProducer.value); + auto reducedValue = lastProducer.core->makeResultRemappable(lastProducer.value); outputTiles[outTile][out_x][out_y] = reducedValue; } } @@ -578,15 +540,14 @@ struct ONNXConvOpTile : public OpConversionPattern { // Now, we need to turn the cores into a spatial::SpatWeightedCompute. rewriter.setInsertionPointAfter(conv); spatial::SpatWeightedCompute lastWComputeOp; - for (auto &core : cores) { + for (auto& core : cores) { lastWComputeOp = core->createWComputeOp(loc); core->remapResults(); rewriter.setInsertionPointAfter(lastWComputeOp); } - for (auto &core : cores) { + for (auto& core : cores) core->addRemappedOperands(); - } // Set the insertion point after the last WComputeOp. rewriter.setInsertionPointAfter(lastWComputeOp); @@ -597,8 +558,7 @@ struct ONNXConvOpTile : public OpConversionPattern { for (size_t outTile = 0; outTile < outputTileCount; outTile++) tilesToConcat.push_back(*outputTiles[outTile][outX][outY]); - Value outputImage = rewriter.create( - loc, conv.getY().getType(), tilesToConcat); + Value outputImage = rewriter.create(loc, conv.getY().getType(), tilesToConcat); // Value outputImage = // createImgConcatOp(outputTiles, rewriter, loc, Y.getType()); @@ -616,9 +576,8 @@ struct ONNXConvOpTile : public OpConversionPattern { } }; -void populateTilingConvOpPattern( - RewritePatternSet &patterns, MLIRContext *ctx) { +void populateTilingConvOpPattern(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert(ctx); } -} // namespace onnx_mlir \ No newline at end of file +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Math/ExperimentalConv.cpp b/src/PIM/Conversion/ONNXToSpatial/Math/ExperimentalConv.cpp index df70637..fae01b7 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Math/ExperimentalConv.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Math/ExperimentalConv.cpp @@ -1,6 +1,3 @@ -#include "Compiler/PimCompilerOptions.hpp" -#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" -#include "Dialect/Spatial/SpatialOps.hpp" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" @@ -8,13 +5,19 @@ #include "mlir/IR/Value.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/DialectConversion.h" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/WeightSubdivider.hpp" -#include "src/Dialect/ONNX/ONNXOps.hpp" + #include "llvm/ADT/SmallVector.h" + #include #include #include +#include "Compiler/PimCompilerOptions.hpp" +#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" +#include "Dialect/Spatial/SpatialOps.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/WeightSubdivider.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" + using namespace mlir; using namespace std; @@ -27,10 +30,11 @@ namespace onnx_mlir { * output tensor. */ struct ExperimentalONNXConvOpTile : public OpConversionPattern { - ExperimentalONNXConvOpTile(MLIRContext *ctx) : OpConversionPattern(ctx) {} + ExperimentalONNXConvOpTile(MLIRContext* ctx) + : OpConversionPattern(ctx) {} - LogicalResult matchAndRewrite(ONNXConvOp conv, ONNXConvOpAdaptor convAdaptor, - ConversionPatternRewriter &rewriter) const final { + LogicalResult + matchAndRewrite(ONNXConvOp conv, ONNXConvOpAdaptor convAdaptor, ConversionPatternRewriter& rewriter) const final { // --------------------------------- // // --- READ OPERATION PARAMETERS --- // @@ -46,12 +50,12 @@ struct ExperimentalONNXConvOpTile : public OpConversionPattern { ShapedType weightsType = cast(convAdaptor.getW().getType()); // TODO: Address bigger batches. - assert(GET_IMAGE_N(inputType) == 1 && "Batch size must be 1" - "for convolution."); + assert(GET_IMAGE_N(inputType) == 1 + && "Batch size must be 1" + "for convolution."); // TODO: Address replication. - assert(coresCount.getValue() == -1 && - "Replication is not yet supported for convolution."); + assert(coresCount.getValue() == -1 && "Replication is not yet supported for convolution."); // TODO: Address bias addition. @@ -97,11 +101,10 @@ struct ExperimentalONNXConvOpTile : public OpConversionPattern { long inputTile = it; // Create the slicing sizes. - SmallVector slicingSizes{ - /* 0 */ rewriter.getIndexAttr(crossbarHeight), - /* 1 */ rewriter.getIndexAttr(crossbarWidth), - /* 2 */ rewriter.getIndexAttr(1), - /* 3 */ rewriter.getIndexAttr(1)}; + SmallVector slicingSizes {/* 0 */ rewriter.getIndexAttr(crossbarHeight), + /* 1 */ rewriter.getIndexAttr(crossbarWidth), + /* 2 */ rewriter.getIndexAttr(1), + /* 3 */ rewriter.getIndexAttr(1)}; // - Slicing along the filter x position. // - Slicing along the filter y position. @@ -109,16 +112,14 @@ struct ExperimentalONNXConvOpTile : public OpConversionPattern { for (size_t filterY = 0; filterY < kernelHeight; ++filterY) { // Create the slicing offsets. - SmallVector slicingOffsets{ - /* 0 */ rewriter.getIndexAttr(outputTile * crossbarSize), - /* 1 */ rewriter.getIndexAttr(inputTile * crossbarSize), - /* 2 */ rewriter.getIndexAttr(filterX), - /* 3 */ rewriter.getIndexAttr(filterY)}; + SmallVector slicingOffsets {/* 0 */ rewriter.getIndexAttr(outputTile * crossbarSize), + /* 1 */ rewriter.getIndexAttr(inputTile * crossbarSize), + /* 2 */ rewriter.getIndexAttr(filterX), + /* 3 */ rewriter.getIndexAttr(filterY)}; // Create the slice extraction operation. auto extractSliceOp = rewriter.create( - conv.getLoc(), convAdaptor.getW(), slicingOffsets, slicingSizes, - slicingStrides); + conv.getLoc(), convAdaptor.getW(), slicingOffsets, slicingSizes, slicingStrides); // Add a note to the extractSliceOp, with the filterX and filterY. weightsGroups[inputTile][outputTile].push_back(extractSliceOp); @@ -150,8 +151,7 @@ struct ExperimentalONNXConvOpTile : public OpConversionPattern { // -------------------------------- // // Get the next group of weights for the compute unit. - SmallVector weightsGroups = - weightSubdivider.popGroups(crossbarCountInCore.getValue()); + SmallVector weightsGroups = weightSubdivider.popGroups(crossbarCountInCore.getValue()); SmallVector computeWeights; SmallVector computeOperands; @@ -168,15 +168,13 @@ struct ExperimentalONNXConvOpTile : public OpConversionPattern { // Iterate over all weights groups for this compute unit. map localSlices; // WRT the current compute unit. for (auto group : weightsGroups) { - for (Value weight : group.weights) { + for (Value weight : group.weights) computeWeights.push_back(weight); - } // There might be multiple weight groups for the same input tile, so if // we've already added the input tile, skip it. - if (localSlices.find(group.inputTile) != localSlices.end()) { + if (localSlices.find(group.inputTile) != localSlices.end()) continue; - } // We might have already sliced the input tensor for some other compute // unit, so if we have, reuse the slicing operation without creating a @@ -188,26 +186,21 @@ struct ExperimentalONNXConvOpTile : public OpConversionPattern { } // Create the input tensor slicing offsets. - SmallVector slicingOffsets{ - /* 0 */ rewriter.getIndexAttr(0), // No offset along the batch axis. - /* 1 */ rewriter.getIndexAttr(group.inputTile * crossbarSize), - /* 2 */ rewriter.getIndexAttr(0), - /* 3 */ rewriter.getIndexAttr(0)}; + SmallVector slicingOffsets {/* 0 */ rewriter.getIndexAttr(0), // No offset along the batch axis. + /* 1 */ rewriter.getIndexAttr(group.inputTile * crossbarSize), + /* 2 */ rewriter.getIndexAttr(0), + /* 3 */ rewriter.getIndexAttr(0)}; // Create the input tensor slicing sizes. - size_t tilingSize = group.inputTile == inputTileCount.quot - ? inputTileCount.rem - : crossbarSize; - SmallVector slicingSizes{ - /* 0 */ rewriter.getIndexAttr(1), // Batch size is always 1. - /* 1 */ rewriter.getIndexAttr(tilingSize), - /* 2 */ rewriter.getIndexAttr(GET_IMAGE_WIDTH(inputType)), - /* 3 */ rewriter.getIndexAttr(GET_IMAGE_HEIGHT(inputType))}; + size_t tilingSize = group.inputTile == inputTileCount.quot ? inputTileCount.rem : crossbarSize; + SmallVector slicingSizes {/* 0 */ rewriter.getIndexAttr(1), // Batch size is always 1. + /* 1 */ rewriter.getIndexAttr(tilingSize), + /* 2 */ rewriter.getIndexAttr(GET_IMAGE_WIDTH(inputType)), + /* 3 */ rewriter.getIndexAttr(GET_IMAGE_HEIGHT(inputType))}; // Create the slice extraction operation. auto extractSliceOp = rewriter.create( - conv.getLoc(), convAdaptor.getX(), slicingOffsets, slicingSizes, - slicingStrides); + conv.getLoc(), convAdaptor.getX(), slicingOffsets, slicingSizes, slicingStrides); computeOperands.push_back(extractSliceOp); @@ -229,36 +222,28 @@ struct ExperimentalONNXConvOpTile : public OpConversionPattern { // There might be multiple weight groups for the same output tile, so if // we've already added the output tile, skip it. - if (outputTileIndices.find(group.outputTile) != - outputTileIndices.end()) { + if (outputTileIndices.find(group.outputTile) != outputTileIndices.end()) continue; - } // Additionally, after adding the input slices as operands, also add any // compatible partial results from previous compute units. - if (globalPartialResults.find(group.outputTile) != - globalPartialResults.end()) { + if (globalPartialResults.find(group.outputTile) != globalPartialResults.end()) { computeOperands.push_back(globalPartialResults[group.outputTile]); reductionTileIndices[group.outputTile] = computeOperands.size() - 1; } // Define the output shape for this group. - long outputTileSize = group.outputTile == outputTileCount.quot - ? outputTileCount.rem - : crossbarSize; + long outputTileSize = group.outputTile == outputTileCount.quot ? outputTileCount.rem : crossbarSize; // TODO: Address non-same padding. - SmallVector outputShapeArray{ - /* 0 */ 1, // Batch size is always 1. - /* 1 */ outputTileSize, - /* 2 */ GET_IMAGE_WIDTH(outputType), // Same padding assumed. - /* 3 */ GET_IMAGE_HEIGHT(outputType)}; + SmallVector outputShapeArray {/* 0 */ 1, // Batch size is always 1. + /* 1 */ outputTileSize, + /* 2 */ GET_IMAGE_WIDTH(outputType), // Same padding assumed. + /* 3 */ GET_IMAGE_HEIGHT(outputType)}; - auto elementType = - dyn_cast(conv.getY().getType()).getElementType(); + auto elementType = dyn_cast(conv.getY().getType()).getElementType(); - computeOutputType.push_back( - RankedTensorType::get(outputShapeArray, elementType)); + computeOutputType.push_back(RankedTensorType::get(outputShapeArray, elementType)); outputTileIndices[group.outputTile] = computeOutputType.size() - 1; } @@ -268,43 +253,36 @@ struct ExperimentalONNXConvOpTile : public OpConversionPattern { // ----------------------------- // // Create the compute unit. - spatial::SpatWeightedCompute currentCompute = - rewriter.create(conv.getLoc(), - computeOutputType, computeWeights, computeOperands); + spatial::SpatWeightedCompute currentCompute = rewriter.create( + conv.getLoc(), computeOutputType, computeWeights, computeOperands); // Create a new block for the compute unit and add the operands. - Block *block = rewriter.createBlock(¤tCompute.getRegion()); + Block* block = rewriter.createBlock(¤tCompute.getRegion()); rewriter.setInsertionPointToStart(block); - for (Value operand : computeOperands) { + for (Value operand : computeOperands) block->addArgument(operand.getType(), conv->getLoc()); - } // Initialize a map of local partial results. map localPartialResults; // WRT the current compute unit. // If we have any reduction tiles, add them to the local partial results. - for (auto reductionTileIndex : reductionTileIndices) { - localPartialResults[reductionTileIndex.first] = - block->getArgument(reductionTileIndex.second); - } + for (auto reductionTileIndex : reductionTileIndices) + localPartialResults[reductionTileIndex.first] = block->getArgument(reductionTileIndex.second); // Add all the applyFilters operations to the block. for (TaggedWeights group : weightsGroups) { // Get the outputType for this group. - Type outputType = - computeOutputType[outputTileIndices[group.outputTile]]; + Type outputType = computeOutputType[outputTileIndices[group.outputTile]]; // Create an apply filters operation. - BlockArgument blockArgument = - block->getArgument(inputTileIndices[group.inputTile]); + BlockArgument blockArgument = block->getArgument(inputTileIndices[group.inputTile]); // The list of weight indices is group.startingCrossbarIndex + 0, 1, 2, // ... As many weights as the size of group.weights. SmallVector weightIndices; - for (size_t i = 0; i < group.weights.size(); ++i) { + for (size_t i = 0; i < group.weights.size(); ++i) weightIndices.push_back(group.startingCrossbarIndex + i); - } SmallVector xKerPos; SmallVector yKerPos; @@ -323,16 +301,14 @@ struct ExperimentalONNXConvOpTile : public OpConversionPattern { ArrayAttr xKerPosAttr = rewriter.getI64ArrayAttr(xKerPos); ArrayAttr yKerPosAttr = rewriter.getI64ArrayAttr(yKerPos); - Value result = - rewriter.create(conv.getLoc(), outputType, - weightIndicesAttr, xKerPosAttr, yKerPosAttr, blockArgument); + Value result = rewriter.create( + conv.getLoc(), outputType, weightIndicesAttr, xKerPosAttr, yKerPosAttr, blockArgument); // Perform local reduction if necessary. - if (localPartialResults.find(group.outputTile) != - localPartialResults.end()) { + if (localPartialResults.find(group.outputTile) != localPartialResults.end()) { - result = rewriter.create(conv.getLoc(), - result.getType(), localPartialResults[group.outputTile], result); + result = rewriter.create( + conv.getLoc(), result.getType(), localPartialResults[group.outputTile], result); } // Update the partial results map. @@ -385,22 +361,18 @@ struct ExperimentalONNXConvOpTile : public OpConversionPattern { // Turn the values into a SmallVector. SmallVector outputValues; - for (long i = 0; i < outputTileCount.quot + (outputTileCount.rem > 0); - ++i) { + for (long i = 0; i < outputTileCount.quot + (outputTileCount.rem > 0); ++i) outputValues.push_back(globalPartialResults[i]); - } // Assert that the number of output values is correct. - assert(outputValues.size() > 0 && - "No output values were generated for the convolution."); + assert(outputValues.size() > 0 && "No output values were generated for the convolution."); // If the conv's user is a ReLU... if (conv->hasOneUse()) { - Operation *user = *conv->getUsers().begin(); + Operation* user = *conv->getUsers().begin(); if (auto relu = dyn_cast(user)) { // ...then we can just replace the ReLU with the concatenation. - rewriter.replaceOp(relu, - rewriter.create(conv.getLoc(), 1, outputValues)); + rewriter.replaceOp(relu, rewriter.create(conv.getLoc(), 1, outputValues)); // And erase the convolution. rewriter.eraseOp(conv); @@ -409,8 +381,7 @@ struct ExperimentalONNXConvOpTile : public OpConversionPattern { } // Return the final output. - rewriter.replaceOp(conv, - rewriter.create(conv.getLoc(), 1, outputValues)); + rewriter.replaceOp(conv, rewriter.create(conv.getLoc(), 1, outputValues)); return success(); } @@ -422,9 +393,8 @@ struct ExperimentalONNXConvOpTile : public OpConversionPattern { * @param patterns The pattern set to populate. * @param ctx The MLIR context. */ -void populateExperimentalTilingConvOpPattern( - RewritePatternSet &patterns, MLIRContext *ctx) { +void populateExperimentalTilingConvOpPattern(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert(ctx); } -} // namespace onnx_mlir \ No newline at end of file +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Math/ExperimentalGemm.cpp b/src/PIM/Conversion/ONNXToSpatial/Math/ExperimentalGemm.cpp index 9292875..6dc34ab 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Math/ExperimentalGemm.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Math/ExperimentalGemm.cpp @@ -1,24 +1,25 @@ +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/Transforms/DialectConversion.h" + +#include + #include "Compiler/PimCompilerOptions.hpp" #include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" #include "Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp" #include "Conversion/ONNXToSpatial/Utils/WeightSubdivider.hpp" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/Transforms/DialectConversion.h" #include "src/Dialect/ONNX/ONNXOps.hpp" -#include using namespace mlir; using namespace std; namespace onnx_mlir { -struct ExperimentalGemmConversionPattern - : public OpConversionPattern { - ExperimentalGemmConversionPattern(MLIRContext *ctx) - : OpConversionPattern(ctx) {} +struct ExperimentalGemmConversionPattern : public OpConversionPattern { + ExperimentalGemmConversionPattern(MLIRContext* ctx) + : OpConversionPattern(ctx) {} - LogicalResult matchAndRewrite(ONNXGemmOp gemmOp, ONNXGemmOpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final { + LogicalResult + matchAndRewrite(ONNXGemmOp gemmOp, ONNXGemmOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final { // --------------------------------- // // --- READ OPERATION PARAMETERS --- // @@ -34,17 +35,14 @@ struct ExperimentalGemmConversionPattern ShapedType matrixType = cast(adaptor.getB().getType()); // TODO: Address bigger batches. - assert(inputType.getShape()[0] == 1 && - "Only batch size of 1 is supported for GEMM."); + assert(inputType.getShape()[0] == 1 && "Only batch size of 1 is supported for GEMM."); // TODO: Address replication. - assert(coresCount.getValue() == -1 && - "Replication is not yet supported for GEMM."); + assert(coresCount.getValue() == -1 && "Replication is not yet supported for GEMM."); // TODO: Address bias addition. - assert(inputType.getShape()[1] == matrixType.getShape()[0] && - "Input tile size must match the matrix's row size."); + assert(inputType.getShape()[1] == matrixType.getShape()[0] && "Input tile size must match the matrix's row size."); ldiv_t inputTileCount = div(inputType.getShape()[1], crossbarSize); ldiv_t outputTileCount = div(outputType.getShape()[1], crossbarSize); @@ -88,11 +86,10 @@ struct ExperimentalGemmConversionPattern long inputTile = it; // Create the slicing sizes. - SmallVector slicingSizes{ - /* 0 */ rewriter.getIndexAttr(crossbarHeight), - /* 1 */ rewriter.getIndexAttr(crossbarWidth), - /* 2 */ /* rewriter.getIndexAttr(1), */ - /* 3 */ /* rewriter.getIndexAttr(1) */}; + SmallVector slicingSizes {/* 0 */ rewriter.getIndexAttr(crossbarHeight), + /* 1 */ rewriter.getIndexAttr(crossbarWidth), + /* 2 */ /* rewriter.getIndexAttr(1), */ + /* 3 */ /* rewriter.getIndexAttr(1) */}; // - Slicing along the filter x position. // - Slicing along the filter y position. @@ -100,16 +97,14 @@ struct ExperimentalGemmConversionPattern for (size_t filterY = 0; filterY < kernelHeight; ++filterY) { // Create the slicing offsets. - SmallVector slicingOffsets{ - /* 0 */ rewriter.getIndexAttr(outputTile * crossbarSize), - /* 1 */ rewriter.getIndexAttr(inputTile * crossbarSize), - /* 2 */ /* rewriter.getIndexAttr(filterX), */ - /* 3 */ /* rewriter.getIndexAttr(filterY) */}; + SmallVector slicingOffsets {/* 0 */ rewriter.getIndexAttr(outputTile * crossbarSize), + /* 1 */ rewriter.getIndexAttr(inputTile * crossbarSize), + /* 2 */ /* rewriter.getIndexAttr(filterX), */ + /* 3 */ /* rewriter.getIndexAttr(filterY) */}; // Create the slice extraction operation. auto extractSliceOp = rewriter.create( - gemmOp.getLoc(), adaptor.getB(), slicingOffsets, slicingSizes, - slicingStrides); + gemmOp.getLoc(), adaptor.getB(), slicingOffsets, slicingSizes, slicingStrides); // Add a note to the extractSliceOp, with the filterX and filterY. weightsGroups[inputTile][outputTile].push_back(extractSliceOp); @@ -141,8 +136,7 @@ struct ExperimentalGemmConversionPattern // -------------------------------- // // Get the next group of weights for the compute unit. - SmallVector weightsGroups = - weightSubdivider.popGroups(crossbarCountInCore.getValue()); + SmallVector weightsGroups = weightSubdivider.popGroups(crossbarCountInCore.getValue()); SmallVector computeWeights; SmallVector computeOperands; @@ -159,15 +153,13 @@ struct ExperimentalGemmConversionPattern // Iterate over all weights groups for this compute unit. map localSlices; // WRT the current compute unit. for (auto group : weightsGroups) { - for (Value weight : group.weights) { + for (Value weight : group.weights) computeWeights.push_back(weight); - } // There might be multiple weight groups for the same input tile, so if // we've already added the input tile, skip it. - if (localSlices.find(group.inputTile) != localSlices.end()) { + if (localSlices.find(group.inputTile) != localSlices.end()) continue; - } // We might have already sliced the input tensor for some other compute // unit, so if we have, reuse the slicing operation without creating a @@ -179,26 +171,21 @@ struct ExperimentalGemmConversionPattern } // Create the input tensor slicing offsets. - SmallVector slicingOffsets{ - /* 0 */ rewriter.getIndexAttr(0), // No offset along the batch axis. - /* 1 */ rewriter.getIndexAttr(group.inputTile * crossbarSize), - /* 2 */ /* rewriter.getIndexAttr(0), */ - /* 3 */ /* rewriter.getIndexAttr(0) */}; + SmallVector slicingOffsets {/* 0 */ rewriter.getIndexAttr(0), // No offset along the batch axis. + /* 1 */ rewriter.getIndexAttr(group.inputTile * crossbarSize), + /* 2 */ /* rewriter.getIndexAttr(0), */ + /* 3 */ /* rewriter.getIndexAttr(0) */}; // Create the input tensor slicing sizes. - size_t tilingSize = group.inputTile == inputTileCount.quot - ? inputTileCount.rem - : crossbarSize; - SmallVector slicingSizes{ - /* 0 */ rewriter.getIndexAttr(1), // Batch size is always 1. - /* 1 */ rewriter.getIndexAttr(tilingSize), - /* 2 */ /* rewriter.getIndexAttr(GET_IMAGE_WIDTH(inputType)), */ - /* 3 */ /* rewriter.getIndexAttr(GET_IMAGE_HEIGHT(inputType)) */}; + size_t tilingSize = group.inputTile == inputTileCount.quot ? inputTileCount.rem : crossbarSize; + SmallVector slicingSizes {/* 0 */ rewriter.getIndexAttr(1), // Batch size is always 1. + /* 1 */ rewriter.getIndexAttr(tilingSize), + /* 2 */ /* rewriter.getIndexAttr(GET_IMAGE_WIDTH(inputType)), */ + /* 3 */ /* rewriter.getIndexAttr(GET_IMAGE_HEIGHT(inputType)) */}; // Create the slice extraction operation. - auto extractSliceOp = - rewriter.create(gemmOp.getLoc(), - adaptor.getA(), slicingOffsets, slicingSizes, slicingStrides); + auto extractSliceOp = rewriter.create( + gemmOp.getLoc(), adaptor.getA(), slicingOffsets, slicingSizes, slicingStrides); computeOperands.push_back(extractSliceOp); @@ -220,36 +207,28 @@ struct ExperimentalGemmConversionPattern // There might be multiple weight groups for the same output tile, so if // we've already added the output tile, skip it. - if (outputTileIndices.find(group.outputTile) != - outputTileIndices.end()) { + if (outputTileIndices.find(group.outputTile) != outputTileIndices.end()) continue; - } // Additionally, after adding the input slices as operands, also add any // compatible partial results from previous compute units. - if (globalPartialResults.find(group.outputTile) != - globalPartialResults.end()) { + if (globalPartialResults.find(group.outputTile) != globalPartialResults.end()) { computeOperands.push_back(globalPartialResults[group.outputTile]); reductionTileIndices[group.outputTile] = computeOperands.size() - 1; } // Define the output shape for this group. - long outputTileSize = group.outputTile == outputTileCount.quot - ? outputTileCount.rem - : crossbarSize; + long outputTileSize = group.outputTile == outputTileCount.quot ? outputTileCount.rem : crossbarSize; // TODO: Address non-same padding. - SmallVector outputShapeArray{ - /* 0 */ 1, // Batch size is always 1. - /* 1 */ outputTileSize, - /* 2 */ /* GET_IMAGE_WIDTH(outputType), */ // Same padding assumed. - /* 3 */ /* GET_IMAGE_HEIGHT(outputType) */}; + SmallVector outputShapeArray {/* 0 */ 1, // Batch size is always 1. + /* 1 */ outputTileSize, + /* 2 */ /* GET_IMAGE_WIDTH(outputType), */ // Same padding assumed. + /* 3 */ /* GET_IMAGE_HEIGHT(outputType) */}; - auto elementType = dyn_cast(gemmOp.getY().getType()) - .getElementType(); + auto elementType = dyn_cast(gemmOp.getY().getType()).getElementType(); - computeOutputType.push_back( - RankedTensorType::get(outputShapeArray, elementType)); + computeOutputType.push_back(RankedTensorType::get(outputShapeArray, elementType)); outputTileIndices[group.outputTile] = computeOutputType.size() - 1; } @@ -259,43 +238,36 @@ struct ExperimentalGemmConversionPattern // ----------------------------- // // Create the compute unit. - spatial::SpatWeightedCompute currentCompute = - rewriter.create(gemmOp.getLoc(), - computeOutputType, computeWeights, computeOperands); + spatial::SpatWeightedCompute currentCompute = rewriter.create( + gemmOp.getLoc(), computeOutputType, computeWeights, computeOperands); // Create a new block for the compute unit and add the operands. - Block *block = rewriter.createBlock(¤tCompute.getRegion()); + Block* block = rewriter.createBlock(¤tCompute.getRegion()); rewriter.setInsertionPointToStart(block); - for (Value operand : computeOperands) { + for (Value operand : computeOperands) block->addArgument(operand.getType(), gemmOp->getLoc()); - } // Initialize a map of local partial results. map localPartialResults; // WRT the current compute unit. // If we have any reduction tiles, add them to the local partial results. - for (auto reductionTileIndex : reductionTileIndices) { - localPartialResults[reductionTileIndex.first] = - block->getArgument(reductionTileIndex.second); - } + for (auto reductionTileIndex : reductionTileIndices) + localPartialResults[reductionTileIndex.first] = block->getArgument(reductionTileIndex.second); // Add all the applyFilters operations to the block. for (TaggedWeights group : weightsGroups) { // Get the outputType for this group. - Type outputType = - computeOutputType[outputTileIndices[group.outputTile]]; + Type outputType = computeOutputType[outputTileIndices[group.outputTile]]; // Create an apply filters operation. - BlockArgument blockArgument = - block->getArgument(inputTileIndices[group.inputTile]); + BlockArgument blockArgument = block->getArgument(inputTileIndices[group.inputTile]); // The list of weight indices is group.startingCrossbarIndex + 0, 1, 2, // ... As many weights as the size of group.weights. SmallVector weightIndices; - for (size_t i = 0; i < group.weights.size(); ++i) { + for (size_t i = 0; i < group.weights.size(); ++i) weightIndices.push_back(group.startingCrossbarIndex + i); - } SmallVector xKerPos; SmallVector yKerPos; @@ -313,16 +285,14 @@ struct ExperimentalGemmConversionPattern ArrayAttr xKerPosAttr = rewriter.getI64ArrayAttr(xKerPos); ArrayAttr yKerPosAttr = rewriter.getI64ArrayAttr(yKerPos); - Value result = rewriter.create(gemmOp.getLoc(), - outputType, weightIndicesAttr, xKerPosAttr, yKerPosAttr, - blockArgument); + Value result = rewriter.create( + gemmOp.getLoc(), outputType, weightIndicesAttr, xKerPosAttr, yKerPosAttr, blockArgument); // Perform local reduction if necessary. - if (localPartialResults.find(group.outputTile) != - localPartialResults.end()) { + if (localPartialResults.find(group.outputTile) != localPartialResults.end()) { - result = rewriter.create(gemmOp.getLoc(), - result.getType(), localPartialResults[group.outputTile], result); + result = rewriter.create( + gemmOp.getLoc(), result.getType(), localPartialResults[group.outputTile], result); } // Update the partial results map. @@ -375,26 +345,21 @@ struct ExperimentalGemmConversionPattern // Turn the values into a SmallVector. SmallVector outputValues; - for (long i = 0; i < outputTileCount.quot + (outputTileCount.rem > 0); - ++i) { + for (long i = 0; i < outputTileCount.quot + (outputTileCount.rem > 0); ++i) outputValues.push_back(globalPartialResults[i]); - } // Assert that the number of output values is correct. - assert(outputValues.size() > 0 && - "No output values were generated for the GEMM operation."); + assert(outputValues.size() > 0 && "No output values were generated for the GEMM operation."); // Return the final output. - rewriter.replaceOp(gemmOp, - rewriter.create(gemmOp.getLoc(), 1, outputValues)); + rewriter.replaceOp(gemmOp, rewriter.create(gemmOp.getLoc(), 1, outputValues)); return success(); } }; -void populateGemmToConvConversionPattern( - RewritePatternSet &patterns, MLIRContext *ctx) { +void populateGemmToConvConversionPattern(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert(ctx); } -} // namespace onnx_mlir \ No newline at end of file +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/NN/ExperimentalPooling.cpp b/src/PIM/Conversion/ONNXToSpatial/NN/ExperimentalPooling.cpp index 2530e72..eae57f5 100644 --- a/src/PIM/Conversion/ONNXToSpatial/NN/ExperimentalPooling.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/NN/ExperimentalPooling.cpp @@ -6,20 +6,22 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" -#include "src/Accelerators/PIM/Common/PIMCommon.hpp" -#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" -#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" -#include "src/Dialect/ONNX/ONNXOps.hpp" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" + #include #include #include +#include "src/Accelerators/PIM/Common/PIMCommon.hpp" +#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" + using namespace mlir; namespace onnx_mlir { @@ -35,71 +37,68 @@ bool hasPostProcessExperimentalPoolingWindow() { } template -Value postProcessExperimentalPoolingWindow(ConversionPatternRewriter &rewriter, - Location loc, PoolOp poolOp, Value valueToDivide, size_t krn_size, - size_t tilesSkippedByPadding) { +Value postProcessExperimentalPoolingWindow(ConversionPatternRewriter& rewriter, + Location loc, + PoolOp poolOp, + Value valueToDivide, + size_t krn_size, + size_t tilesSkippedByPadding) { return nullptr; } template <> -Value postProcessExperimentalPoolingWindow( - ConversionPatternRewriter &rewriter, Location loc, ONNXAveragePoolOp poolOp, - Value valueToDivide, size_t krn_size, size_t tilesSkippedByPadding) { +Value postProcessExperimentalPoolingWindow(ConversionPatternRewriter& rewriter, + Location loc, + ONNXAveragePoolOp poolOp, + Value valueToDivide, + size_t krn_size, + size_t tilesSkippedByPadding) { bool countIncludePad = poolOp.getCountIncludePad() == 1; - size_t divisorNumber = - countIncludePad ? krn_size : krn_size - tilesSkippedByPadding; + size_t divisorNumber = countIncludePad ? krn_size : krn_size - tilesSkippedByPadding; - RankedTensorType scalarTensor = - RankedTensorType::get({1}, rewriter.getF32Type()); + RankedTensorType scalarTensor = RankedTensorType::get({1}, rewriter.getF32Type()); // Put a spat.const before the computeOp, and use its value. We do this to be // compatible with the current code generation, which assumes constant to be // loaded in global memory, which is allocated by adding a spat.const OP // directly under func.func (i.e. alongside ComputeOps) - auto computeOp = cast( - valueToDivide.getDefiningOp()->getParentOp()); + auto computeOp = cast(valueToDivide.getDefiningOp()->getParentOp()); rewriter.setInsertionPoint(computeOp); - auto divisorValue = rewriter.create(loc, scalarTensor, - rewriter.getI64IntegerAttr(divisorNumber), - /* should_allocate = */ rewriter.getBoolAttr(true)); + auto divisorValue = rewriter.create(loc, + scalarTensor, + rewriter.getI64IntegerAttr(divisorNumber), + /* should_allocate = */ rewriter.getBoolAttr(true)); rewriter.setInsertionPointAfterValue(valueToDivide); - return rewriter.create( - loc, valueToDivide.getType(), valueToDivide, divisorValue); + return rewriter.create(loc, valueToDivide.getType(), valueToDivide, divisorValue); } template -Value reduceInputTiles( - SmallVector &inputTiles, ConversionPatternRewriter &rewriter) { - if (inputTiles.size() == 1) { +Value reduceInputTiles(SmallVector& inputTiles, ConversionPatternRewriter& rewriter) { + if (inputTiles.size() == 1) return inputTiles[0]; - } if (inputTiles.size() == 2) { - return rewriter.create(inputTiles[0].getLoc(), - inputTiles[0].getType(), inputTiles[0], inputTiles[1]); + return rewriter.create( + inputTiles[0].getLoc(), inputTiles[0].getType(), inputTiles[0], inputTiles[1]); } - SmallVector left( - inputTiles.begin(), inputTiles.begin() + inputTiles.size() / 2); - SmallVector right( - inputTiles.begin() + inputTiles.size() / 2, inputTiles.end()); + SmallVector left(inputTiles.begin(), inputTiles.begin() + inputTiles.size() / 2); + SmallVector right(inputTiles.begin() + inputTiles.size() / 2, inputTiles.end()); Value leftReduced = reduceInputTiles(left, rewriter); Value rightReduced = reduceInputTiles(right, rewriter); - return rewriter.create( - inputTiles[0].getLoc(), leftReduced.getType(), leftReduced, rightReduced); + return rewriter.create(inputTiles[0].getLoc(), leftReduced.getType(), leftReduced, rightReduced); } template struct ExperimentalPoolingBaseConverter : public OpConversionPattern { - ExperimentalPoolingBaseConverter(MLIRContext *ctx) - : OpConversionPattern(ctx) {} + ExperimentalPoolingBaseConverter(MLIRContext* ctx) + : OpConversionPattern(ctx) {} - LogicalResult matchAndRewrite(PoolOp poolOp, PoolOpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final { + LogicalResult matchAndRewrite(PoolOp poolOp, PoolOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final { Value X = adaptor.getX(); ShapedType xShape = mlir::cast(X.getType()); Value Y = poolOp.getResult(); @@ -110,17 +109,13 @@ struct ExperimentalPoolingBaseConverter : public OpConversionPattern { unpackOptionalPairVector(adaptor.getDilations(), dilation_x, dilation_y); unpackOptionalPairVector(adaptor.getKernelShape(), krn_w, krn_h); - if (adaptor.getAutoPad() != "NOTSET") { - return rewriter.notifyMatchFailure( - poolOp, "auto_pad != NOTSET is deprecated."); - } + if (adaptor.getAutoPad() != "NOTSET") + return rewriter.notifyMatchFailure(poolOp, "auto_pad != NOTSET is deprecated."); size_t pad_x, pad_y; - auto padUnpackError = - unpackOptionalPadsVector(adaptor.getPads(), pad_x, pad_y); - if (padUnpackError.has_value()) { + auto padUnpackError = unpackOptionalPadsVector(adaptor.getPads(), pad_x, pad_y); + if (padUnpackError.has_value()) return rewriter.notifyMatchFailure(poolOp, padUnpackError.value()); - } Location loc = poolOp.getLoc(); @@ -133,10 +128,8 @@ struct ExperimentalPoolingBaseConverter : public OpConversionPattern { // Assert that the input is a tensor.ConcatOp. auto concat = X.getDefiningOp(); - if (!concat) { - return rewriter.notifyMatchFailure( - poolOp, "Expected input to be a tensor.ConcatOp"); - } + if (!concat) + return rewriter.notifyMatchFailure(poolOp, "Expected input to be a tensor.ConcatOp"); // Create a [channel_tile][x][y] array to store the input tiles. std::map>> inputTiles; @@ -145,24 +138,21 @@ struct ExperimentalPoolingBaseConverter : public OpConversionPattern { for (size_t y = 0; y < input_h; ++y) { for (size_t x = 0; x < input_w; ++x) { for (long it = 0; it < tileCount.quot + (tileCount.rem > 0); ++it) { - size_t tilingSize = - it == tileCount.quot ? tileCount.rem : crossbarSize; + size_t tilingSize = it == tileCount.quot ? tileCount.rem : crossbarSize; SmallVector strides(4, rewriter.getIndexAttr(1)); SmallVector offsets = {/* 0 */ rewriter.getIndexAttr(0), - /* 1 */ rewriter.getIndexAttr(0), - /* 2 */ rewriter.getIndexAttr(x), - /* 3 */ rewriter.getIndexAttr(y)}; - SmallVector sizes = { - /* 0 */ rewriter.getIndexAttr(1), // Batch size is always 1. - /* 1 */ rewriter.getIndexAttr(tilingSize), - /* 2 */ rewriter.getIndexAttr(1), - /* 3 */ rewriter.getIndexAttr(1)}; + /* 1 */ rewriter.getIndexAttr(0), + /* 2 */ rewriter.getIndexAttr(x), + /* 3 */ rewriter.getIndexAttr(y)}; + SmallVector sizes = {/* 0 */ rewriter.getIndexAttr(1), // Batch size is always 1. + /* 1 */ rewriter.getIndexAttr(tilingSize), + /* 2 */ rewriter.getIndexAttr(1), + /* 3 */ rewriter.getIndexAttr(1)}; // Get the concat's operand that we want to slice. Value concatInput = concat.getOperand(it); - Value slicedTile = rewriter.create( - loc, concatInput, offsets, sizes, strides); + Value slicedTile = rewriter.create(loc, concatInput, offsets, sizes, strides); inputTiles[it][x][y] = slicedTile; } @@ -175,19 +165,15 @@ struct ExperimentalPoolingBaseConverter : public OpConversionPattern { for (size_t y = 0; y < output_h; ++y) { for (size_t x = 0; x < output_w; ++x) { for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) { - SmallVector outputShapeArray{ - /* 0 */ 1, // Batch size is always 1. - /* 1 */ - cast(inputTiles[it][0][0].getType()) - .getShape()[1], - /* 2 */ 1, - /* 3 */ 1}; + SmallVector outputShapeArray {/* 0 */ 1, // Batch size is always 1. + /* 1 */ + cast(inputTiles[it][0][0].getType()).getShape()[1], + /* 2 */ 1, + /* 3 */ 1}; - auto elementType = - dyn_cast(xShape).getElementType(); + auto elementType = dyn_cast(xShape).getElementType(); - outputTileTypes.push_back( - RankedTensorType::get(outputShapeArray, elementType)); + outputTileTypes.push_back(RankedTensorType::get(outputShapeArray, elementType)); } } } @@ -195,29 +181,25 @@ struct ExperimentalPoolingBaseConverter : public OpConversionPattern { // Create a plain value list of the input tiles. SmallVector inputTilesList; for (size_t y = 0; y < input_h; ++y) { - for (size_t x = 0; x < input_w; ++x) { - for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) { + for (size_t x = 0; x < input_w; ++x) + for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) inputTilesList.push_back(inputTiles[it][y][x]); - } - } } // Create a single compute to calculate the output. - auto computeOp = rewriter.create( - loc, outputTileTypes, SmallVector(), inputTilesList); + auto computeOp = + rewriter.create(loc, outputTileTypes, SmallVector(), inputTilesList); // Create a new block for the compute unit and add the operands. - Block *block = rewriter.createBlock(&computeOp.getRegion()); + Block* block = rewriter.createBlock(&computeOp.getRegion()); // Fill the block arguments and keep a reference to them. std::map>> inputTilesArgs; for (size_t y = 0; y < input_h; ++y) { for (size_t x = 0; x < input_w; ++x) { for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) { - auto tileIndex = y * input_w * (itc.quot + (itc.rem > 0)) + - x * (itc.quot + (itc.rem > 0)) + it; - inputTilesArgs[it][y][x] = block->addArgument( - computeOp->getOperand(tileIndex).getType(), loc); + auto tileIndex = y * input_w * (itc.quot + (itc.rem > 0)) + x * (itc.quot + (itc.rem > 0)) + it; + inputTilesArgs[it][y][x] = block->addArgument(computeOp->getOperand(tileIndex).getType(), loc); } } } @@ -236,28 +218,26 @@ struct ExperimentalPoolingBaseConverter : public OpConversionPattern { size_t end_y = std::min(start_y + krn_h, input_h); SmallVector inputTilesToReduce; - for (size_t ky = start_y; ky < end_y; ++ky) { - for (size_t kx = start_x; kx < end_x; ++kx) { + for (size_t ky = start_y; ky < end_y; ++ky) + for (size_t kx = start_x; kx < end_x; ++kx) inputTilesToReduce.push_back(inputTilesArgs[it][ky][kx]); - } - } - auto reduceResult = - reduceInputTiles(inputTilesToReduce, rewriter); + auto reduceResult = reduceInputTiles(inputTilesToReduce, rewriter); // If the reduce op is add, we need to divide the result by the // number of elements in the pooling window. if (hasPostProcessExperimentalPoolingWindow()) { // Add a spat.const before the computeOp. rewriter.setInsertionPoint(computeOp); - auto divisorValue = rewriter.create(loc, - RankedTensorType::get({1}, rewriter.getF32Type()), - rewriter.getI64IntegerAttr(krn_w * krn_h), - rewriter.getBoolAttr(true)); + auto divisorValue = + rewriter.create(loc, + RankedTensorType::get({1}, rewriter.getF32Type()), + rewriter.getI64IntegerAttr(krn_w * krn_h), + rewriter.getBoolAttr(true)); rewriter.setInsertionPointAfter(reduceResult.getDefiningOp()); - reduceResult = rewriter.create( - loc, reduceResult.getType(), reduceResult, divisorValue); + reduceResult = + rewriter.create(loc, reduceResult.getType(), reduceResult, divisorValue); } outputTiles.push_back(reduceResult); } @@ -274,8 +254,7 @@ struct ExperimentalPoolingBaseConverter : public OpConversionPattern { for (size_t y = 0; y < output_h; ++y) { for (size_t x = 0; x < output_w; ++x) { for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) { - auto tileIndex = y * output_w * (itc.quot + (itc.rem > 0)) + - x * (itc.quot + (itc.rem > 0)) + it; + auto tileIndex = y * output_w * (itc.quot + (itc.rem > 0)) + x * (itc.quot + (itc.rem > 0)) + it; computeOutput[it][y][x] = computeOp.getResult(tileIndex); } } @@ -285,30 +264,25 @@ struct ExperimentalPoolingBaseConverter : public OpConversionPattern { SmallVector outputTilesList; for (long it = 0; it < itc.quot + (itc.rem > 0); ++it) { SmallVector imgConcatTiles; - for (size_t y = 0; y < output_h; ++y) { - for (size_t x = 0; x < output_w; ++x) { + for (size_t y = 0; y < output_h; ++y) + for (size_t x = 0; x < output_w; ++x) imgConcatTiles.push_back(computeOutput[it][y][x]); - } - } size_t tilingSize = it == tileCount.quot ? tileCount.rem : crossbarSize; - SmallVector outputShapeArray{ - /* 0 */ 1, // Batch size is always 1. - /* 1 */ (long)tilingSize, - /* 2 */ (long)output_w, - /* 3 */ (long)output_h}; + SmallVector outputShapeArray {/* 0 */ 1, // Batch size is always 1. + /* 1 */ (long) tilingSize, + /* 2 */ (long) output_w, + /* 3 */ (long) output_h}; auto elementType = dyn_cast(xShape).getElementType(); - outputTilesList.push_back(rewriter.create(loc, - RankedTensorType::get(outputShapeArray, elementType), - imgConcatTiles)); + outputTilesList.push_back(rewriter.create( + loc, RankedTensorType::get(outputShapeArray, elementType), imgConcatTiles)); } // Create a new tensor.ConcatOp to concatenate the output tiles. - Value outputTensor = - rewriter.create(loc, 1, outputTilesList); + Value outputTensor = rewriter.create(loc, 1, outputTilesList); rewriter.replaceOp(poolOp, outputTensor); @@ -316,12 +290,11 @@ struct ExperimentalPoolingBaseConverter : public OpConversionPattern { } }; -void populateExperimentalPoolingTilingPattern( - RewritePatternSet &patterns, MLIRContext *ctx) { - patterns.insert>(ctx); - patterns.insert>(ctx); +void populateExperimentalPoolingTilingPattern(RewritePatternSet& patterns, MLIRContext* ctx) { + patterns.insert< + ExperimentalPoolingBaseConverter>(ctx); + patterns.insert>( + ctx); } -} // namespace onnx_mlir \ No newline at end of file +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/NN/Pooling.cpp b/src/PIM/Conversion/ONNXToSpatial/NN/Pooling.cpp index 58dc549..e6995bb 100644 --- a/src/PIM/Conversion/ONNXToSpatial/NN/Pooling.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/NN/Pooling.cpp @@ -6,38 +6,39 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" -#include "src/Accelerators/PIM/Common/PIMCommon.hpp" -#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" -#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" -#include "src/Dialect/ONNX/ONNXOps.hpp" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" + #include #include #include +#include "src/Accelerators/PIM/Common/PIMCommon.hpp" +#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" + using namespace mlir; namespace onnx_mlir { -llvm::SmallPtrSet oldComputeOpsReplaced; +llvm::SmallPtrSet oldComputeOpsReplaced; -Value applyReducePatternNew(SmallVector &valuesToReduce, - ConversionPatternRewriter &rewriter, - std::function reduce, - std::function preprocess, - std::function postprocess) { +Value applyReducePatternNew(SmallVector& valuesToReduce, + ConversionPatternRewriter& rewriter, + std::function reduce, + std::function preprocess, + std::function postprocess) { // Simple case: if we have only one input, just return it - if (valuesToReduce.size() == 1) { + if (valuesToReduce.size() == 1) return valuesToReduce[0]; - } if (preprocess) { - for (auto &valToReduce : valuesToReduce) { + for (auto& valToReduce : valuesToReduce) { rewriter.setInsertionPointAfterValue(valToReduce); valToReduce = preprocess(valToReduce); } @@ -47,9 +48,9 @@ Value applyReducePatternNew(SmallVector &valuesToReduce, // computeOp. In this case, we need to apply the reduction within-computef // Keep a map between a computeOp and the last Value for this reduction - std::unordered_map lastValueForCompute; - for (auto &valToReduce : valuesToReduce) { - Operation *computeOp = valToReduce.getParentBlock()->getParentOp(); + std::unordered_map lastValueForCompute; + for (auto& valToReduce : valuesToReduce) { + Operation* computeOp = valToReduce.getParentBlock()->getParentOp(); // if (valToReduce.getDefiningOp()) { // // If the value is defined by an operation, we take the parent // operation computeOp = valToReduce.getDefiningOp()->getParentOp(); @@ -67,12 +68,10 @@ Value applyReducePatternNew(SmallVector &valuesToReduce, // within-compute Value lastWithinComputeValue = it->second; - if (valToReduce.getDefiningOp()->isBeforeInBlock( - lastWithinComputeValue.getDefiningOp())) { + if (valToReduce.getDefiningOp()->isBeforeInBlock(lastWithinComputeValue.getDefiningOp())) rewriter.setInsertionPointAfterValue(lastWithinComputeValue); - } else { + else rewriter.setInsertionPointAfterValue(valToReduce); - } valToReduce = reduce(lastWithinComputeValue, valToReduce); lastValueForCompute[computeOp] = valToReduce; } @@ -83,9 +82,8 @@ Value applyReducePatternNew(SmallVector &valuesToReduce, // Now, reconstruct from the map the valuesToReduce list valuesToReduce.clear(); valuesToReduce.reserve(lastValueForCompute.size()); - for (auto &entry : lastValueForCompute) { + for (auto& entry : lastValueForCompute) valuesToReduce.push_back(entry.second); - } Location loc = valuesToReduce[0].getLoc(); auto channelType = spatial::SpatChannelType::get(rewriter.getContext()); @@ -123,8 +121,7 @@ Value applyReducePatternNew(SmallVector &valuesToReduce, // 3. Add a receiveOp after the second value rewriter.setInsertionPointAfterValue(secondValue); - auto receivedValue = rewriter.create( - loc, secondValue.getType(), channel); + auto receivedValue = rewriter.create(loc, secondValue.getType(), channel); // 4. Apply reduction between second value and received value rewriter.setInsertionPointAfterValue(receivedValue); @@ -135,17 +132,14 @@ Value applyReducePatternNew(SmallVector &valuesToReduce, // If we have an odd number of inputs, we need to add the last one to the // newInputs list. - if (valuesToReduceRef.size() % 2 == 1) { + if (valuesToReduceRef.size() % 2 == 1) nextValuesToReduce.push_back(valuesToReduceRef.back()); - } // Replace the inputOps list with the new one. - valuesToReduceRef = - llvm::OwningArrayRef(std::move(nextValuesToReduce)); + valuesToReduceRef = llvm::OwningArrayRef(std::move(nextValuesToReduce)); } - assert(valuesToReduceRef.size() == 1 && - "Internal error: expected a single input at this point."); + assert(valuesToReduceRef.size() == 1 && "Internal error: expected a single input at this point."); auto finalValue = valuesToReduceRef[0]; @@ -168,46 +162,49 @@ bool hasPostProcessPoolingWindow() { } template -Value postProcessPoolingWindow(ConversionPatternRewriter &rewriter, - Location loc, PoolOp poolOp, Value valueToDivide, size_t krn_size, - size_t tilesSkippedByPadding) { +Value postProcessPoolingWindow(ConversionPatternRewriter& rewriter, + Location loc, + PoolOp poolOp, + Value valueToDivide, + size_t krn_size, + size_t tilesSkippedByPadding) { return nullptr; } template <> -Value postProcessPoolingWindow( - ConversionPatternRewriter &rewriter, Location loc, ONNXAveragePoolOp poolOp, - Value valueToDivide, size_t krn_size, size_t tilesSkippedByPadding) { +Value postProcessPoolingWindow(ConversionPatternRewriter& rewriter, + Location loc, + ONNXAveragePoolOp poolOp, + Value valueToDivide, + size_t krn_size, + size_t tilesSkippedByPadding) { bool countIncludePad = poolOp.getCountIncludePad() == 1; - size_t divisorNumber = - countIncludePad ? krn_size : krn_size - tilesSkippedByPadding; + size_t divisorNumber = countIncludePad ? krn_size : krn_size - tilesSkippedByPadding; - RankedTensorType scalarTensor = - RankedTensorType::get({1}, rewriter.getF32Type()); + RankedTensorType scalarTensor = RankedTensorType::get({1}, rewriter.getF32Type()); // Put a spat.const before the computeOp, and use its value. We do this to be // compatible with the current code generation, which assumes constant to be // loaded in global memory, which is allocated by adding a spat.const OP // directly under func.func (i.e. alongside ComputeOps) - auto computeOp = cast( - valueToDivide.getDefiningOp()->getParentOp()); + auto computeOp = cast(valueToDivide.getDefiningOp()->getParentOp()); rewriter.setInsertionPoint(computeOp); - auto divisorValue = rewriter.create(loc, scalarTensor, - rewriter.getI64IntegerAttr(divisorNumber), - /* should_allocate = */ rewriter.getBoolAttr(true)); + auto divisorValue = rewriter.create(loc, + scalarTensor, + rewriter.getI64IntegerAttr(divisorNumber), + /* should_allocate = */ rewriter.getBoolAttr(true)); rewriter.setInsertionPointAfterValue(valueToDivide); - return rewriter.create( - loc, valueToDivide.getType(), valueToDivide, divisorValue); + return rewriter.create(loc, valueToDivide.getType(), valueToDivide, divisorValue); } template struct PoolingBaseConverter : public OpConversionPattern { - PoolingBaseConverter(MLIRContext *ctx) : OpConversionPattern(ctx) {} + PoolingBaseConverter(MLIRContext* ctx) + : OpConversionPattern(ctx) {} - LogicalResult matchAndRewrite(PoolOp poolOp, PoolOpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final { + LogicalResult matchAndRewrite(PoolOp poolOp, PoolOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final { Value X = adaptor.getX(); ShapedType xShape = mlir::cast(X.getType()); Value Y = poolOp.getResult(); @@ -218,17 +215,13 @@ struct PoolingBaseConverter : public OpConversionPattern { unpackOptionalPairVector(adaptor.getDilations(), dilation_x, dilation_y); unpackOptionalPairVector(adaptor.getKernelShape(), krn_w, krn_h); - if (adaptor.getAutoPad() != "NOTSET") { - return rewriter.notifyMatchFailure( - poolOp, "auto_pad != NOTSET is deprecated."); - } + if (adaptor.getAutoPad() != "NOTSET") + return rewriter.notifyMatchFailure(poolOp, "auto_pad != NOTSET is deprecated."); size_t pad_x, pad_y; - auto padUnpackError = - unpackOptionalPadsVector(adaptor.getPads(), pad_x, pad_y); - if (padUnpackError.has_value()) { + auto padUnpackError = unpackOptionalPadsVector(adaptor.getPads(), pad_x, pad_y); + if (padUnpackError.has_value()) return rewriter.notifyMatchFailure(poolOp, padUnpackError.value()); - } Location loc = poolOp.getLoc(); @@ -236,8 +229,7 @@ struct PoolingBaseConverter : public OpConversionPattern { size_t input_w = GET_IMAGE_WIDTH(xShape); size_t output_h = GET_IMAGE_HEIGHT(yShape); size_t output_w = GET_IMAGE_WIDTH(yShape); - size_t channelTileCount = - ceilIntegerDivide(GET_IMAGE_CHANNEL(xShape), crossbarSize.getValue()); + size_t channelTileCount = ceilIntegerDivide(GET_IMAGE_CHANNEL(xShape), crossbarSize.getValue()); size_t channelTileRest = GET_IMAGE_CHANNEL(xShape) % crossbarSize; // 1: Tile the input tensor @@ -249,14 +241,13 @@ struct PoolingBaseConverter : public OpConversionPattern { // Example complete input tensor: tensor<1x3x12x12xf32> (NxCxWxH) // Suppose that the input tensor is produced by concatenating the results of // many ComputeOps. Get the result tiles from these ComputeOps. - SmallVector>> inputTiles(channelTileCount, - SmallVector>(input_w, SmallVector(input_h))); + SmallVector>> inputTiles( + channelTileCount, SmallVector>(input_w, SmallVector(input_h))); - auto resolveErrorOpt = resolveImgInputTiles(X, inputTiles, channelTileCount, - channelTileRest, input_w, input_h, rewriter); - if (resolveErrorOpt.has_value()) { + auto resolveErrorOpt = + resolveImgInputTiles(X, inputTiles, channelTileCount, channelTileRest, input_w, input_h, rewriter); + if (resolveErrorOpt.has_value()) return rewriter.notifyMatchFailure(poolOp, *resolveErrorOpt); - } // TODO: This requires a core for each input tile, which is not ideal. We // can do better. @@ -265,18 +256,17 @@ struct PoolingBaseConverter : public OpConversionPattern { for (size_t t = 0; t < channelTileCount; t++) { for (size_t x = 0; x < input_w; x++) { for (size_t y = 0; y < input_h; y++) { - if (auto extractSliceOp = - inputTiles[t][x][y].getDefiningOp()) { + if (auto extractSliceOp = inputTiles[t][x][y].getDefiningOp()) { Location tileLoc = extractSliceOp.getLoc(); - auto tempComputeOp = rewriter.create( - tileLoc, extractSliceOp.getResultType(), - /* xbarWeights =*/ValueRange(), extractSliceOp.getResult()); + auto tempComputeOp = rewriter.create(tileLoc, + extractSliceOp.getResultType(), + /* xbarWeights =*/ValueRange(), + extractSliceOp.getResult()); - Block *tempComputeOpBlock = new Block(); + Block* tempComputeOpBlock = new Block(); tempComputeOp.getBody().push_back(tempComputeOpBlock); - auto tempComputeOpBlockArg = tempComputeOpBlock->addArgument( - extractSliceOp.getType(), tileLoc); + auto tempComputeOpBlockArg = tempComputeOpBlock->addArgument(extractSliceOp.getType(), tileLoc); rewriter.setInsertionPointToStart(tempComputeOpBlock); rewriter.create(tileLoc, tempComputeOpBlockArg); @@ -295,8 +285,7 @@ struct PoolingBaseConverter : public OpConversionPattern { // For example: outputTiles[channelTile][x][y] // Example complete output tensor: tensor<1x3x6x6xf32> (NxCxWxH) SmallVector>> outputTiles( - channelTileCount, SmallVector>( - output_w, SmallVector(output_h, nullptr))); + channelTileCount, SmallVector>(output_w, SmallVector(output_h, nullptr))); // List of values to pool for each output pixel SmallVector valuesToPool; @@ -312,15 +301,12 @@ struct PoolingBaseConverter : public OpConversionPattern { valuesToPool.clear(); size_t tilesSkippedByPadding = 0; - auto [start_x, end_x] = kernel_get_start_and_end( - outX, input_w, krn_w, stride_x, dilation_x, pad_x); - auto [start_y, end_y] = kernel_get_start_and_end( - outY, input_h, krn_h, stride_y, dilation_y, pad_y); + auto [start_x, end_x] = kernel_get_start_and_end(outX, input_w, krn_w, stride_x, dilation_x, pad_x); + auto [start_y, end_y] = kernel_get_start_and_end(outY, input_h, krn_h, stride_y, dilation_y, pad_y); for (size_t inX = start_x; inX < end_x; inX += dilation_x) { for (size_t inY = start_y; inY < end_y; inY += dilation_y) { - if (failed(verifyWithinBoundsAndPaddings( - input_w, input_h, inX, inY, pad_x, pad_y))) { + if (failed(verifyWithinBoundsAndPaddings(input_w, input_h, inX, inY, pad_x, pad_y))) { tilesSkippedByPadding++; continue; } @@ -328,78 +314,73 @@ struct PoolingBaseConverter : public OpConversionPattern { Value inputTile = inputTiles[outTile][inX][inY]; Value valueToPool; - if (auto computeProducer = - inputTile.getDefiningOp()) { + if (auto computeProducer = inputTile.getDefiningOp()) { int resultNumber = getResultIndex(computeProducer, inputTile); - auto yieldInComputeOp = cast( - computeProducer.getBody().front().getTerminator()); + auto yieldInComputeOp = cast(computeProducer.getBody().front().getTerminator()); valueToPool = yieldInComputeOp.getOperand(resultNumber); - } else if (auto receiveProducer = - inputTile - .getDefiningOp()) { - auto sendOpOpt = - getOtherEndOfChannel(receiveProducer, true, rewriter); + } + else if (auto receiveProducer = inputTile.getDefiningOp()) { + auto sendOpOpt = getOtherEndOfChannel(receiveProducer, true, rewriter); if (failed(sendOpOpt)) { return rewriter.notifyMatchFailure(poolOp, - "ChannelReceiveOp does not have a matching " - "ChannelSendOp."); + "ChannelReceiveOp does not have a matching " + "ChannelSendOp."); } auto sendOp = cast(*sendOpOpt); valueToPool = sendOp.getData(); - } else { + } + else { return rewriter.notifyMatchFailure(poolOp, - "Input tile for Pooling is not produced by a " - "WeightedComputeOp nor a receiveOp"); + "Input tile for Pooling is not produced by a " + "WeightedComputeOp nor a receiveOp"); } valuesToPool.push_back(valueToPool); } } - assert(valuesToPool.size() != 0 && - "Pooling computed on zero tiles make no sense."); + assert(valuesToPool.size() != 0 && "Pooling computed on zero tiles make no sense."); // assert(computeOpsForPooling.size() != 1 && // "Pooling computed on one tiles make no sense??? Or maybe // this " "should have been simplified earlier???"); - std::function postProcessFn = nullptr; + std::function postProcessFn = nullptr; if (hasPostProcessPoolingWindow()) { postProcessFn = [&](const Value prevFinalRes) { - return postProcessPoolingWindow(rewriter, loc, poolOp, - prevFinalRes, krn_h * krn_w, tilesSkippedByPadding); + return postProcessPoolingWindow( + rewriter, loc, poolOp, prevFinalRes, krn_h * krn_w, tilesSkippedByPadding); }; } Value reducedWithinCompute = applyReducePatternNew( - valuesToPool, rewriter, - [&](const Value lhs, const Value rhs) { - return rewriter.create(loc, lhs.getType(), lhs, rhs); - }, - nullptr, postProcessFn); + valuesToPool, + rewriter, + [&](const Value lhs, const Value rhs) { return rewriter.create(loc, lhs.getType(), lhs, rhs); }, + nullptr, + postProcessFn); // Send this value through a channel, and receive it in the // `func.func`. During lowering, we will need to "move it" into the // users computeOps - auto computeOpOfReduced = cast( - reducedWithinCompute.getDefiningOp()->getParentOp()); + auto computeOpOfReduced = + cast(reducedWithinCompute.getDefiningOp()->getParentOp()); // Create a new channel before the computeOp rewriter.setInsertionPoint(computeOpOfReduced); - auto reduceChannel = rewriter.create( - loc, spatial::SpatChannelType::get(rewriter.getContext())); + auto reduceChannel = + rewriter.create(loc, spatial::SpatChannelType::get(rewriter.getContext())); // Send value through the channel rewriter.setInsertionPointAfterValue(reducedWithinCompute); - rewriter.create( - loc, reduceChannel, reducedWithinCompute); + rewriter.create(loc, reduceChannel, reducedWithinCompute); // Receive after the computeOp rewriter.setInsertionPointAfter(computeOpOfReduced); - auto receivedValue = rewriter.create( - loc, reducedWithinCompute.getType(), reduceChannel); + auto receivedValue = + rewriter.create(loc, reducedWithinCompute.getType(), reduceChannel); outputTiles[outTile][outX][outY] = receivedValue; } @@ -409,9 +390,7 @@ struct PoolingBaseConverter : public OpConversionPattern { // TODO: outputTiles are not the results of the computeOps! We need to add // them! - std::unordered_map>> - computeOpNeedingResults; + std::unordered_map>> computeOpNeedingResults; // Iterate each output tile for (size_t outTile = 0; outTile < channelTileCount; outTile++) { @@ -422,18 +401,16 @@ struct PoolingBaseConverter : public OpConversionPattern { auto outputTileProducer = outputTile.getDefiningOp()->getParentOp(); if (!outputTileProducer) { return rewriter.notifyMatchFailure(poolOp, - "Output tile for Pooling is not produced by a " - "WeightedComputeOp."); + "Output tile for Pooling is not produced by a " + "WeightedComputeOp."); } - computeOpNeedingResults[outputTileProducer].push_back( - std::make_tuple(outTile, outX, outY, outputTile)); + computeOpNeedingResults[outputTileProducer].push_back(std::make_tuple(outTile, outX, outY, outputTile)); } } } - Value outputImage = - createImgConcatOp(outputTiles, rewriter, loc, poolOp.getType()); + Value outputImage = createImgConcatOp(outputTiles, rewriter, loc, poolOp.getType()); rewriter.replaceOp(poolOp, outputImage); @@ -441,12 +418,10 @@ struct PoolingBaseConverter : public OpConversionPattern { } }; -void populatePoolingTilingPattern( - RewritePatternSet &patterns, MLIRContext *ctx) { - patterns.insert>(ctx); - patterns.insert>(ctx); +void populatePoolingTilingPattern(RewritePatternSet& patterns, MLIRContext* ctx) { + patterns.insert>( + ctx); + patterns.insert>(ctx); } -} // namespace onnx_mlir \ No newline at end of file +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/NN/ReduceMean.cpp b/src/PIM/Conversion/ONNXToSpatial/NN/ReduceMean.cpp index f8e44ab..5906242 100644 --- a/src/PIM/Conversion/ONNXToSpatial/NN/ReduceMean.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/NN/ReduceMean.cpp @@ -1,20 +1,19 @@ - +#include "mlir/Transforms/DialectConversion.h" #include "Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp" -#include "mlir/Transforms/DialectConversion.h" #include "src/Dialect/ONNX/ONNXOps.hpp" using namespace mlir; namespace onnx_mlir { -struct ReduceMeanConversionPattern - : public OpConversionPattern { +struct ReduceMeanConversionPattern : public OpConversionPattern { - ReduceMeanConversionPattern(MLIRContext *ctx) : OpConversionPattern(ctx) {} + ReduceMeanConversionPattern(MLIRContext* ctx) + : OpConversionPattern(ctx) {} LogicalResult matchAndRewrite(ONNXReduceMeanV13Op reduceMean, - ONNXReduceMeanV13OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final { + ONNXReduceMeanV13OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const final { // Get the input tensor. Value inputTensor = adaptor.getData(); @@ -38,42 +37,42 @@ struct ReduceMeanConversionPattern SmallVector padsVals = {0, 0, 0, 0}; // Create the ArrayAttrs - auto kernelShape = mlir::ArrayAttr::get(rewriter.getContext(), - llvm::to_vector( - llvm::map_range(kernelShapeVals, [&](int64_t v) -> mlir::Attribute { - return rewriter.getI64IntegerAttr(v); - }))); + auto kernelShape = mlir::ArrayAttr::get( + rewriter.getContext(), llvm::to_vector(llvm::map_range(kernelShapeVals, [&](int64_t v) -> mlir::Attribute { + return rewriter.getI64IntegerAttr(v); + }))); auto strides = mlir::ArrayAttr::get(rewriter.getContext(), - llvm::to_vector( - llvm::map_range(stridesVals, [&](int64_t v) -> mlir::Attribute { - return rewriter.getI64IntegerAttr(v); - }))); + llvm::to_vector(llvm::map_range(stridesVals, [&](int64_t v) -> mlir::Attribute { + return rewriter.getI64IntegerAttr(v); + }))); - auto dilations = mlir::ArrayAttr::get(rewriter.getContext(), - llvm::to_vector( - llvm::map_range(dilationsVals, [&](int64_t v) -> mlir::Attribute { - return rewriter.getI64IntegerAttr(v); - }))); + auto dilations = mlir::ArrayAttr::get( + rewriter.getContext(), llvm::to_vector(llvm::map_range(dilationsVals, [&](int64_t v) -> mlir::Attribute { + return rewriter.getI64IntegerAttr(v); + }))); auto pads = mlir::ArrayAttr::get(rewriter.getContext(), - llvm::to_vector( - llvm::map_range(padsVals, [&](int64_t v) -> mlir::Attribute { - return rewriter.getI64IntegerAttr(v); - }))); + llvm::to_vector(llvm::map_range(padsVals, [&](int64_t v) -> mlir::Attribute { + return rewriter.getI64IntegerAttr(v); + }))); // Create the resulting tensor type. auto resultType = RankedTensorType::get( - /*shape=*/{inputTensorType.getShape()[0], inputTensorType.getShape()[1], - 1, 1}, - /*elementType=*/inputTensorType.getElementType()); + /*shape=*/ {inputTensorType.getShape()[0], inputTensorType.getShape()[1], 1, 1}, + /*elementType=*/inputTensorType.getElementType()); // Create the ONNXAveragePoolOp. auto averagePool = rewriter.create(reduceMean.getLoc(), - resultType, inputTensor, /*auto_pad=*/"NOTSET", - /*ceil_mode=*/0, /*count_include_pad=*/1, dilations, - /*kernel_shape=*/kernelShape, - /*pads=*/pads, /*strides=*/strides); + resultType, + inputTensor, + /*auto_pad=*/"NOTSET", + /*ceil_mode=*/0, + /*count_include_pad=*/1, + dilations, + /*kernel_shape=*/kernelShape, + /*pads=*/pads, + /*strides=*/strides); // Replace the ONNXReduceMeanV13Op with the ONNXAveragePoolOp. rewriter.replaceOp(reduceMean, averagePool.getResult()); @@ -82,9 +81,8 @@ struct ReduceMeanConversionPattern } }; -void populateReduceMeanConversionPattern( - RewritePatternSet &patterns, MLIRContext *ctx) { +void populateReduceMeanConversionPattern(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert(ctx); } -} // namespace onnx_mlir \ No newline at end of file +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp index 32fd6ca..52ec881 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp @@ -6,11 +6,12 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/DialectConversion.h" -#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" -#include "src/Dialect/ONNX/ONNXOps.hpp" #include "llvm/Support/LogicalResult.h" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" + #define DEFINE_MAP_OP(opname) opname, #define GET_IMAGE_WIDTH(shapedType) shapedType.getDimSize(2) diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp index 0981380..64fdb5e 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp @@ -6,7 +6,6 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/raw_os_ostream.h" -#include #include #include "Common/PIMCommon.hpp" @@ -16,7 +15,6 @@ #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp" #include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" -#include "src/Accelerators/PIM/Pass/PimPasses.hpp" #include "src/Compiler/CompilerOptions.hpp" using namespace mlir; @@ -39,7 +37,7 @@ void ONNXToSpatialPass::runOnOperation() { mergeActivationPatterns.add(ctx); mergeActivationPatterns.add(ctx); - if (failed(applyPatternsAndFoldGreedily(moduleOp, std::move(mergeActivationPatterns)))) + if (failed(applyPatternsGreedily(moduleOp, std::move(mergeActivationPatterns)))) llvm::dbgs() << "Failed to merge activation patterns, continuing...\n"; IRRewriter rewriter(moduleOp); @@ -88,7 +86,7 @@ void ONNXToSpatialPass::runOnOperation() { if (coresCount != -1) { int computeOpsCount = 0; for (auto& op : funcOp.getFunctionBody().front().getOperations()) - if (isa(op)) + if (isa(op)) computeOpsCount++; if (computeOpsCount > coresCount) { diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp index 31fd82f..014cbd5 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp @@ -3,38 +3,26 @@ namespace onnx_mlir { -void populateLoweringONNXMatMulOpToSpatialPattern( - mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx); +void populateLoweringONNXMatMulOpToSpatialPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); -void populateTilingGemmOpPattern( - mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx); -void populateTilingConvOpPattern( - mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx); +void populateTilingGemmOpPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); +void populateTilingConvOpPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); -void populatePoolingTilingPattern( - mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx); +void populatePoolingTilingPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); -void populateDistributeReducePattern( - mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx); +void populateDistributeReducePattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); -void populateFoldComputePattern( - mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx); +void populateFoldComputePattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); -void populateONNXConcatToTensorConcatPattern( - mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx); +void populateONNXConcatToTensorConcatPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); -void populateRemoveUnusedHelperOpsPatterns( - mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx); +void populateRemoveUnusedHelperOpsPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); -void populateReduceMeanConversionPattern( - mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx); +void populateReduceMeanConversionPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); // Experimental patterns. -void populateExperimentalTilingConvOpPattern( - mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx); -void populateGemmToConvConversionPattern( - mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx); -void populateExperimentalPoolingTilingPattern( - mlir::RewritePatternSet &patterns, mlir::MLIRContext *ctx); +void populateExperimentalTilingConvOpPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); +void populateGemmToConvConversionPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); +void populateExperimentalPoolingTilingPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); -} // namespace onnx_mlir \ No newline at end of file +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Tensor/ONNXConcatToTensorConcat.cpp b/src/PIM/Conversion/ONNXToSpatial/Tensor/ONNXConcatToTensorConcat.cpp index fa93f91..d640073 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Tensor/ONNXConcatToTensorConcat.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Tensor/ONNXConcatToTensorConcat.cpp @@ -1,19 +1,20 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/PatternMatch.h" -#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" using namespace mlir; namespace onnx_mlir { struct ONNXConcatToTensorConcat : public OpConversionPattern { - ONNXConcatToTensorConcat(MLIRContext *ctx) : OpConversionPattern(ctx) {} + ONNXConcatToTensorConcat(MLIRContext* ctx) + : OpConversionPattern(ctx) {} LogicalResult matchAndRewrite(ONNXConcatOp maxpoolOp, - ONNXConcatOpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final { + ONNXConcatOpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const final { auto inputs = adaptor.getInputs(); int64_t axis = adaptor.getAxis(); @@ -23,9 +24,8 @@ struct ONNXConcatToTensorConcat : public OpConversionPattern { } }; -void populateONNXConcatToTensorConcatPattern( - RewritePatternSet &patterns, MLIRContext *ctx) { +void populateONNXConcatToTensorConcatPattern(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert(ctx); } -} // namespace onnx_mlir \ No newline at end of file +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Tensor/RemoveUnusedHelperOps.cpp b/src/PIM/Conversion/ONNXToSpatial/Tensor/RemoveUnusedHelperOps.cpp index abc87ed..ec7a874 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Tensor/RemoveUnusedHelperOps.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Tensor/RemoveUnusedHelperOps.cpp @@ -1,5 +1,6 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/PatternMatch.h" + #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" diff --git a/src/PIM/Conversion/ONNXToSpatial/Utils/AnnotateReplication.cpp b/src/PIM/Conversion/ONNXToSpatial/Utils/AnnotateReplication.cpp index c335fd8..a9b0c70 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Utils/AnnotateReplication.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Utils/AnnotateReplication.cpp @@ -1,10 +1,10 @@ +#include + #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/AnnotateReplication.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" -#include - using namespace mlir; namespace onnx_mlir { diff --git a/src/PIM/Conversion/ONNXToSpatial/Utils/AnnotateReplication.hpp b/src/PIM/Conversion/ONNXToSpatial/Utils/AnnotateReplication.hpp index 108aa07..ebd859d 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Utils/AnnotateReplication.hpp +++ b/src/PIM/Conversion/ONNXToSpatial/Utils/AnnotateReplication.hpp @@ -5,7 +5,6 @@ namespace onnx_mlir { -mlir::LogicalResult annotateReplication( - mlir::func::FuncOp funcOp, mlir::IRRewriter &rewriter); +mlir::LogicalResult annotateReplication(mlir::func::FuncOp funcOp, mlir::IRRewriter& rewriter); -} // namespace onnx_mlir \ No newline at end of file +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.cpp b/src/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.cpp index ad0267a..faed00d 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.cpp @@ -1,32 +1,31 @@ - -#include "SpatialReducer.hpp" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Value.h" -#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" + #include "llvm/Support/raw_ostream.h" + #include #include #include +#include "SpatialReducer.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" + #define GET_COMP(computeOpAndResNum) std::get<0>(computeOpAndResNum) #define GET_RES_NUM(computeOpAndResNum) std::get<1>(computeOpAndResNum) namespace onnx_mlir { -llvm::SmallPtrSet - onnx_mlir::SpatialReducer::oldComputeOpsReplaced; +llvm::SmallPtrSet onnx_mlir::SpatialReducer::oldComputeOpsReplaced; -ResNum SpatialReducer::applyResultProcessing( - ComputeAndResNum computeOpAndResNum, - std::function processFun, - ConversionPatternRewriter &rewriter) { +ResNum SpatialReducer::applyResultProcessing(ComputeAndResNum computeOpAndResNum, + std::function processFun, + ConversionPatternRewriter& rewriter) { assert(processFun); auto computeOp = GET_COMP(computeOpAndResNum); auto resultNum = GET_RES_NUM(computeOpAndResNum); - spatial::SpatYieldOp yieldOp = - cast(computeOp.getBody().front().getTerminator()); + spatial::SpatYieldOp yieldOp = cast(computeOp.getBody().front().getTerminator()); Value result = yieldOp->getOperand(resultNum); rewriter.setInsertionPointAfterValue(result); @@ -43,30 +42,24 @@ ResNum SpatialReducer::applyResultProcessing( return yieldOp.getNumOperands() - 1; } -OpAndResNum SpatialReducer::applyReducePattern( - SmallVector &computeOpsAndResNum, - std::function reduce, - std::function preprocess, - std::function postprocess) { +OpAndResNum SpatialReducer::applyReducePattern(SmallVector& computeOpsAndResNum, + std::function reduce, + std::function preprocess, + std::function postprocess) { - if (preprocess) { - for (auto &computeOpAndResNum : computeOpsAndResNum) { - GET_RES_NUM(computeOpAndResNum) = - applyResultProcessing(computeOpAndResNum, preprocess, rewriter); - } - } + if (preprocess) + for (auto& computeOpAndResNum : computeOpsAndResNum) + GET_RES_NUM(computeOpAndResNum) = applyResultProcessing(computeOpAndResNum, preprocess, rewriter); // It is possible that `computeOpsAndResNum` contains two entries for the same // computeOp. In this case, we need to apply the reduction within-computef // Keep a map between a computeOp and the last Value for this reduction - std::unordered_map lastValueForCompute; - for (auto &computeOpAndResNum : computeOpsAndResNum) { + std::unordered_map lastValueForCompute; + for (auto& computeOpAndResNum : computeOpsAndResNum) { auto computeOp = GET_COMP(computeOpAndResNum); - auto yieldOp = - cast(computeOp.getBody().front().getTerminator()); - Value valueWithinCompute = - yieldOp->getOperand(GET_RES_NUM(computeOpAndResNum)); + auto yieldOp = cast(computeOp.getBody().front().getTerminator()); + Value valueWithinCompute = yieldOp->getOperand(GET_RES_NUM(computeOpAndResNum)); auto it = lastValueForCompute.find(computeOp.getOperation()); @@ -75,15 +68,12 @@ OpAndResNum SpatialReducer::applyReducePattern( // within-compute Value lastWithinComputeValue = it->second; - assert(valueWithinCompute.getDefiningOp() && - lastWithinComputeValue.getDefiningOp()); + assert(valueWithinCompute.getDefiningOp() && lastWithinComputeValue.getDefiningOp()); - if (valueWithinCompute.getDefiningOp()->isBeforeInBlock( - lastWithinComputeValue.getDefiningOp())) { + if (valueWithinCompute.getDefiningOp()->isBeforeInBlock(lastWithinComputeValue.getDefiningOp())) rewriter.setInsertionPointAfterValue(lastWithinComputeValue); - } else { + else rewriter.setInsertionPointAfterValue(valueWithinCompute); - } valueWithinCompute = reduce(lastWithinComputeValue, valueWithinCompute); lastValueForCompute[computeOp.getOperation()] = valueWithinCompute; } @@ -94,16 +84,15 @@ OpAndResNum SpatialReducer::applyReducePattern( // Now, reconstruct from the map the computeOpsAndResNum list computeOpsAndResNum.clear(); computeOpsAndResNum.reserve(lastValueForCompute.size()); - for (auto &entry : lastValueForCompute) { + for (auto& entry : lastValueForCompute) { auto computeOp = cast(entry.first); auto valueWithinCompute = entry.second; // We check if `valueWithinCompute` is already used by the yieldOp, in that // case no need to add it - auto yieldOp = - cast(computeOp.getBody().front().getTerminator()); + auto yieldOp = cast(computeOp.getBody().front().getTerminator()); bool yieldOpUseFound = false; - for (auto &use : valueWithinCompute.getUses()) { + for (auto& use : valueWithinCompute.getUses()) { if (use.getOwner() == yieldOp.getOperation()) { // If the value is already used by the yieldOp, we can just use it computeOpsAndResNum.push_back({computeOp, use.getOperandNumber()}); @@ -111,9 +100,8 @@ OpAndResNum SpatialReducer::applyReducePattern( break; } } - if (yieldOpUseFound) { + if (yieldOpUseFound) continue; - } // If this result is not used within a yieldOp, then add it auto resultNum = yieldOp->getNumOperands(); @@ -147,23 +135,18 @@ OpAndResNum SpatialReducer::applyReducePattern( // the number of results) // See below `reducerChanges.push_back` and `finalizeReduceUpdates` - auto yieldOpFirstCompute = cast( - firstCompute.getBody().front().getTerminator()); + auto yieldOpFirstCompute = cast(firstCompute.getBody().front().getTerminator()); // Add a new operand to the block of the second computeOp - Block &secondBlock = secondCompute.getBody().front(); - Value formerRes1 = secondBlock.addArgument( - yieldOpFirstCompute->getOperand(firstResultNum).getType(), loc); + Block& secondBlock = secondCompute.getBody().front(); + Value formerRes1 = secondBlock.addArgument(yieldOpFirstCompute->getOperand(firstResultNum).getType(), loc); auto secondComputeWeightsNum = - secondCompute->getAttrOfType( - secondCompute.getOperandSegmentSizesAttrName())[0]; - auto secondComputeOperandNum = - secondComputeWeightsNum + secondBlock.getNumArguments() - 1; + secondCompute->getAttrOfType(secondCompute.getOperandSegmentSizesAttrName())[0]; + auto secondComputeOperandNum = secondComputeWeightsNum + secondBlock.getNumArguments() - 1; // Take the "former-result" from the second computeOp - spatial::SpatYieldOp secondYield = - cast(secondBlock.getTerminator()); + spatial::SpatYieldOp secondYield = cast(secondBlock.getTerminator()); Value formerRes2 = secondYield.getOperand(secondResultNum); // Apply reduction operation @@ -184,37 +167,31 @@ OpAndResNum SpatialReducer::applyReducePattern( // We should also add an entry for updating the results of the last // operation (the one which never becomes a `firstCompute`): because it is // not tracked by reducerChanges as `fromOp` - reducerChanges.push_back({firstCompute.getOperation(), firstResultNum, - secondCompute.getOperation(), secondComputeOperandNum}); + reducerChanges.push_back( + {firstCompute.getOperation(), firstResultNum, secondCompute.getOperation(), secondComputeOperandNum}); nextComputeOps.push_back(std::make_pair(secondCompute, secondResultNum)); } // If we have an odd number of inputs, we need to add the last one to the // newInputs list. - if (computeOpsRef.size() % 2 == 1) { + if (computeOpsRef.size() % 2 == 1) nextComputeOps.push_back(computeOpsRef.back()); - } // Replace the inputOps list with the new one. - computeOpsRef = - llvm::OwningArrayRef(std::move(nextComputeOps)); + computeOpsRef = llvm::OwningArrayRef(std::move(nextComputeOps)); } - assert(computeOpsRef.size() == 1 && - "Internal error: expected a single input at this point."); + assert(computeOpsRef.size() == 1 && "Internal error: expected a single input at this point."); auto finalComputeAndResNum = computeOpsRef[0]; // Force the update of the results of this computeOp, when finalizing computeOpNeedingResUpdate.push_back(GET_COMP(finalComputeAndResNum)); - if (postprocess) { - GET_RES_NUM(finalComputeAndResNum) = - applyResultProcessing(finalComputeAndResNum, postprocess, rewriter); - } + if (postprocess) + GET_RES_NUM(finalComputeAndResNum) = applyResultProcessing(finalComputeAndResNum, postprocess, rewriter); - return std::make_pair(GET_COMP(finalComputeAndResNum).getOperation(), - GET_RES_NUM(finalComputeAndResNum)); + return std::make_pair(GET_COMP(finalComputeAndResNum).getOperation(), GET_RES_NUM(finalComputeAndResNum)); } void SpatialReducer::finalizeReduceUpdates() { @@ -223,15 +200,13 @@ void SpatialReducer::finalizeReduceUpdates() { reducesFinalized = true; // First, add the results to the computeOps - for (auto &reduceChange : reducerChanges) { + for (auto& reduceChange : reducerChanges) updateResultsOfCompute(reduceChange.fromOp); - } - for (auto &c : computeOpNeedingResUpdate) { + for (auto& c : computeOpNeedingResUpdate) updateResultsOfCompute(c.getOperation()); - } - for (auto &reducerChange : this->reducerChanges) { + for (auto& reducerChange : this->reducerChanges) { auto fromOp = reducerChange.fromOp; auto toOp = reducerChange.toOp; auto fromOpResNum = reducerChange.fromOpResNum; @@ -243,16 +218,14 @@ void SpatialReducer::finalizeReduceUpdates() { // toComputeOp could be the existing pointer, or we have to remap it with // `opToReplacedCompute` auto toComputeOp = opToReplacedCompute[toOp]; - if (!toComputeOp) { + if (!toComputeOp) toComputeOp = cast(toOp); - } - assert(toComputeOp != fromComputeOp && - "Oops should have caught this earlier!"); + assert(toComputeOp != fromComputeOp && "Oops should have caught this earlier!"); - assert(toComputeOp->getNumOperands() == toOpOperandNum && - "toOpOperandNum should be the last operand of toComputeOp, are the " - "operations in the right order?"); + assert(toComputeOp->getNumOperands() == toOpOperandNum + && "toOpOperandNum should be the last operand of toComputeOp, are the " + "operations in the right order?"); // Add the new operand to `toComputeOp` auto fromResult = fromComputeOp.getResult(fromOpResNum); @@ -261,24 +234,22 @@ void SpatialReducer::finalizeReduceUpdates() { } } -Value SpatialReducer::resolveValueFromOpAndResNum(OpAndResNum &opAndResNum) { - assert(reducesFinalized && - "Cannot create resolve values before finalizing the reduce updates."); +Value SpatialReducer::resolveValueFromOpAndResNum(OpAndResNum& opAndResNum) { + assert(reducesFinalized && "Cannot create resolve values before finalizing the reduce updates."); - Operation *opToCast; + Operation* opToCast; auto it = opToReplacedCompute.find(opAndResNum.first); - if (it != opToReplacedCompute.end()) { + if (it != opToReplacedCompute.end()) opToCast = it->second; - } else { + else opToCast = opAndResNum.first; - } auto computeOp = cast(opToCast); return computeOp.getResult(opAndResNum.second); } -void SpatialReducer::updateResultsOfCompute(Operation *computeOp) { +void SpatialReducer::updateResultsOfCompute(Operation* computeOp) { if (opToReplacedCompute.find(computeOp) != opToReplacedCompute.end()) { // If we have already replaced the fromOp, we do not need to do it again return; @@ -287,8 +258,7 @@ void SpatialReducer::updateResultsOfCompute(Operation *computeOp) { auto oldComputeOpNum = oldComputeOp->getNumOperands(); - auto yieldOp = - cast(oldComputeOp.getBody().front().getTerminator()); + auto yieldOp = cast(oldComputeOp.getBody().front().getTerminator()); if (yieldOp.getNumOperands() == oldComputeOp->getNumResults()) { // No result was added, just add itself to the map @@ -301,9 +271,8 @@ void SpatialReducer::updateResultsOfCompute(Operation *computeOp) { // Create a new ComputeOp with the new result type, but same operands rewriter.setInsertionPoint(oldComputeOp); - auto newComputeOp = - rewriter.create(oldComputeOp->getLoc(), - newResultTypes, oldComputeOp.getWeights(), oldComputeOp.getInputs()); + auto newComputeOp = rewriter.create( + oldComputeOp->getLoc(), newResultTypes, oldComputeOp.getWeights(), oldComputeOp.getInputs()); newComputeOp.getBody().takeBody(oldComputeOp.getBody()); @@ -329,54 +298,49 @@ void SpatialReducer::updateResultsOfCompute(Operation *computeOp) { rewriter.eraseOp(oldComputeOp); } -Value SpatialReducer::createImgConcatOp( - SmallVector>> &outputTiles, - Location &loc, Type outputType) { +Value SpatialReducer::createImgConcatOp(SmallVector>>& outputTiles, + Location& loc, + Type outputType) { - assert(reducesFinalized && - "Cannot create ImgConcatOp before finalizing the reduce updates."); + assert(reducesFinalized && "Cannot create ImgConcatOp before finalizing the reduce updates."); // outputTiles are indexed like this: [channelTile][x][y] auto tilesCount = outputTiles.size(); auto width = outputTiles[0].size(); auto height = outputTiles[0][0].size(); - SmallVector>> remappedOutputTiles(tilesCount, - SmallVector>(width, SmallVector(height))); + SmallVector>> remappedOutputTiles( + tilesCount, SmallVector>(width, SmallVector(height))); for (size_t t = 0; t < tilesCount; t++) for (size_t x = 0; x < width; x++) for (size_t y = 0; y < height; y++) - remappedOutputTiles[t][x][y] = - resolveValueFromOpAndResNum(outputTiles[t][x][y]); + remappedOutputTiles[t][x][y] = resolveValueFromOpAndResNum(outputTiles[t][x][y]); - return ::onnx_mlir::createImgConcatOp( - remappedOutputTiles, rewriter, loc, outputType); + return ::onnx_mlir::createImgConcatOp(remappedOutputTiles, rewriter, loc, outputType); } -OpAndResNum SpatialReducer::applyAddMapReduction( - SmallVector &computeOps, - ConversionPatternRewriter &rewriter, Value biasTile, MapOperations mapOp) { +OpAndResNum SpatialReducer::applyAddMapReduction(SmallVector& computeOps, + ConversionPatternRewriter& rewriter, + Value biasTile, + MapOperations mapOp) { - std::function postprocessing = nullptr; + std::function postprocessing = nullptr; if (mapOp != MapOperations::None) { postprocessing = [&](const Value a) { Value mapOperand = a; - if (biasTile) { - mapOperand = rewriter.create( - a.getLoc(), a.getType(), a, biasTile); - } + if (biasTile) + mapOperand = rewriter.create(a.getLoc(), a.getType(), a, biasTile); return createMapOperation(rewriter, mapOp, mapOperand); }; } return this->applyReducePattern( - computeOps, - [&](Value a, Value b) { - return rewriter.create(a.getLoc(), a.getType(), a, b); - }, - /* preprocess = */ nullptr, postprocessing); + computeOps, + [&](Value a, Value b) { return rewriter.create(a.getLoc(), a.getType(), a, b); }, + /* preprocess = */ nullptr, + postprocessing); } -} // namespace onnx_mlir \ No newline at end of file +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp b/src/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp index 4dcfe67..f724691 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp +++ b/src/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp @@ -1,9 +1,10 @@ #pragma once +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/Support/Casting.h" + #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" -#include "llvm/ADT/SmallPtrSet.h" -#include "llvm/Support/Casting.h" namespace onnx_mlir { @@ -12,48 +13,48 @@ using ResNum = unsigned int; using ComputeAndResNum = std::pair; struct SpatialReducerChange { - Operation *fromOp; + Operation* fromOp; unsigned int fromOpResNum; - Operation *toOp; + Operation* toOp; unsigned int toOpOperandNum; }; -using OpAndResNum = std::pair; +using OpAndResNum = std::pair; class SpatialReducer { public: - SpatialReducer(ConversionPatternRewriter &rewriter) : rewriter(rewriter) {} + SpatialReducer(ConversionPatternRewriter& rewriter) + : rewriter(rewriter) {} - OpAndResNum applyReducePattern( - SmallVector &computeOpsAndResNum, - std::function reduce, - std::function preprocess, - std::function postprocess); + OpAndResNum applyReducePattern(SmallVector& computeOpsAndResNum, + std::function reduce, + std::function preprocess, + std::function postprocess); - OpAndResNum applyAddMapReduction(SmallVector &computeOps, - ConversionPatternRewriter &rewriter, Value biasTile, MapOperations mapOp); + OpAndResNum applyAddMapReduction(SmallVector& computeOps, + ConversionPatternRewriter& rewriter, + Value biasTile, + MapOperations mapOp); void finalizeReduceUpdates(); ~SpatialReducer() { - if (!reducesFinalized) { + if (!reducesFinalized) finalizeReduceUpdates(); - } } - Value createImgConcatOp( - llvm::SmallVector>> - &outputTiles, - Location &loc, Type outputType); + Value createImgConcatOp(llvm::SmallVector>>& outputTiles, + Location& loc, + Type outputType); - Value resolveValueFromOpAndResNum(OpAndResNum &opAndResNum); + Value resolveValueFromOpAndResNum(OpAndResNum& opAndResNum); private: [[nodiscard("computeOp result number gets updated")]] ResNum applyResultProcessing(ComputeAndResNum computeOpAndResNum, - std::function processFun, - ConversionPatternRewriter &rewriter); + std::function processFun, + ConversionPatternRewriter& rewriter); /** * @brief Update the results of a ComputeOp. @@ -65,9 +66,9 @@ private: * * @param computeOp The ComputeOp to update the results of. */ - void updateResultsOfCompute(Operation *computeOp); + void updateResultsOfCompute(Operation* computeOp); - ConversionPatternRewriter &rewriter; + ConversionPatternRewriter& rewriter; bool reducesFinalized = false; // List of changes to be applied after the reduction is finalized @@ -75,9 +76,9 @@ private: // List of computeOps that need to be replaced with new results SmallVector computeOpNeedingResUpdate; - std::unordered_map opToReplacedCompute; + std::unordered_map opToReplacedCompute; - static llvm::SmallPtrSet oldComputeOpsReplaced; + static llvm::SmallPtrSet oldComputeOpsReplaced; }; -} // namespace onnx_mlir \ No newline at end of file +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Utils/WeightSubdivider.cpp b/src/PIM/Conversion/ONNXToSpatial/Utils/WeightSubdivider.cpp index 81314ac..e2466ec 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Utils/WeightSubdivider.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Utils/WeightSubdivider.cpp @@ -1,11 +1,11 @@ -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/WeightSubdivider.hpp" #include +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/WeightSubdivider.hpp" + namespace onnx_mlir { -WeightSubdivider::WeightSubdivider( - map>> weights) - : weights(std::move(weights)) {} +WeightSubdivider::WeightSubdivider(map>> weights) +: weights(std::move(weights)) {} bool WeightSubdivider::isEmpty() const { return weights.empty(); } @@ -13,7 +13,7 @@ TaggedWeights WeightSubdivider::popGroup(size_t amount) { assert(!weights.empty() && "No weights to extract."); auto it = weights.begin(); - SmallVector &values = it->second.begin()->second; + SmallVector& values = it->second.begin()->second; long inputTile = it->first; long outputTile = it->second.begin()->first; @@ -26,11 +26,11 @@ TaggedWeights WeightSubdivider::popGroup(size_t amount) { if (n < values.size()) { values.erase(values.begin(), values.begin() + n); - } else { + } + else { it->second.erase(outputTile); - if (it->second.empty()) { + if (it->second.empty()) weights.erase(inputTile); - } } return {inputTile, outputTile, crossbarsUsed - n, result}; diff --git a/src/PIM/Conversion/ONNXToSpatial/Utils/WeightSubdivider.hpp b/src/PIM/Conversion/ONNXToSpatial/Utils/WeightSubdivider.hpp index 6e5c5f1..7c71986 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Utils/WeightSubdivider.hpp +++ b/src/PIM/Conversion/ONNXToSpatial/Utils/WeightSubdivider.hpp @@ -1,7 +1,9 @@ #pragma once #include "mlir/IR/Value.h" + #include "llvm/ADT/SmallVector.h" + #include using namespace mlir; diff --git a/src/PIM/Conversion/SpatialToGraphviz/SpatialToGraphviz.cpp b/src/PIM/Conversion/SpatialToGraphviz/SpatialToGraphviz.cpp index bdda694..0cab2b9 100644 --- a/src/PIM/Conversion/SpatialToGraphviz/SpatialToGraphviz.cpp +++ b/src/PIM/Conversion/SpatialToGraphviz/SpatialToGraphviz.cpp @@ -1,23 +1,21 @@ +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/IR/Block.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Value.h" -#include "mlir/Support/LLVM.h" -#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" -#include "src/Accelerators/PIM/Pass/PimPasses.hpp" -#include "src/Dialect/ONNX/ONNXOps.hpp" - -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Format.h" -#define FORMAT_OPERATION(op) \ - 'x' << llvm::format_hex_no_prefix(reinterpret_cast(op), 0) -#define FORMAT_ARGUMENT(computeOpPointer, argumentNum) \ - llvm::format("Arg_%p_%u", computeOpPointer, argumentNum) +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" +#include "src/Accelerators/PIM/Pass/PimPasses.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" + +#define FORMAT_OPERATION(op) 'x' << llvm::format_hex_no_prefix(reinterpret_cast(op), 0) +#define FORMAT_ARGUMENT(computeOpPointer, argumentNum) llvm::format("Arg_%p_%u", computeOpPointer, argumentNum) using namespace mlir; @@ -25,26 +23,22 @@ namespace onnx_mlir { namespace { -struct SpatialToGraphvizPass - : public PassWrapper> { +struct SpatialToGraphvizPass : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialToGraphvizPass) - StringRef getArgument() const override { - return "convert-spatial-to-graphviz"; - } + StringRef getArgument() const override { return "convert-spatial-to-graphviz"; } - StringRef getDescription() const override { - return "Lower ONNX ops to Spatial ops."; - } + StringRef getDescription() const override { return "Lower ONNX ops to Spatial ops."; } - SpatialToGraphvizPass(raw_ostream &os = llvm::errs()) : os(os) {} - SpatialToGraphvizPass(const SpatialToGraphvizPass &pass) - : SpatialToGraphvizPass(pass.os) {} + SpatialToGraphvizPass(raw_ostream& os = llvm::errs()) + : os(os) {} + SpatialToGraphvizPass(const SpatialToGraphvizPass& pass) + : SpatialToGraphvizPass(pass.os) {} void runOnOperation() final; private: - raw_ostream &os; + raw_ostream& os; /** * Draws the subgraph for a given spatial::SpatWeightedCompute, including: @@ -56,31 +50,27 @@ private: * @param computeNum The number of the compute operation. */ void drawComputeOpSubgraph(spatial::SpatWeightedCompute op, size_t computeNum) { - os << "\tsubgraph cluster" << computeNum << " {\n\t\tlabel=\"Compute" - << computeNum << "\";\n" + os << "\tsubgraph cluster" << computeNum << " {\n\t\tlabel=\"Compute" << computeNum << "\";\n" << "\t\tstyle=filled;\n" << "\t\tcolor=lightblue;\n"; - Block &block = op.getBody().front(); + Block& block = op.getBody().front(); // Inputs size_t inputNum = 0; - for (BlockArgument &input : block.getArguments()) { + for (BlockArgument& input : block.getArguments()) { auto fromOp = FORMAT_ARGUMENT(op.getOperation(), inputNum); - os << "\t\t" << fromOp << " [label=\"Arg" << inputNum - << "\",shape=box];\n"; - for (auto userOp : input.getUsers()) { + os << "\t\t" << fromOp << " [label=\"Arg" << inputNum << "\",shape=box];\n"; + for (auto userOp : input.getUsers()) os << "\t\t" << fromOp << " -> " << FORMAT_OPERATION(userOp) << ";\n"; - } inputNum++; } // Iterate operations - for (auto &childOp : block.getOperations()) { - os << "\t\t" << FORMAT_OPERATION(&childOp) << " [label=\"" - << childOp.getName() << "\"];\n"; + for (auto& childOp : block.getOperations()) { + os << "\t\t" << FORMAT_OPERATION(&childOp) << " [label=\"" << childOp.getName() << "\"];\n"; drawEdgesFromOpToItsUsers(&childOp); } @@ -88,7 +78,7 @@ private: os << "\t}\n"; // Draw edges from the yield to the users of this computeOp - Operation *yieldOp = block.getTerminator(); + Operation* yieldOp = block.getTerminator(); if (!isa(yieldOp)) { yieldOp->emitError("Terminator of block must be YieldOp ???"); signalPassFailure(); @@ -96,9 +86,8 @@ private: } for (auto computeOpResult : op->getResults()) { - for (auto &computeOpUse : computeOpResult.getUses()) { - auto toOp = FORMAT_ARGUMENT( - computeOpUse.getOwner(), computeOpUse.getOperandNumber()); + for (auto& computeOpUse : computeOpResult.getUses()) { + auto toOp = FORMAT_ARGUMENT(computeOpUse.getOwner(), computeOpUse.getOperandNumber()); os << "\t" << FORMAT_OPERATION(yieldOp) << " -> " << toOp << ";\n"; } } @@ -114,9 +103,8 @@ private: * @param concatOp The concatOp for which the subgraph is drawn. * @param concatOpNum The number of the concatOp. */ - void drawConcatOpSubgraph(Operation *concatOp, size_t concatOpNum) { - os << "\tsubgraph clusterconcat" << concatOpNum - << " {\n\t\tlabel=\"ConcatOp" << concatOpNum << "\";\n" + void drawConcatOpSubgraph(Operation* concatOp, size_t concatOpNum) { + os << "\tsubgraph clusterconcat" << concatOpNum << " {\n\t\tlabel=\"ConcatOp" << concatOpNum << "\";\n" << "\t\tstyle=filled;\n" << "\t\tcolor=orange;\n"; @@ -126,9 +114,8 @@ private: auto fromOp = FORMAT_ARGUMENT(concatOp, inputNum); os << "\t\t" << fromOp << " [label=\"Input" << inputNum << "\"];\n"; - for (auto userOp : input.getUsers()) { + for (auto userOp : input.getUsers()) os << "\t\t" << fromOp << " -> " << FORMAT_OPERATION(userOp) << ";\n"; - } inputNum++; } @@ -139,11 +126,9 @@ private: // Edges from output to users - for (auto &computeOpUse : concatOp->getResult(0).getUses()) { + for (auto& computeOpUse : concatOp->getResult(0).getUses()) { os << "\t" << FORMAT_OPERATION(concatOp) << " -> " - << FORMAT_ARGUMENT( - computeOpUse.getOwner(), computeOpUse.getOperandNumber()) - << ";\n"; + << FORMAT_ARGUMENT(computeOpUse.getOwner(), computeOpUse.getOperandNumber()) << ";\n"; } } @@ -164,10 +149,8 @@ private: sliceOp.getStaticOffsetsAttr().print(os); os << "\",color=lawngreen];\n"; - for (auto &computeOpUse : sliceOp.getResult().getUses()) { - os << "\t" << nodeId << " -> " - << FORMAT_ARGUMENT( - computeOpUse.getOwner(), computeOpUse.getOperandNumber()) + for (auto& computeOpUse : sliceOp.getResult().getUses()) { + os << "\t" << nodeId << " -> " << FORMAT_ARGUMENT(computeOpUse.getOwner(), computeOpUse.getOperandNumber()) << ";\n"; } } @@ -178,9 +161,8 @@ private: sliceOp.getStaticOffsetsAttr().print(os); os << "\",color=lightpink];\n"; - for (auto user : sliceOp.getResult().getUsers()) { + for (auto user : sliceOp.getResult().getUsers()) os << "\t" << nodeId << " -> " << FORMAT_OPERATION(user) << ";\n"; - } } /** @@ -188,13 +170,10 @@ private: * * @param fromOp The operation from which the edges are drawn. */ - void drawEdgesFromOpToItsUsers(mlir::Operation *fromOp) { - for (auto result : fromOp->getResults()) { - for (auto userOp : result.getUsers()) { - os << "\t\t" << FORMAT_OPERATION(fromOp) << " -> " - << FORMAT_OPERATION(userOp) << ";\n"; - } - } + void drawEdgesFromOpToItsUsers(mlir::Operation* fromOp) { + for (auto result : fromOp->getResults()) + for (auto userOp : result.getUsers()) + os << "\t\t" << FORMAT_OPERATION(fromOp) << " -> " << FORMAT_OPERATION(userOp) << ";\n"; } /** @@ -202,16 +181,15 @@ private: * * @param funcOp The `funcOp` for which to draw input nodes and edges. */ - void drawInputNodesAndEdges(func::FuncOp &funcOp) { + void drawInputNodesAndEdges(func::FuncOp& funcOp) { os << "\tinput [label=\"Module Input\",color=green];\n"; size_t funcOpArgNum = 0; - for (BlockArgument &arg : funcOp.getArguments()) { + for (BlockArgument& arg : funcOp.getArguments()) { - for (auto &useOp : arg.getUses()) { - os << "\tinput -> " - << FORMAT_ARGUMENT(useOp.getOwner(), useOp.getOperandNumber()) - << "[label=" << funcOpArgNum << "];\n"; + for (auto& useOp : arg.getUses()) { + os << "\tinput -> " << FORMAT_ARGUMENT(useOp.getOwner(), useOp.getOperandNumber()) << "[label=" << funcOpArgNum + << "];\n"; } funcOpArgNum++; } @@ -237,20 +215,22 @@ void SpatialToGraphvizPass::runOnOperation() { // Iterate over the ComputeOps within FuncOp: // 1. Print their subgraph // 2. Print the edges from its inputs to its outputs - for (Operation &op : func.getOps()) { + for (Operation& op : func.getOps()) { if (auto computeOp = dyn_cast(op)) { drawComputeOpSubgraph(computeOp, computeNum++); - } else if (auto concatOp = dyn_cast(op)) { + } + else if (auto concatOp = dyn_cast(op)) { drawConcatOpSubgraph(concatOp, concatNum++); - } else if (auto imgConcatOp = dyn_cast(op)) { + } + else if (auto imgConcatOp = dyn_cast(op)) { drawConcatOpSubgraph(imgConcatOp, concatNum++); - } else if (auto extractSliceOp = dyn_cast(op)) { + } + else if (auto extractSliceOp = dyn_cast(op)) { auto producerOp = extractSliceOp->getOperand(0).getDefiningOp(); if (producerOp) { // Skip extractSliceOp if producer is constant weights (ONNXConstantOp) - if (llvm::isa(producerOp)) { + if (llvm::isa(producerOp)) continue; - } // If produced by tosa::ReshapeOp (i.e. it is a bias tile) connect // directly to its user, which is not a ComputeOp argument. if (llvm::isa(producerOp)) { @@ -268,16 +248,13 @@ void SpatialToGraphvizPass::runOnOperation() { // Draw output node (use the return Operation - argument number=0 - as nodeId) auto returnOp = func.getBody().front().getTerminator(); - os << '\t' << FORMAT_ARGUMENT(returnOp, 0) - << " [label=\"Module Output\",color=green];\n"; + os << '\t' << FORMAT_ARGUMENT(returnOp, 0) << " [label=\"Module Output\",color=green];\n"; os << "}\n"; } } // namespace -std::unique_ptr createSpatialToGraphvizPass() { - return std::make_unique(); -} +std::unique_ptr createSpatialToGraphvizPass() { return std::make_unique(); } -} // namespace onnx_mlir \ No newline at end of file +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/SpatialToPIM/SpatialToPIMPass.cpp b/src/PIM/Conversion/SpatialToPIM/SpatialToPIMPass.cpp index c4ed4d0..bbaeabd 100644 --- a/src/PIM/Conversion/SpatialToPIM/SpatialToPIMPass.cpp +++ b/src/PIM/Conversion/SpatialToPIM/SpatialToPIMPass.cpp @@ -148,7 +148,8 @@ void SpatialToPIMPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IR size_t resultIndexInConcat = resultUses.begin()->getOperandNumber(); size_t offset = 0; for (auto operand : concatOp->getOperands().take_front(resultIndexInConcat)) - offset += cast(operand.getType()).getNumElements() * cast(operand.getType()).getElementTypeBitWidth() / 8; + offset += cast(operand.getType()).getNumElements() + * cast(operand.getType()).getElementTypeBitWidth() / 8; size_t elementSize = yieldType.getElementTypeBitWidth() / 8; diff --git a/src/PIM/Conversion/SpatialToPIM/SpatialToPIMPatterns.hpp b/src/PIM/Conversion/SpatialToPIM/SpatialToPIMPatterns.hpp index b972265..365efe9 100644 --- a/src/PIM/Conversion/SpatialToPIM/SpatialToPIMPatterns.hpp +++ b/src/PIM/Conversion/SpatialToPIM/SpatialToPIMPatterns.hpp @@ -9,4 +9,4 @@ namespace spatial { } -} // namespace onnx_mlir \ No newline at end of file +} // namespace onnx_mlir diff --git a/src/PIM/Dialect/PIM/PimOps.hpp b/src/PIM/Dialect/PIM/PimOps.hpp index 24b6711..4811ddb 100644 --- a/src/PIM/Dialect/PIM/PimOps.hpp +++ b/src/PIM/Dialect/PIM/PimOps.hpp @@ -1,8 +1,5 @@ #pragma once -#include -#include - #include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -11,6 +8,9 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" +#include +#include + /// Include the auto-generated header files containing the declarations #include "src/Accelerators/PIM/Dialect/PIM/PimDialect.hpp.inc" diff --git a/src/PIM/Dialect/PIM/Transforms/PimBufferizableOpInterface.cpp b/src/PIM/Dialect/PIM/Transforms/PimBufferizableOpInterface.cpp index dc09326..1a802b0 100644 --- a/src/PIM/Dialect/PIM/Transforms/PimBufferizableOpInterface.cpp +++ b/src/PIM/Dialect/PIM/Transforms/PimBufferizableOpInterface.cpp @@ -15,7 +15,10 @@ namespace pim { struct MemCopyHostToDevOpInterface : DstBufferizableOpInterfaceExternalModel { - LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const { + LogicalResult bufferize(Operation* op, + RewriterBase& rewriter, + const BufferizationOptions& options, + BufferizationState& state) const { auto memCopyHostToDevOp = cast(op); auto deviceDst = memCopyHostToDevOp.getDeviceDst(); auto hostSrc = memCopyHostToDevOp.getHostSrc(); @@ -44,7 +47,10 @@ struct MemCopyHostToDevOpInterface struct MemCopyDevToHostOpInterface : DstBufferizableOpInterfaceExternalModel { - LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const { + LogicalResult bufferize(Operation* op, + RewriterBase& rewriter, + const BufferizationOptions& options, + BufferizationState& state) const { auto memCopyDevToHostOp = cast(op); auto globalDst = memCopyDevToHostOp.getHostDst(); @@ -88,7 +94,10 @@ struct VMMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel(op); auto vectorInputOpt = getBuffer(rewriter, vmmOp.getVectorInput(), options, state); @@ -110,7 +119,10 @@ struct MVMOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel(op).isDpsInit(&opOperand); } - LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const { + LogicalResult bufferize(Operation* op, + RewriterBase& rewriter, + const BufferizationOptions& options, + BufferizationState& state) const { auto mvmOp = cast(op); auto vectorInputOpt = getBuffer(rewriter, mvmOp.getVectorInput(), options, state); @@ -138,7 +150,10 @@ struct VAddOpBufferizeInterface : DstBufferizableOpInterfaceExternalModel(op); auto aOpt = getBuffer(rewriter, vaddOp.getA(), options, state); diff --git a/src/PIM/Dialect/PIM/Transforms/PimBufferizableOpInterface.hpp b/src/PIM/Dialect/PIM/Transforms/PimBufferizableOpInterface.hpp index aa9e47c..f4872e0 100644 --- a/src/PIM/Dialect/PIM/Transforms/PimBufferizableOpInterface.hpp +++ b/src/PIM/Dialect/PIM/Transforms/PimBufferizableOpInterface.hpp @@ -1,6 +1,7 @@ #pragma once #include "mlir/IR/DialectRegistry.h" + #include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp" using namespace mlir; @@ -8,7 +9,7 @@ using namespace mlir; namespace onnx_mlir { namespace pim { -void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry); +void registerBufferizableOpInterfaceExternalModels(DialectRegistry& registry); } // namespace pim -} // namespace onnx_mlir \ No newline at end of file +} // namespace onnx_mlir diff --git a/src/PIM/Dialect/Spatial/SpatialOps.hpp b/src/PIM/Dialect/Spatial/SpatialOps.hpp index 2aca69e..15c6650 100644 --- a/src/PIM/Dialect/Spatial/SpatialOps.hpp +++ b/src/PIM/Dialect/Spatial/SpatialOps.hpp @@ -1,8 +1,5 @@ #pragma once -#include -#include - #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" @@ -10,6 +7,9 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/IR/Types.h" +#include +#include + /// Include the auto-generated header files containing the declarations #include "src/Accelerators/PIM/Dialect/Spatial/SpatialDialect.hpp.inc" diff --git a/src/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.cpp b/src/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.cpp index 9daeadc..b8d3934 100644 --- a/src/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.cpp @@ -75,7 +75,10 @@ struct WComputeOpInterface : BufferizableOpInterface::ExternalModelgetRegion(0).front(); @@ -104,7 +107,10 @@ struct VariadicArgumentElementWiseOpInterface : BufferizableOpInterface::Externa } // Cast tensor values into memref values - LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const { + LogicalResult bufferize(Operation* op, + RewriterBase& rewriter, + const BufferizationOptions& options, + BufferizationState& state) const { // Turn Tensor Operands into Memref Operands SmallVector memrefOperands; @@ -151,7 +157,10 @@ struct WeightedMultiplicationsOpInterface : BufferizableOpInterface::ExternalMod } // Cast tensor value into memref value - LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const { + LogicalResult bufferize(Operation* op, + RewriterBase& rewriter, + const BufferizationOptions& options, + BufferizationState& state) const { auto memrefOperandOpt = getBuffer(rewriter, op->getOperand(0), options, state); if (failed(memrefOperandOpt)) return failure(); @@ -190,7 +199,10 @@ struct ChannelReceiveOpInterface /* * Turn the channel receive to pim.recv */ - LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const { + LogicalResult bufferize(Operation* op, + RewriterBase& rewriter, + const BufferizationOptions& options, + BufferizationState& state) const { auto outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter); @@ -235,7 +247,10 @@ struct ChannelSendOpInterface : BufferizableOpInterface::ExternalModelgetOperand(1); auto srcTensorOpt = getBuffer(rewriter, srcTensor, options, state); @@ -278,7 +293,10 @@ struct ChannelBroadcastReceiveOpInterface /* * Turn the channel receive to pim.load using by creating a new global buffer */ - LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const { + LogicalResult bufferize(Operation* op, + RewriterBase& rewriter, + const BufferizationOptions& options, + BufferizationState& state) const { auto outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter); @@ -340,7 +358,10 @@ struct ChannelBroadcastSendOpInterface /* * Turn the channel send to pim.send */ - LogicalResult bufferize(Operation* op, RewriterBase& rewriter, const BufferizationOptions& options, BufferizationState &state) const { + LogicalResult bufferize(Operation* op, + RewriterBase& rewriter, + const BufferizationOptions& options, + BufferizationState& state) const { auto srcTensor = op->getOperand(1); auto srcTensorOpt = getBuffer(rewriter, srcTensor, options, state); @@ -414,7 +435,10 @@ struct ApplyFiltersOpInterface : BufferizableOpInterface::ExternalModelgetOperand(0), options, state); diff --git a/src/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp b/src/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp index 455bbe7..3fa6ceb 100644 --- a/src/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp +++ b/src/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp @@ -1,6 +1,7 @@ #pragma once #include "mlir/IR/DialectRegistry.h" + #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" using namespace mlir; diff --git a/src/PIM/Pass/CountInstructionPass.cpp b/src/PIM/Pass/CountInstructionPass.cpp index edb72e2..a4e0095 100644 --- a/src/PIM/Pass/CountInstructionPass.cpp +++ b/src/PIM/Pass/CountInstructionPass.cpp @@ -1,5 +1,6 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Pass/Pass.h" + #include "src/Accelerators/PIM/Dialect/PIM/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Compiler/CompilerUtils.hpp" @@ -10,22 +11,19 @@ namespace onnx_mlir { namespace { -struct CountInstructionPass - : public PassWrapper> { +struct CountInstructionPass : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CountInstructionPass) StringRef getArgument() const override { return "count-instruction-pass"; } - StringRef getDescription() const override { - return "Count instructions for each core/compute in the module"; - } + StringRef getDescription() const override { return "Count instructions for each core/compute in the module"; } // Make sure that we have a valid default constructor and copy // constructor to make sure that the options are initialized properly. CountInstructionPass() {} - CountInstructionPass(const CountInstructionPass &pass) - : PassWrapper>() {} + CountInstructionPass(const CountInstructionPass& pass) + : PassWrapper>() {} void runOnOperation() final { ModuleOp module = getOperation(); @@ -37,8 +35,7 @@ struct CountInstructionPass for (auto computeOp : func.getOps()) { unsigned instructionCount = 0; instructionCount += computeOp.getBody().front().getOperations().size(); - llvm::outs() << "Compute " << computeId << ": " << instructionCount - << " instructions\n"; + llvm::outs() << "Compute " << computeId << ": " << instructionCount << " instructions\n"; totalInstructionCount += instructionCount; computeId++; } @@ -47,21 +44,17 @@ struct CountInstructionPass for (auto coreOp : func.getOps()) { unsigned instructionCount = 0; instructionCount += coreOp.getBody().front().getOperations().size(); - llvm::outs() << "Core " << coreId << ": " << instructionCount - << " instructions\n"; + llvm::outs() << "Core " << coreId << ": " << instructionCount << " instructions\n"; totalInstructionCount += instructionCount; coreId++; } - llvm::outs() << "Total instruction count: " << totalInstructionCount - << "\n"; + llvm::outs() << "Total instruction count: " << totalInstructionCount << "\n"; } }; } // namespace -std::unique_ptr createCountInstructionPass() { - return std::make_unique(); -} +std::unique_ptr createCountInstructionPass() { return std::make_unique(); } -} // namespace onnx_mlir \ No newline at end of file +} // namespace onnx_mlir diff --git a/src/PIM/Pass/PimPasses.hpp b/src/PIM/Pass/PimPasses.hpp index f5a320f..d548f40 100644 --- a/src/PIM/Pass/PimPasses.hpp +++ b/src/PIM/Pass/PimPasses.hpp @@ -1,6 +1,7 @@ #pragma once #include "mlir/Pass/Pass.h" + #include using namespace mlir; @@ -21,4 +22,4 @@ std::unique_ptr createMessagePass(std::string message); std::unique_ptr createCountInstructionPass(); -} // namespace onnx_mlir \ No newline at end of file +} // namespace onnx_mlir diff --git a/src/PIM/PimAccelerator.hpp b/src/PIM/PimAccelerator.hpp index f96cb60..d7f61df 100644 --- a/src/PIM/PimAccelerator.hpp +++ b/src/PIM/PimAccelerator.hpp @@ -1,6 +1,7 @@ #pragma once #include "mlir/IR/BuiltinTypes.h" + #include "src/Accelerators/Accelerator.hpp" namespace onnx_mlir { @@ -9,38 +10,37 @@ namespace accel { /// Singleton class to construct PIM accelerator. class PimAccelerator final : public Accelerator { private: - static PimAccelerator *instance; + static PimAccelerator* instance; PimAccelerator(); public: /// Singleton should not be clonable or assignable. - PimAccelerator(PimAccelerator &) = delete; - void operator=(const PimAccelerator &) = delete; + PimAccelerator(PimAccelerator&) = delete; + void operator=(const PimAccelerator&) = delete; ~PimAccelerator(); /// Creates an instance on the first invocation. Subsequent invocations /// return the existing instance. - static PimAccelerator *getInstance(); + static PimAccelerator* getInstance(); /// Define classof to be able to use isa<>, cast<>, dyn_cast<>, etc. - static bool classof(const Accelerator *accel) { - return accel->getKind() == Accelerator::Kind::PIM; - } - static bool classof(const PimAccelerator *) { return true; } + static bool classof(const Accelerator* accel) { return accel->getKind() == Accelerator::Kind::PIM; } + static bool classof(const PimAccelerator*) { return true; } uint64_t getVersionNumber() const final; //===--------------------------------------------------------------------===// // Hooks for onnx-mlir-opt driver //===--------------------------------------------------------------------===// - virtual void addPasses(mlir::OwningOpRef &module, - mlir::PassManager &pm, onnx_mlir::EmissionTargetType &emissionTarget, - std::string outputNameNoExt) const final; + virtual void addPasses(mlir::OwningOpRef& module, + mlir::PassManager& pm, + onnx_mlir::EmissionTargetType& emissionTarget, + std::string outputNameNoExt) const final; //===--------------------------------------------------------------------===// // Hooks for onnx-mlir-opt driver //===--------------------------------------------------------------------===// - virtual void registerDialects(mlir::DialectRegistry ®istry) const final; + virtual void registerDialects(mlir::DialectRegistry& registry) const final; virtual void registerPasses(int optLevel) const final; //===--------------------------------------------------------------------===// // Hooks for both onnx-mlir and onnx-mlir-opt drivers @@ -49,21 +49,19 @@ public: //===--------------------------------------------------------------------===// // Hooks for onnx-to-krnl pass //===--------------------------------------------------------------------===// - virtual mlir::MemRefType convertTensorTypeToMemRefType( - const mlir::TensorType tensorType) const final; - virtual void conversionTargetONNXToKrnl( - mlir::ConversionTarget &target) const final; - virtual void rewritePatternONNXToKrnl(mlir::RewritePatternSet &patterns, - mlir::TypeConverter &typeConverter, mlir::MLIRContext *ctx) const final; + virtual mlir::MemRefType convertTensorTypeToMemRefType(const mlir::TensorType tensorType) const final; + virtual void conversionTargetONNXToKrnl(mlir::ConversionTarget& target) const final; + virtual void rewritePatternONNXToKrnl(mlir::RewritePatternSet& patterns, + mlir::TypeConverter& typeConverter, + mlir::MLIRContext* ctx) const final; //===--------------------------------------------------------------------===// // Hooks for krnl-to-llvm pass //===--------------------------------------------------------------------===// - virtual void conversionTargetKrnlToLLVM( - mlir::ConversionTarget &target) const final; - virtual void rewritePatternKrnlToLLVM(mlir::RewritePatternSet &patterns, - mlir::LLVMTypeConverter &typeConverter, - mlir::MLIRContext *ctx) const final; + virtual void conversionTargetKrnlToLLVM(mlir::ConversionTarget& target) const final; + virtual void rewritePatternKrnlToLLVM(mlir::RewritePatternSet& patterns, + mlir::LLVMTypeConverter& typeConverter, + mlir::MLIRContext* ctx) const final; }; } // namespace accel diff --git a/src/PIM/Transforms/PimBufferizationPass.cpp b/src/PIM/Transforms/PimBufferizationPass.cpp index eb3496b..9173843 100644 --- a/src/PIM/Transforms/PimBufferizationPass.cpp +++ b/src/PIM/Transforms/PimBufferizationPass.cpp @@ -63,7 +63,8 @@ void PimBufferizationPass::runOnOperation() { void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const { MLIRContext* ctx = funcOp.getContext(); funcOp.walk([&](memref::GetGlobalOp getGlobalOp) { - bool isAlwaysWeight = llvm::all_of(getGlobalOp->getUsers(), [](auto user) -> bool { return isa(user); }) && !getGlobalOp->getUsers().empty() ; + bool isAlwaysWeight = llvm::all_of(getGlobalOp->getUsers(), [](auto user) -> bool { return isa(user); }) + && !getGlobalOp->getUsers().empty(); if (isAlwaysWeight) { auto globalMemrefOp = moduleOp.lookupSymbol(getGlobalOp.getName()); assert("Weights must be constants" && globalMemrefOp.getConstant());