From 909c4acfdd9bed3074662182691a42d9a2e513d7 Mon Sep 17 00:00:00 2001 From: NiccoloN Date: Tue, 12 May 2026 10:35:44 +0200 Subject: [PATCH] huge refactor for high RewritePatterns usage and less ad-hoc cpp code remove Spatial many ops in favor of tensor ops like in pim --- src/PIM/Common/IR/CompactAsmUtils.hpp | 65 +- src/PIM/Compiler/CMakeLists.txt | 3 + src/PIM/Compiler/PimArtifactWriter.cpp | 123 ++ src/PIM/Compiler/PimArtifactWriter.hpp | 26 + src/PIM/Compiler/PimBatchEmission.cpp | 126 ++ src/PIM/Compiler/PimBatchEmission.hpp | 13 + src/PIM/Compiler/PimCodeGen.cpp | 400 +---- src/PIM/Compiler/PimWeightEmitter.cpp | 221 +++ src/PIM/Compiler/PimWeightEmitter.hpp | 16 + .../Conversion/ONNXToSpatial/CMakeLists.txt | 5 + .../ONNXToSpatial/Common/Common.hpp | 5 +- .../ONNXToSpatial/Common/ShapeTilingUtils.cpp | 25 +- .../ONNXToSpatial/Common/ShapeTilingUtils.hpp | 11 +- .../Common/WeightMaterialization.cpp | 6 +- .../ONNXToSpatial/ConversionPatterns.cpp | 32 + .../{Patterns.hpp => ConversionPatterns.hpp} | 2 + .../ONNXToSpatial/HostFoldability.cpp | 75 + .../ONNXToSpatial/HostFoldability.hpp | 12 + .../Conversion/ONNXToSpatial/HostLegality.cpp | 29 + .../Conversion/ONNXToSpatial/HostLegality.hpp | 10 + .../ONNXToSpatial/ONNXToSpatialPass.cpp | 490 +----- .../Patterns/Math/Elementwise.cpp | 4 +- .../ONNXToSpatial/Patterns/Math/Gemm.cpp | 112 +- .../ONNXToSpatial/Patterns/Math/MatMul.cpp | 140 +- .../Patterns/Math/ReduceMean.cpp | 36 +- .../ONNXToSpatial/Patterns/NN/Pool.cpp | 8 +- .../ONNXToSpatial/Patterns/NN/Softmax.cpp | 32 +- .../ONNXToSpatial/Patterns/Tensor/Concat.cpp | 14 +- .../ONNXToSpatial/Patterns/Tensor/Gather.cpp | 2 +- .../ONNXToSpatial/Patterns/Tensor/Reshape.cpp | 36 +- .../ONNXToSpatial/Patterns/Tensor/Resize.cpp | 2 +- .../ONNXToSpatial/Patterns/Tensor/Split.cpp | 36 +- .../Conversion/ONNXToSpatial/PostPatterns.cpp | 265 ++++ .../Conversion/ONNXToSpatial/PostPatterns.hpp | 14 + .../Conversion/ONNXToSpatial/PrePatterns.cpp | 25 + .../Conversion/ONNXToSpatial/PrePatterns.hpp | 10 + .../BatchCoreLoweringPatterns.cpp | 224 +++ .../BatchCoreLoweringPatterns.hpp | 10 + .../Conversion/SpatialToPim/CMakeLists.txt | 10 +- .../SpatialToPim/ChannelLoweringPatterns.cpp | 136 ++ .../SpatialToPim/ChannelLoweringPatterns.hpp | 9 + src/PIM/Conversion/SpatialToPim/Cleanup.cpp | 42 + src/PIM/Conversion/SpatialToPim/Cleanup.hpp | 11 + .../SpatialToPim/ComputeLikeRegionUtils.cpp | 44 + .../SpatialToPim/ComputeLikeRegionUtils.hpp | 17 + .../SpatialToPim/CoreLoweringPatterns.cpp | 213 +++ .../SpatialToPim/CoreLoweringPatterns.hpp | 21 + ...ns.cpp => GlobalTensorMaterialization.cpp} | 174 +-- .../GlobalTensorMaterialization.hpp | 9 + src/PIM/Conversion/SpatialToPim/Patterns.hpp | 10 - .../SpatialToPim/PhaseVerification.cpp | 20 + .../SpatialToPim/PhaseVerification.hpp | 9 + .../SpatialToPim/ReturnPathNormalization.cpp | 587 +++++++ .../SpatialToPim/ReturnPathNormalization.hpp | 37 + .../SpatialToPim/SpatialToPimPass.cpp | 1350 ++--------------- .../SpatialToPim/TensorPackingPatterns.cpp | 113 ++ .../SpatialToPim/TensorPackingPatterns.hpp | 13 + src/PIM/Dialect/Pim/Pim.td | 34 + src/PIM/Dialect/Pim/PimOpsAsm.cpp | 64 + src/PIM/Dialect/Pim/PimOpsVerify.cpp | 44 + .../Bufferization/BufferizationUtils.cpp | 40 + .../Bufferization/BufferizationUtils.hpp | 15 + .../Transforms/Bufferization/CMakeLists.txt | 2 + .../OpBufferizationInterfaces.cpp | 76 +- src/PIM/Dialect/Spatial/Channels.cpp | 4 +- src/PIM/Dialect/Spatial/Spatial.td | 61 +- src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp | 227 +-- src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp | 143 +- .../DCPGraph/DCPAnalysis.cpp | 16 +- .../DCPGraph/DCPAnalysis.hpp | 8 +- .../MergeComputeNodes/DCPGraph/GraphDebug.cpp | 35 +- .../MergeComputeNodes/DCPGraph/GraphDebug.hpp | 15 +- .../MergeComputeNodesPass.cpp | 330 +--- .../MergeComputeNodes/RegularOpCompaction.cpp | 364 ++++- .../PimCodegen/HostConstantFolding/Common.cpp | 41 +- .../PimCodegen/HostConstantFolding/Common.hpp | 3 + .../HostConstantFolding/Patterns/Constant.cpp | 88 +- .../HostConstantFolding/Patterns/Subview.cpp | 4 + .../MaterializeHostConstantsPass.cpp | 158 +- src/PIM/Pass/PimCodegen/VerificationPass.cpp | 8 + validation/raptor.py | 8 +- validation/subprocess_utils.py | 12 +- validation/validate.py | 17 +- validation/validate_one.py | 51 +- 84 files changed, 4048 insertions(+), 3310 deletions(-) create mode 100644 src/PIM/Compiler/PimArtifactWriter.cpp create mode 100644 src/PIM/Compiler/PimArtifactWriter.hpp create mode 100644 src/PIM/Compiler/PimBatchEmission.cpp create mode 100644 src/PIM/Compiler/PimBatchEmission.hpp create mode 100644 src/PIM/Compiler/PimWeightEmitter.cpp create mode 100644 src/PIM/Compiler/PimWeightEmitter.hpp create mode 100644 src/PIM/Conversion/ONNXToSpatial/ConversionPatterns.cpp rename src/PIM/Conversion/ONNXToSpatial/{Patterns.hpp => ConversionPatterns.hpp} (93%) create mode 100644 src/PIM/Conversion/ONNXToSpatial/HostFoldability.cpp create mode 100644 src/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp create mode 100644 src/PIM/Conversion/ONNXToSpatial/HostLegality.cpp create mode 100644 src/PIM/Conversion/ONNXToSpatial/HostLegality.hpp create mode 100644 src/PIM/Conversion/ONNXToSpatial/PostPatterns.cpp create mode 100644 src/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp create mode 100644 src/PIM/Conversion/ONNXToSpatial/PrePatterns.cpp create mode 100644 src/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp create mode 100644 src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp create mode 100644 src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.hpp create mode 100644 src/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.cpp create mode 100644 src/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.hpp create mode 100644 src/PIM/Conversion/SpatialToPim/Cleanup.cpp create mode 100644 src/PIM/Conversion/SpatialToPim/Cleanup.hpp create mode 100644 src/PIM/Conversion/SpatialToPim/ComputeLikeRegionUtils.cpp create mode 100644 src/PIM/Conversion/SpatialToPim/ComputeLikeRegionUtils.hpp create mode 100644 src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.cpp create mode 100644 src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.hpp rename src/PIM/Conversion/SpatialToPim/{Patterns.cpp => GlobalTensorMaterialization.cpp} (69%) create mode 100644 src/PIM/Conversion/SpatialToPim/GlobalTensorMaterialization.hpp delete mode 100644 src/PIM/Conversion/SpatialToPim/Patterns.hpp create mode 100644 src/PIM/Conversion/SpatialToPim/PhaseVerification.cpp create mode 100644 src/PIM/Conversion/SpatialToPim/PhaseVerification.hpp create mode 100644 src/PIM/Conversion/SpatialToPim/ReturnPathNormalization.cpp create mode 100644 src/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp create mode 100644 src/PIM/Conversion/SpatialToPim/TensorPackingPatterns.cpp create mode 100644 src/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp create mode 100644 src/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.cpp create mode 100644 src/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.hpp diff --git a/src/PIM/Common/IR/CompactAsmUtils.hpp b/src/PIM/Common/IR/CompactAsmUtils.hpp index 1a1fc84..6bada90 100644 --- a/src/PIM/Common/IR/CompactAsmUtils.hpp +++ b/src/PIM/Common/IR/CompactAsmUtils.hpp @@ -72,9 +72,8 @@ inline ParseResult parseCompressedRepeatedList(OpAsmParser& parser, } template -inline ParseResult parseCompressedIntegerEntries(OpAsmParser& parser, - ListDelimiter delimiter, - SmallVectorImpl& values) { +inline ParseResult +parseCompressedIntegerEntries(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl& values) { if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) return success(); @@ -113,8 +112,8 @@ inline ParseResult parseCompressedIntegerEntries(OpAsmParser& parser, return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); } if ((last - first) % step != 0) { - return parser.emitError( - parser.getCurrentLocation(), "range end must be reachable from start using the given step"); + return parser.emitError(parser.getCurrentLocation(), + "range end must be reachable from start using the given step"); } for (int64_t value = first; value <= last; value += step) @@ -140,9 +139,8 @@ inline ParseResult parseCompressedIntegerEntries(OpAsmParser& parser, } template -inline ParseResult parseCompressedIntegerSequence(OpAsmParser& parser, - ListDelimiter delimiter, - SmallVectorImpl& values) { +inline ParseResult +parseCompressedIntegerSequence(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl& values) { if (parseOpenDelimiter(parser, delimiter)) return failure(); return parseCompressedIntegerEntries(parser, delimiter, values); @@ -166,9 +164,7 @@ inline void printCompressedEqualRuns(OpAsmPrinter& printer, RangeT entries, Prin } template -inline void printCompressedIntegerSequence(OpAsmPrinter& printer, - ArrayRef values, - ListDelimiter delimiter) { +inline void printCompressedIntegerSequence(OpAsmPrinter& printer, ArrayRef values, ListDelimiter delimiter) { struct FlatCompression { enum class Kind { Single, @@ -388,9 +384,7 @@ inline void printCompressedTypeList(OpAsmPrinter& printer, TypeRange types, List printCloseDelimiter(printer, delimiter); } -inline ParseResult parseCompressedTypeSequence(OpAsmParser& parser, - SmallVectorImpl& types, - bool allowEmpty) { +inline ParseResult parseCompressedTypeSequence(OpAsmParser& parser, SmallVectorImpl& types, bool allowEmpty) { Type firstType; OptionalParseResult firstTypeResult = parser.parseOptionalType(firstType); if (!firstTypeResult.has_value()) { @@ -422,10 +416,9 @@ inline ParseResult parseCompressedTypeSequence(OpAsmParser& parser, return success(); } -inline ParseResult parseCompressedOperandEntryWithFirst( - OpAsmParser& parser, - OpAsmParser::UnresolvedOperand firstOperand, - SmallVectorImpl& operands) { +inline ParseResult parseCompressedOperandEntryWithFirst(OpAsmParser& parser, + OpAsmParser::UnresolvedOperand firstOperand, + SmallVectorImpl& operands) { if (succeeded(parser.parseOptionalKeyword("to"))) { OpAsmParser::UnresolvedOperand lastOperand; if (parser.parseOperand(lastOperand)) @@ -447,9 +440,8 @@ inline ParseResult parseCompressedOperandEntryWithFirst( return success(); } -inline ParseResult parseOneCompressedOperandEntry( - OpAsmParser& parser, - SmallVectorImpl& operands) { +inline ParseResult parseOneCompressedOperandEntry(OpAsmParser& parser, + SmallVectorImpl& operands) { OpAsmParser::UnresolvedOperand firstOperand; if (parser.parseOperand(firstOperand)) return failure(); @@ -474,9 +466,8 @@ inline ParseResult parseCompressedOperandList(OpAsmParser& parser, } } -inline ParseResult parseCompressedOperandSequence( - OpAsmParser& parser, - SmallVectorImpl& operands) { +inline ParseResult parseCompressedOperandSequence(OpAsmParser& parser, + SmallVectorImpl& operands) { if (parseOneCompressedOperandEntry(parser, operands)) return failure(); while (succeeded(parser.parseOptionalComma())) @@ -485,9 +476,7 @@ inline ParseResult parseCompressedOperandSequence( return success(); } -inline ParseResult parseCompressedTypeList(OpAsmParser& parser, - ListDelimiter delimiter, - SmallVectorImpl& types) { +inline ParseResult parseCompressedTypeList(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl& types) { if (parseOpenDelimiter(parser, delimiter)) return failure(); if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) @@ -522,10 +511,7 @@ inline bool hasRepeatedTuple(TypeRange types, size_t tupleSize) { return true; } -inline void printValueTupleRun(OpAsmPrinter& printer, - ValueRange values, - size_t tupleSize, - ListDelimiter delimiter) { +inline void printValueTupleRun(OpAsmPrinter& printer, ValueRange values, size_t tupleSize, ListDelimiter delimiter) { printOpenDelimiter(printer, delimiter); printOpenDelimiter(printer, ListDelimiter::Paren); for (size_t index = 0; index < tupleSize; ++index) { @@ -538,10 +524,7 @@ inline void printValueTupleRun(OpAsmPrinter& printer, printCloseDelimiter(printer, delimiter); } -inline void printTypeTupleRun(OpAsmPrinter& printer, - TypeRange types, - size_t tupleSize, - ListDelimiter delimiter) { +inline void printTypeTupleRun(OpAsmPrinter& printer, TypeRange types, size_t tupleSize, ListDelimiter delimiter) { printOpenDelimiter(printer, delimiter); printOpenDelimiter(printer, ListDelimiter::Paren); for (size_t index = 0; index < tupleSize; ++index) { @@ -554,10 +537,9 @@ inline void printTypeTupleRun(OpAsmPrinter& printer, printCloseDelimiter(printer, delimiter); } -inline ParseResult parseCompressedOrTupleOperandList( - OpAsmParser& parser, - ListDelimiter delimiter, - SmallVectorImpl& operands) { +inline ParseResult parseCompressedOrTupleOperandList(OpAsmParser& parser, + ListDelimiter delimiter, + SmallVectorImpl& operands) { if (parseOpenDelimiter(parser, delimiter)) return failure(); if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) @@ -604,9 +586,8 @@ inline ParseResult parseCompressedOrTupleOperandList( } } -inline ParseResult parseCompressedOrTupleTypeList(OpAsmParser& parser, - ListDelimiter delimiter, - SmallVectorImpl& types) { +inline ParseResult +parseCompressedOrTupleTypeList(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl& types) { if (parseOpenDelimiter(parser, delimiter)) return failure(); if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) diff --git a/src/PIM/Compiler/CMakeLists.txt b/src/PIM/Compiler/CMakeLists.txt index 57fbfbb..5048f67 100644 --- a/src/PIM/Compiler/CMakeLists.txt +++ b/src/PIM/Compiler/CMakeLists.txt @@ -15,7 +15,10 @@ add_pim_library(OMPimCompilerOptions add_pim_library(OMPimCompilerUtils PimCompilerUtils.cpp + PimArtifactWriter.cpp + PimBatchEmission.cpp PimCodeGen.cpp + PimWeightEmitter.cpp EXCLUDE_FROM_OM_LIBS diff --git a/src/PIM/Compiler/PimArtifactWriter.cpp b/src/PIM/Compiler/PimArtifactWriter.cpp new file mode 100644 index 0000000..909abc0 --- /dev/null +++ b/src/PIM/Compiler/PimArtifactWriter.cpp @@ -0,0 +1,123 @@ +#include "mlir/Dialect/MemRef/IR/MemRef.h" + +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/raw_ostream.h" + +#include +#include +#include +#include + +#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp" +#include "src/Accelerators/PIM/Compiler/PimArtifactWriter.hpp" +#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp" +#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" + +using namespace llvm; +using namespace mlir; + +namespace onnx_mlir { + +OnnxMlirCompilerErrorCodes writeHostCoreJson(StringRef outputDirPath) { + std::error_code errorCode; + std::string outputHostCorePath = outputDirPath.str() + "/core_0.json"; + raw_fd_ostream hostFileStream(outputHostCorePath, errorCode); + if (errorCode) { + errs() << "Error while opening host core file `" << outputHostCorePath << "`: " << errorCode.message() << '\n'; + return InvalidOutputFileAccess; + } + + // The host core json contains two no-op-like instructions to satisfy pimsim-nn. + hostFileStream << "[{\"imm\":0,\"op\":\"sldi\",\"rd\":0},{\"imm\":0,\"op\":\"sldi\",\"rd\":0}]"; + hostFileStream.close(); + return CompilerSuccess; +} + +OnnxMlirCompilerErrorCodes +writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory& memory, StringRef outputDirPath) { + auto memoryFilePath = (outputDirPath + "/memory.bin").str(); + std::error_code errorCode; + raw_fd_ostream memoryFileStream(memoryFilePath, errorCode, sys::fs::OF_None); + if (errorCode) { + errs() << "Error while opening memory file " << memoryFilePath << ": " << errorCode.message() << '\n'; + return InvalidOutputFileAccess; + } + + std::vector memoryBuffer(memory.hostMem.getFirstAvailableAddress(), 0); + + SmallPtrSet writtenGlobals; + funcOp.walk([&](memref::GetGlobalOp getGlobalOp) { + if (hasWeightAlways(getGlobalOp)) + return; + auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp); + if (!globalOp) + return; + if (!writtenGlobals.insert(globalOp.getOperation()).second) + return; + auto initialValue = globalOp.getInitialValue(); + if (!initialValue) + return; + auto denseAttr = dyn_cast(*initialValue); + if (!denseAttr) + return; + + MemEntry memEntry = memory.hostMem.getMemEntry(getGlobalOp.getResult()); + ArrayRef rawData = denseAttr.getRawData(); + char* dst = memoryBuffer.data() + memEntry.address; + + if (denseAttr.isSplat()) { + size_t elementSize = rawData.size(); + assert(elementSize * getGlobalOp.getType().getNumElements() == memEntry.size && "Data size mismatch"); + for (size_t offset = 0; offset < memEntry.size; offset += elementSize) + std::memcpy(dst + offset, rawData.data(), std::min(elementSize, memEntry.size - offset)); + } + else { + assert(rawData.size() == memEntry.size && "Data size mismatch"); + std::memcpy(dst, rawData.data(), rawData.size()); + } + }); + + memoryFileStream.write(memoryBuffer.data(), memoryBuffer.size()); + memoryFileStream.close(); + return CompilerSuccess; +} + +OnnxMlirCompilerErrorCodes writeConfigJson(func::FuncOp funcOp, + PimAcceleratorMemory& memory, + size_t maxCoreId, + json::Object xbarsPerArrayGroup, + StringRef outputDirPath) { + json::Object configJson; + + configJson["core_cnt"] = maxCoreId + 1; + configJson["adc_count"] = 16; + configJson["cell_precision"] = 2; + configJson["xbar_array_count"] = crossbarCountInCore.getValue(); + configJson["xbar_size"] = {crossbarSize.getValue(), crossbarSize.getValue()}; + configJson["array_group_map"] = std::move(xbarsPerArrayGroup); + + json::Array inputsAddresses; + for (BlockArgument input : funcOp.getArguments()) + inputsAddresses.push_back(memory.getValueAddress(input)); + configJson["inputs_addresses"] = std::move(inputsAddresses); + + json::Array outputsAddresses; + for (func::ReturnOp returnOp : funcOp.getOps()) + for (mlir::Value output : returnOp.getOperands()) + outputsAddresses.push_back(memory.getValueAddress(output)); + configJson["outputs_addresses"] = std::move(outputsAddresses); + + auto configPath = (outputDirPath + "/config.json").str(); + std::error_code errorCode; + raw_fd_ostream jsonOS(configPath, errorCode); + if (errorCode) { + errs() << "Error while opening config file: " << errorCode.message() << '\n'; + return InvalidOutputFileAccess; + } + jsonOS << json::Value(std::move(configJson)) << '\n'; + jsonOS.close(); + return CompilerSuccess; +} + +} // namespace onnx_mlir diff --git a/src/PIM/Compiler/PimArtifactWriter.hpp b/src/PIM/Compiler/PimArtifactWriter.hpp new file mode 100644 index 0000000..346bdab --- /dev/null +++ b/src/PIM/Compiler/PimArtifactWriter.hpp @@ -0,0 +1,26 @@ +#pragma once + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinOps.h" + +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/JSON.h" + +#include "onnx-mlir/Compiler/OMCompilerTypes.h" + +namespace onnx_mlir { + +class PimAcceleratorMemory; + +OnnxMlirCompilerErrorCodes writeHostCoreJson(llvm::StringRef outputDirPath); +OnnxMlirCompilerErrorCodes writeMemoryBinary(mlir::ModuleOp moduleOp, + mlir::func::FuncOp funcOp, + PimAcceleratorMemory& memory, + llvm::StringRef outputDirPath); +OnnxMlirCompilerErrorCodes writeConfigJson(mlir::func::FuncOp funcOp, + PimAcceleratorMemory& memory, + size_t maxCoreId, + llvm::json::Object xbarsPerArrayGroup, + llvm::StringRef outputDirPath); + +} // namespace onnx_mlir diff --git a/src/PIM/Compiler/PimBatchEmission.cpp b/src/PIM/Compiler/PimBatchEmission.cpp new file mode 100644 index 0000000..752e79a --- /dev/null +++ b/src/PIM/Compiler/PimBatchEmission.cpp @@ -0,0 +1,126 @@ +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/IRMapping.h" + +#include "src/Accelerators/PIM/Common/PimCommon.hpp" +#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp" + +using namespace mlir; + +namespace onnx_mlir { +namespace { + +static SmallVector getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) { + auto coreIdsAttr = coreBatchOp->getAttrOfType(onnx_mlir::kCoreIdsAttrName); + assert(coreIdsAttr && "pim.core_batch requires coreIds array attribute"); + return SmallVector(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end()); +} + +static SmallVector getLaneChunkCoreIds(ArrayRef coreIds, size_t laneCount, unsigned lane) { + SmallVector laneCoreIds; + laneCoreIds.reserve(coreIds.size() / laneCount); + for (size_t chunkIndex = 0; chunkIndex < coreIds.size() / laneCount; ++chunkIndex) + laneCoreIds.push_back(coreIds[chunkIndex * laneCount + lane]); + return laneCoreIds; +} + +} // namespace + +LogicalResult withScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp, + unsigned lane, + llvm::function_ref callback) { + OwningOpRef scratchModule = ModuleOp::create(coreBatchOp.getLoc()); + OpBuilder builder(scratchModule->getContext()); + builder.setInsertionPointToStart(scratchModule->getBody()); + + size_t laneCount = static_cast(coreBatchOp.getLaneCount()); + size_t weightsPerLane = coreBatchOp.getWeights().size() / laneCount; + SmallVector laneWeights; + laneWeights.reserve(weightsPerLane); + for (size_t weightIndex = 0; weightIndex < weightsPerLane; ++weightIndex) + laneWeights.push_back(coreBatchOp.getWeights()[lane * weightsPerLane + weightIndex]); + + auto coreIds = getBatchCoreIds(coreBatchOp); + auto scalarCore = pim::PimCoreOp::create( + builder, coreBatchOp.getLoc(), ValueRange(laneWeights), builder.getI32IntegerAttr(coreIds[lane])); + Block* block = builder.createBlock(&scalarCore.getBody(), scalarCore.getBody().end()); + IRMapping mapper; + if (coreBatchOp.getBody().front().getNumArguments() == 1) + mapper.map(coreBatchOp.getBody().front().getArgument(0), coreBatchOp.getInputs()[lane]); + + builder.setInsertionPointToEnd(block); + for (Operation& op : coreBatchOp.getBody().front()) { + if (isa(op)) { + pim::PimHaltOp::create(builder, op.getLoc()); + continue; + } + + if (auto sendBatchOp = dyn_cast(op)) { + pim::PimSendOp::create(builder, + sendBatchOp.getLoc(), + mapper.lookup(sendBatchOp.getInput()), + sendBatchOp.getSizeAttr(), + builder.getI32IntegerAttr(sendBatchOp.getTargetCoreIds()[lane])); + continue; + } + + if (auto sendTensorBatchOp = dyn_cast(op)) { + pim::PimSendTensorOp::create( + builder, + sendTensorBatchOp.getLoc(), + mapper.lookup(sendTensorBatchOp.getInput()), + builder.getDenseI32ArrayAttr(getLaneChunkCoreIds(sendTensorBatchOp.getTargetCoreIds(), laneCount, lane))); + continue; + } + + if (auto receiveBatchOp = dyn_cast(op)) { + auto scalarReceive = + pim::PimReceiveOp::create(builder, + receiveBatchOp.getLoc(), + receiveBatchOp.getOutput().getType(), + mapper.lookup(receiveBatchOp.getOutputBuffer()), + receiveBatchOp.getSizeAttr(), + builder.getI32IntegerAttr(receiveBatchOp.getSourceCoreIds()[lane])); + mapper.map(receiveBatchOp.getOutput(), scalarReceive.getOutput()); + continue; + } + + if (auto receiveTensorBatchOp = dyn_cast(op)) { + auto scalarReceive = pim::PimReceiveTensorOp::create( + builder, + receiveTensorBatchOp.getLoc(), + receiveTensorBatchOp.getOutput().getType(), + mapper.lookup(receiveTensorBatchOp.getOutputBuffer()), + builder.getDenseI32ArrayAttr(getLaneChunkCoreIds(receiveTensorBatchOp.getSourceCoreIds(), laneCount, lane))); + mapper.map(receiveTensorBatchOp.getOutput(), scalarReceive.getOutput()); + continue; + } + + if (auto memcpBatchOp = dyn_cast(op)) { + Value hostSource = mapper.lookupOrNull(memcpBatchOp.getHostSource()); + if (!hostSource) + hostSource = memcpBatchOp.getHostSource(); + + auto scalarCopy = pim::PimMemCopyHostToDevOp::create(builder, + memcpBatchOp.getLoc(), + memcpBatchOp.getOutput().getType(), + mapper.lookup(memcpBatchOp.getDeviceTarget()), + hostSource, + memcpBatchOp.getDeviceTargetOffsetAttr(), + memcpBatchOp.getHostSourceOffsetAttr(), + memcpBatchOp.getSizeAttr()); + mapper.map(memcpBatchOp.getOutput(), scalarCopy.getOutput()); + continue; + } + + Operation* cloned = builder.clone(op, mapper); + for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults())) + mapper.map(originalResult, clonedResult); + } + + if (block->empty() || !isa(block->back())) + pim::PimHaltOp::create(builder, coreBatchOp.getLoc()); + return callback(scalarCore); +} + +} // namespace onnx_mlir diff --git a/src/PIM/Compiler/PimBatchEmission.hpp b/src/PIM/Compiler/PimBatchEmission.hpp new file mode 100644 index 0000000..62c4797 --- /dev/null +++ b/src/PIM/Compiler/PimBatchEmission.hpp @@ -0,0 +1,13 @@ +#pragma once + +#include "llvm/ADT/STLFunctionalExtras.h" + +#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" + +namespace onnx_mlir { + +mlir::LogicalResult withScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp, + unsigned lane, + llvm::function_ref callback); + +} // namespace onnx_mlir diff --git a/src/PIM/Compiler/PimCodeGen.cpp b/src/PIM/Compiler/PimCodeGen.cpp index 288798c..2da1735 100644 --- a/src/PIM/Compiler/PimCodeGen.cpp +++ b/src/PIM/Compiler/PimCodeGen.cpp @@ -5,12 +5,10 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/IRMapping.h" #include "mlir/IR/Value.h" #include "mlir/IR/Verifier.h" #include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/Debug.h" #include "llvm/Support/FileSystem.h" @@ -21,7 +19,6 @@ #include #include #include -#include #include #include #include @@ -29,8 +26,11 @@ #include "Common/PimCommon.hpp" #include "Conversion/ONNXToSpatial/Common/Common.hpp" +#include "src/Accelerators/PIM/Compiler/PimArtifactWriter.hpp" +#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp" #include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" +#include "src/Accelerators/PIM/Compiler/PimWeightEmitter.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" using namespace llvm; @@ -42,79 +42,6 @@ static size_t getValueSizeInBytes(mlir::Value value) { return type.getNumElements() * type.getElementTypeBitWidth() / 8; } -struct DenseWeightView { - DenseElementsAttr denseAttr; - SmallVector shape; - SmallVector strides; - int64_t offset = 0; -}; - -static SmallVector computeRowMajorStridesForShape(ArrayRef shape) { - SmallVector strides(shape.size(), 1); - for (int64_t index = static_cast(shape.size()) - 2; index >= 0; --index) - strides[index] = strides[index + 1] * shape[index + 1]; - return strides; -} - -static bool allStaticSubviewParts(memref::SubViewOp subview) { - return llvm::all_of(subview.getStaticOffsets(), [](int64_t value) { return !ShapedType::isDynamic(value); }) - && llvm::all_of(subview.getStaticSizes(), [](int64_t value) { return !ShapedType::isDynamic(value); }) - && llvm::all_of(subview.getStaticStrides(), [](int64_t value) { return !ShapedType::isDynamic(value); }); -} - -static FailureOr resolveDenseWeightView(ModuleOp moduleOp, mlir::Value weight) { - SmallVector subviews; - mlir::Value current = weight; - memref::GetGlobalOp getGlobalOp; - - while (true) { - Operation* defOp = current.getDefiningOp(); - if (!defOp) - return failure(); - if ((getGlobalOp = dyn_cast(defOp))) - break; - if (auto subview = dyn_cast(defOp)) { - if (!allStaticSubviewParts(subview)) - return failure(); - subviews.push_back(subview); - current = subview.getSource(); - continue; - } - if (auto cast = dyn_cast(defOp)) { - current = cast.getSource(); - continue; - } - return failure(); - } - - auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp); - if (!globalOp || !globalOp.getInitialValue()) - return failure(); - - auto denseAttr = dyn_cast(*globalOp.getInitialValue()); - if (!denseAttr) - return failure(); - - DenseWeightView view; - view.denseAttr = denseAttr; - view.shape.assign(denseAttr.getType().getShape().begin(), denseAttr.getType().getShape().end()); - view.strides = computeRowMajorStridesForShape(view.shape); - - for (memref::SubViewOp subview : llvm::reverse(subviews)) { - SmallVector nextStrides; - nextStrides.reserve(subview.getStaticStrides().size()); - for (auto [offset, stride, sourceStride] : - llvm::zip_equal(subview.getStaticOffsets(), subview.getStaticStrides(), view.strides)) { - view.offset += offset * sourceStride; - nextStrides.push_back(stride * sourceStride); - } - view.shape.assign(subview.getStaticSizes().begin(), subview.getStaticSizes().end()); - view.strides = std::move(nextStrides); - } - - return view; -} - MemEntry* PimMemory::gatherMemEntry(mlir::Value value) { auto type = cast(value.getType()); assert("Only static shape is supported" && type.hasStaticShape()); @@ -745,80 +672,6 @@ static SmallVector collectTopLevelCoreLikeOps(func::FuncOp funcOp) { return coreLikeOps; } -static pim::PimCoreOp materializeScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp, unsigned lane) { - OpBuilder builder(coreBatchOp); - builder.setInsertionPointAfter(coreBatchOp); - - size_t laneCount = static_cast(coreBatchOp.getLaneCount()); - size_t weightsPerLane = coreBatchOp.getWeights().size() / laneCount; - SmallVector laneWeights; - laneWeights.reserve(weightsPerLane); - for (size_t weightIndex = 0; weightIndex < weightsPerLane; ++weightIndex) - laneWeights.push_back(coreBatchOp.getWeights()[lane * weightsPerLane + weightIndex]); - - auto coreIds = getBatchCoreIds(coreBatchOp); - auto scalarCore = pim::PimCoreOp::create( - builder, coreBatchOp.getLoc(), ValueRange(laneWeights), builder.getI32IntegerAttr(coreIds[lane])); - Block* block = builder.createBlock(&scalarCore.getBody(), scalarCore.getBody().end()); - IRMapping mapper; - if (coreBatchOp.getBody().front().getNumArguments() == 1) - mapper.map(coreBatchOp.getBody().front().getArgument(0), coreBatchOp.getInputs()[lane]); - - builder.setInsertionPointToEnd(block); - for (Operation& op : coreBatchOp.getBody().front()) { - if (isa(op)) { - pim::PimHaltOp::create(builder, op.getLoc()); - continue; - } - - if (auto sendBatchOp = dyn_cast(op)) { - pim::PimSendOp::create(builder, - sendBatchOp.getLoc(), - mapper.lookup(sendBatchOp.getInput()), - sendBatchOp.getSizeAttr(), - builder.getI32IntegerAttr(sendBatchOp.getTargetCoreIds()[lane])); - continue; - } - - if (auto receiveBatchOp = dyn_cast(op)) { - auto scalarReceive = - pim::PimReceiveOp::create(builder, - receiveBatchOp.getLoc(), - receiveBatchOp.getOutput().getType(), - mapper.lookup(receiveBatchOp.getOutputBuffer()), - receiveBatchOp.getSizeAttr(), - builder.getI32IntegerAttr(receiveBatchOp.getSourceCoreIds()[lane])); - mapper.map(receiveBatchOp.getOutput(), scalarReceive.getOutput()); - continue; - } - - if (auto memcpBatchOp = dyn_cast(op)) { - mlir::Value hostSource = mapper.lookupOrNull(memcpBatchOp.getHostSource()); - if (!hostSource) - hostSource = memcpBatchOp.getHostSource(); - - auto scalarCopy = pim::PimMemCopyHostToDevOp::create(builder, - memcpBatchOp.getLoc(), - memcpBatchOp.getOutput().getType(), - mapper.lookup(memcpBatchOp.getDeviceTarget()), - hostSource, - memcpBatchOp.getDeviceTargetOffsetAttr(), - memcpBatchOp.getHostSourceOffsetAttr(), - memcpBatchOp.getSizeAttr()); - mapper.map(memcpBatchOp.getOutput(), scalarCopy.getOutput()); - continue; - } - - Operation* cloned = builder.clone(op, mapper); - for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults())) - mapper.map(originalResult, clonedResult); - } - - if (block->empty() || !isa(block->back())) - pim::PimHaltOp::create(builder, coreBatchOp.getLoc()); - return scalarCore; -} - static void aliasMaterializedHostGlobals(ModuleOp moduleOp, func::FuncOp funcOp, pim::PimCoreOp coreOp, @@ -844,56 +697,6 @@ static void aliasMaterializedHostGlobals(ModuleOp moduleOp, }); } -/// Write global constant data into a binary memory image at their allocated addresses. -static OnnxMlirCompilerErrorCodes -writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory& memory, StringRef outputDirPath) { - auto memoryFilePath = (outputDirPath + "/memory.bin").str(); - std::error_code errorCode; - raw_fd_ostream memoryFileStream(memoryFilePath, errorCode, sys::fs::OF_None); - if (errorCode) { - errs() << "Error while opening memory file " << memoryFilePath << ": " << errorCode.message() << '\n'; - return InvalidOutputFileAccess; - } - - std::vector memoryBuffer(memory.hostMem.getFirstAvailableAddress(), 0); - - SmallPtrSet writtenGlobals; - funcOp.walk([&](memref::GetGlobalOp getGlobalOp) { - if (hasWeightAlways(getGlobalOp)) - return; - auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp); - if (!globalOp) - return; - if (!writtenGlobals.insert(globalOp.getOperation()).second) - return; - auto initialValue = globalOp.getInitialValue(); - if (!initialValue) - return; - auto denseAttr = dyn_cast(*initialValue); - if (!denseAttr) - return; - - MemEntry memEntry = memory.hostMem.getMemEntry(getGlobalOp.getResult()); - ArrayRef rawData = denseAttr.getRawData(); - char* dst = memoryBuffer.data() + memEntry.address; - - if (denseAttr.isSplat()) { - size_t elementSize = rawData.size(); - assert(elementSize * getGlobalOp.getType().getNumElements() == memEntry.size && "Data size mismatch"); - for (size_t offset = 0; offset < memEntry.size; offset += elementSize) - std::memcpy(dst + offset, rawData.data(), std::min(elementSize, memEntry.size - offset)); - } - else { - assert(rawData.size() == memEntry.size && "Data size mismatch"); - std::memcpy(dst, rawData.data(), rawData.size()); - } - }); - - memoryFileStream.write(memoryBuffer.data(), memoryBuffer.size()); - memoryFileStream.close(); - return CompilerSuccess; -} - /// Dispatch all operations in a core region to the appropriate code generator. /// scf.for loops are statically unrolled via walkPimCoreBlock so that addressing is /// fully resolved before the JSON instructions are emitted. @@ -948,7 +751,6 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) { coreCodeGen.codeGetGlobalOp(getGlobalOp, knowledge); else { op.emitError("Unsupported codegen for this operation"); - op.dump(); return failure(); } processedOperations++; @@ -957,154 +759,6 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) { return failed(result) ? -1 : static_cast(processedOperations); } -llvm::DenseMap> -createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) { - ModuleOp moduleOp = funcOp->getParentOfType(); - auto coreWeightsDirPath = outputDirPath + "/weights"; - auto error = sys::fs::create_directory(coreWeightsDirPath); - assert(!error && "Error creating weights directory"); - size_t indexFileName = 0; - - int64_t xbarSize = crossbarSize.getValue(); - llvm::DenseMap> mapCoreWeightToFileName; - llvm::DenseMap mapGlobalOpToFileName; - - SmallVector coreLikeOps = collectTopLevelCoreLikeOps(funcOp); - - for (Operation* op : coreLikeOps) { - SmallVector scalarCores; - if (auto coreOp = dyn_cast(op)) { - scalarCores.push_back(coreOp); - } - else { - auto coreBatchOp = cast(op); - for (unsigned lane = 0; lane < static_cast(coreBatchOp.getLaneCount()); ++lane) - scalarCores.push_back(materializeScalarCoreFromBatchLane(coreBatchOp, lane)); - } - - for (pim::PimCoreOp coreOp : scalarCores) { - size_t coreId = static_cast(coreOp.getCoreId()); - for (unsigned index : getUsedWeightIndices(coreOp)) { - if (index >= coreOp.getWeights().size()) { - coreOp.emitWarning("Weight index " + std::to_string(index) + " is out of range"); - assert(index < coreOp.getWeights().size() && "Weight index is out of range"); - } - mlir::Value weight = coreOp.getWeights()[index]; - - auto weightView = resolveDenseWeightView(moduleOp, weight); - if (failed(weightView)) { - coreOp.emitWarning("Weight is not from a memref.get_global at index " + std::to_string(index)); - assert(succeeded(weightView) && "Weight is not from a dense memref.global view"); - } - - if (mapCoreWeightToFileName[coreId].contains(weight)) - continue; - - auto getGlobalOp = weight.getDefiningOp(); - auto globalOp = getGlobalOp ? lookupGlobalForGetGlobal(moduleOp, getGlobalOp) : memref::GlobalOp {}; - if (globalOp && mapGlobalOpToFileName.contains(globalOp)) { - auto& fileName = mapGlobalOpToFileName[globalOp]; - mapCoreWeightToFileName[coreId].insert({weight, fileName}); - continue; - } - - DenseElementsAttr denseAttr = weightView->denseAttr; - ArrayRef shape = weightView->shape; - assert(isMatrixShape(shape) && "Weight matrix must be 2-dimensional"); - int64_t numRows = shape[0]; - int64_t numCols = shape[1]; - assert(numRows <= xbarSize && numCols <= xbarSize && "Weight dimensions must not exceed crossbar size"); - - size_t elementByteWidth = denseAttr.getElementType().getIntOrFloatBitWidth() / 8; - - std::string newFileName = "crossbar_" + std::to_string(indexFileName++) + ".bin"; - auto weightFilePath = (coreWeightsDirPath + "/" + newFileName).str(); - std::error_code errorCode; - raw_fd_ostream weightFileStream(weightFilePath, errorCode, sys::fs::OF_None); - if (errorCode) { - errs() << "Error while opening weight file `" << weightFilePath << "`: " << errorCode.message() << '\n'; - assert(errorCode); - } - - uint64_t zero = 0; - for (int64_t row = 0; row < xbarSize; row++) { - for (int64_t col = 0; col < xbarSize; col++) { - if (row < numRows && col < numCols) { - int64_t elementIndex = weightView->offset + row * weightView->strides[0] + col * weightView->strides[1]; - APInt bits = denseAttr.getValues()[elementIndex].bitcastToAPInt(); - uint64_t word = bits.getZExtValue(); - weightFileStream.write(reinterpret_cast(&word), elementByteWidth); - } - else { - weightFileStream.write(reinterpret_cast(&zero), elementByteWidth); - } - } - } - - weightFileStream.close(); - if (globalOp) - mapGlobalOpToFileName.insert({globalOp, newFileName}); - mapCoreWeightToFileName[coreId].insert({weight, newFileName}); - } - } - - for (pim::PimCoreOp coreOp : scalarCores) - if (coreOp.getOperation() != op) - coreOp.erase(); - } - return mapCoreWeightToFileName; -} - -/// Write the top-level PIM configuration JSON (core count, crossbar config, I/O addresses). -static OnnxMlirCompilerErrorCodes writeConfigJson(func::FuncOp funcOp, - PimAcceleratorMemory& memory, - size_t maxCoreId, - json::Object xbarsPerArrayGroup, - StringRef outputDirPath) { - json::Object configJson; - - // pimsim-nn indexes cores directly by their numeric core ID, with the host - // occupying core 0. - configJson["core_cnt"] = maxCoreId + 1; - - // TODO: Should this be based on the floating point type used in the model? - // The 2 following values determine the bitwidth of the vectors' elements: bitwidth = adc_count * cell_precision - - // Number of ADC for MVM units - configJson["adc_count"] = 16; - // The bit precision of each ADC - configJson["cell_precision"] = 2; - - // Crossbar configuration - configJson["xbar_array_count"] = crossbarCountInCore.getValue(); - configJson["xbar_size"] = {crossbarSize.getValue(), crossbarSize.getValue()}; - configJson["array_group_map"] = std::move(xbarsPerArrayGroup); - - // Memory layout of inputs and outputs - json::Array inputsAddresses; - for (BlockArgument input : funcOp.getArguments()) - inputsAddresses.push_back(memory.getValueAddress(input)); - configJson["inputs_addresses"] = std::move(inputsAddresses); - - json::Array outputsAddresses; - for (func::ReturnOp returnOp : funcOp.getOps()) - for (mlir::Value output : returnOp.getOperands()) - outputsAddresses.push_back(memory.getValueAddress(output)); - configJson["outputs_addresses"] = std::move(outputsAddresses); - - auto configPath = (outputDirPath + "/config.json").str(); - std::error_code errorCode; - raw_fd_ostream jsonOS(configPath, errorCode); - if (errorCode) { - errs() << "Error while opening config file: " << errorCode.message() << '\n'; - return InvalidOutputFileAccess; - } - jsonOS << json::Value(std::move(configJson)) << '\n'; - jsonOS.close(); - - return CompilerSuccess; -} - OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::string& outputDirPath) { if (!outputDirPath.empty()) { if (auto error = sys::fs::create_directory(outputDirPath)) { @@ -1125,17 +779,8 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std:: if (auto err = writeMemoryBinary(moduleOp, funcOp, memory, outputDirPath)) return err; - // Write empty host core file - std::error_code errorCode; - auto outputHostCorePath = outputDirPath + "/core_0.json"; - raw_fd_ostream hostFileStream(outputHostCorePath, errorCode); - if (errorCode) { - errs() << "Error while opening host core file `" << outputHostCorePath << "`: " << errorCode.message() << '\n'; - return InvalidOutputFileAccess; - } - // The host core json contains 2 random instructions, just to make pimsim-nn happy - hostFileStream << "[{\"imm\":0,\"op\":\"sldi\",\"rd\":0},{\"imm\":0,\"op\":\"sldi\",\"rd\":0}]"; - hostFileStream.close(); + if (auto err = writeHostCoreJson(outputDirPath)) + return err; // For each core, specify the number of crossbar per array group. // This implementation always assigns one crossbar per group. @@ -1167,17 +812,7 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std:: } for (Operation* op : coreLikeOps) { - SmallVector scalarCores; - if (auto coreOp = dyn_cast(op)) { - scalarCores.push_back(coreOp); - } - else { - auto coreBatchOp = cast(op); - for (unsigned lane = 0; lane < static_cast(coreBatchOp.getLaneCount()); ++lane) - scalarCores.push_back(materializeScalarCoreFromBatchLane(coreBatchOp, lane)); - } - - for (pim::PimCoreOp coreOp : scalarCores) { + auto emitCore = [&](pim::PimCoreOp coreOp, bool temporaryCore) -> OnnxMlirCompilerErrorCodes { size_t originalCoreId = static_cast(coreOp.getCoreId()); size_t coreId = emittedCoreIds.lookup(originalCoreId); maxCoreId = std::max(maxCoreId, coreId); @@ -1232,13 +867,26 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std:: } xbarsPerArrayGroup["core" + std::to_string(coreId)] = std::move(xbarsPerGroup); + if (temporaryCore) + coreOp.walk([&memory](Operation* op) { memory.clean(op); }); + return CompilerSuccess; + }; + + if (auto coreOp = dyn_cast(op)) { + if (auto err = emitCore(coreOp, false)) + return err; + continue; } - for (pim::PimCoreOp coreOp : scalarCores) - if (coreOp.getOperation() != op) { - coreOp.walk([&memory](Operation* op) { memory.clean(op); }); - coreOp.erase(); - } + auto coreBatchOp = cast(op); + for (unsigned lane = 0; lane < static_cast(coreBatchOp.getLaneCount()); ++lane) { + OnnxMlirCompilerErrorCodes laneResult = CompilerSuccess; + if (failed(withScalarCoreFromBatchLane(coreBatchOp, lane, [&](pim::PimCoreOp coreOp) { + laneResult = emitCore(coreOp, true); + return laneResult == CompilerSuccess ? success() : failure(); + }))) + return laneResult == CompilerSuccess ? CompilerFailure : laneResult; + } } return writeConfigJson(funcOp, memory, maxCoreId, std::move(xbarsPerArrayGroup), outputDirPath); diff --git a/src/PIM/Compiler/PimWeightEmitter.cpp b/src/PIM/Compiler/PimWeightEmitter.cpp new file mode 100644 index 0000000..7545d5d --- /dev/null +++ b/src/PIM/Compiler/PimWeightEmitter.cpp @@ -0,0 +1,221 @@ +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/raw_ostream.h" + +#include + +#include "Conversion/ONNXToSpatial/Common/Common.hpp" +#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp" +#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp" +#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp" +#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" +#include "src/Accelerators/PIM/Compiler/PimWeightEmitter.hpp" +#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" + +using namespace llvm; +using namespace mlir; + +namespace onnx_mlir { +namespace { + +struct DenseWeightView { + DenseElementsAttr denseAttr; + SmallVector shape; + SmallVector strides; + int64_t offset = 0; +}; + +SmallVector computeRowMajorStridesForShape(ArrayRef shape) { + SmallVector strides(shape.size(), 1); + for (int64_t index = static_cast(shape.size()) - 2; index >= 0; --index) + strides[index] = strides[index + 1] * shape[index + 1]; + return strides; +} + +bool allStaticSubviewParts(memref::SubViewOp subview) { + return llvm::all_of(subview.getStaticOffsets(), [](int64_t value) { return !ShapedType::isDynamic(value); }) + && llvm::all_of(subview.getStaticSizes(), [](int64_t value) { return !ShapedType::isDynamic(value); }) + && llvm::all_of(subview.getStaticStrides(), [](int64_t value) { return !ShapedType::isDynamic(value); }); +} + +FailureOr resolveDenseWeightView(ModuleOp moduleOp, mlir::Value weight) { + SmallVector subviews; + mlir::Value current = weight; + memref::GetGlobalOp getGlobalOp; + + while (true) { + Operation* defOp = current.getDefiningOp(); + if (!defOp) + return failure(); + if ((getGlobalOp = dyn_cast(defOp))) + break; + if (auto subview = dyn_cast(defOp)) { + if (!allStaticSubviewParts(subview)) + return failure(); + subviews.push_back(subview); + current = subview.getSource(); + continue; + } + if (auto cast = dyn_cast(defOp)) { + current = cast.getSource(); + continue; + } + return failure(); + } + + auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp); + if (!globalOp || !globalOp.getInitialValue()) + return failure(); + + auto denseAttr = dyn_cast(*globalOp.getInitialValue()); + if (!denseAttr) + return failure(); + + DenseWeightView view; + view.denseAttr = denseAttr; + view.shape.assign(denseAttr.getType().getShape().begin(), denseAttr.getType().getShape().end()); + view.strides = computeRowMajorStridesForShape(view.shape); + + for (memref::SubViewOp subview : llvm::reverse(subviews)) { + SmallVector nextStrides; + nextStrides.reserve(subview.getStaticStrides().size()); + for (auto [offset, stride, sourceStride] : + llvm::zip_equal(subview.getStaticOffsets(), subview.getStaticStrides(), view.strides)) { + view.offset += offset * sourceStride; + nextStrides.push_back(stride * sourceStride); + } + view.shape.assign(subview.getStaticSizes().begin(), subview.getStaticSizes().end()); + view.strides = std::move(nextStrides); + } + + return view; +} + +SmallVector getUsedWeightIndices(Block& block) { + SmallVector indices; + auto addIndex = [&](unsigned weightIndex) { + if (!llvm::is_contained(indices, weightIndex)) + indices.push_back(weightIndex); + }; + + block.walk([&](pim::PimMVMOp mvmOp) { addIndex(mvmOp.getWeightIndex()); }); + block.walk([&](pim::PimVMMOp vmmOp) { addIndex(vmmOp.getWeightIndex()); }); + llvm::sort(indices); + return indices; +} + +SmallVector getUsedWeightIndices(pim::PimCoreOp coreOp) { + return getUsedWeightIndices(coreOp.getBody().front()); +} + +SmallVector collectTopLevelCoreLikeOps(func::FuncOp funcOp) { + SmallVector coreLikeOps; + for (Operation& op : funcOp.getBody().front()) + if (dyn_cast(&op) || dyn_cast(&op)) + coreLikeOps.push_back(&op); + return coreLikeOps; +} + +} // namespace + +llvm::DenseMap> +createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) { + ModuleOp moduleOp = funcOp->getParentOfType(); + auto coreWeightsDirPath = outputDirPath + "/weights"; + auto error = sys::fs::create_directory(coreWeightsDirPath); + assert(!error && "Error creating weights directory"); + size_t indexFileName = 0; + + int64_t xbarSize = crossbarSize.getValue(); + llvm::DenseMap> mapCoreWeightToFileName; + llvm::DenseMap mapGlobalOpToFileName; + + SmallVector coreLikeOps = collectTopLevelCoreLikeOps(funcOp); + + for (Operation* op : coreLikeOps) { + auto processCore = [&](pim::PimCoreOp coreOp) { + size_t coreId = static_cast(coreOp.getCoreId()); + for (unsigned index : getUsedWeightIndices(coreOp)) { + if (index >= coreOp.getWeights().size()) { + coreOp.emitWarning("Weight index " + std::to_string(index) + " is out of range"); + assert(index < coreOp.getWeights().size() && "Weight index is out of range"); + } + mlir::Value weight = coreOp.getWeights()[index]; + + auto weightView = resolveDenseWeightView(moduleOp, weight); + if (failed(weightView)) { + coreOp.emitWarning("Weight is not from a memref.get_global at index " + std::to_string(index)); + assert(succeeded(weightView) && "Weight is not from a dense memref.global view"); + } + + if (mapCoreWeightToFileName[coreId].contains(weight)) + continue; + + auto getGlobalOp = weight.getDefiningOp(); + auto globalOp = getGlobalOp ? lookupGlobalForGetGlobal(moduleOp, getGlobalOp) : memref::GlobalOp {}; + if (globalOp && mapGlobalOpToFileName.contains(globalOp)) { + auto& fileName = mapGlobalOpToFileName[globalOp]; + mapCoreWeightToFileName[coreId].insert({weight, fileName}); + continue; + } + + DenseElementsAttr denseAttr = weightView->denseAttr; + ArrayRef shape = weightView->shape; + assert(isMatrixShape(shape) && "Weight matrix must be 2-dimensional"); + int64_t numRows = shape[0]; + int64_t numCols = shape[1]; + assert(numRows <= xbarSize && numCols <= xbarSize && "Weight dimensions must not exceed crossbar size"); + + size_t elementByteWidth = denseAttr.getElementType().getIntOrFloatBitWidth() / 8; + + std::string newFileName = "crossbar_" + std::to_string(indexFileName++) + ".bin"; + auto weightFilePath = (coreWeightsDirPath + "/" + newFileName).str(); + std::error_code errorCode; + raw_fd_ostream weightFileStream(weightFilePath, errorCode, sys::fs::OF_None); + if (errorCode) { + errs() << "Error while opening weight file `" << weightFilePath << "`: " << errorCode.message() << '\n'; + assert(errorCode); + } + + uint64_t zero = 0; + for (int64_t row = 0; row < xbarSize; row++) { + for (int64_t col = 0; col < xbarSize; col++) { + if (row < numRows && col < numCols) { + int64_t elementIndex = weightView->offset + row * weightView->strides[0] + col * weightView->strides[1]; + APInt bits = denseAttr.getValues()[elementIndex].bitcastToAPInt(); + uint64_t word = bits.getZExtValue(); + weightFileStream.write(reinterpret_cast(&word), elementByteWidth); + } + else { + weightFileStream.write(reinterpret_cast(&zero), elementByteWidth); + } + } + } + + weightFileStream.close(); + if (globalOp) + mapGlobalOpToFileName.insert({globalOp, newFileName}); + mapCoreWeightToFileName[coreId].insert({weight, newFileName}); + } + return success(); + }; + + if (auto coreOp = dyn_cast(op)) { + (void) processCore(coreOp); + continue; + } + + auto coreBatchOp = cast(op); + for (unsigned lane = 0; lane < static_cast(coreBatchOp.getLaneCount()); ++lane) + if (failed(withScalarCoreFromBatchLane(coreBatchOp, lane, processCore))) + return mapCoreWeightToFileName; + } + return mapCoreWeightToFileName; +} + +} // namespace onnx_mlir diff --git a/src/PIM/Compiler/PimWeightEmitter.hpp b/src/PIM/Compiler/PimWeightEmitter.hpp new file mode 100644 index 0000000..a620028 --- /dev/null +++ b/src/PIM/Compiler/PimWeightEmitter.hpp @@ -0,0 +1,16 @@ +#pragma once + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Value.h" + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/StringRef.h" + +#include + +namespace onnx_mlir { + +llvm::DenseMap> +createAndPopulateWeightFolder(mlir::func::FuncOp funcOp, llvm::StringRef outputDirPath); + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt b/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt index 530ad6b..bccab97 100644 --- a/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt +++ b/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt @@ -3,6 +3,11 @@ mlir_tablegen(ONNXToSpatial.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}") add_public_tablegen_target(ONNXToSpatialIncGen) add_pim_library(OMONNXToSpatial + ConversionPatterns.cpp + HostFoldability.cpp + HostLegality.cpp + PrePatterns.cpp + PostPatterns.cpp Patterns/Math/Conv.cpp Patterns/Math/Elementwise.cpp Patterns/Math/Gemm.cpp diff --git a/src/PIM/Conversion/ONNXToSpatial/Common/Common.hpp b/src/PIM/Conversion/ONNXToSpatial/Common/Common.hpp index c820073..9099e29 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Common/Common.hpp +++ b/src/PIM/Conversion/ONNXToSpatial/Common/Common.hpp @@ -1,8 +1,7 @@ #pragma once -#include "src/Accelerators/PIM/Common/PimCommon.hpp" -#include "src/Dialect/ONNX/ONNXOps.hpp" - #include "ComputeRegionBuilder.hpp" #include "ShapeTilingUtils.hpp" #include "WeightMaterialization.hpp" +#include "src/Accelerators/PIM/Common/PimCommon.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" diff --git a/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.cpp b/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.cpp index 41c629d..b263a23 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.cpp @@ -5,6 +5,8 @@ #include "ShapeTilingUtils.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp" using namespace mlir; @@ -30,10 +32,29 @@ SmallVector sliceTensor( for (int64_t i = 0; i < numSlices; i++) { offsets[axis] = rewriter.getIndexAttr(i * sliceSize); - if (i == numSlices - 1 && lastSliceSize != 0) + int64_t currentSliceSize = sliceSize; + if (i == numSlices - 1 && lastSliceSize != 0) { + currentSliceSize = lastSliceSize; sizes[axis] = rewriter.getIndexAttr(lastSliceSize); + } - Value slice = tensor::ExtractSliceOp::create(rewriter, loc, tensorToSlice, offsets, sizes, strides); + SmallVector sliceShape(shape.begin(), shape.end()); + sliceShape[axis] = currentSliceSize; + auto sliceType = + RankedTensorType::get(sliceShape, cast(tensorToSlice.getType()).getElementType()); + + Value slice; + if (isHostFoldableValue(tensorToSlice)) { + slice = tensor::ExtractSliceOp::create(rewriter, loc, tensorToSlice, offsets, sizes, strides); + } + else { + auto sliceCompute = + createSpatCompute<1>(rewriter, loc, TypeRange {sliceType}, {}, ValueRange {tensorToSlice}, [&](Value input) { + Value computedSlice = tensor::ExtractSliceOp::create(rewriter, loc, input, offsets, sizes, strides); + spatial::SpatYieldOp::create(rewriter, loc, computedSlice); + }); + slice = sliceCompute.getResult(0); + } slices.push_back(slice); } diff --git a/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.hpp b/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.hpp index 6c433a9..b6d6182 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.hpp +++ b/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.hpp @@ -5,15 +5,15 @@ #include "mlir/IR/Value.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallVector.h" + #include #include #include #include -#include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/SmallVector.h" - namespace onnx_mlir { template @@ -105,7 +105,8 @@ inline auto getTensorShape(mlir::Value tensor) { inline bool haveSameStaticShape(mlir::Value lhs, mlir::Value rhs) { auto lhsType = mlir::dyn_cast(lhs.getType()); auto rhsType = mlir::dyn_cast(rhs.getType()); - return lhsType && rhsType && lhsType.hasStaticShape() && rhsType.hasStaticShape() && lhsType.getShape() == rhsType.getShape(); + return lhsType && rhsType && lhsType.hasStaticShape() && rhsType.hasStaticShape() + && lhsType.getShape() == rhsType.getShape(); } /// Slices a statically shaped tensor along one axis into contiguous pieces of diff --git a/src/PIM/Conversion/ONNXToSpatial/Common/WeightMaterialization.cpp b/src/PIM/Conversion/ONNXToSpatial/Common/WeightMaterialization.cpp index 4645aa0..75931bf 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Common/WeightMaterialization.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Common/WeightMaterialization.cpp @@ -5,12 +5,12 @@ #include "mlir/IR/Value.h" #include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/STLExtras.h" -#include "WeightMaterialization.hpp" #include "ShapeTilingUtils.hpp" +#include "WeightMaterialization.hpp" #include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" @@ -28,7 +28,7 @@ bool isWeightLikeComputeOperand(Value value) { while (auto* definingOp = value.getDefiningOp()) { if (!visited.insert(definingOp).second) return false; - if (hasWeightAlways(definingOp)) + if (isa(definingOp) || hasWeightAlways(definingOp)) return true; if (auto extractSliceOp = dyn_cast(definingOp)) { diff --git a/src/PIM/Conversion/ONNXToSpatial/ConversionPatterns.cpp b/src/PIM/Conversion/ONNXToSpatial/ConversionPatterns.cpp new file mode 100644 index 0000000..67d0273 --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/ConversionPatterns.cpp @@ -0,0 +1,32 @@ +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" + +using namespace mlir; + +namespace onnx_mlir { + +namespace { + +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc" + +} // namespace + +void populateConversionPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx) { + patterns.add(ctx); + + populateElementwisePatterns(patterns, ctx); + populateGemmPatterns(patterns, ctx); + populateConvPatterns(patterns, ctx); + populatePoolPatterns(patterns, ctx); + populateReduceMeanPatterns(patterns, ctx); + populateReluPatterns(patterns, ctx); + populateSigmoidPatterns(patterns, ctx); + populateSoftmaxPatterns(patterns, ctx); + populateConcatPatterns(patterns, ctx); + populateGatherPatterns(patterns, ctx); + populateResizePatterns(patterns, ctx); + populateReshapePatterns(patterns, ctx); + populateSplitPatterns(patterns, ctx); +} + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns.hpp b/src/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp similarity index 93% rename from src/PIM/Conversion/ONNXToSpatial/Patterns.hpp rename to src/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp index 7c44286..892c18d 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns.hpp +++ b/src/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp @@ -5,6 +5,8 @@ namespace onnx_mlir { +void populateConversionPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); + void populateConvPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateElementwisePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); diff --git a/src/PIM/Conversion/ONNXToSpatial/HostFoldability.cpp b/src/PIM/Conversion/ONNXToSpatial/HostFoldability.cpp new file mode 100644 index 0000000..4c73c52 --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/HostFoldability.cpp @@ -0,0 +1,75 @@ +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinTypes.h" + +#include "llvm/ADT/SmallPtrSet.h" + +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" + +using namespace mlir; + +namespace onnx_mlir { + +namespace { + +static bool hasStaticUnitStrides(tensor::ExtractSliceOp extractSliceOp) { + return llvm::all_of(extractSliceOp.getStaticStrides(), [](int64_t stride) { return stride == 1; }); +} + +static bool isStaticTensorResult(Operation* op) { + return llvm::all_of(op->getResultTypes(), [](Type type) { + auto shapedType = dyn_cast(type); + return shapedType && shapedType.hasStaticShape(); + }); +} + +static bool isHostFoldableOpImpl(Operation* op, llvm::SmallPtrSetImpl& visited) { + if (!op || !visited.insert(op).second) + return false; + + if (isa(op)) + return true; + + if (!isStaticTensorResult(op)) + return false; + + if (auto transposeOp = dyn_cast(op)) + return isHostFoldableValue(transposeOp.getData()); + + if (auto collapseShapeOp = dyn_cast(op)) + return isHostFoldableValue(collapseShapeOp.getSrc()); + + if (auto expandShapeOp = dyn_cast(op)) + return isHostFoldableValue(expandShapeOp.getSrc()); + + if (auto extractSliceOp = dyn_cast(op)) + return hasStaticUnitStrides(extractSliceOp) && isHostFoldableValue(extractSliceOp.getSource()); + + if (auto extractRowsOp = dyn_cast(op)) + return isHostFoldableValue(extractRowsOp.getInput()); + + if (auto concatOp = dyn_cast(op)) + return llvm::all_of(concatOp.getInputs(), isHostFoldableValue); + + return false; +} + +} // namespace + +bool isHostFoldableValue(Value value) { + auto* definingOp = value.getDefiningOp(); + if (!definingOp) + return false; + + llvm::SmallPtrSet visited; + return isHostFoldableOpImpl(definingOp, visited); +} + +bool isHostFoldableOp(Operation* op) { + llvm::SmallPtrSet visited; + return isHostFoldableOpImpl(op, visited); +} + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp b/src/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp new file mode 100644 index 0000000..0479987 --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp @@ -0,0 +1,12 @@ +#pragma once + +#include "mlir/IR/Operation.h" +#include "mlir/IR/Value.h" + +namespace onnx_mlir { + +bool isHostFoldableValue(mlir::Value value); + +bool isHostFoldableOp(mlir::Operation* op); + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/HostLegality.cpp b/src/PIM/Conversion/ONNXToSpatial/HostLegality.cpp new file mode 100644 index 0000000..b16d513 --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/HostLegality.cpp @@ -0,0 +1,29 @@ +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" + +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostLegality.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" + +using namespace mlir; + +namespace onnx_mlir { + +LogicalResult verifyONNXToSpatialHostLegality(func::FuncOp funcOp) { + bool hasFailure = false; + + for (Operation& op : funcOp.getFunctionBody().front()) { + if (isa(&op)) + continue; + if (isHostFoldableOp(&op)) + continue; + + op.emitOpError("non-foldable top-level runtime op remains after ONNX-to-Spatial; lower it inside spat.compute"); + hasFailure = true; + } + + return success(!hasFailure); +} + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/HostLegality.hpp b/src/PIM/Conversion/ONNXToSpatial/HostLegality.hpp new file mode 100644 index 0000000..3521eae --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/HostLegality.hpp @@ -0,0 +1,10 @@ +#pragma once + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Support/LogicalResult.h" + +namespace onnx_mlir { + +mlir::LogicalResult verifyONNXToSpatialHostLegality(mlir::func::FuncOp funcOp); + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp index c08fa4e..12e8010 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp @@ -8,21 +8,17 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" -#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" -#include "llvm/Support/raw_os_ostream.h" - -#include -#include -#include #include "Common/Common.hpp" #include "Common/PimCommon.hpp" +#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" -#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostLegality.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Compiler/CompilerOptions.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" @@ -33,8 +29,6 @@ namespace onnx_mlir { namespace { -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc" - struct ONNXToSpatialPass : PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ONNXToSpatialPass) StringRef getArgument() const override { return "convert-onnx-to-spatial"; } @@ -44,71 +38,64 @@ struct ONNXToSpatialPass : PassWrapper batchOps; - funcOp.walk([&](spatial::SpatComputeBatch batchOp) { batchOps.push_back(batchOp); }); + IRMapping mapper; + SmallVector computes(funcOp.getOps()); + if (!computes.empty()) + return; - for (auto batchOp : batchOps) { - if (batchOp.getLaneCount() != 1) - continue; + auto returnOp = cast(funcOp.getFunctionBody().front().getTerminator()); + rewriter.setInsertionPoint(returnOp); - auto loc = batchOp.getLoc(); - rewriter.setInsertionPoint(batchOp); - auto computeOp = - spatial::SpatCompute::create(rewriter, loc, batchOp.getResultTypes(), batchOp.getWeights(), batchOp.getInputs()); - computeOp.getProperties().setOperandSegmentSizes( - {static_cast(batchOp.getWeights().size()), static_cast(batchOp.getInputs().size())}); + SmallVector sourceTypes; + SmallVector sourceLocs; + sourceTypes.reserve(funcOp.getNumArguments()); + sourceLocs.reserve(funcOp.getNumArguments()); + for (Value source : funcOp.getArguments()) { + sourceTypes.push_back(source.getType()); + sourceLocs.push_back(source.getLoc()); + } - Block& templateBlock = batchOp.getBody().front(); - SmallVector blockArgTypes; - SmallVector blockArgLocs; - for (BlockArgument arg : templateBlock.getArguments()) { - blockArgTypes.push_back(arg.getType()); - blockArgLocs.push_back(loc); - } - auto* newBlock = - rewriter.createBlock(&computeOp.getBody(), computeOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs); + auto newCompute = spatial::SpatCompute::create( + rewriter, returnOp.getLoc(), returnOp.getOperandTypes(), funcOp.getArguments(), {}, {}); + auto* newBlock = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLocs); + for (auto [blockArg, computeArg] : llvm::zip(newBlock->getArguments(), newCompute.getOperands())) + mapper.map(computeArg, blockArg); + newCompute.getProperties().setOperandSegmentSizes({0, static_cast(sourceTypes.size())}); - IRMapping mapper; - for (auto [oldArg, newArg] : llvm::zip(templateBlock.getArguments(), newBlock->getArguments())) - mapper.map(oldArg, newArg); - rewriter.setInsertionPointToEnd(newBlock); - for (Operation& op : templateBlock) + rewriter.setInsertionPointToEnd(newBlock); + for (Operation& op : funcOp.getOps()) + if (!isa(&op)) rewriter.clone(op, mapper); - batchOp.replaceAllUsesWith(computeOp.getResults()); - rewriter.eraseOp(batchOp); - } + auto yield = spatial::SpatYieldOp::create(rewriter, funcOp.getLoc(), returnOp.getOperands()); + for (size_t i = 0; i < yield.getNumOperands(); ++i) + yield.setOperand(i, mapper.lookupOrDefault(yield.getOperand(i))); + + for (Operation& op : llvm::make_early_inc_range(funcOp.getOps())) + if (!isa(&op)) { + op.dropAllUses(); + rewriter.eraseOp(&op); + } + + for (auto [index, computeResult] : llvm::enumerate(newCompute.getResults())) + returnOp.setOperand(index, computeResult); } void ONNXToSpatialPass::runOnOperation() { ModuleOp moduleOp = getOperation(); MLIRContext* ctx = &getContext(); - RewritePatternSet mergeActivationPatterns(ctx); - mergeActivationPatterns.add(ctx); - mergeActivationPatterns.add(ctx); - mergeActivationPatterns.add(ctx); - mergeActivationPatterns.add(ctx); - mergeActivationPatterns.add(ctx); - mergeActivationPatterns.add(ctx); - populateMatMulRewritePatterns(mergeActivationPatterns, ctx); + RewritePatternSet prePatterns(ctx); + populatePrePatterns(prePatterns, ctx); + if (failed(applyPatternsGreedily(moduleOp, std::move(prePatterns)))) + llvm::dbgs() << "Failed to apply pre-patterns, continuing...\n"; - if (failed(applyPatternsGreedily(moduleOp, std::move(mergeActivationPatterns)))) - llvm::dbgs() << "Failed to merge activation patterns, continuing...\n"; - - IRRewriter rewriter(moduleOp); auto entryFunc = getPimEntryFunc(moduleOp); if (failed(entryFunc)) { signalPassFailure(); @@ -140,34 +127,23 @@ void ONNXToSpatialPass::runOnOperation() { target.addIllegalOp(); target.addIllegalOp(); - RewritePatternSet patterns(ctx); - patterns.add(ctx); - - populateElementwisePatterns(patterns, ctx); - populateGemmPatterns(patterns, ctx); - populateConvPatterns(patterns, ctx); - populatePoolPatterns(patterns, ctx); - populateReduceMeanPatterns(patterns, ctx); - populateReluPatterns(patterns, ctx); - populateSigmoidPatterns(patterns, ctx); - populateSoftmaxPatterns(patterns, ctx); - populateConcatPatterns(patterns, ctx); - populateGatherPatterns(patterns, ctx); - populateResizePatterns(patterns, ctx); - populateReshapePatterns(patterns, ctx); - populateSplitPatterns(patterns, ctx); - - if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) { + RewritePatternSet conversionPatterns(ctx); + populateConversionPatterns(conversionPatterns, ctx); + if (failed(applyPartialConversion(moduleOp, target, std::move(conversionPatterns)))) { signalPassFailure(); return; } - foldSingleLaneComputeBatches(*entryFunc); + RewritePatternSet earlyPostPatterns(ctx); + populateEarlyPostPatterns(earlyPostPatterns, ctx); + if (failed(applyPatternsGreedily(*entryFunc, std::move(earlyPostPatterns)))) { + signalPassFailure(); + return; + } - // Count the number of compute ops and check they do not exceed the core count if (coresCount != -1) { int computeOpsCount = 0; - for (auto& op : entryFunc->getFunctionBody().front().getOperations()) + for (Operation& op : entryFunc->getFunctionBody().front().getOperations()) if (isa(op)) computeOpsCount++; @@ -185,355 +161,23 @@ void ONNXToSpatialPass::runOnOperation() { annotateWeightsConstants(*entryFunc); + RewritePatternSet postPatterns(ctx); + populatePostPatterns(postPatterns, ctx); + if (failed(applyPatternsGreedily(*entryFunc, std::move(postPatterns)))) { + signalPassFailure(); + return; + } + + if (failed(verifyONNXToSpatialHostLegality(*entryFunc))) { + signalPassFailure(); + return; + } + populateEmptyFunction(*entryFunc); - if (failed(encapsulateGlobalInstruction(*entryFunc))) { - signalPassFailure(); - return; - } - - if (failed(promoteConstantInputsToWeights(*entryFunc))) { - signalPassFailure(); - return; - } - - // Dump to file for debug dumpModule(moduleOp, "spatial0"); } -template -bool encapsulator(IRRewriter& rewriter, Location loc, Operation* inst, std::function funcSource) { - if (T toRemoveOp = llvm::dyn_cast_if_present(inst)) { - Value source = funcSource(toRemoveOp); - rewriter.setInsertionPointAfter(toRemoveOp); - auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), source); - auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), {source.getType()}, {loc}); - newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) 1}); - rewriter.setInsertionPointToEnd(BB); - IRMapping mapper; - mapper.map(source, BB->getArgument(0)); - auto newInst = rewriter.clone(*inst, mapper); - spatial::SpatYieldOp::create(rewriter, loc, newInst->getResults()); - inst->replaceAllUsesWith(newCompute->getResults()); - inst->erase(); - return true; - } - return false; -} - -bool encapsulateSlice(IRRewriter& rewriter, Location loc, Operation* inst) { - if (tensor::ExtractSliceOp toRemoveOp = llvm::dyn_cast_if_present(inst)) { - auto source = toRemoveOp.getSource(); - rewriter.setInsertionPointAfter(toRemoveOp); - auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), source); - auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), {source.getType()}, {loc}); - newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) 1}); - rewriter.setInsertionPointToEnd(BB); - IRMapping mapper; - mapper.map(source, BB->getArgument(0)); - auto newInst = rewriter.clone(*inst, mapper); - spatial::SpatYieldOp::create(rewriter, loc, newInst->getResults()); - inst->replaceAllUsesWith(newCompute->getResults()); - inst->erase(); - return true; - } - return false; -} - -bool encapsulateConcat(IRRewriter& rewriter, Location loc, Operation* inst) { - if (auto toRemoveOp = llvm::dyn_cast_if_present(inst)) { - auto sources = toRemoveOp.getInputs(); - rewriter.setInsertionPointAfter(toRemoveOp); - if (llvm::any_of(sources, - [](auto source) { return isa_and_present(source.getDefiningOp()); })) { - auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), sources); - SmallVector sourceTypes; - SmallVector sourceLoc; - for (auto source : sources) { - sourceTypes.push_back(source.getType()); - sourceLoc.push_back(loc); - } - auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLoc); - newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) sources.size()}); - rewriter.setInsertionPointToEnd(BB); - IRMapping mapper; - for (auto [source, bbArg] : llvm::zip(sources, BB->getArguments())) - mapper.map(source, bbArg); - auto newConcat = spatial::SpatConcatOp::create(rewriter, - loc, - toRemoveOp.getType(), - rewriter.getI64IntegerAttr(toRemoveOp.getDim()), - ValueRange(BB->getArguments())); - spatial::SpatYieldOp::create(rewriter, loc, newConcat.getOutput()); - inst->replaceAllUsesWith(newCompute->getResults()); - inst->erase(); - return true; - } - auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), sources); - SmallVector sourceTypes; - SmallVector sourceLoc; - for (auto source : sources) { - sourceTypes.push_back(source.getType()); - sourceLoc.push_back(loc); - } - auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLoc); - newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) sources.size()}); - rewriter.setInsertionPointToEnd(BB); - IRMapping mapper; - for (auto [source, bbArg] : llvm::zip(sources, BB->getArguments())) - mapper.map(source, bbArg); - auto newConcat = rewriter.clone(*inst, mapper); - spatial::SpatYieldOp::create(rewriter, loc, newConcat->getResults()); - inst->replaceAllUsesWith(newCompute->getResults()); - inst->erase(); - return true; - } - return false; -} - -static FailureOr sourceOperandHasWeightAlways(Operation* op) { - if (op == nullptr) - return false; - - Operation* source = nullptr; - do { - - if (isa(*op)) { - return false; - } - else if (auto extractSliceOp = dyn_cast(*op)) { - auto tmpSource = extractSliceOp.getSource(); - auto definingOp = tmpSource.getDefiningOp(); - if (definingOp) - op = definingOp; - else - return false; - } - else if (auto extractRowsOp = dyn_cast(*op)) { - auto tmpSource = extractRowsOp.getInput(); - auto definingOp = tmpSource.getDefiningOp(); - if (definingOp) - op = definingOp; - else - return false; - } - else if (auto expandShapeOp = dyn_cast(*op)) { - auto tmpSource = expandShapeOp.getSrc(); - auto definingOp = tmpSource.getDefiningOp(); - if (definingOp) - op = definingOp; - else - return false; - } - else if (auto transposeOp = dyn_cast(*op)) { - auto tmpSource = transposeOp.getData(); - auto definingOp = tmpSource.getDefiningOp(); - if (definingOp) - op = definingOp; - else - return false; - } - else if (auto collapseShapeOp = dyn_cast(*op)) { - auto tmpSource = collapseShapeOp.getSrc(); - auto definingOp = tmpSource.getDefiningOp(); - if (definingOp) - op = definingOp; - else - return false; - } - else if (auto constantOp = dyn_cast(*op)) { - source = constantOp; - } - else if (auto concatOp = dyn_cast(*op)) { - bool res = false; - for (auto operand : concatOp.getOperands()) { - res |= hasWeightAlways(operand.getDefiningOp()); - if (res) - return res; - } - return res; - } - else if (auto concatOp = dyn_cast(*op)) { - bool res = false; - for (auto operand : concatOp.getOperands()) { - res |= hasWeightAlways(operand.getDefiningOp()); - if (res) - return res; - } - return res; - } - else { - op->emitOpError("unsupported global instruction while promoting weight-backed operands into Spatial computes"); - return failure(); - } - } - while (source == nullptr); - - return hasWeightAlways(source); -} - -// TODO what we want to keep in global? -LogicalResult ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) { - Location loc = funcOp.getLoc(); - IRRewriter rewriter(&getContext()); - bool keep = true; - while (keep) { - keep = false; - for (auto& instruction : llvm::make_early_inc_range(funcOp.getOps())) { - if (isa(instruction) - || isa(instruction)) - continue; - - auto weightBacked = sourceOperandHasWeightAlways(&instruction); - if (failed(weightBacked)) - return failure(); - if (*weightBacked) - continue; - - keep |= encapsulateSlice(rewriter, loc, &instruction); - - keep |= encapsulator( - rewriter, loc, &instruction, [](tensor::ExpandShapeOp expand) { return expand.getSrc(); }); - - keep |= encapsulator( - rewriter, loc, &instruction, [](ONNXTransposeOp transpose) { return transpose.getData(); }); - - keep |= encapsulator( - rewriter, loc, &instruction, [](tensor::CollapseShapeOp collapse) { return collapse.getSrc(); }); - - keep |= encapsulateConcat(rewriter, loc, &instruction); - } - } - return success(); -} - -void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const { - funcOp.walk([&](arith::ConstantOp constantOp) { - if (hasOnlySpatialMvmVmmWeightUses(constantOp.getResult())) - markWeightAlways(constantOp); - }); -} - -LogicalResult ONNXToSpatialPass::promoteConstantInputsToWeights(func::FuncOp funcOp) { - IRRewriter rewriter(&getContext()); - SmallVector computes(funcOp.getOps()); - - for (auto compute : computes) { - SmallVector promoteInput(compute.getInputs().size(), false); - bool needsRewrite = false; - for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) { - if (!isWeightLikeComputeOperand(input)) - continue; - promoteInput[inputIdx] = true; - needsRewrite = true; - } - if (!needsRewrite) - continue; - - rewriter.setInsertionPointAfter(compute); - - SmallVector newWeights(compute.getWeights().begin(), compute.getWeights().end()); - SmallVector newInputs; - SmallVector newInputTypes; - SmallVector newInputLocs; - newWeights.reserve(compute.getWeights().size() + compute.getInputs().size()); - newInputs.reserve(compute.getInputs().size()); - newInputTypes.reserve(compute.getInputs().size()); - newInputLocs.reserve(compute.getInputs().size()); - - for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) { - if (promoteInput[inputIdx]) { - newWeights.push_back(input); - continue; - } - newInputs.push_back(input); - newInputTypes.push_back(input.getType()); - newInputLocs.push_back(input.getLoc()); - } - - auto newCompute = - spatial::SpatCompute::create(rewriter, compute.getLoc(), compute.getResultTypes(), newWeights, newInputs); - auto* newBlock = - rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), newInputTypes, newInputLocs); - newCompute.getProperties().setOperandSegmentSizes( - {static_cast(newWeights.size()), static_cast(newInputs.size())}); - rewriter.setInsertionPointToStart(newBlock); - - IRMapping mapper; - auto& oldBlock = compute.getBody().front(); - size_t newInputIdx = 0; - for (auto [oldInputIdx, oldArg] : llvm::enumerate(oldBlock.getArguments())) { - if (!promoteInput[oldInputIdx]) { - mapper.map(oldArg, newBlock->getArgument(newInputIdx++)); - continue; - } - - auto clonedValue = materializeWeightLikeValueInBlock(compute.getInputs()[oldInputIdx], rewriter, mapper); - if (failed(clonedValue)) - return compute.emitError("failed to materialize promoted weight-like operand inside compute body"); - mapper.map(oldArg, *clonedValue); - } - - for (auto& op : oldBlock.without_terminator()) - rewriter.clone(op, mapper); - - auto oldYield = cast(oldBlock.getTerminator()); - SmallVector newYieldOperands; - newYieldOperands.reserve(oldYield.getOutputs().size()); - for (Value operand : oldYield.getOutputs()) { - auto mapped = mapper.lookupOrNull(operand); - newYieldOperands.push_back(mapped ? cast(mapped) : operand); - } - spatial::SpatYieldOp::create(rewriter, oldYield.getLoc(), newYieldOperands); - - compute.replaceAllUsesWith(newCompute); - compute.erase(); - } - - return success(); -} - -void ONNXToSpatialPass::populateEmptyFunction(func::FuncOp funcOp) { - IRRewriter rewriter(&getContext()); - IRMapping mapper; - SmallVector computes(funcOp.getOps()); - if (!computes.empty()) - return; - auto returnOp = llvm::cast(funcOp.getRegion().front().getTerminator()); - rewriter.setInsertionPoint(returnOp); - - SmallVector sourceTypes; - SmallVector sourceLoc; - for (auto source : funcOp.getArguments()) { - sourceTypes.push_back(source.getType()); - sourceLoc.push_back(source.getLoc()); - } - - auto newCompute = spatial::SpatCompute::create( - rewriter, returnOp.getLoc(), returnOp.getOperandTypes(), funcOp.getArguments(), {}, {}); - auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLoc); - for (auto [bbArg, computeArg] : llvm::zip(BB->getArguments(), newCompute.getOperands())) - mapper.map(computeArg, bbArg); - newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) sourceTypes.size()}); - rewriter.setInsertionPointToEnd(BB); - for (Operation& inst : funcOp.getOps()) - if (!isa(&inst)) - rewriter.clone(inst, mapper); - - auto yield = spatial::SpatYieldOp::create(rewriter, funcOp.getLoc(), returnOp.getOperands()); - for (size_t i = 0; i < yield.getNumOperands(); ++i) - yield.setOperand(i, mapper.lookupOrDefault(yield.getOperand(i))); - - for (Operation& inst : llvm::make_early_inc_range(funcOp.getOps())) - if (!isa(&inst)){ - inst.dropAllUses(); - rewriter.eraseOp(&inst); - } - - for (auto [index, computeResult] : llvm::enumerate(newCompute.getResults())) - returnOp.setOperand(index, computeResult); -} - std::unique_ptr createONNXToSpatialPass() { return std::make_unique(); } } // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Elementwise.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Elementwise.cpp index d3ac767..44c2fc1 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Elementwise.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Elementwise.cpp @@ -5,9 +5,9 @@ #include "llvm/ADT/SmallVector.h" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp index 77052ba..b4a2e5e 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp @@ -11,6 +11,7 @@ #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" @@ -49,6 +50,45 @@ materializeScaledConstantTensor(Value value, float factor, ConversionPatternRewr return arith::ConstantOp::create(rewriter, loc, denseAttr.getType(), scaledAttr).getResult(); } +static Value transposeForSpatial(Value value, + RankedTensorType resultType, + ArrayRef permutation, + ConversionPatternRewriter& rewriter, + Location loc) { + if (isHostFoldableValue(value)) + return ONNXTransposeOp::create(rewriter, loc, resultType, value, rewriter.getI64ArrayAttr(permutation)); + + auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {resultType}, {}, value, [&](Value input) { + Value transposed = ONNXTransposeOp::create(rewriter, loc, resultType, input, rewriter.getI64ArrayAttr(permutation)); + spatial::SpatYieldOp::create(rewriter, loc, transposed); + }); + return computeOp.getResult(0); +} + +static Value +expandRankOneBias(Value value, RankedTensorType resultType, ConversionPatternRewriter& rewriter, Location loc) { + if (isHostFoldableValue(value)) + return tensor::ExpandShapeOp::create(rewriter, + loc, + resultType, + value, + SmallVector { + {0, 1} + }); + + auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {resultType}, {}, value, [&](Value input) { + Value expanded = tensor::ExpandShapeOp::create(rewriter, + loc, + resultType, + input, + SmallVector { + {0, 1} + }); + spatial::SpatYieldOp::create(rewriter, loc, expanded); + }); + return computeOp.getResult(0); +} + struct GemmToManyGemv : OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -81,6 +121,11 @@ static SmallVector materializeBatchRowSlices(Value matrix, auto rowType = RankedTensorType::get({1, matrixType.getDimSize(1)}, matrixType.getElementType()); SmallVector resultTypes(static_cast(numRows), rowType); + if (isHostFoldableValue(matrix)) { + auto extractRowsOp = spatial::SpatExtractRowsOp::create(rewriter, loc, TypeRange(resultTypes), matrix); + return SmallVector(extractRowsOp->result_begin(), extractRowsOp->result_end()); + } + auto buildRowSlices = [&](Value matrixArg) { auto extractRowsOp = spatial::SpatExtractRowsOp::create(rewriter, loc, TypeRange(resultTypes), matrixArg); return SmallVector(extractRowsOp->result_begin(), extractRowsOp->result_end()); @@ -122,7 +167,8 @@ static SmallVector materializeBatchRowSlices(Value matrix, rootValue = definingOp->getOperand(0); } - return buildRowSlices(matrix); + SmallVector reversedChainOps(chainOps.rbegin(), chainOps.rend()); + return cloneBatchInputChainIntoSliceCompute(rootValue, reversedChainOps, rootValue); } } // namespace @@ -175,13 +221,7 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp, // Expand rank-1 bias [N] to rank-2 [1, N] for uniform handling if (cType.getRank() == 1) { auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType()); - c = tensor::ExpandShapeOp::create(rewriter, - loc, - expandedType, - c, - SmallVector { - {0, 1} - }); + c = expandRankOneBias(c, expandedType, rewriter, loc); cType = expandedType; } if (!cType.hasStaticShape()) { @@ -196,25 +236,18 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp, } auto outRowType = RankedTensorType::get({1, outType.getDimSize(1)}, outType.getElementType()); + SmallVector aSlices = materializeBatchRowSlices(a, aType, rewriter, loc); + SmallVector cSlices; + if (hasC && cHasNumOutRows) + cSlices = materializeBatchRowSlices(c, cType, rewriter, loc); SmallVector gemvOps; - gemvOps.reserve(numOutRows); + gemvOps.reserve(static_cast(numOutRows)); for (int64_t rowIdx = 0; rowIdx < numOutRows; rowIdx++) { - SmallVector offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)}; - SmallVector sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(aType.getDimSize(1))}; - SmallVector strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; - auto aSliceType = RankedTensorType::get({1, aType.getDimSize(1)}, aType.getElementType()); - auto aSlice = tensor::ExtractSliceOp::create(rewriter, loc, aSliceType, a, offsets, sizes, strides).getResult(); - Value cSlice = c; if (hasC) { - if (cHasNumOutRows) { - SmallVector offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)}; - SmallVector sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(cType.getDimSize(1))}; - SmallVector strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; - auto cSliceType = RankedTensorType::get({1, cType.getDimSize(1)}, cType.getElementType()); - cSlice = tensor::ExtractSliceOp::create(rewriter, loc, cSliceType, c, offsets, sizes, strides).getResult(); - } + if (cHasNumOutRows) + cSlice = cSlices[static_cast(rowIdx)]; else if (!isVectorShape(getTensorShape(c))) { gemmOp.emitOpError("requires Gemm bias C to be vector-like when shared across decomposed rows"); return failure(); @@ -224,7 +257,7 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp, auto gemvOp = ONNXGemmOp::create(rewriter, loc, outRowType, - aSlice, + aSlices[static_cast(rowIdx)], b, cSlice, rewriter.getF32FloatAttr(1.0f), @@ -267,13 +300,7 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp, // Expand rank-1 bias [N] to rank-2 [1, N] for uniform handling if (cType.getRank() == 1) { auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType()); - c = tensor::ExpandShapeOp::create(rewriter, - gemmLoc, - expandedType, - c, - SmallVector { - {0, 1} - }); + c = expandRankOneBias(c, expandedType, rewriter, gemmLoc); cType = expandedType; } if (!cType.hasStaticShape()) { @@ -305,13 +332,14 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp, if (transA) { auto aShape = aType.getShape(); - auto transposedType = aType.cloneWith(ArrayRef({aShape[1], aShape[0]}), aType.getElementType()); - a = ONNXTransposeOp::create(rewriter, gemmLoc, transposedType, a, rewriter.getI64ArrayAttr({1, 0})); + auto transposedType = RankedTensorType::get({aShape[1], aShape[0]}, aType.getElementType()); + a = transposeForSpatial(a, transposedType, {1, 0}, rewriter, gemmLoc); + aType = cast(a.getType()); } if (transB) { auto bShape = bType.getShape(); - auto transposedType = bType.cloneWith(ArrayRef({bShape[1], bShape[0]}), bType.getElementType()); - b = ONNXTransposeOp::create(rewriter, gemmLoc, transposedType, b, rewriter.getI64ArrayAttr({1, 0})); + auto transposedType = RankedTensorType::get({bShape[1], bShape[0]}, bType.getElementType()); + b = transposeForSpatial(b, transposedType, {1, 0}, rewriter, gemmLoc); bType = cast(b.getType()); } @@ -335,7 +363,6 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp, auto [aNumHSlices, aLastHSliceSize] = ceilIntegerDivideWithRemainder(aType.getDimSize(1), crossbarSize.getValue()); auto [bNumHSlices, bLastHSliceSize] = ceilIntegerDivideWithRemainder(bType.getDimSize(1), crossbarSize.getValue()); auto bNumVSlices = aNumHSlices; - auto bLastVSliceSize = aLastHSliceSize; auto cNumHSlices = bNumHSlices; auto cLastHSliceSize = bLastHSliceSize; auto outNumHSlices = cNumHSlices; @@ -469,12 +496,15 @@ LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp, if (gemmOpAdaptor.getTransB()) { auto bShape = bType.getShape(); - auto transposedType = bType.cloneWith(ArrayRef({bShape[1], bShape[0]}), bType.getElementType()); - b = ONNXTransposeOp::create(rewriter, loc, transposedType, b, rewriter.getI64ArrayAttr({1, 0})); + auto transposedType = RankedTensorType::get({bShape[1], bShape[0]}, bType.getElementType()); + b = transposeForSpatial(b, transposedType, {1, 0}, rewriter, loc); bType = cast(b.getType()); } (void) bType; + if (!isHostFoldableValue(b)) + return failure(); + Value sharedBias; if (hasC) { auto scaledC = materializeScaledConstantTensor(c, gemmOpAdaptor.getBeta().convertToFloat(), rewriter, loc); @@ -484,13 +514,7 @@ LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp, auto cType = cast(c.getType()); if (cType.getRank() == 1) { auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType()); - c = tensor::ExpandShapeOp::create(rewriter, - loc, - expandedType, - c, - SmallVector { - {0, 1} - }); + c = expandRankOneBias(c, expandedType, rewriter, loc); cType = cast(c.getType()); } if (!cType.hasStaticShape()) { diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp index a82adcc..5dd9a2d 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp @@ -2,11 +2,11 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/PatternMatch.h" -#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" @@ -36,49 +36,27 @@ static Value extractBatchMatrix(Value value, SmallVector sizes = { rewriter.getIndexAttr(1), rewriter.getIndexAttr(rows), rewriter.getIndexAttr(cols)}; SmallVector strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; - Value slice = tensor::ExtractSliceOp::create(rewriter, loc, sliceType, value, offsets, sizes, strides); - auto matrixType = RankedTensorType::get({rows, cols}, type.getElementType()); - return tensor::CollapseShapeOp::create(rewriter, - loc, - matrixType, - slice, - SmallVector { - {0, 1}, - {2} - }); -} + auto buildMatrix = [&](Value input) -> Value { + Value slice = tensor::ExtractSliceOp::create(rewriter, loc, sliceType, input, offsets, sizes, strides); + return tensor::CollapseShapeOp::create(rewriter, + loc, + matrixType, + slice, + SmallVector { + {0, 1}, + {2} + }); + }; -static bool isConstantLikeOperand(Value value) { - llvm::SmallPtrSet visited; + if (isHostFoldableValue(value)) + return buildMatrix(value); - while (auto* definingOp = value.getDefiningOp()) { - if (!visited.insert(definingOp).second) - return false; - if (definingOp->hasTrait()) - return true; - - if (auto extractSliceOp = dyn_cast(definingOp)) { - value = extractSliceOp.getSource(); - continue; - } - if (auto expandShapeOp = dyn_cast(definingOp)) { - value = expandShapeOp.getSrc(); - continue; - } - if (auto collapseShapeOp = dyn_cast(definingOp)) { - value = collapseShapeOp.getSrc(); - continue; - } - if (auto transposeOp = dyn_cast(definingOp)) { - value = transposeOp.getData(); - continue; - } - - return false; - } - - return false; + auto batchMatrixCompute = + createSpatCompute<1>(rewriter, loc, TypeRange {matrixType}, {}, ValueRange {value}, [&](Value input) { + spatial::SpatYieldOp::create(rewriter, loc, buildMatrix(input)); + }); + return batchMatrixCompute.getResult(0); } static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Location loc) { @@ -107,15 +85,31 @@ static Value transposeLastTwoDimsInCompute(Value value, PatternRewriter& rewrite perm = {0, 2, 1}; } - auto transposeCompute = - createSpatCompute<1>(rewriter, loc, transposedType, {}, ValueRange {value}, [&](Value input) { - Value transposed = - ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm)); - spatial::SpatYieldOp::create(rewriter, loc, transposed); - }); + auto transposeCompute = createSpatCompute<1>(rewriter, loc, transposedType, {}, ValueRange {value}, [&](Value input) { + Value transposed = ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm)); + spatial::SpatYieldOp::create(rewriter, loc, transposed); + }); return transposeCompute.getResult(0); } +static Value concatValues(ValueRange inputs, int64_t axis, PatternRewriter& rewriter, Location loc) { + auto firstType = cast(inputs.front().getType()); + SmallVector outputShape(firstType.getShape().begin(), firstType.getShape().end()); + int64_t concatDimSize = 0; + for (Value input : inputs) + concatDimSize += cast(input.getType()).getDimSize(axis); + outputShape[axis] = concatDimSize; + auto resultType = RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding()); + + if (llvm::all_of(inputs, isHostFoldableValue)) + return createSpatConcat(rewriter, loc, axis, inputs); + + auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {resultType}, {}, inputs, [&](ValueRange args) { + spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, axis, args)); + }); + return concatCompute.getResult(0); +} + struct MatMulToGemm : OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -157,7 +151,7 @@ struct MatMulToGemm : OpRewritePattern { } Location loc = matmulOp.getLoc(); - bool useTransposedForm = isConstantLikeOperand(matmulOp.getA()) && !isConstantLikeOperand(matmulOp.getB()); + bool useTransposedForm = isHostFoldableValue(matmulOp.getA()) && !isHostFoldableValue(matmulOp.getB()); Value lhs = matmulOp.getA(); Value rhs = matmulOp.getB(); @@ -193,8 +187,14 @@ struct MatMulToGemm : OpRewritePattern { rewriter.getBoolAttr(false), rewriter.getBoolAttr(false)) .getY(); - if (useTransposedForm) - gemmResult = ONNXTransposeOp::create(rewriter, loc, outType, gemmResult, rewriter.getI64ArrayAttr({1, 0})); + if (useTransposedForm) { + auto transposeCompute = + createSpatCompute<1>(rewriter, loc, TypeRange {outType}, {}, gemmResult, [&](Value input) { + Value transposed = ONNXTransposeOp::create(rewriter, loc, outType, input, rewriter.getI64ArrayAttr({1, 0})); + spatial::SpatYieldOp::create(rewriter, loc, transposed); + }); + gemmResult = transposeCompute.getResult(0); + } rewriter.replaceOp(matmulOp, gemmResult); return success(); } @@ -215,24 +215,30 @@ struct MatMulToGemm : OpRewritePattern { rewriter.getBoolAttr(false), rewriter.getBoolAttr(false)) .getY(); - if (useTransposedForm) - gemmResult = ONNXTransposeOp::create( - rewriter, - loc, - RankedTensorType::get({m, n}, outType.getElementType()), - gemmResult, - rewriter.getI64ArrayAttr({1, 0})); - batchResults.push_back(tensor::ExpandShapeOp::create(rewriter, - loc, - batchedOutType, - gemmResult, - SmallVector { - {0, 1}, - {2} - })); + auto batchResultCompute = + createSpatCompute<1>(rewriter, loc, TypeRange {batchedOutType}, {}, gemmResult, [&](Value input) { + Value resultMatrix = input; + if (useTransposedForm) { + resultMatrix = ONNXTransposeOp::create(rewriter, + loc, + RankedTensorType::get({m, n}, outType.getElementType()), + input, + rewriter.getI64ArrayAttr({1, 0})); + } + Value expanded = tensor::ExpandShapeOp::create(rewriter, + loc, + batchedOutType, + resultMatrix, + SmallVector { + {0, 1}, + {2} + }); + spatial::SpatYieldOp::create(rewriter, loc, expanded); + }); + batchResults.push_back(batchResultCompute.getResult(0)); } - Value result = createSpatConcat(rewriter, loc, /*axis=*/0, batchResults); + Value result = concatValues(batchResults, /*axis=*/0, rewriter, loc); rewriter.replaceOp(matmulOp, result); return success(); } diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/ReduceMean.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/ReduceMean.cpp index 252ee9d..fa96e2e 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/ReduceMean.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/ReduceMean.cpp @@ -6,7 +6,8 @@ #include #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" @@ -81,6 +82,24 @@ createAverageCompute(Value input, RankedTensorType resultType, ConversionPattern return computeOp.getResult(0); } +static Value concatValues(ValueRange inputs, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) { + auto firstType = cast(inputs.front().getType()); + SmallVector outputShape(firstType.getShape().begin(), firstType.getShape().end()); + int64_t concatDimSize = 0; + for (Value input : inputs) + concatDimSize += cast(input.getType()).getDimSize(axis); + outputShape[axis] = concatDimSize; + auto resultType = RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding()); + + if (llvm::all_of(inputs, isHostFoldableValue)) + return createSpatConcat(rewriter, loc, axis, inputs); + + auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {resultType}, {}, inputs, [&](ValueRange args) { + spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, axis, args)); + }); + return concatCompute.getResult(0); +} + static Value buildReduceMeanKeepdims(Value input, ArrayRef reducedAxes, int64_t axis, @@ -100,7 +119,7 @@ static Value buildReduceMeanKeepdims(Value input, for (Value slice : slices) reducedSlices.push_back(buildReduceMeanKeepdims(slice, reducedAxes, axis + 1, leafType, rewriter, loc)); - return createSpatConcat(rewriter, loc, axis, reducedSlices); + return concatValues(reducedSlices, axis, rewriter, loc); } static Value squeezeReducedAxes(Value keepdimsValue, @@ -115,9 +134,16 @@ static Value squeezeReducedAxes(Value keepdimsValue, return tensor::FromElementsOp::create(rewriter, loc, resultType, ValueRange {element}); } - return tensor::CollapseShapeOp::create( - rewriter, loc, resultType, keepdimsValue, buildCollapseReassociation(reducedAxes)) - .getResult(); + auto reassociation = buildCollapseReassociation(reducedAxes); + if (isHostFoldableValue(keepdimsValue)) + return tensor::CollapseShapeOp::create(rewriter, loc, resultType, keepdimsValue, reassociation).getResult(); + + auto squeezeCompute = + createSpatCompute<1>(rewriter, loc, TypeRange {resultType}, {}, ValueRange {keepdimsValue}, [&](Value input) { + Value collapsed = tensor::CollapseShapeOp::create(rewriter, loc, resultType, input, reassociation); + spatial::SpatYieldOp::create(rewriter, loc, collapsed); + }); + return squeezeCompute.getResult(0); } struct ReduceMeanToSpatialCompute : OpConversionPattern { diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Pool.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Pool.cpp index f74b7a0..dbee081 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Pool.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Pool.cpp @@ -31,8 +31,8 @@ static int64_t getOptionalI64(std::optional arrayAttr, size_t index, } template -static FailureOr -concatAlongAxis(ConversionPatternRewriter& rewriter, Location loc, PoolOp poolOp, int64_t axis, ArrayRef values) { +static FailureOr concatAlongAxis( + ConversionPatternRewriter& rewriter, Location loc, PoolOp poolOp, int64_t axis, ArrayRef values) { if (values.empty()) { poolOp.emitOpError("failed to build pooled output because an intermediate concatenation input list was empty"); return failure(); @@ -68,8 +68,8 @@ reduceWindowValues(ConversionPatternRewriter& rewriter, Location loc, Operation* return reduced; } -static FailureOr -scaleAverageWindow(ConversionPatternRewriter& rewriter, Location loc, Operation* op, Value reducedWindow, int64_t divisor) { +static FailureOr scaleAverageWindow( + ConversionPatternRewriter& rewriter, Location loc, Operation* op, Value reducedWindow, int64_t divisor) { if (divisor <= 0) { op->emitOpError("AveragePool divisor must be positive"); return failure(); diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Softmax.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Softmax.cpp index 1e58276..278ddc1 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Softmax.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Softmax.cpp @@ -2,7 +2,8 @@ #include "mlir/Transforms/DialectConversion.h" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" @@ -32,6 +33,24 @@ static Value createSoftmaxCompute(Value input, ConversionPatternRewriter& rewrit return computeOp.getResult(0); } +static Value concatValues(ValueRange inputs, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) { + auto firstType = cast(inputs.front().getType()); + SmallVector outputShape(firstType.getShape().begin(), firstType.getShape().end()); + int64_t concatDimSize = 0; + for (Value input : inputs) + concatDimSize += cast(input.getType()).getDimSize(axis); + outputShape[axis] = concatDimSize; + auto resultType = RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding()); + + if (llvm::all_of(inputs, isHostFoldableValue)) + return createSpatConcat(rewriter, loc, axis, inputs); + + auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {resultType}, {}, inputs, [&](ValueRange args) { + spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, axis, args)); + }); + return concatCompute.getResult(0); +} + static Value buildSoftmax(Value input, int64_t softmaxAxis, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) { auto inputType = cast(input.getType()); @@ -47,7 +66,7 @@ buildSoftmax(Value input, int64_t softmaxAxis, int64_t axis, ConversionPatternRe for (Value slice : slices) rebuiltSlices.push_back(buildSoftmax(slice, softmaxAxis, axis + 1, rewriter, loc)); - return createSpatConcat(rewriter, loc, axis, rebuiltSlices); + return concatValues(rebuiltSlices, axis, rewriter, loc); } struct SoftmaxToSpatialCompute : OpConversionPattern { @@ -92,8 +111,13 @@ struct SoftmaxToSpatialCompute : OpConversionPattern { Value transposedInput = preTransposeCompute.getResult(0); Value transposedResult = buildSoftmax( transposedInput, /*softmaxAxis=*/inputType.getRank() - 1, /*axis=*/0, rewriter, softmaxOp.getLoc()); - result = ONNXTransposeOp::create( - rewriter, softmaxOp.getLoc(), inputType, transposedResult, rewriter.getI64ArrayAttr(inversePermutation)); + auto postTransposeCompute = + createSpatCompute<1>(rewriter, softmaxOp.getLoc(), TypeRange {inputType}, {}, transposedResult, [&](Value x) { + Value transposed = ONNXTransposeOp::create( + rewriter, softmaxOp.getLoc(), inputType, x, rewriter.getI64ArrayAttr(inversePermutation)); + spatial::SpatYieldOp::create(rewriter, softmaxOp.getLoc(), transposed); + }); + result = postTransposeCompute.getResult(0); } rewriter.replaceOp(softmaxOp, result); diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Concat.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Concat.cpp index 64f8805..4d616e8 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Concat.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Concat.cpp @@ -2,6 +2,8 @@ #include "mlir/IR/PatternMatch.h" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/ComputeRegionBuilder.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" @@ -18,7 +20,17 @@ struct Concat : public OpConversionPattern { auto inputs = adaptor.getInputs(); int64_t axis = adaptor.getAxis(); - rewriter.replaceOp(maxpoolOp, createSpatConcat(rewriter, maxpoolOp.getLoc(), axis, inputs)); + if (llvm::all_of(inputs, isHostFoldableValue)) { + rewriter.replaceOp(maxpoolOp, createSpatConcat(rewriter, maxpoolOp.getLoc(), axis, inputs)); + return success(); + } + + auto computeOp = createSpatCompute( + rewriter, maxpoolOp.getLoc(), TypeRange {maxpoolOp.getResult().getType()}, {}, inputs, [&](ValueRange args) { + spatial::SpatYieldOp::create( + rewriter, maxpoolOp.getLoc(), createSpatConcat(rewriter, maxpoolOp.getLoc(), axis, args)); + }); + rewriter.replaceOp(maxpoolOp, computeOp.getResults()); return success(); } diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Gather.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Gather.cpp index e388b83..84cef43 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Gather.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Gather.cpp @@ -6,7 +6,7 @@ #include "llvm/ADT/SmallVector.h" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Reshape.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Reshape.cpp index 4499c7c..d670838 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Reshape.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Reshape.cpp @@ -3,7 +3,10 @@ #include "llvm/ADT/SmallVector.h" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" using namespace mlir; @@ -95,18 +98,33 @@ struct Reshape : OpConversionPattern { return success(); } + auto replaceWithReshape = [&](auto buildReshape) -> LogicalResult { + if (isHostFoldableValue(adaptor.getData())) { + rewriter.replaceOp(reshapeOp, buildReshape(adaptor.getData())); + return success(); + } + + auto computeOp = createSpatCompute<1>( + rewriter, reshapeOp.getLoc(), TypeRange {resultType}, {}, adaptor.getData(), [&](Value data) { + Value reshaped = buildReshape(data); + spatial::SpatYieldOp::create(rewriter, reshapeOp.getLoc(), reshaped); + }); + rewriter.replaceOp(reshapeOp, computeOp.getResults()); + return success(); + }; + SmallVector reassociation; if (sourceType.getRank() > resultType.getRank() - && inferCollapseReassociation(sourceType.getShape(), resultType.getShape(), reassociation)) { - rewriter.replaceOpWithNewOp(reshapeOp, resultType, adaptor.getData(), reassociation); - return success(); - } + && inferCollapseReassociation(sourceType.getShape(), resultType.getShape(), reassociation)) + return replaceWithReshape([&](Value data) { + return tensor::CollapseShapeOp::create(rewriter, reshapeOp.getLoc(), resultType, data, reassociation); + }); if (sourceType.getRank() < resultType.getRank() - && inferExpandReassociation(sourceType.getShape(), resultType.getShape(), reassociation)) { - rewriter.replaceOpWithNewOp(reshapeOp, resultType, adaptor.getData(), reassociation); - return success(); - } + && inferExpandReassociation(sourceType.getShape(), resultType.getShape(), reassociation)) + return replaceWithReshape([&](Value data) { + return tensor::ExpandShapeOp::create(rewriter, reshapeOp.getLoc(), resultType, data, reassociation); + }); return failure(); } diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Resize.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Resize.cpp index 39eebe2..d3977d8 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Resize.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Resize.cpp @@ -6,7 +6,7 @@ #include #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Split.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Split.cpp index 1e6c93b..9a3e6b2 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Split.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Split.cpp @@ -2,7 +2,9 @@ #include "mlir/Transforms/DialectConversion.h" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" using namespace mlir; @@ -47,16 +49,40 @@ struct Split : OpConversionPattern { outputs.reserve(splitOp.getNumResults()); int64_t offset = 0; + SmallVector resultTypes; + resultTypes.reserve(splitOp.getNumResults()); + SmallVector sliceSizes; + sliceSizes.reserve(splitOp.getNumResults()); for (Value result : splitOp.getResults()) { auto resultType = dyn_cast(result.getType()); if (!resultType || !resultType.hasStaticShape()) return failure(); - int64_t sliceSize = resultType.getShape()[axis]; - outputs.push_back(extractSliceAt(adaptor.getInput(), axis, offset, sliceSize, rewriter, splitOp.getLoc())); - offset += sliceSize; + resultTypes.push_back(resultType); + sliceSizes.push_back(resultType.getShape()[axis]); } - rewriter.replaceOp(splitOp, outputs); + if (isHostFoldableValue(adaptor.getInput())) { + for (int64_t sliceSize : sliceSizes) { + outputs.push_back(extractSliceAt(adaptor.getInput(), axis, offset, sliceSize, rewriter, splitOp.getLoc())); + offset += sliceSize; + } + rewriter.replaceOp(splitOp, outputs); + return success(); + } + + auto computeOp = createSpatCompute<1>( + rewriter, splitOp.getLoc(), TypeRange(splitOp.getResultTypes()), {}, adaptor.getInput(), [&](Value input) { + SmallVector runtimeOutputs; + runtimeOutputs.reserve(resultTypes.size()); + int64_t runtimeOffset = 0; + for (int64_t sliceSize : sliceSizes) { + runtimeOutputs.push_back(extractSliceAt(input, axis, runtimeOffset, sliceSize, rewriter, splitOp.getLoc())); + runtimeOffset += sliceSize; + } + spatial::SpatYieldOp::create(rewriter, splitOp.getLoc(), runtimeOutputs); + }); + + rewriter.replaceOp(splitOp, computeOp.getResults()); return success(); } }; diff --git a/src/PIM/Conversion/ONNXToSpatial/PostPatterns.cpp b/src/PIM/Conversion/ONNXToSpatial/PostPatterns.cpp new file mode 100644 index 0000000..4a3861f --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/PostPatterns.cpp @@ -0,0 +1,265 @@ +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/PatternMatch.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" + +#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/WeightMaterialization.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" + +using namespace mlir; + +namespace onnx_mlir { + +namespace { + +static bool isWeightMaterializationHelperUser(Operation* op) { + return isa(op); +} + +static bool canPromoteInputBlockArgument(BlockArgument arg) { + return !arg.use_empty() && llvm::all_of(arg.getUsers(), isWeightMaterializationHelperUser); +} + +static bool isDirectConstantValue(Value value) { + return isa_and_nonnull(value.getDefiningOp()); +} + +// Collapses one-lane batches so later phases do not carry batch-only structure unnecessarily. +struct FoldSingleLaneComputeBatchPattern : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(spatial::SpatComputeBatch batchOp, PatternRewriter& rewriter) const override { + if (batchOp.getLaneCount() != 1) + return rewriter.notifyMatchFailure(batchOp, "requires a single lane"); + + auto loc = batchOp.getLoc(); + rewriter.setInsertionPoint(batchOp); + auto computeOp = + spatial::SpatCompute::create(rewriter, loc, batchOp.getResultTypes(), batchOp.getWeights(), batchOp.getInputs()); + computeOp.getProperties().setOperandSegmentSizes( + {static_cast(batchOp.getWeights().size()), static_cast(batchOp.getInputs().size())}); + + Block& templateBlock = batchOp.getBody().front(); + SmallVector blockArgTypes; + SmallVector blockArgLocs; + blockArgTypes.reserve(templateBlock.getNumArguments()); + blockArgLocs.reserve(templateBlock.getNumArguments()); + for (BlockArgument arg : templateBlock.getArguments()) { + blockArgTypes.push_back(arg.getType()); + blockArgLocs.push_back(loc); + } + + auto* newBlock = + rewriter.createBlock(&computeOp.getBody(), computeOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs); + IRMapping mapper; + for (auto [oldArg, newArg] : llvm::zip(templateBlock.getArguments(), newBlock->getArguments())) + mapper.map(oldArg, newArg); + + rewriter.setInsertionPointToEnd(newBlock); + for (Operation& op : templateBlock) + rewriter.clone(op, mapper); + + batchOp->replaceAllUsesWith(computeOp->getResults()); + rewriter.eraseOp(batchOp); + return success(); + } +}; + +// Promotes foldable helper chains from runtime inputs to weights to avoid artificial compute inputs. +struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(spatial::SpatCompute compute, PatternRewriter& rewriter) const override { + SmallVector promoteInput(compute.getInputs().size(), false); + bool needsRewrite = false; + Block& oldBlock = compute.getBody().front(); + for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) { + if (inputIdx >= oldBlock.getNumArguments()) + continue; + if (!isWeightLikeComputeOperand(input)) + continue; + if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(oldBlock.getArgument(inputIdx))) + continue; + promoteInput[inputIdx] = true; + needsRewrite = true; + } + if (!needsRewrite) + return rewriter.notifyMatchFailure(compute, "no weight-like inputs to promote"); + + rewriter.setInsertionPointAfter(compute); + + SmallVector newWeights(compute.getWeights().begin(), compute.getWeights().end()); + SmallVector newInputs; + SmallVector newInputTypes; + SmallVector newInputLocs; + newWeights.reserve(compute.getWeights().size() + compute.getInputs().size()); + newInputs.reserve(compute.getInputs().size()); + newInputTypes.reserve(compute.getInputs().size()); + newInputLocs.reserve(compute.getInputs().size()); + + for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) { + if (promoteInput[inputIdx]) { + newWeights.push_back(input); + continue; + } + newInputs.push_back(input); + newInputTypes.push_back(input.getType()); + newInputLocs.push_back(input.getLoc()); + } + + auto newCompute = + spatial::SpatCompute::create(rewriter, compute.getLoc(), compute.getResultTypes(), newWeights, newInputs); + auto* newBlock = + rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), newInputTypes, newInputLocs); + newCompute.getProperties().setOperandSegmentSizes( + {static_cast(newWeights.size()), static_cast(newInputs.size())}); + rewriter.setInsertionPointToStart(newBlock); + + IRRewriter bodyRewriter(rewriter.getContext()); + bodyRewriter.setInsertionPointToStart(newBlock); + + IRMapping mapper; + size_t newInputIdx = 0; + for (auto [oldInputIdx, oldArg] : llvm::enumerate(oldBlock.getArguments())) { + if (!promoteInput[oldInputIdx]) { + mapper.map(oldArg, newBlock->getArgument(newInputIdx++)); + continue; + } + + auto clonedValue = materializeWeightLikeValueInBlock(compute.getInputs()[oldInputIdx], bodyRewriter, mapper); + if (failed(clonedValue)) + return rewriter.notifyMatchFailure(compute, "failed to materialize promoted weight-like operand"); + mapper.map(oldArg, *clonedValue); + } + + for (Operation& op : oldBlock.without_terminator()) + rewriter.clone(op, mapper); + + auto oldYield = cast(oldBlock.getTerminator()); + SmallVector newYieldOperands; + newYieldOperands.reserve(oldYield.getOutputs().size()); + for (Value operand : oldYield.getOutputs()) { + auto mapped = mapper.lookupOrNull(operand); + newYieldOperands.push_back(mapped ? cast(mapped) : operand); + } + spatial::SpatYieldOp::create(rewriter, oldYield.getLoc(), newYieldOperands); + + rewriter.replaceOp(compute, newCompute.getResults()); + return success(); + } +}; + +// Promotes foldable batch helper chains to weights while preserving compact compute_batch IR. +struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(spatial::SpatComputeBatch compute, PatternRewriter& rewriter) const override { + SmallVector promoteInput(compute.getInputs().size(), false); + bool needsRewrite = false; + Block& oldBlock = compute.getBody().front(); + for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) { + if (inputIdx >= oldBlock.getNumArguments()) + continue; + if (!isWeightLikeComputeOperand(input)) + continue; + if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(oldBlock.getArgument(inputIdx))) + continue; + promoteInput[inputIdx] = true; + needsRewrite = true; + } + if (!needsRewrite) + return rewriter.notifyMatchFailure(compute, "no weight-like batch inputs to promote"); + + rewriter.setInsertionPointAfter(compute); + + SmallVector newWeights(compute.getWeights().begin(), compute.getWeights().end()); + SmallVector newInputs; + SmallVector newInputTypes; + SmallVector newInputLocs; + newWeights.reserve(compute.getWeights().size() + compute.getInputs().size()); + newInputs.reserve(compute.getInputs().size()); + newInputTypes.reserve(compute.getInputs().size()); + newInputLocs.reserve(compute.getInputs().size()); + + for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) { + if (promoteInput[inputIdx]) { + newWeights.push_back(input); + continue; + } + newInputs.push_back(input); + newInputTypes.push_back(input.getType()); + newInputLocs.push_back(input.getLoc()); + } + + auto newCompute = + spatial::SpatComputeBatch::create(rewriter, + compute.getLoc(), + compute.getResultTypes(), + rewriter.getI32IntegerAttr(static_cast(compute.getLaneCount())), + newWeights, + newInputs); + auto* newBlock = + rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), newInputTypes, newInputLocs); + newCompute.getProperties().setOperandSegmentSizes( + {static_cast(newWeights.size()), static_cast(newInputs.size())}); + rewriter.setInsertionPointToStart(newBlock); + + IRRewriter bodyRewriter(rewriter.getContext()); + bodyRewriter.setInsertionPointToStart(newBlock); + + IRMapping mapper; + size_t newInputIdx = 0; + for (auto [oldInputIdx, oldArg] : llvm::enumerate(oldBlock.getArguments())) { + if (!promoteInput[oldInputIdx]) { + mapper.map(oldArg, newBlock->getArgument(newInputIdx++)); + continue; + } + + auto clonedValue = materializeWeightLikeValueInBlock(compute.getInputs()[oldInputIdx], bodyRewriter, mapper); + if (failed(clonedValue)) + return rewriter.notifyMatchFailure(compute, "failed to materialize promoted batch weight-like operand"); + mapper.map(oldArg, *clonedValue); + } + + for (Operation& op : oldBlock.without_terminator()) + rewriter.clone(op, mapper); + + auto oldYield = cast(oldBlock.getTerminator()); + SmallVector newYieldOperands; + newYieldOperands.reserve(oldYield.getOutputs().size()); + for (Value operand : oldYield.getOutputs()) { + auto mapped = mapper.lookupOrNull(operand); + newYieldOperands.push_back(mapped ? cast(mapped) : operand); + } + spatial::SpatYieldOp::create(rewriter, oldYield.getLoc(), newYieldOperands); + + rewriter.replaceOp(compute, newCompute.getResults()); + return success(); + } +}; + +} // namespace + +void populateEarlyPostPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { + patterns.add(ctx); +} + +void populatePostPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { + patterns.add(ctx); +} + +void annotateWeightsConstants(func::FuncOp funcOp) { + funcOp.walk([&](arith::ConstantOp constantOp) { + if (hasOnlySpatialMvmVmmWeightUses(constantOp.getResult())) + markWeightAlways(constantOp); + }); +} + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp b/src/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp new file mode 100644 index 0000000..b094373 --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp @@ -0,0 +1,14 @@ +#pragma once + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/MLIRContext.h" + +namespace onnx_mlir { + +void populateEarlyPostPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); + +void populatePostPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); + +void annotateWeightsConstants(mlir::func::FuncOp funcOp); + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/PrePatterns.cpp b/src/PIM/Conversion/ONNXToSpatial/PrePatterns.cpp new file mode 100644 index 0000000..783733e --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/PrePatterns.cpp @@ -0,0 +1,25 @@ +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp" + +using namespace mlir; + +namespace onnx_mlir { + +namespace { + +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc" + +} // namespace + +void populatePrePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx) { + patterns.add(ctx); + patterns.add(ctx); + patterns.add(ctx); + patterns.add(ctx); + patterns.add(ctx); + patterns.add(ctx); + populateMatMulRewritePatterns(patterns, ctx); +} + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp b/src/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp new file mode 100644 index 0000000..47085af --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp @@ -0,0 +1,10 @@ +#pragma once + +#include "mlir/IR/MLIRContext.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace onnx_mlir { + +void populatePrePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp b/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp new file mode 100644 index 0000000..f66fc8b --- /dev/null +++ b/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp @@ -0,0 +1,224 @@ +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/IRMapping.h" + +#include "Conversion/ONNXToSpatial/Common/Common.hpp" +#include "src/Accelerators/PIM/Common/PimCommon.hpp" +#include "src/Accelerators/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.hpp" +#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp" +#include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp" +#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" + +using namespace mlir; +using namespace onnx_mlir::pim; + +namespace onnx_mlir { +namespace { + +static int32_t translateSpatialCoreIdToPimCoreId(size_t spatialCoreId) { return static_cast(spatialCoreId); } + +static SmallVector getPimCoreIdsForBatchOp(spatial::SpatComputeBatch computeBatchOp, size_t& fallbackCoreId) { + if (auto coreIdsAttr = computeBatchOp->getAttrOfType(onnx_mlir::kCoreIdsAttrName)) + return SmallVector(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end()); + + SmallVector coreIds; + coreIds.reserve(static_cast(computeBatchOp.getLaneCount())); + for (uint32_t lane = 0; lane < computeBatchOp.getLaneCount(); ++lane) + coreIds.push_back(static_cast(fallbackCoreId++)); + return coreIds; +} + +static void lowerChannelSendTensorBatch(spatial::SpatChannelSendTensorBatchOp sendTensorBatchOp, + IRMapping& mapper, + IRRewriter& rewriter) { + SmallVector targetCoreIds; + targetCoreIds.reserve(sendTensorBatchOp.getTargetCoreIds().size()); + for (int32_t targetCoreId : sendTensorBatchOp.getTargetCoreIds()) + targetCoreIds.push_back(translateSpatialCoreIdToPimCoreId(targetCoreId)); + + Value input = mapper.lookup(sendTensorBatchOp.getInput()); + if (auto concatOp = input.getDefiningOp()) + if (concatOp.getDim() == 0) + if (Value packedInput = + createPackedExtractSliceTensor(concatOp.getInputs(), rewriter, sendTensorBatchOp.getLoc())) + input = packedInput; + + pim::PimSendTensorBatchOp::create( + rewriter, sendTensorBatchOp.getLoc(), input, rewriter.getDenseI32ArrayAttr(targetCoreIds)); +} + +static void lowerChannelReceiveTensorBatch(spatial::SpatChannelReceiveTensorBatchOp receiveTensorBatchOp, + IRMapping& mapper, + IRRewriter& rewriter) { + SmallVector sourceCoreIds; + sourceCoreIds.reserve(receiveTensorBatchOp.getSourceCoreIds().size()); + for (int32_t sourceCoreId : receiveTensorBatchOp.getSourceCoreIds()) + sourceCoreIds.push_back(translateSpatialCoreIdToPimCoreId(sourceCoreId)); + + auto outputType = cast(receiveTensorBatchOp.getOutput().getType()); + auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveTensorBatchOp.getLoc(), outputType); + Value received = pim::PimReceiveTensorBatchOp::create(rewriter, + receiveTensorBatchOp.getLoc(), + outputBuffer.getType(), + outputBuffer, + rewriter.getDenseI32ArrayAttr(sourceCoreIds)) + .getOutput(); + mapper.map(receiveTensorBatchOp.getOutput(), received); +} + +} // namespace + +LogicalResult +lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState& state, IRRewriter& rewriter) { + if (computeBatchOp.getNumResults() != 0) + return computeBatchOp.emitOpError( + "batched Spatial-to-PIM lowering currently requires channelized compute_batch with no results"); + + Location loc = computeBatchOp.getLoc(); + Block& oldBlock = computeBatchOp.getBody().front(); + auto oldYield = cast(oldBlock.getTerminator()); + if (oldYield.getNumOperands() != 0) + return computeBatchOp.emitOpError("batched Spatial-to-PIM lowering currently requires empty spat.yield"); + + SmallVector coreIds = getPimCoreIdsForBatchOp(computeBatchOp, state.nextCoreId); + SmallVector batchWeights(computeBatchOp.getWeights().begin(), computeBatchOp.getWeights().end()); + SmallVector batchInputs; + if (!computeBatchOp.getInputs().empty()) + batchInputs.append(computeBatchOp.getInputs().begin(), computeBatchOp.getInputs().end()); + + rewriter.setInsertionPointAfter(computeBatchOp); + auto coreBatchOp = pim::PimCoreBatchOp::create(rewriter, + loc, + rewriter.getI32IntegerAttr(computeBatchOp.getLaneCount()), + ValueRange(batchWeights), + ValueRange(batchInputs)); + coreBatchOp.getProperties().setOperandSegmentSizes( + {static_cast(batchWeights.size()), static_cast(batchInputs.size())}); + coreBatchOp->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds)); + + SmallVector blockArgTypes; + SmallVector blockArgLocs; + for (BlockArgument arg : oldBlock.getArguments()) { + blockArgTypes.push_back(arg.getType()); + blockArgLocs.push_back(arg.getLoc()); + } + Block* newBlock = + rewriter.createBlock(&coreBatchOp.getBody(), coreBatchOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs); + + IRMapping mapper; + rewriter.setInsertionPointToStart(newBlock); + for (auto [oldArg, newArg] : llvm::zip(oldBlock.getArguments(), newBlock->getArguments())) { + auto newArgType = cast(newArg.getType()); + auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, newArgType); + auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter, + loc, + outputBuffer.getType(), + outputBuffer, + newArg, + rewriter.getI32IntegerAttr(0), + rewriter.getI32IntegerAttr(0), + getTensorSizeInBytesAttr(rewriter, newArg)) + .getOutput(); + mapper.map(oldArg, copied); + } + + auto materializeCapturedTensor = [&](Value capturedTensor) -> Value { + if (auto mapped = mapper.lookupOrNull(capturedTensor)) + return mapped; + + auto capturedType = cast(capturedTensor.getType()); + auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, capturedType); + auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter, + loc, + outputBuffer.getType(), + outputBuffer, + capturedTensor, + rewriter.getI32IntegerAttr(0), + rewriter.getI32IntegerAttr(0), + getTensorSizeInBytesAttr(rewriter, capturedTensor)) + .getOutput(); + mapper.map(capturedTensor, copied); + return copied; + }; + + rewriter.setInsertionPointToEnd(newBlock); + for (Operation& op : oldBlock) { + if (isa(op)) + continue; + + if (auto sendBatchOp = dyn_cast(op)) { + pim::PimSendBatchOp::create(rewriter, + loc, + mapper.lookup(sendBatchOp.getInput()), + getTensorSizeInBytesAttr(rewriter, mapper.lookup(sendBatchOp.getInput())), + sendBatchOp.getTargetCoreIdsAttr()); + continue; + } + + if (auto sendTensorBatchOp = dyn_cast(op)) { + lowerChannelSendTensorBatch(sendTensorBatchOp, mapper, rewriter); + continue; + } + + if (auto receiveBatchOp = dyn_cast(op)) { + auto outputType = cast(receiveBatchOp.getOutput().getType()); + auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, outputType); + auto received = pim::PimReceiveBatchOp::create(rewriter, + loc, + outputBuffer.getType(), + outputBuffer, + getTensorSizeInBytesAttr(rewriter, receiveBatchOp.getOutput()), + receiveBatchOp.getSourceCoreIdsAttr()) + .getOutput(); + mapper.map(receiveBatchOp.getOutput(), received); + continue; + } + + if (auto receiveTensorBatchOp = dyn_cast(op)) { + lowerChannelReceiveTensorBatch(receiveTensorBatchOp, mapper, rewriter); + continue; + } + + if (auto toTensorOp = dyn_cast(op)) { + if (isa_and_present(toTensorOp.getBuffer().getDefiningOp())) { + Operation* cloned = rewriter.clone(op, mapper); + auto clonedTensor = cloned->getResult(0); + auto clonedType = cast(clonedTensor.getType()); + auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, clonedType); + auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter, + loc, + outputBuffer.getType(), + outputBuffer, + clonedTensor, + rewriter.getI32IntegerAttr(0), + rewriter.getI32IntegerAttr(0), + getTensorSizeInBytesAttr(rewriter, clonedTensor)) + .getOutput(); + mapper.map(toTensorOp.getResult(), copied); + continue; + } + } + + for (Value operand : op.getOperands()) { + if (!isa(operand.getType()) || mapper.contains(operand)) + continue; + + Operation* definingOp = operand.getDefiningOp(); + if (definingOp && definingOp->getBlock() == &oldBlock) + continue; + + materializeCapturedTensor(operand); + } + + Operation* cloned = rewriter.clone(op, mapper); + for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults())) + mapper.map(originalResult, clonedResult); + } + + rewriter.setInsertionPointToEnd(newBlock); + PimHaltOp::create(rewriter, loc); + return success(); +} + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.hpp b/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.hpp new file mode 100644 index 0000000..3afc4b0 --- /dev/null +++ b/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.hpp @@ -0,0 +1,10 @@ +#pragma once + +#include "src/Accelerators/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.hpp" + +namespace onnx_mlir { + +mlir::LogicalResult +lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState& state, mlir::IRRewriter& rewriter); + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/SpatialToPim/CMakeLists.txt b/src/PIM/Conversion/SpatialToPim/CMakeLists.txt index 351d96d..a0bb5d2 100644 --- a/src/PIM/Conversion/SpatialToPim/CMakeLists.txt +++ b/src/PIM/Conversion/SpatialToPim/CMakeLists.txt @@ -4,8 +4,16 @@ add_public_tablegen_target(SpatialToPimIncGen) add_pim_library(OMSpatialToPim SpatialToPimPass.cpp + BatchCoreLoweringPatterns.cpp + ChannelLoweringPatterns.cpp + Cleanup.cpp Common.cpp - Patterns.cpp + ComputeLikeRegionUtils.cpp + CoreLoweringPatterns.cpp + GlobalTensorMaterialization.cpp + PhaseVerification.cpp + ReturnPathNormalization.cpp + TensorPackingPatterns.cpp EXCLUDE_FROM_OM_LIBS diff --git a/src/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.cpp b/src/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.cpp new file mode 100644 index 0000000..bb526be --- /dev/null +++ b/src/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.cpp @@ -0,0 +1,136 @@ +#include "mlir/Dialect/Tensor/IR/Tensor.h" + +#include "src/Accelerators/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.hpp" +#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp" +#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" + +using namespace mlir; + +namespace onnx_mlir { +namespace { + +static int32_t toPimCoreId(int32_t spatialCoreId) { return spatialCoreId; } + +struct ChannelSendLowering : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(spatial::SpatChannelSendOp op, PatternRewriter& rewriter) const override { + pim::PimSendOp::create(rewriter, + op.getLoc(), + op.getInput(), + getTensorSizeInBytesAttr(rewriter, op.getInput()), + rewriter.getI32IntegerAttr(toPimCoreId(op.getTargetCoreId()))); + rewriter.eraseOp(op); + return success(); + } +}; + +struct ChannelReceiveLowering : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(spatial::SpatChannelReceiveOp op, PatternRewriter& rewriter) const override { + if (op->use_empty()) { + rewriter.eraseOp(op); + return success(); + } + auto outputType = cast(op.getResult().getType()); + Value outputBuffer = + tensor::EmptyOp::create(rewriter, op.getLoc(), outputType.getShape(), outputType.getElementType()).getResult(); + Value received = pim::PimReceiveOp::create(rewriter, + op.getLoc(), + op.getResult().getType(), + outputBuffer, + getTensorSizeInBytesAttr(rewriter, op.getResult()), + rewriter.getI32IntegerAttr(toPimCoreId(op.getSourceCoreId()))) + .getOutput(); + rewriter.replaceOp(op, received); + return success(); + } +}; + +struct ChannelSendTensorLowering : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(spatial::SpatChannelSendTensorOp op, PatternRewriter& rewriter) const override { + SmallVector targetCoreIds; + targetCoreIds.reserve(op.getTargetCoreIds().size()); + for (int32_t targetCoreId : op.getTargetCoreIds()) + targetCoreIds.push_back(toPimCoreId(targetCoreId)); + pim::PimSendTensorOp::create(rewriter, op.getLoc(), op.getInput(), rewriter.getDenseI32ArrayAttr(targetCoreIds)); + rewriter.eraseOp(op); + return success(); + } +}; + +struct ChannelReceiveTensorLowering : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(spatial::SpatChannelReceiveTensorOp op, PatternRewriter& rewriter) const override { + SmallVector sourceCoreIds; + sourceCoreIds.reserve(op.getSourceCoreIds().size()); + for (int32_t sourceCoreId : op.getSourceCoreIds()) + sourceCoreIds.push_back(toPimCoreId(sourceCoreId)); + auto outputType = cast(op.getOutput().getType()); + Value outputBuffer = + tensor::EmptyOp::create(rewriter, op.getLoc(), outputType.getShape(), outputType.getElementType()).getResult(); + Value received = + pim::PimReceiveTensorOp::create( + rewriter, op.getLoc(), op.getOutput().getType(), outputBuffer, rewriter.getDenseI32ArrayAttr(sourceCoreIds)) + .getOutput(); + rewriter.replaceOp(op, received); + return success(); + } +}; + +struct ExtractRowsLowering : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(spatial::SpatExtractRowsOp op, PatternRewriter& rewriter) const override { + auto inputType = cast(op.getInput().getType()); + SmallVector replacements; + replacements.reserve(op.getNumResults()); + for (auto [rowIndex, output] : llvm::enumerate(op.getOutputs())) { + auto outputType = cast(output.getType()); + SmallVector offsets = { + rewriter.getIndexAttr(static_cast(rowIndex) * outputType.getDimSize(0)), rewriter.getIndexAttr(0)}; + SmallVector sizes = {rewriter.getIndexAttr(outputType.getDimSize(0)), + rewriter.getIndexAttr(inputType.getDimSize(1))}; + SmallVector strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; + replacements.push_back( + tensor::ExtractSliceOp::create(rewriter, op.getLoc(), outputType, op.getInput(), offsets, sizes, strides) + .getResult()); + } + rewriter.replaceOp(op, replacements); + return success(); + } +}; + +struct ConcatLowering : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(spatial::SpatConcatOp op, PatternRewriter& rewriter) const override { + auto outputType = cast(op.getOutput().getType()); + Value outputBuffer = + tensor::EmptyOp::create(rewriter, op.getLoc(), outputType.getShape(), outputType.getElementType()).getResult(); + Value concatenated = + pim::PimConcatOp::create( + rewriter, op.getLoc(), op.getOutput().getType(), op.getAxisAttr(), op.getInputs(), outputBuffer) + .getOutput(); + rewriter.replaceOp(op, concatenated); + return success(); + } +}; + +} // namespace + +void populateChannelLoweringPatterns(RewritePatternSet& patterns) { + patterns.add(patterns.getContext()); +} + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.hpp b/src/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.hpp new file mode 100644 index 0000000..068b2a3 --- /dev/null +++ b/src/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.hpp @@ -0,0 +1,9 @@ +#pragma once + +#include "mlir/IR/PatternMatch.h" + +namespace onnx_mlir { + +void populateChannelLoweringPatterns(mlir::RewritePatternSet& patterns); + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/SpatialToPim/Cleanup.cpp b/src/PIM/Conversion/SpatialToPim/Cleanup.cpp new file mode 100644 index 0000000..2da4aa8 --- /dev/null +++ b/src/PIM/Conversion/SpatialToPim/Cleanup.cpp @@ -0,0 +1,42 @@ +#include "llvm/ADT/STLExtras.h" + +#include "src/Accelerators/PIM/Conversion/SpatialToPim/Cleanup.hpp" + +using namespace mlir; + +namespace onnx_mlir { + +LogicalResult erasePendingOps(SmallVectorImpl& pendingOps, IRRewriter& rewriter) { + while (!pendingOps.empty()) { + bool erasedAnyOp = false; + for (auto it = pendingOps.begin(); it != pendingOps.end();) { + Operation* opToRemove = *it; + if (!opToRemove->use_empty()) { + ++it; + continue; + } + + rewriter.eraseOp(opToRemove); + it = pendingOps.erase(it); + erasedAnyOp = true; + } + + if (erasedAnyOp) + continue; + + for (Operation* opToRemove : pendingOps) { + InFlightDiagnostic diag = opToRemove->emitError("pending Spatial-to-PIM cleanup could not erase operation"); + diag << "; op has " << llvm::range_size(opToRemove->getUsers()) << " remaining user(s)"; + for (Operation* user : opToRemove->getUsers()) { + bool userPendingRemoval = llvm::is_contained(pendingOps, user); + opToRemove->emitRemark() << "remaining user `" << user->getName() << "`" + << (userPendingRemoval ? " is also pending removal" : " is not pending removal"); + } + } + return failure(); + } + + return success(); +} + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/SpatialToPim/Cleanup.hpp b/src/PIM/Conversion/SpatialToPim/Cleanup.hpp new file mode 100644 index 0000000..8935fe7 --- /dev/null +++ b/src/PIM/Conversion/SpatialToPim/Cleanup.hpp @@ -0,0 +1,11 @@ +#pragma once + +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Support/LLVM.h" + +namespace onnx_mlir { + +mlir::LogicalResult erasePendingOps(llvm::SmallVectorImpl& pendingOps, mlir::IRRewriter& rewriter); + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/SpatialToPim/ComputeLikeRegionUtils.cpp b/src/PIM/Conversion/SpatialToPim/ComputeLikeRegionUtils.cpp new file mode 100644 index 0000000..a94a7b6 --- /dev/null +++ b/src/PIM/Conversion/SpatialToPim/ComputeLikeRegionUtils.cpp @@ -0,0 +1,44 @@ +#include "src/Accelerators/PIM/Conversion/SpatialToPim/ComputeLikeRegionUtils.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" + +using namespace mlir; + +namespace onnx_mlir { + +std::optional getDirectComputeLikeInputIndex(Operation* owner, unsigned operandNumber) { + auto getInputIndex = [operandNumber](Operation* op, unsigned inputCount) -> std::optional { + if (inputCount == 0) + return std::nullopt; + unsigned inputBegin = op->getNumOperands() - inputCount; + if (operandNumber < inputBegin) + return std::nullopt; + return operandNumber - inputBegin; + }; + + if (auto compute = dyn_cast(owner)) + return getInputIndex(owner, compute.getInputs().size()); + + if (auto computeBatch = dyn_cast(owner)) + return getInputIndex(owner, computeBatch.getInputs().size()); + + return std::nullopt; +} + +void replaceAndEraseDirectComputeLikeInput(PatternRewriter& rewriter, + Operation* owner, + unsigned inputIndex, + Value replacement) { + Block& body = owner->getRegion(0).front(); + BlockArgument bodyArgument = body.getArgument(inputIndex); + + rewriter.startOpModification(owner); + bodyArgument.replaceAllUsesWith(replacement); + if (auto compute = dyn_cast(owner)) + compute.getInputsMutable().erase(inputIndex); + else + cast(owner).getInputsMutable().erase(inputIndex); + body.eraseArgument(inputIndex); + rewriter.finalizeOpModification(owner); +} + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/SpatialToPim/ComputeLikeRegionUtils.hpp b/src/PIM/Conversion/SpatialToPim/ComputeLikeRegionUtils.hpp new file mode 100644 index 0000000..c0ac20f --- /dev/null +++ b/src/PIM/Conversion/SpatialToPim/ComputeLikeRegionUtils.hpp @@ -0,0 +1,17 @@ +#pragma once + +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" + +#include + +namespace onnx_mlir { + +std::optional getDirectComputeLikeInputIndex(mlir::Operation* owner, unsigned operandNumber); + +void replaceAndEraseDirectComputeLikeInput(mlir::PatternRewriter& rewriter, + mlir::Operation* owner, + unsigned inputIndex, + mlir::Value replacement); + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.cpp b/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.cpp new file mode 100644 index 0000000..a1cba21 --- /dev/null +++ b/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.cpp @@ -0,0 +1,213 @@ +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/IR/IRMapping.h" + +#include "Conversion/ONNXToSpatial/Common/Common.hpp" +#include "src/Accelerators/PIM/Common/PimCommon.hpp" +#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp" +#include "src/Accelerators/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.hpp" +#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" + +using namespace mlir; +using namespace onnx_mlir::pim; + +namespace onnx_mlir { +namespace { + +static bool isChannelUseChainOp(Operation* op) { + return isa(op); +} + +static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter) { + for (Value operand : op->getOperands()) { + if (mapping.lookupOrNull(operand)) + continue; + + Operation* definingOp = operand.getDefiningOp(); + if (!definingOp) + continue; + + if (!isa(definingOp)) + continue; + + Operation* clonedOp = rewriter.clone(*definingOp, mapping); + for (auto [originalResult, newResult] : llvm::zip(definingOp->getResults(), clonedOp->getResults())) + mapping.map(originalResult, newResult); + rewriter.setInsertionPointAfter(clonedOp); + } +} + +static int32_t translateSpatialCoreIdToPimCoreId(size_t spatialCoreId) { return static_cast(spatialCoreId); } + +static int32_t getPimCoreIdForComputeOp(spatial::SpatCompute computeOp, size_t& fallbackCoreId) { + if (auto spatialCoreIdAttr = computeOp->getAttrOfType(onnx_mlir::kCoreIdAttrName)) + return static_cast(spatialCoreIdAttr.getInt()); + return static_cast(fallbackCoreId++); +} + +static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp, + SmallVectorImpl& helperChain, + bool requireReturnUse = true) { + if (computeOp.getInputs().size() != 1 || computeOp.getNumResults() != 1) + return failure(); + if (requireReturnUse + && (!computeOp.getResult(0).hasOneUse() || !isa(*computeOp.getResult(0).getUsers().begin()))) + return failure(); + + Block& block = computeOp.getBody().front(); + if (block.getNumArguments() != 1) + return failure(); + + auto yieldOp = dyn_cast(block.getTerminator()); + if (!yieldOp || yieldOp.getNumOperands() != 1) + return failure(); + + SmallVector reverseChain; + Value currentValue = yieldOp.getOperands().front(); + Value blockArg = block.getArgument(0); + + while (currentValue != blockArg) { + Operation* definingOp = currentValue.getDefiningOp(); + if (!definingOp || definingOp->getBlock() != &block || !isChannelUseChainOp(definingOp)) + return failure(); + reverseChain.push_back(definingOp); + currentValue = definingOp->getOperand(0); + } + + SmallPtrSet chainSet(reverseChain.begin(), reverseChain.end()); + for (Operation& op : llvm::make_early_inc_range(block.without_terminator())) + if (!chainSet.contains(&op) && !isa(op)) + return failure(); + + helperChain.assign(reverseChain.rbegin(), reverseChain.rend()); + return success(); +} + +static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute computeOp, IRRewriter& rewriter) { + if (!computeOp.getInputs().empty() || computeOp.getNumResults() != 1) + return false; + if (!llvm::all_of(computeOp.getResult(0).getUsers(), [](Operation* user) { + return isa(user); + })) + return false; + + Block& block = computeOp.getBody().front(); + if (block.getNumArguments() != 0) + return false; + + auto yieldOp = dyn_cast(block.getTerminator()); + if (!yieldOp || yieldOp.getNumOperands() != 1) + return false; + + rewriter.setInsertionPoint(computeOp); + IRMapping mapping; + for (Operation& op : block.without_terminator()) { + cloneMappedHelperOperands(&op, mapping, rewriter); + Operation* clonedOp = rewriter.clone(op, mapping); + for (auto [originalResult, newResult] : llvm::zip(op.getResults(), clonedOp->getResults())) + mapping.map(originalResult, newResult); + rewriter.setInsertionPointAfter(clonedOp); + } + + Value replacement = mapping.lookupOrDefault(yieldOp.getOperand(0)); + computeOp.getResult(0).replaceAllUsesWith(replacement); + return true; +} + +} // namespace + +void markOpToRemove(CoreLoweringState& state, Operation* op) { + if (!llvm::is_contained(state.operationsToRemove, op)) + state.operationsToRemove.push_back(op); +} + +LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState& state, IRRewriter& rewriter) { + Location loc = computeOp->getLoc(); + + if (inlineInputlessHelperComputeForWeightLikeUsers(computeOp, rewriter)) + return success(); + + SmallVector helperChain; + if (succeeded(collectHelperComputeChain(computeOp, helperChain))) + return success(); + + auto& block = computeOp.getRegion().front(); + auto yieldOp = cast(block.getTerminator()); + + for (auto [argIndex, blockArg] : llvm::enumerate(block.getArguments())) { + auto receiveOp = dyn_cast_or_null(computeOp.getInputs()[argIndex].getDefiningOp()); + if (!receiveOp || blockArg.use_empty()) + continue; + + rewriter.setInsertionPoint(getEarliestUserWithinBlock(blockArg)); + auto outputType = cast(blockArg.getType()); + auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveOp.getLoc(), outputType); + auto sizeAttr = getTensorSizeInBytesAttr(rewriter, blockArg); + auto sourceCoreIdAttr = rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(receiveOp.getSourceCoreId())); + Value received = PimReceiveOp::create( + rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr) + .getOutput(); + blockArg.replaceAllUsesWith(received); + markOpToRemove(state, receiveOp); + } + + if (computeOp.getNumResults() != yieldOp.getNumOperands()) + llvm_unreachable("ComputeOp must have same number of results as yieldOp operands"); + + for (auto [result, yieldValue] : llvm::zip(computeOp.getResults(), yieldOp.getOperands())) { + if (result.use_empty()) + continue; + + ReturnPathState returnPathState {state.outputTensors, state.operationsToRemove}; + ReturnPathLoweringResult returnPathResult = + lowerComputeResultReturnPath(computeOp, cast(result), yieldValue, returnPathState, rewriter); + if (returnPathResult == ReturnPathLoweringResult::Failure) + return failure(); + if (returnPathResult == ReturnPathLoweringResult::Handled) + continue; + + auto resultUses = result.getUses(); + if (rangeLength(resultUses) == 1) { + OpOperand& resultUse = *resultUses.begin(); + Operation* resultUser = resultUse.getOwner(); + if (isa(resultUser)) + continue; + } + + return computeOp.emitOpError("has an unsupported remaining result use during Spatial-to-PIM lowering"); + } + + rewriter.setInsertionPoint(yieldOp); + rewriter.replaceOpWithNewOp(yieldOp); + + SmallVector computeWeights; + if (!computeOp.getWeights().empty()) + computeWeights.append(computeOp.getWeights().begin(), computeOp.getWeights().end()); + rewriter.setInsertionPointAfter(computeOp); + auto coreOp = PimCoreOp::create(rewriter, + loc, + ValueRange(computeWeights), + rewriter.getI32IntegerAttr(getPimCoreIdForComputeOp(computeOp, state.nextCoreId))); + auto& coreOpBlocks = coreOp.getBody().getBlocks(); + for (auto [argIndex, blockArg] : llvm::enumerate(block.getArguments())) + if (!blockArg.use_empty()) + blockArg.replaceAllUsesWith(computeOp.getInputs()[argIndex]); + block.eraseArguments(0, block.getNumArguments()); + coreOpBlocks.splice(coreOpBlocks.begin(), computeOp.getBody().getBlocks()); + Block* tempComputeBlock = new Block(); + computeOp.getBody().push_back(tempComputeBlock); + rewriter.setInsertionPointToEnd(tempComputeBlock); + PimHaltOp::create(rewriter, computeOp.getLoc()); + return success(); +} + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.hpp b/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.hpp new file mode 100644 index 0000000..74304ed --- /dev/null +++ b/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.hpp @@ -0,0 +1,21 @@ +#pragma once + +#include "mlir/IR/PatternMatch.h" + +#include "src/Accelerators/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" + +namespace onnx_mlir { + +struct CoreLoweringState { + size_t& nextCoreId; + llvm::SmallVectorImpl& outputTensors; + llvm::SmallVectorImpl& operationsToRemove; +}; + +void markOpToRemove(CoreLoweringState& state, mlir::Operation* op); + +mlir::LogicalResult +lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState& state, mlir::IRRewriter& rewriter); + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/SpatialToPim/Patterns.cpp b/src/PIM/Conversion/SpatialToPim/GlobalTensorMaterialization.cpp similarity index 69% rename from src/PIM/Conversion/SpatialToPim/Patterns.cpp rename to src/PIM/Conversion/SpatialToPim/GlobalTensorMaterialization.cpp index 2eed336..2e6b456 100644 --- a/src/PIM/Conversion/SpatialToPim/Patterns.cpp +++ b/src/PIM/Conversion/SpatialToPim/GlobalTensorMaterialization.cpp @@ -6,16 +6,17 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/SymbolTable.h" #include "mlir/IR/Value.h" -#include "mlir/Support/LLVM.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Casting.h" -#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/LogicalResult.h" #include "Common/PimCommon.hpp" +#include "src/Accelerators/PIM/Conversion/SpatialToPim/ComputeLikeRegionUtils.hpp" +#include "src/Accelerators/PIM/Conversion/SpatialToPim/GlobalTensorMaterialization.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" using namespace mlir; @@ -23,33 +24,33 @@ using namespace mlir; namespace onnx_mlir { namespace { - -static std::optional getDirectComputeInputIndex(Operation* owner, unsigned operandNumber) { - if (auto compute = dyn_cast(owner)) { - unsigned inputCount = compute.getInputs().size(); - if (inputCount == 0) - return std::nullopt; - - unsigned inputBegin = compute->getNumOperands() - inputCount; - if (operandNumber < inputBegin) - return std::nullopt; - return operandNumber - inputBegin; - } - - if (auto computeBatch = dyn_cast(owner)) { - unsigned inputCount = computeBatch.getInputs().size(); - if (inputCount == 0) - return std::nullopt; - - unsigned inputBegin = computeBatch->getNumOperands() - inputCount; - if (operandNumber < inputBegin) - return std::nullopt; - return operandNumber - inputBegin; - } - - return std::nullopt; +static std::string makeUniqueSymbolName(Operation* symbolTableOp, StringRef baseName) { + std::string name = baseName.str(); + unsigned suffix = 0; + while (SymbolTable::lookupSymbolIn(symbolTableOp, name)) + name = (baseName + "_" + Twine(suffix++)).str(); + return name; } +static memref::GlobalOp createPrivateMemrefGlobalWithUniqueName(PatternRewriter& rewriter, + Location loc, + ModuleOp moduleOp, + StringRef baseName, + MemRefType type, + Attribute initialValue = {}, + UnitAttr constant = {}) { + std::string symbolName = makeUniqueSymbolName(moduleOp, baseName); + return memref::GlobalOp::create(rewriter, + loc, + rewriter.getStringAttr(symbolName), + rewriter.getStringAttr("private"), + TypeAttr::get(type), + initialValue, + constant, + IntegerAttr {}); +} + +// Sinks top-level tensor slices into compute regions so later lowering sees local runtime work. struct MoveExtractSliceIntoCompute final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -59,7 +60,7 @@ struct MoveExtractSliceIntoCompute final : OpRewritePatterngetUses()) { if (isa(uses.getOwner())) { - if (!getDirectComputeInputIndex(uses.getOwner(), uses.getOperandNumber())) + if (!getDirectComputeLikeInputIndex(uses.getOwner(), uses.getOperandNumber())) return failure(); } else if (isa_and_present(uses.getOwner()->getParentOp())) { @@ -72,7 +73,7 @@ struct MoveExtractSliceIntoCompute final : OpRewritePatterngetUses())) { if (auto spatCompute = dyn_cast(uses.getOwner())) { - auto inputIndex = getDirectComputeInputIndex(spatCompute, uses.getOperandNumber()); + auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, uses.getOperandNumber()); if (!inputIndex) return failure(); auto BBArgIndex = *inputIndex; @@ -87,14 +88,11 @@ struct MoveExtractSliceIntoCompute final : OpRewritePatterngetResult(0)}); } - rewriter.startOpModification(spatCompute.getOperation()); - BBArgValue.replaceAllUsesWith(mapSpatToExtract[spatCompute.getOperation()]); - spatCompute.getInputsMutable().erase(BBArgIndex); - spatCompute.getBody().front().eraseArgument(BBArgIndex); - rewriter.finalizeOpModification(spatCompute.getOperation()); + replaceAndEraseDirectComputeLikeInput( + rewriter, spatCompute.getOperation(), BBArgIndex, mapSpatToExtract[spatCompute.getOperation()]); } else if (auto spatComputeBatch = dyn_cast(uses.getOwner())) { - auto inputIndex = getDirectComputeInputIndex(spatComputeBatch, uses.getOperandNumber()); + auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, uses.getOperandNumber()); if (!inputIndex) return failure(); auto BBArgIndex = *inputIndex; @@ -109,11 +107,8 @@ struct MoveExtractSliceIntoCompute final : OpRewritePatterngetResult(0)}); } - rewriter.startOpModification(spatComputeBatch.getOperation()); - BBArgValue.replaceAllUsesWith(mapSpatToExtract[spatComputeBatch.getOperation()]); - spatComputeBatch.getInputsMutable().erase(BBArgIndex); - spatComputeBatch.getBody().front().eraseArgument(BBArgIndex); - rewriter.finalizeOpModification(spatComputeBatch.getOperation()); + replaceAndEraseDirectComputeLikeInput( + rewriter, spatComputeBatch.getOperation(), BBArgIndex, mapSpatToExtract[spatComputeBatch.getOperation()]); } else { { @@ -148,11 +143,11 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(mlir::arith::ConstantOp constantOp, PatternRewriter& rewriter) const override { - static int i = 0; Location loc = constantOp.getLoc(); if (hasWeightAlways(constantOp)) @@ -177,15 +172,14 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePatterngetParentOfType(), + "const", + memRefType, + constantOp.getValueAttr(), + rewriter.getUnitAttr()); + std::string argName = globalOp.getSymName().str(); llvm::DenseMap mapSpatComputeToConst; @@ -193,11 +187,10 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePattern(constUsers)) { - auto inputIndex = getDirectComputeInputIndex(spatCompute, constUses.getOperandNumber()); + auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, constUses.getOperandNumber()); if (!inputIndex) return failure(); auto BBArgIndex = *inputIndex; - auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex); rewriter.setInsertionPoint(&spatCompute.getBody().front().front()); if (!mapSpatComputeToConst.contains(spatCompute.getOperation())) { auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName); @@ -206,18 +199,14 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePattern(constUsers)) { - auto inputIndex = getDirectComputeInputIndex(spatComputeBatch, constUses.getOperandNumber()); + auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, constUses.getOperandNumber()); if (!inputIndex) return failure(); auto BBArgIndex = *inputIndex; - auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex); rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front()); if (!mapSpatComputeToConst.contains(spatComputeBatch.getOperation())) { auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName); @@ -226,11 +215,10 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePattern(constUsers)) { - auto inputIndex = getDirectComputeInputIndex(spatCompute, constUses.getOperandNumber()); + auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, constUses.getOperandNumber()); if (!inputIndex) return failure(); auto BBArgIndex = *inputIndex; - auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex); rewriter.setInsertionPoint(&spatCompute.getBody().front().front()); auto newConst = rewriter.clone(*constantOp); - rewriter.startOpModification(spatCompute.getOperation()); - BBArgValue.replaceAllUsesWith(newConst->getResult(0)); - spatCompute.getInputsMutable().erase(BBArgIndex); - spatCompute.getBody().front().eraseArgument(BBArgIndex); - rewriter.finalizeOpModification(spatCompute.getOperation()); + replaceAndEraseDirectComputeLikeInput( + rewriter, spatCompute.getOperation(), BBArgIndex, newConst->getResult(0)); } else if (auto spatComputeBatch = llvm::dyn_cast(constUsers)) { - auto inputIndex = getDirectComputeInputIndex(spatComputeBatch, constUses.getOperandNumber()); + auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, constUses.getOperandNumber()); if (!inputIndex) return failure(); auto BBArgIndex = *inputIndex; - auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex); rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front()); auto newConst = rewriter.clone(*constantOp); - rewriter.startOpModification(spatComputeBatch.getOperation()); - BBArgValue.replaceAllUsesWith(newConst->getResult(0)); - spatComputeBatch.getInputsMutable().erase(BBArgIndex); - spatComputeBatch.getBody().front().eraseArgument(BBArgIndex); - rewriter.finalizeOpModification(spatComputeBatch.getOperation()); + replaceAndEraseDirectComputeLikeInput( + rewriter, spatComputeBatch.getOperation(), BBArgIndex, newConst->getResult(0)); } else if (auto parent = constUsers->getParentOfType()) { if (!mapSpatComputeToConst.contains(parent)) { @@ -321,11 +301,13 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePatternuse_empty()) + rewriter.eraseOp(constantOp); return success(); } }; +// Materializes public function tensor inputs as globals so compute bodies can load them uniformly. struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -352,52 +334,36 @@ struct FuncOpArgToGlobalMemoryPattern final : OpRewritePatterngetParentOfType(), baseName, memRefType); + std::string argName = globalOp.getSymName().str(); for (auto& argUses : llvm::make_early_inc_range(arg.getUses())) { auto argUser = argUses.getOwner(); if (auto spatCompute = dyn_cast(argUser)) { - auto inputIndex = getDirectComputeInputIndex(spatCompute, argUses.getOperandNumber()); + auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, argUses.getOperandNumber()); if (!inputIndex) return failure(); auto BBArgIndex = *inputIndex; - auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex); rewriter.setInsertionPoint(&spatCompute.getBody().front().front()); auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName); auto toTensor = bufferization::ToTensorOp::create( rewriter, loc, argRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr()); - rewriter.startOpModification(spatCompute.getOperation()); - BBArgValue.replaceAllUsesWith(toTensor); - spatCompute.getInputsMutable().erase(BBArgIndex); - spatCompute.getBody().front().eraseArgument(BBArgIndex); - rewriter.finalizeOpModification(spatCompute.getOperation()); + replaceAndEraseDirectComputeLikeInput(rewriter, spatCompute.getOperation(), BBArgIndex, toTensor); } else if (auto spatComputeBatch = dyn_cast(argUser)) { - auto inputIndex = getDirectComputeInputIndex(spatComputeBatch, argUses.getOperandNumber()); + auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, argUses.getOperandNumber()); if (!inputIndex) return failure(); auto BBArgIndex = *inputIndex; - auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex); rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front()); auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName); auto toTensor = bufferization::ToTensorOp::create( rewriter, loc, argRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr()); - rewriter.startOpModification(spatComputeBatch.getOperation()); - BBArgValue.replaceAllUsesWith(toTensor); - spatComputeBatch.getInputsMutable().erase(BBArgIndex); - spatComputeBatch.getBody().front().eraseArgument(BBArgIndex); - rewriter.finalizeOpModification(spatComputeBatch.getOperation()); + replaceAndEraseDirectComputeLikeInput(rewriter, spatComputeBatch.getOperation(), BBArgIndex, toTensor); } else { rewriter.setInsertionPoint(argUser); @@ -416,7 +382,7 @@ struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern( patterns.getContext()); } diff --git a/src/PIM/Conversion/SpatialToPim/GlobalTensorMaterialization.hpp b/src/PIM/Conversion/SpatialToPim/GlobalTensorMaterialization.hpp new file mode 100644 index 0000000..7464dec --- /dev/null +++ b/src/PIM/Conversion/SpatialToPim/GlobalTensorMaterialization.hpp @@ -0,0 +1,9 @@ +#pragma once + +#include "mlir/IR/PatternMatch.h" + +namespace onnx_mlir { + +void populateGlobalTensorMaterializationPatterns(mlir::RewritePatternSet& patterns); + +} diff --git a/src/PIM/Conversion/SpatialToPim/Patterns.hpp b/src/PIM/Conversion/SpatialToPim/Patterns.hpp deleted file mode 100644 index e34f6ab..0000000 --- a/src/PIM/Conversion/SpatialToPim/Patterns.hpp +++ /dev/null @@ -1,10 +0,0 @@ -#pragma once - -#include "mlir/IR/PatternMatch.h" - - -namespace onnx_mlir { - -void populateGlobalTensorToMemrefPatterns(mlir::RewritePatternSet& patterns); - -} diff --git a/src/PIM/Conversion/SpatialToPim/PhaseVerification.cpp b/src/PIM/Conversion/SpatialToPim/PhaseVerification.cpp new file mode 100644 index 0000000..4c18886 --- /dev/null +++ b/src/PIM/Conversion/SpatialToPim/PhaseVerification.cpp @@ -0,0 +1,20 @@ +#include "src/Accelerators/PIM/Conversion/SpatialToPim/PhaseVerification.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" + +using namespace mlir; + +namespace onnx_mlir { + +LogicalResult verifySpatialToPimBoundary(ModuleOp moduleOp) { + bool hasFailure = false; + moduleOp.walk([&](Operation* op) { + if (op->getDialect()->getNamespace() != "spat") + return; + + op->emitError("illegal Spatial operation remains after Spatial-to-PIM lowering"); + hasFailure = true; + }); + return success(!hasFailure); +} + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/SpatialToPim/PhaseVerification.hpp b/src/PIM/Conversion/SpatialToPim/PhaseVerification.hpp new file mode 100644 index 0000000..d17da32 --- /dev/null +++ b/src/PIM/Conversion/SpatialToPim/PhaseVerification.hpp @@ -0,0 +1,9 @@ +#pragma once + +#include "mlir/IR/BuiltinOps.h" + +namespace onnx_mlir { + +mlir::LogicalResult verifySpatialToPimBoundary(mlir::ModuleOp moduleOp); + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/SpatialToPim/ReturnPathNormalization.cpp b/src/PIM/Conversion/SpatialToPim/ReturnPathNormalization.cpp new file mode 100644 index 0000000..3a5f755 --- /dev/null +++ b/src/PIM/Conversion/SpatialToPim/ReturnPathNormalization.cpp @@ -0,0 +1,587 @@ +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/SymbolTable.h" + +#include "Conversion/ONNXToSpatial/Common/Common.hpp" +#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp" +#include "src/Accelerators/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp" +#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" + +using namespace mlir; +using namespace onnx_mlir::pim; + +namespace onnx_mlir { +namespace { + +struct ReturnUseInfo { + size_t returnIndex; + SmallVector helperChain; +}; + +struct ConcatReturnUseInfo { + size_t returnIndex; + SmallVector sliceOffsets; + SmallVector concatShape; + SmallVector concatChain; + SmallVector helperChain; +}; + +static bool isReturnHelperChainOp(Operation* op) { + return isa(op); +} + +static void markOpToRemove(ReturnPathState& state, Operation* op) { + if (!llvm::is_contained(state.operationsToRemove, op)) + state.operationsToRemove.push_back(op); +} + +static std::string makeUniqueSymbolName(Operation* symbolTableOp, StringRef baseName) { + std::string name = baseName.str(); + unsigned suffix = 0; + while (SymbolTable::lookupSymbolIn(symbolTableOp, name)) + name = (baseName + "_" + Twine(suffix++)).str(); + return name; +} + +static int64_t computeFlatElementIndex(ArrayRef indices, ArrayRef shape) { + int64_t flatIndex = 0; + for (size_t i = 0; i < shape.size(); ++i) { + flatIndex *= shape[i]; + flatIndex += indices[i]; + } + return flatIndex; +} + +static SmallVector expandFlatElementIndex(int64_t flatIndex, ArrayRef shape) { + SmallVector indices(shape.size(), 0); + for (int64_t dim = static_cast(shape.size()) - 1; dim >= 0; --dim) { + indices[dim] = flatIndex % shape[dim]; + flatIndex /= shape[dim]; + } + return indices; +} + +static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp, + SmallVectorImpl& helperChain) { + if (computeOp.getInputs().size() != 1 || computeOp.getNumResults() != 1) + return failure(); + if (!computeOp.getResult(0).hasOneUse() || !isa(*computeOp.getResult(0).getUsers().begin())) + return failure(); + + Block& block = computeOp.getBody().front(); + if (block.getNumArguments() != 1) + return failure(); + + auto yieldOp = dyn_cast(block.getTerminator()); + if (!yieldOp || yieldOp.getNumOperands() != 1) + return failure(); + + SmallVector reverseChain; + Value currentValue = yieldOp.getOperands().front(); + Value blockArg = block.getArgument(0); + + while (currentValue != blockArg) { + Operation* definingOp = currentValue.getDefiningOp(); + if (!definingOp || definingOp->getBlock() != &block || !isReturnHelperChainOp(definingOp)) + return failure(); + reverseChain.push_back(definingOp); + currentValue = definingOp->getOperand(0); + } + + SmallPtrSet chainSet(reverseChain.begin(), reverseChain.end()); + for (Operation& op : llvm::make_early_inc_range(block.without_terminator())) + if (!chainSet.contains(&op) && !isa(op)) + return failure(); + + helperChain.assign(reverseChain.rbegin(), reverseChain.rend()); + return success(); +} + +static std::optional analyzeReturnUse(Value value) { + auto uses = value.getUses(); + if (rangeLength(uses) != 1) + return std::nullopt; + + SmallVector helperChain; + Value currentValue = value; + Operation* currentUser = uses.begin()->getOwner(); + + while (isReturnHelperChainOp(currentUser)) { + helperChain.push_back(currentUser); + auto currentUses = currentUser->getResult(0).getUses(); + if (rangeLength(currentUses) != 1) + return std::nullopt; + currentValue = currentUser->getResult(0); + currentUser = currentUses.begin()->getOwner(); + } + + if (!isa(currentUser)) + return std::nullopt; + + return ReturnUseInfo { + currentValue.getUses().begin()->getOperandNumber(), + std::move(helperChain), + }; +} + +static std::optional analyzeConcatReturnUse(Value value) { + auto getConcatResult = [](Operation* op) -> Value { + if (auto tensorConcat = dyn_cast(op)) + return tensorConcat.getResult(); + if (auto spatialConcat = dyn_cast(op)) + return spatialConcat.getOutput(); + if (auto pimConcat = dyn_cast(op)) + return pimConcat.getOutput(); + return {}; + }; + auto getConcatAxis = [](Operation* op) -> std::optional { + if (auto tensorConcat = dyn_cast(op)) + return tensorConcat.getDim(); + if (auto spatialConcat = dyn_cast(op)) + return spatialConcat.getAxis(); + if (auto pimConcat = dyn_cast(op)) + return pimConcat.getAxis(); + return std::nullopt; + }; + auto getConcatOperands = [](Operation* op) -> OperandRange { + if (auto tensorConcat = dyn_cast(op)) + return tensorConcat.getOperands(); + if (auto spatialConcat = dyn_cast(op)) + return spatialConcat.getInputs(); + return cast(op).getInputs(); + }; + + auto uses = value.getUses(); + if (rangeLength(uses) != 1 + || !isa(uses.begin()->getOwner())) + return std::nullopt; + + auto valueType = dyn_cast(value.getType()); + if (!valueType || !valueType.hasStaticShape()) + return std::nullopt; + + SmallVector sliceOffsets(valueType.getRank(), 0); + SmallVector concatShape(valueType.getShape().begin(), valueType.getShape().end()); + SmallVector concatChain; + Value currentValue = value; + Operation* currentUser = uses.begin()->getOwner(); + + while (isa(currentUser)) { + concatChain.push_back(currentUser); + size_t operandIndex = currentValue.getUses().begin()->getOperandNumber(); + int64_t axis = *getConcatAxis(currentUser); + for (Value operand : getConcatOperands(currentUser).take_front(operandIndex)) + sliceOffsets[axis] += cast(operand.getType()).getShape()[axis]; + + Value concatResult = getConcatResult(currentUser); + auto concatType = dyn_cast(concatResult.getType()); + if (!concatType || !concatType.hasStaticShape()) + return std::nullopt; + concatShape.assign(concatType.getShape().begin(), concatType.getShape().end()); + + currentValue = concatResult; + auto currentUses = currentValue.getUses(); + if (rangeLength(currentUses) != 1) + return std::nullopt; + currentUser = currentUses.begin()->getOwner(); + } + + SmallVector helperChain; + if (auto helperCompute = dyn_cast(currentUser)) { + if (helperCompute.getInputs().size() != 1 || helperCompute.getInputs().front() != currentValue) + return std::nullopt; + + if (failed(collectHelperComputeChain(helperCompute, helperChain))) + return std::nullopt; + + currentValue = helperCompute.getResult(0); + auto currentUses = currentValue.getUses(); + if (rangeLength(currentUses) != 1) + return std::nullopt; + currentUser = currentUses.begin()->getOwner(); + } + + while (isReturnHelperChainOp(currentUser)) { + helperChain.push_back(currentUser); + auto currentUses = currentUser->getResult(0).getUses(); + if (rangeLength(currentUses) != 1) + return std::nullopt; + currentValue = currentUser->getResult(0); + currentUser = currentUses.begin()->getOwner(); + } + + if (!isa(currentUser)) + return std::nullopt; + + return ConcatReturnUseInfo { + currentValue.getUses().begin()->getOperandNumber(), + std::move(sliceOffsets), + std::move(concatShape), + std::move(concatChain), + std::move(helperChain), + }; +} + +static LogicalResult mapIndicesThroughHelperChain(ArrayRef sourceIndices, + ArrayRef sourceShape, + ArrayRef helperChain, + SmallVectorImpl& mappedIndices) { + SmallVector currentIndices(sourceIndices.begin(), sourceIndices.end()); + SmallVector currentShape(sourceShape.begin(), sourceShape.end()); + + auto reshapeToResultShape = [&](Operation* op) -> LogicalResult { + auto resultType = dyn_cast(op->getResult(0).getType()); + if (!resultType || !resultType.hasStaticShape()) + return failure(); + int64_t flatIndex = computeFlatElementIndex(currentIndices, currentShape); + currentShape.assign(resultType.getShape().begin(), resultType.getShape().end()); + currentIndices = expandFlatElementIndex(flatIndex, currentShape); + return success(); + }; + + for (Operation* op : helperChain) { + if (auto extractSliceOp = dyn_cast(op)) { + auto hasStaticValues = [](ArrayRef values) { + return llvm::all_of(values, [](int64_t value) { return !ShapedType::isDynamic(value); }); + }; + if (!hasStaticValues(extractSliceOp.getStaticOffsets()) || !hasStaticValues(extractSliceOp.getStaticSizes()) + || !hasStaticValues(extractSliceOp.getStaticStrides())) + return failure(); + + SmallVector nextIndices; + nextIndices.reserve(currentIndices.size()); + for (auto [index, offset, size, stride] : llvm::zip_equal(currentIndices, + extractSliceOp.getStaticOffsets(), + extractSliceOp.getStaticSizes(), + extractSliceOp.getStaticStrides())) { + if (stride != 1 || index < offset || index >= offset + size) + return failure(); + nextIndices.push_back(index - offset); + } + + auto resultType = dyn_cast(extractSliceOp.getResult().getType()); + if (!resultType || !resultType.hasStaticShape()) + return failure(); + currentIndices = std::move(nextIndices); + currentShape.assign(resultType.getShape().begin(), resultType.getShape().end()); + continue; + } + + if (auto transposeOp = dyn_cast(op)) { + SmallVector nextIndices(currentIndices.size()); + SmallVector nextShape(currentShape.size()); + for (auto [destIndex, attr] : llvm::enumerate(transposeOp.getPermAttr().getAsRange())) { + int64_t sourceIndex = attr.getInt(); + nextIndices[destIndex] = currentIndices[sourceIndex]; + nextShape[destIndex] = currentShape[sourceIndex]; + } + currentIndices = std::move(nextIndices); + currentShape = std::move(nextShape); + continue; + } + + if (auto transposeOp = dyn_cast(op)) { + SmallVector nextIndices(currentIndices.size()); + SmallVector nextShape(currentShape.size()); + for (auto [destIndex, attr] : llvm::enumerate(transposeOp.getPermutation().getAsRange())) { + int64_t sourceIndex = attr.getInt(); + nextIndices[destIndex] = currentIndices[sourceIndex]; + nextShape[destIndex] = currentShape[sourceIndex]; + } + currentIndices = std::move(nextIndices); + currentShape = std::move(nextShape); + continue; + } + + if (isa(op)) { + if (failed(reshapeToResultShape(op))) + return failure(); + continue; + } + + return failure(); + } + + mappedIndices.assign(currentIndices.begin(), currentIndices.end()); + return success(); +} + +static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter) { + for (Value operand : op->getOperands()) { + if (mapping.lookupOrNull(operand)) + continue; + + Operation* definingOp = operand.getDefiningOp(); + if (!definingOp) + continue; + + if (!isa(definingOp)) + continue; + + Operation* clonedOp = rewriter.clone(*definingOp, mapping); + for (auto [originalResult, newResult] : llvm::zip(definingOp->getResults(), clonedOp->getResults())) + mapping.map(originalResult, newResult); + rewriter.setInsertionPointAfter(clonedOp); + } +} + +static void +cloneHelperChain(Value sourceValue, ArrayRef helperChain, IRRewriter& rewriter, Value& clonedValue) { + IRMapping mapping; + mapping.map(sourceValue, sourceValue); + clonedValue = sourceValue; + + rewriter.setInsertionPointAfterValue(sourceValue); + for (Operation* op : helperChain) { + cloneMappedHelperOperands(op, mapping, rewriter); + Operation* clonedOp = rewriter.clone(*op, mapping); + for (auto [originalResult, newResult] : llvm::zip(op->getResults(), clonedOp->getResults())) + mapping.map(originalResult, newResult); + clonedValue = clonedOp->getResult(0); + rewriter.setInsertionPointAfter(clonedOp); + } +} + +static Value emitHostCopy(IRRewriter& rewriter, + Location loc, + Value outputTensor, + Value sourceValue, + int32_t hostTargetOffset, + int32_t deviceSourceOffset, + int32_t sizeInBytes) { + return PimMemCopyDevToHostOp::create(rewriter, + loc, + outputTensor.getType(), + outputTensor, + sourceValue, + rewriter.getI32IntegerAttr(hostTargetOffset), + rewriter.getI32IntegerAttr(deviceSourceOffset), + rewriter.getI32IntegerAttr(sizeInBytes)) + .getOutput(); +} + +} // namespace + +void addReturnOutputBuffers(func::ReturnOp returnOp, + IRRewriter& rewriter, + SmallVectorImpl& outputTensors) { + outputTensors.reserve(returnOp->getNumOperands()); + for (auto [index, returnValue] : llvm::enumerate(returnOp->getOperands())) { + Value currentReturnValue = returnValue; + Operation* returnValueDefiningOp = currentReturnValue.getDefiningOp(); + if (returnValueDefiningOp->hasTrait()) { + assert(!hasWeightAlways(returnValueDefiningOp)); + outputTensors.push_back( + [currentReturnValue](IRRewriter& rewriter, Location loc) -> Value { return currentReturnValue; }); + } + else { + auto outRankedTensorType = llvm::dyn_cast(currentReturnValue.getType()); + auto memRefType = MemRefType::get(outRankedTensorType.getShape(), outRankedTensorType.getElementType()); + + std::string outputBaseName = ("output_" + Twine(index)).str(); + std::string outputName = makeUniqueSymbolName(returnOp->getParentOfType(), outputBaseName); + rewriter.setInsertionPoint(returnOp.getParentOp()); + memref::GlobalOp::create(rewriter, + returnOp.getLoc(), + rewriter.getStringAttr(outputName), + rewriter.getStringAttr("private"), + TypeAttr::get(memRefType), + {}, + {}, + {}); + outputTensors.push_back([memRefType, outputName, outRankedTensorType](IRRewriter& rewriter, Location loc) { + auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, outputName); + auto toTensor = bufferization::ToTensorOp::create( + rewriter, loc, outRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr()); + return toTensor.getResult(); + }); + } + } +} + +ReturnPathLoweringResult lowerComputeResultReturnPath( + spatial::SpatCompute computeOp, OpResult result, Value yieldValue, ReturnPathState& state, IRRewriter& rewriter) { + Location loc = computeOp->getLoc(); + auto yieldType = cast(yieldValue.getType()); + + if (auto returnUse = analyzeReturnUse(result)) { + Value storedValue = yieldValue; + cloneHelperChain(yieldValue, returnUse->helperChain, rewriter, storedValue); + for (Operation* op : returnUse->helperChain) + markOpToRemove(state, op); + + auto storedType = cast(storedValue.getType()); + size_t elementSize = storedType.getElementTypeBitWidth() / 8; + if (auto storedOp = storedValue.getDefiningOp()) + rewriter.setInsertionPointAfter(storedOp); + Value outputTensor = state.outputTensors[returnUse->returnIndex](rewriter, loc); + emitHostCopy( + rewriter, loc, outputTensor, storedValue, 0, 0, static_cast(storedType.getNumElements() * elementSize)); + return ReturnPathLoweringResult::Handled; + } + + auto resultUses = result.getUses(); + if (rangeLength(resultUses) == 1) { + OpOperand& resultUse = *resultUses.begin(); + Operation* resultUser = resultUse.getOwner(); + + if (isa(resultUser)) { + size_t resultIndexInReturn = resultUse.getOperandNumber(); + size_t elementSize = yieldType.getElementType().getIntOrFloatBitWidth() / 8; + rewriter.setInsertionPointAfterValue(yieldValue); + Value outputTensor = state.outputTensors[resultIndexInReturn](rewriter, loc); + emitHostCopy( + rewriter, loc, outputTensor, yieldValue, 0, 0, static_cast(yieldType.getNumElements() * elementSize)); + return ReturnPathLoweringResult::Handled; + } + } + + if (auto concatReturnUse = analyzeConcatReturnUse(result)) { + size_t elementSize = yieldType.getElementTypeBitWidth() / 8; + for (Operation* concatOp : concatReturnUse->concatChain) + markOpToRemove(state, concatOp); + + if (concatReturnUse->helperChain.empty()) { + rewriter.setInsertionPointAfterValue(yieldValue); + Value outputTensor = state.outputTensors[concatReturnUse->returnIndex](rewriter, loc); + auto outputType = cast(outputTensor.getType()); + int64_t flatOffset = computeFlatElementIndex(concatReturnUse->sliceOffsets, outputType.getShape()); + emitHostCopy(rewriter, + loc, + outputTensor, + yieldValue, + static_cast(flatOffset * elementSize), + 0, + static_cast(yieldType.getNumElements() * elementSize)); + return ReturnPathLoweringResult::Handled; + } + + auto storedType = dyn_cast(yieldValue.getType()); + if (!storedType) { + computeOp.emitOpError("has an unsupported non-ranked concat-return helper yield during Spatial-to-PIM lowering"); + return ReturnPathLoweringResult::Failure; + } + rewriter.setInsertionPointAfterValue(yieldValue); + Value outputTensor = state.outputTensors[concatReturnUse->returnIndex](rewriter, loc); + auto outputType = cast(outputTensor.getType()); + for (int64_t linearIndex = 0; linearIndex < storedType.getNumElements(); ++linearIndex) { + SmallVector sourceIndices = expandFlatElementIndex(linearIndex, storedType.getShape()); + for (auto [dim, idx] : llvm::enumerate(sourceIndices)) + sourceIndices[dim] = concatReturnUse->sliceOffsets[dim] + idx; + + SmallVector destinationIndices; + if (failed(mapIndicesThroughHelperChain( + sourceIndices, concatReturnUse->concatShape, concatReturnUse->helperChain, destinationIndices))) { + computeOp.emitOpError("has an unsupported concat-return helper chain during Spatial-to-PIM lowering"); + return ReturnPathLoweringResult::Failure; + } + + SmallVector extractOffsets; + SmallVector extractSizes; + SmallVector extractStrides; + extractOffsets.reserve(storedType.getRank()); + extractSizes.reserve(storedType.getRank()); + extractStrides.reserve(storedType.getRank()); + for (int64_t idx : expandFlatElementIndex(linearIndex, storedType.getShape())) { + extractOffsets.push_back(rewriter.getIndexAttr(idx)); + extractSizes.push_back(rewriter.getIndexAttr(1)); + extractStrides.push_back(rewriter.getIndexAttr(1)); + } + + auto scalarTensorType = + RankedTensorType::get(SmallVector(storedType.getRank(), 1), storedType.getElementType()); + auto elementSlice = tensor::ExtractSliceOp::create( + rewriter, loc, scalarTensorType, yieldValue, extractOffsets, extractSizes, extractStrides); + rewriter.setInsertionPointAfter(elementSlice); + + int64_t destinationFlatOffset = computeFlatElementIndex(destinationIndices, outputType.getShape()); + outputTensor = emitHostCopy(rewriter, + loc, + outputTensor, + elementSlice.getResult(), + static_cast(destinationFlatOffset * elementSize), + 0, + static_cast(elementSize)); + } + return ReturnPathLoweringResult::Handled; + } + + return ReturnPathLoweringResult::NotReturnPath; +} + +void replaceReturnWithOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewriter, ReturnPathState& state) { + auto markOwnedReturnChain = [&](Operation* op, auto&& markOwnedReturnChain) -> void { + if (!op) + return; + + bool isExclusivelyOwnedByReturnChain = op->use_empty(); + if (!isExclusivelyOwnedByReturnChain && op->hasOneUse()) { + Operation* onlyUser = *op->getUsers().begin(); + isExclusivelyOwnedByReturnChain = + isa(onlyUser) + || isReturnHelperChainOp(onlyUser); + } + if (!isExclusivelyOwnedByReturnChain) + return; + + if (isReturnHelperChainOp(op)) { + Value source = op->getOperand(0); + markOpToRemove(state, op); + markOwnedReturnChain(source.getDefiningOp(), markOwnedReturnChain); + return; + } + + if (auto computeOp = dyn_cast(op)) { + markOpToRemove(state, computeOp); + if (!computeOp.getInputs().empty()) + for (Value input : computeOp.getInputs()) + markOwnedReturnChain(input.getDefiningOp(), markOwnedReturnChain); + return; + } + + if (auto concatOp = dyn_cast(op)) { + markOpToRemove(state, concatOp); + for (Value operand : concatOp.getOperands()) + markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain); + return; + } + + if (auto concatOp = dyn_cast(op)) { + markOpToRemove(state, concatOp); + for (Value operand : concatOp.getInputs()) + markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain); + return; + } + + if (auto concatOp = dyn_cast(op)) { + markOpToRemove(state, concatOp); + for (Value operand : concatOp.getInputs()) + markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain); + } + }; + + SmallVector originalOperands(returnOp.getOperands().begin(), returnOp.getOperands().end()); + auto loc = returnOp.getLoc(); + for (auto it : llvm::enumerate(originalOperands)) { + size_t orderWithinReturn = it.index(); + Operation* returnOperand = it.value().getDefiningOp(); + rewriter.setInsertionPoint(returnOp); + Value outputTensor = state.outputTensors[orderWithinReturn](rewriter, loc); + rewriter.modifyOpInPlace(returnOp, [&] { returnOp.setOperand(orderWithinReturn, outputTensor); }); + markOwnedReturnChain(returnOperand, markOwnedReturnChain); + } +} + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp b/src/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp new file mode 100644 index 0000000..6a1c78c --- /dev/null +++ b/src/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp @@ -0,0 +1,37 @@ +#pragma once + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/PatternMatch.h" + +#include + +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" + +namespace onnx_mlir { + +using OutputTensorFactory = std::function; + +struct ReturnPathState { + llvm::SmallVectorImpl& outputTensors; + llvm::SmallVectorImpl& operationsToRemove; +}; + +enum class ReturnPathLoweringResult { + Handled, + NotReturnPath, + Failure +}; + +void addReturnOutputBuffers(mlir::func::ReturnOp returnOp, + mlir::IRRewriter& rewriter, + llvm::SmallVectorImpl& outputTensors); + +ReturnPathLoweringResult lowerComputeResultReturnPath(spatial::SpatCompute computeOp, + mlir::OpResult result, + mlir::Value yieldValue, + ReturnPathState& state, + mlir::IRRewriter& rewriter); + +void replaceReturnWithOutputBuffers(mlir::func::ReturnOp returnOp, mlir::IRRewriter& rewriter, ReturnPathState& state); + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp index 0ce4618..1475c2f 100644 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp @@ -1,18 +1,16 @@ #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/IRMapping.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" -#include "mlir/Support/LLVM.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/WalkPatternRewriteDriver.h" @@ -20,14 +18,19 @@ #include "llvm/Support/Casting.h" #include -#include -#include #include #include "Conversion/ONNXToSpatial/Common/Common.hpp" -#include "Patterns.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" +#include "src/Accelerators/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.hpp" +#include "src/Accelerators/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.hpp" +#include "src/Accelerators/PIM/Conversion/SpatialToPim/Cleanup.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp" +#include "src/Accelerators/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.hpp" +#include "src/Accelerators/PIM/Conversion/SpatialToPim/GlobalTensorMaterialization.hpp" +#include "src/Accelerators/PIM/Conversion/SpatialToPim/PhaseVerification.hpp" +#include "src/Accelerators/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp" +#include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Pass/PIMPasses.h" @@ -53,74 +56,21 @@ struct SpatialToPimPass : PassWrapper> void runOnOperation() final; private: - SmallVector> outputTensors; + SmallVector outputTensors; size_t coreId = 0; SmallVector operationsToRemove; - void addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter); - LogicalResult allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter); void markOpToRemove(Operation* op); - void runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter& rewriter); - void runOnComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, IRRewriter& rewriter); - void enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter); - - void replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter); }; } // namespace -static bool isChannelUseChainOp(Operation* op) { - return isa(op); -} - -static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter) { - for (Value operand : op->getOperands()) { - if (mapping.lookupOrNull(operand)) - continue; - - Operation* definingOp = operand.getDefiningOp(); - if (!definingOp) - continue; - - if (!isa(definingOp)) - continue; - - Operation* clonedOp = rewriter.clone(*definingOp, mapping); - for (auto [originalResult, newResult] : llvm::zip(definingOp->getResults(), clonedOp->getResults())) - mapping.map(originalResult, newResult); - rewriter.setInsertionPointAfter(clonedOp); - } -} - static int32_t translateSpatialCoreIdToPimCoreId(size_t spatialCoreId) { return static_cast(spatialCoreId); } -static int32_t getPimCoreIdForComputeOp(spatial::SpatCompute computeOp, size_t& fallbackCoreId) { - if (auto spatialCoreIdAttr = computeOp->getAttrOfType(onnx_mlir::kCoreIdAttrName)) - return static_cast(spatialCoreIdAttr.getInt()); - return static_cast(fallbackCoreId++); -} - -static SmallVector getPimCoreIdsForBatchOp(spatial::SpatComputeBatch computeBatchOp, size_t& fallbackCoreId) { - if (auto coreIdsAttr = computeBatchOp->getAttrOfType(onnx_mlir::kCoreIdsAttrName)) - return SmallVector(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end()); - - SmallVector coreIds; - coreIds.reserve(static_cast(computeBatchOp.getLaneCount())); - for (uint32_t lane = 0; lane < computeBatchOp.getLaneCount(); ++lane) - coreIds.push_back(static_cast(fallbackCoreId++)); - return coreIds; -} - static void lowerChannelSend(spatial::SpatChannelSendOp sendOp, IRRewriter& rewriter) { auto sizeAttr = getTensorSizeInBytesAttr(rewriter, sendOp.getInput()); auto targetCoreIdAttr = rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(sendOp.getTargetCoreId())); @@ -148,88 +98,38 @@ static void lowerChannelReceive(spatial::SpatChannelReceiveOp receiveOp, IRRewri rewriter.replaceOp(receiveOp, received); } -static void lowerChannelSendMany(spatial::SpatChannelSendManyOp sendManyOp, IRRewriter& rewriter) { - rewriter.setInsertionPoint(sendManyOp); - for (auto [input, targetCoreId] : llvm::zip(sendManyOp.getInputs(), sendManyOp.getTargetCoreIds())) { - PimSendOp::create(rewriter, - sendManyOp.getLoc(), - input, - getTensorSizeInBytesAttr(rewriter, input), - rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(targetCoreId))); - } - rewriter.eraseOp(sendManyOp); -} - -static void lowerChannelReceiveMany(spatial::SpatChannelReceiveManyOp receiveManyOp, IRRewriter& rewriter) { - rewriter.setInsertionPoint(receiveManyOp); - SmallVector replacements; - replacements.reserve(receiveManyOp.getNumResults()); - for (auto [output, sourceCoreId] : llvm::zip(receiveManyOp.getOutputs(), receiveManyOp.getSourceCoreIds())) { - auto outputType = cast(output.getType()); - Value outputBuffer = createEmptyTensorFromShaped(rewriter, receiveManyOp.getLoc(), outputType).getResult(); - replacements.push_back( - PimReceiveOp::create(rewriter, - receiveManyOp.getLoc(), - output.getType(), - outputBuffer, - getTensorSizeInBytesAttr(rewriter, output), - rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(sourceCoreId))) - .getOutput()); - } - rewriter.replaceOp(receiveManyOp, replacements); -} - -static void lowerChannelSendManyBatch(spatial::SpatChannelSendManyBatchOp sendManyBatchOp, - int32_t laneCount, - IRMapping& mapper, - IRRewriter& rewriter) { +static void lowerChannelSendTensor(spatial::SpatChannelSendTensorOp sendTensorOp, IRRewriter& rewriter) { SmallVector targetCoreIds; - targetCoreIds.reserve(sendManyBatchOp.getTargetCoreIds().size()); - for (int32_t targetCoreId : sendManyBatchOp.getTargetCoreIds()) + targetCoreIds.reserve(sendTensorOp.getTargetCoreIds().size()); + for (int32_t targetCoreId : sendTensorOp.getTargetCoreIds()) targetCoreIds.push_back(translateSpatialCoreIdToPimCoreId(targetCoreId)); - SmallVector mappedInputs; - mappedInputs.reserve(sendManyBatchOp.getInputs().size()); - for (Value input : sendManyBatchOp.getInputs()) - mappedInputs.push_back(mapper.lookup(input)); - for (auto [valueIndex, input] : llvm::enumerate(mappedInputs)) { - SmallVector laneTargetCoreIds; - laneTargetCoreIds.reserve(laneCount); - for (int32_t lane = 0; lane < laneCount; ++lane) - laneTargetCoreIds.push_back(targetCoreIds[valueIndex * laneCount + lane]); - pim::PimSendBatchOp::create(rewriter, - sendManyBatchOp.getLoc(), - input, - getTensorSizeInBytesAttr(rewriter, input), - rewriter.getDenseI32ArrayAttr(laneTargetCoreIds)); - } + + rewriter.setInsertionPoint(sendTensorOp); + Value input = sendTensorOp.getInput(); + if (auto concatOp = input.getDefiningOp()) + if (concatOp.getDim() == 0) + if (Value packedInput = createPackedExtractSliceTensor(concatOp.getInputs(), rewriter, sendTensorOp.getLoc())) + input = packedInput; + PimSendTensorOp::create(rewriter, sendTensorOp.getLoc(), input, rewriter.getDenseI32ArrayAttr(targetCoreIds)); + rewriter.eraseOp(sendTensorOp); } -static void lowerChannelReceiveManyBatch(spatial::SpatChannelReceiveManyBatchOp receiveManyBatchOp, - int32_t laneCount, - IRMapping& mapper, - IRRewriter& rewriter) { +static void lowerChannelReceiveTensor(spatial::SpatChannelReceiveTensorOp receiveTensorOp, IRRewriter& rewriter) { SmallVector sourceCoreIds; - sourceCoreIds.reserve(receiveManyBatchOp.getSourceCoreIds().size()); - for (int32_t sourceCoreId : receiveManyBatchOp.getSourceCoreIds()) + sourceCoreIds.reserve(receiveTensorOp.getSourceCoreIds().size()); + for (int32_t sourceCoreId : receiveTensorOp.getSourceCoreIds()) sourceCoreIds.push_back(translateSpatialCoreIdToPimCoreId(sourceCoreId)); - for (auto [valueIndex, output] : llvm::enumerate(receiveManyBatchOp.getOutputs())) { - auto outputType = cast(output.getType()); - Value outputBuffer = createEmptyTensorFromShaped(rewriter, receiveManyBatchOp.getLoc(), outputType).getResult(); - SmallVector laneSourceCoreIds; - laneSourceCoreIds.reserve(laneCount); - for (int32_t lane = 0; lane < laneCount; ++lane) - laneSourceCoreIds.push_back(sourceCoreIds[valueIndex * laneCount + lane]); - - auto received = pim::PimReceiveBatchOp::create(rewriter, - receiveManyBatchOp.getLoc(), - output.getType(), - outputBuffer, - getTensorSizeInBytesAttr(rewriter, output), - rewriter.getDenseI32ArrayAttr(laneSourceCoreIds)) - .getOutput(); - mapper.map(output, received); - } + rewriter.setInsertionPoint(receiveTensorOp); + auto outputType = cast(receiveTensorOp.getOutput().getType()); + Value outputBuffer = createEmptyTensorFromShaped(rewriter, receiveTensorOp.getLoc(), outputType).getResult(); + Value received = PimReceiveTensorOp::create(rewriter, + receiveTensorOp.getLoc(), + receiveTensorOp.getOutput().getType(), + outputBuffer, + rewriter.getDenseI32ArrayAttr(sourceCoreIds)) + .getOutput(); + rewriter.replaceOp(receiveTensorOp, received); } static void lowerExtractRows(spatial::SpatExtractRowsOp extractRowsOp, IRRewriter& rewriter) { @@ -252,77 +152,6 @@ static void lowerExtractRows(spatial::SpatExtractRowsOp extractRowsOp, IRRewrite rewriter.replaceOp(extractRowsOp, replacements); } -static void lowerConcat(spatial::SpatConcatOp concatOp, IRRewriter& rewriter) { - rewriter.setInsertionPoint(concatOp); - auto outputType = cast(concatOp.getOutput().getType()); - Value outputBuffer = createEmptyTensorFromShaped(rewriter, concatOp.getLoc(), outputType).getResult(); - Value concatenated = pim::PimConcatOp::create(rewriter, - concatOp.getLoc(), - concatOp.getOutput().getType(), - rewriter.getI64IntegerAttr(concatOp.getAxis()), - concatOp.getInputs(), - outputBuffer) - .getOutput(); - rewriter.replaceOp(concatOp, concatenated); -} - -static void lowerMapOps(func::FuncOp funcOp, IRRewriter& rewriter) { - SmallVector mapOps; - funcOp.walk([&](spatial::SpatMapOp mapOp) { - if (mapOp->getParentOfType() || mapOp->getParentOfType()) - mapOps.push_back(mapOp); - }); - - for (auto mapOp : mapOps) { - Block& body = mapOp.getBody().front(); - auto yieldOp = cast(body.getTerminator()); - SmallVector replacements; - replacements.reserve(mapOp.getInputs().size()); - rewriter.setInsertionPoint(mapOp); - - for (Value input : mapOp.getInputs()) { - IRMapping mapping; - mapping.map(body.getArgument(0), input); - - for (Operation& bodyOp : body.without_terminator()) { - Operation* cloned = rewriter.clone(bodyOp, mapping); - for (auto [originalResult, clonedResult] : llvm::zip(bodyOp.getResults(), cloned->getResults())) - mapping.map(originalResult, clonedResult); - rewriter.setInsertionPointAfter(cloned); - } - - replacements.push_back(mapping.lookupOrDefault(yieldOp.getOperand(0))); - } - - rewriter.replaceOp(mapOp, replacements); - } -} - -static RankedTensorType getPackedTensorType(RankedTensorType elementType, int64_t count) { - SmallVector packedShape(elementType.getShape().begin(), elementType.getShape().end()); - packedShape[0] *= count; - return RankedTensorType::get(packedShape, elementType.getElementType()); -} - -static bool getContiguousOpResults(ValueRange values, Operation*& owner, unsigned& startIndex) { - if (values.empty()) - return false; - - auto firstResult = dyn_cast(values.front()); - if (!firstResult) - return false; - - owner = firstResult.getOwner(); - startIndex = firstResult.getResultNumber(); - for (auto [index, value] : llvm::enumerate(values)) { - auto result = dyn_cast(value); - if (!result || result.getOwner() != owner || result.getResultNumber() != startIndex + index) - return false; - } - - return true; -} - static Value createPackedExtractRowsSlice( spatial::SpatExtractRowsOp extractRowsOp, unsigned startIndex, unsigned count, IRRewriter& rewriter, Location loc) { auto rowType = dyn_cast(extractRowsOp.getOutputs()[startIndex].getType()); @@ -355,150 +184,7 @@ static Value createPackedExtractRowsSlice( .getResult(); } -static Value createPackedTensorForValues(ValueRange values, IRRewriter& rewriter, Location loc) { - Operation* owner = nullptr; - unsigned startIndex = 0; - if (!getContiguousOpResults(values, owner, startIndex)) - return {}; - - if (auto extractRowsOp = dyn_cast(owner)) - return createPackedExtractRowsSlice(extractRowsOp, startIndex, static_cast(values.size()), rewriter, loc); - - return {}; -} - -static Value createPackedReceiveTensor(spatial::SpatChannelReceiveManyOp receiveManyOp, - unsigned startIndex, - unsigned count, - IRRewriter& rewriter, - Location loc) { - auto rowType = dyn_cast(receiveManyOp.getOutputs()[startIndex].getType()); - if (!rowType || !rowType.hasStaticShape() || rowType.getRank() == 0) - return {}; - - auto packedType = getPackedTensorType(rowType, static_cast(count)); - auto outputBuffer = tensor::EmptyOp::create(rewriter, loc, packedType.getShape(), packedType.getElementType()); - - SmallVector sourceCoreIds; - sourceCoreIds.reserve(count); - ArrayRef allSourceCoreIds = receiveManyOp.getSourceCoreIds(); - for (unsigned index = 0; index < count; ++index) - sourceCoreIds.push_back(translateSpatialCoreIdToPimCoreId(allSourceCoreIds[startIndex + index])); - - return pim::PimReceiveTensorOp::create( - rewriter, loc, packedType, outputBuffer.getResult(), rewriter.getDenseI32ArrayAttr(sourceCoreIds)) - .getOutput(); -} - -static Value createPackedMapTensor( - spatial::SpatMapOp mapOp, unsigned startIndex, unsigned count, IRRewriter& rewriter, Location loc) { - Value packedInput = createPackedTensorForValues(mapOp.getInputs().slice(startIndex, count), rewriter, loc); - if (!packedInput) - return {}; - - auto inputType = dyn_cast(mapOp.getInputs()[startIndex].getType()); - auto outputType = dyn_cast(mapOp.getOutputs()[startIndex].getType()); - if (!inputType || !outputType || !inputType.hasStaticShape() || !outputType.hasStaticShape() - || inputType.getRank() == 0 || outputType.getRank() == 0) - return {}; - - auto packedOutputType = getPackedTensorType(outputType, static_cast(count)); - auto packedInit = - tensor::EmptyOp::create(rewriter, loc, packedOutputType.getShape(), packedOutputType.getElementType()); - auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0); - auto upper = arith::ConstantIndexOp::create(rewriter, loc, count); - auto step = arith::ConstantIndexOp::create(rewriter, loc, 1); - auto loop = scf::ForOp::create(rewriter, loc, zero, upper, step, ValueRange {packedInit.getResult()}); - - { - OpBuilder::InsertionGuard guard(rewriter); - Block* loopBlock = loop.getBody(); - rewriter.setInsertionPointToStart(loopBlock); - Value iv = loopBlock->getArgument(0); - Value acc = loopBlock->getArgument(1); - - int64_t inputRowsPerValue = inputType.getDimSize(0); - Value inputRowOffset = iv; - if (inputRowsPerValue != 1) { - auto rowsPerValue = arith::ConstantIndexOp::create(rewriter, loc, inputRowsPerValue); - inputRowOffset = arith::MulIOp::create(rewriter, loc, iv, rowsPerValue); - } - - SmallVector extractOffsets; - SmallVector extractSizes; - SmallVector extractStrides; - extractOffsets.push_back(inputRowOffset); - extractSizes.push_back(rewriter.getIndexAttr(inputRowsPerValue)); - extractStrides.push_back(rewriter.getIndexAttr(1)); - for (int64_t dim = 1; dim < inputType.getRank(); ++dim) { - extractOffsets.push_back(rewriter.getIndexAttr(0)); - extractSizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(dim))); - extractStrides.push_back(rewriter.getIndexAttr(1)); - } - auto inputSlice = tensor::ExtractSliceOp::create( - rewriter, loc, inputType, packedInput, extractOffsets, extractSizes, extractStrides); - - IRMapping mapping; - Block& body = mapOp.getBody().front(); - mapping.map(body.getArgument(0), inputSlice.getResult()); - for (Operation& bodyOp : body.without_terminator()) { - Operation* cloned = rewriter.clone(bodyOp, mapping); - for (auto [originalResult, clonedResult] : llvm::zip(bodyOp.getResults(), cloned->getResults())) - mapping.map(originalResult, clonedResult); - rewriter.setInsertionPointAfter(cloned); - } - - auto yieldOp = cast(body.getTerminator()); - Value mappedOutput = mapping.lookupOrDefault(yieldOp.getOperand(0)); - - int64_t outputRowsPerValue = outputType.getDimSize(0); - Value outputRowOffset = iv; - if (outputRowsPerValue != 1) { - auto rowsPerValue = arith::ConstantIndexOp::create(rewriter, loc, outputRowsPerValue); - outputRowOffset = arith::MulIOp::create(rewriter, loc, iv, rowsPerValue); - } - - SmallVector insertOffsets; - SmallVector insertSizes; - SmallVector insertStrides; - insertOffsets.push_back(outputRowOffset); - insertSizes.push_back(rewriter.getIndexAttr(outputRowsPerValue)); - insertStrides.push_back(rewriter.getIndexAttr(1)); - for (int64_t dim = 1; dim < outputType.getRank(); ++dim) { - insertOffsets.push_back(rewriter.getIndexAttr(0)); - insertSizes.push_back(rewriter.getIndexAttr(outputType.getDimSize(dim))); - insertStrides.push_back(rewriter.getIndexAttr(1)); - } - - auto inserted = - tensor::InsertSliceOp::create(rewriter, loc, mappedOutput, acc, insertOffsets, insertSizes, insertStrides); - scf::YieldOp::create(rewriter, loc, inserted.getResult()); - } - - return loop.getResult(0); -} - static void compactSpatialTensorGroups(func::FuncOp funcOp, IRRewriter& rewriter) { - SmallVector sendManyOps; - funcOp.walk([&](spatial::SpatChannelSendManyOp sendManyOp) { sendManyOps.push_back(sendManyOp); }); - for (auto sendManyOp : sendManyOps) { - if (sendManyOp.getInputs().empty()) - continue; - - rewriter.setInsertionPoint(sendManyOp); - Value packedInput = createPackedTensorForValues(sendManyOp.getInputs(), rewriter, sendManyOp.getLoc()); - if (!packedInput) - continue; - - SmallVector targetCoreIds; - targetCoreIds.reserve(sendManyOp.getTargetCoreIds().size()); - for (int32_t targetCoreId : sendManyOp.getTargetCoreIds()) - targetCoreIds.push_back(translateSpatialCoreIdToPimCoreId(targetCoreId)); - pim::PimSendTensorOp::create( - rewriter, sendManyOp.getLoc(), packedInput, rewriter.getDenseI32ArrayAttr(targetCoreIds)); - rewriter.eraseOp(sendManyOp); - } - SmallVector concatOps; funcOp.walk([&](spatial::SpatConcatOp concatOp) { concatOps.push_back(concatOp); }); for (auto concatOp : concatOps) { @@ -511,6 +197,23 @@ static void compactSpatialTensorGroups(func::FuncOp funcOp, IRRewriter& rewriter for (unsigned index = 0; index < concatOp.getInputs().size();) { Value input = concatOp.getInputs()[index]; + + if (input.getDefiningOp()) { + unsigned endIndex = index + 1; + while (endIndex < concatOp.getInputs().size() + && concatOp.getInputs()[endIndex].getDefiningOp()) + ++endIndex; + + Value packedInput = createPackedExtractSliceTensor( + concatOp.getInputs().slice(index, endIndex - index), rewriter, concatOp.getLoc()); + if (packedInput) { + packedInputs.push_back(packedInput); + changed = true; + index = endIndex; + continue; + } + } + auto result = dyn_cast(input); if (!result) { packedInputs.push_back(input); @@ -531,11 +234,7 @@ static void compactSpatialTensorGroups(func::FuncOp funcOp, IRRewriter& rewriter unsigned count = endIndex - index; Value packedInput; - if (auto mapOp = dyn_cast(owner)) - packedInput = createPackedMapTensor(mapOp, startIndex, count, rewriter, concatOp.getLoc()); - else if (auto receiveManyOp = dyn_cast(owner)) - packedInput = createPackedReceiveTensor(receiveManyOp, startIndex, count, rewriter, concatOp.getLoc()); - else if (auto extractRowsOp = dyn_cast(owner)) + if (auto extractRowsOp = dyn_cast(owner)) packedInput = createPackedExtractRowsSlice(extractRowsOp, startIndex, count, rewriter, concatOp.getLoc()); if (packedInput) { @@ -564,6 +263,10 @@ static void compactSpatialTensorGroups(func::FuncOp funcOp, IRRewriter& rewriter rewriter.replaceOp(concatOp, newConcat.getOutput()); } + RewritePatternSet tensorPackingPatterns(funcOp.getContext()); + populateTensorPackingPatterns(tensorPackingPatterns); + (void) applyPatternsGreedily(funcOp, std::move(tensorPackingPatterns)); + auto eraseUnusedOps = [&](auto tag) { using OpTy = decltype(tag); SmallVector ops; @@ -572,355 +275,11 @@ static void compactSpatialTensorGroups(func::FuncOp funcOp, IRRewriter& rewriter if (op->use_empty()) rewriter.eraseOp(op); }; - eraseUnusedOps(spatial::SpatMapOp {}); - eraseUnusedOps(spatial::SpatChannelReceiveManyOp {}); + eraseUnusedOps(tensor::ConcatOp {}); + eraseUnusedOps(tensor::ExtractSliceOp {}); eraseUnusedOps(spatial::SpatExtractRowsOp {}); } -static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp, - SmallVectorImpl& helperChain, - bool requireReturnUse = true) { - if (computeOp.getInputs().size() != 1 || computeOp.getNumResults() != 1) - return failure(); - if (requireReturnUse - && (!computeOp.getResult(0).hasOneUse() || !isa(*computeOp.getResult(0).getUsers().begin()))) - return failure(); - - Block& block = computeOp.getBody().front(); - if (block.getNumArguments() != 1) - return failure(); - - auto yieldOp = dyn_cast(block.getTerminator()); - if (!yieldOp || yieldOp.getNumOperands() != 1) - return failure(); - - SmallVector reverseChain; - Value currentValue = yieldOp.getOperands().front(); - Value blockArg = block.getArgument(0); - - while (currentValue != blockArg) { - Operation* definingOp = currentValue.getDefiningOp(); - if (!definingOp || definingOp->getBlock() != &block || !isChannelUseChainOp(definingOp)) - return failure(); - reverseChain.push_back(definingOp); - currentValue = definingOp->getOperand(0); - } - - SmallPtrSet chainSet(reverseChain.begin(), reverseChain.end()); - for (Operation& op : llvm::make_early_inc_range(block.without_terminator())) - if (!chainSet.contains(&op) && !isa(op)) - return failure(); - - helperChain.assign(reverseChain.rbegin(), reverseChain.rend()); - return success(); -} - -static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute computeOp, IRRewriter& rewriter) { - if (!computeOp.getInputs().empty() || computeOp.getNumResults() != 1) - return false; - if (!llvm::all_of(computeOp.getResult(0).getUsers(), [](Operation* user) { - return isa(user); - })) - return false; - - Block& block = computeOp.getBody().front(); - if (block.getNumArguments() != 0) - return false; - - auto yieldOp = dyn_cast(block.getTerminator()); - if (!yieldOp || yieldOp.getNumOperands() != 1) - return false; - - rewriter.setInsertionPoint(computeOp); - IRMapping mapping; - for (Operation& op : block.without_terminator()) { - cloneMappedHelperOperands(&op, mapping, rewriter); - Operation* clonedOp = rewriter.clone(op, mapping); - for (auto [originalResult, newResult] : llvm::zip(op.getResults(), clonedOp->getResults())) - mapping.map(originalResult, newResult); - rewriter.setInsertionPointAfter(clonedOp); - } - - Value replacement = mapping.lookupOrDefault(yieldOp.getOperand(0)); - computeOp.getResult(0).replaceAllUsesWith(replacement); - return true; -} - -struct ReturnUseInfo { - size_t returnIndex; - SmallVector helperChain; -}; - -struct ConcatReturnUseInfo { - size_t returnIndex; - SmallVector sliceOffsets; - SmallVector concatShape; - SmallVector concatChain; - SmallVector helperChain; -}; - -static int64_t computeFlatElementIndex(ArrayRef indices, ArrayRef shape) { - int64_t flatIndex = 0; - for (size_t i = 0; i < shape.size(); ++i) { - flatIndex *= shape[i]; - flatIndex += indices[i]; - } - return flatIndex; -} - -static SmallVector expandFlatElementIndex(int64_t flatIndex, ArrayRef shape) { - SmallVector indices(shape.size(), 0); - for (int64_t dim = static_cast(shape.size()) - 1; dim >= 0; --dim) { - indices[dim] = flatIndex % shape[dim]; - flatIndex /= shape[dim]; - } - return indices; -} - -static std::optional analyzeReturnUse(Value value) { - auto uses = value.getUses(); - if (rangeLength(uses) != 1) - return std::nullopt; - - SmallVector helperChain; - Value currentValue = value; - Operation* currentUser = uses.begin()->getOwner(); - - while (isChannelUseChainOp(currentUser)) { - helperChain.push_back(currentUser); - auto currentUses = currentUser->getResult(0).getUses(); - if (rangeLength(currentUses) != 1) - return std::nullopt; - currentValue = currentUser->getResult(0); - currentUser = currentUses.begin()->getOwner(); - } - - if (!isa(currentUser)) - return std::nullopt; - - return ReturnUseInfo { - currentValue.getUses().begin()->getOperandNumber(), - std::move(helperChain), - }; -} - -static std::optional analyzeConcatReturnUse(Value value) { - auto getConcatResult = [](Operation* op) -> Value { - if (auto tensorConcat = dyn_cast(op)) - return tensorConcat.getResult(); - if (auto spatialConcat = dyn_cast(op)) - return spatialConcat.getOutput(); - if (auto pimConcat = dyn_cast(op)) - return pimConcat.getOutput(); - return {}; - }; - auto getConcatAxis = [](Operation* op) -> std::optional { - if (auto tensorConcat = dyn_cast(op)) - return tensorConcat.getDim(); - if (auto spatialConcat = dyn_cast(op)) - return spatialConcat.getAxis(); - if (auto pimConcat = dyn_cast(op)) - return pimConcat.getAxis(); - return std::nullopt; - }; - auto getConcatOperands = [](Operation* op) -> OperandRange { - if (auto tensorConcat = dyn_cast(op)) - return tensorConcat.getOperands(); - if (auto spatialConcat = dyn_cast(op)) - return spatialConcat.getInputs(); - return cast(op).getInputs(); - }; - - auto uses = value.getUses(); - if (rangeLength(uses) != 1 - || !isa(uses.begin()->getOwner())) - return std::nullopt; - - auto valueType = dyn_cast(value.getType()); - if (!valueType || !valueType.hasStaticShape()) - return std::nullopt; - - SmallVector sliceOffsets(valueType.getRank(), 0); - SmallVector concatShape(valueType.getShape().begin(), valueType.getShape().end()); - SmallVector concatChain; - Value currentValue = value; - Operation* currentUser = uses.begin()->getOwner(); - - while (isa(currentUser)) { - concatChain.push_back(currentUser); - size_t operandIndex = currentValue.getUses().begin()->getOperandNumber(); - int64_t axis = *getConcatAxis(currentUser); - for (Value operand : getConcatOperands(currentUser).take_front(operandIndex)) - sliceOffsets[axis] += cast(operand.getType()).getShape()[axis]; - - Value concatResult = getConcatResult(currentUser); - auto concatType = dyn_cast(concatResult.getType()); - if (!concatType || !concatType.hasStaticShape()) - return std::nullopt; - concatShape.assign(concatType.getShape().begin(), concatType.getShape().end()); - - currentValue = concatResult; - auto currentUses = currentValue.getUses(); - if (rangeLength(currentUses) != 1) - return std::nullopt; - currentUser = currentUses.begin()->getOwner(); - } - - SmallVector helperChain; - if (auto helperCompute = dyn_cast(currentUser)) { - if (helperCompute.getInputs().size() != 1 || helperCompute.getInputs().front() != currentValue) - return std::nullopt; - - if (failed(collectHelperComputeChain(helperCompute, helperChain))) - return std::nullopt; - - currentValue = helperCompute.getResult(0); - auto currentUses = currentValue.getUses(); - if (rangeLength(currentUses) != 1) - return std::nullopt; - currentUser = currentUses.begin()->getOwner(); - } - - while (isChannelUseChainOp(currentUser)) { - helperChain.push_back(currentUser); - auto currentUses = currentUser->getResult(0).getUses(); - if (rangeLength(currentUses) != 1) - return std::nullopt; - currentValue = currentUser->getResult(0); - currentUser = currentUses.begin()->getOwner(); - } - - if (!isa(currentUser)) - return std::nullopt; - - return ConcatReturnUseInfo { - currentValue.getUses().begin()->getOperandNumber(), - std::move(sliceOffsets), - std::move(concatShape), - std::move(concatChain), - std::move(helperChain), - }; -} - -static LogicalResult mapIndicesThroughHelperChain(ArrayRef sourceIndices, - ArrayRef sourceShape, - ArrayRef helperChain, - SmallVectorImpl& mappedIndices) { - SmallVector currentIndices(sourceIndices.begin(), sourceIndices.end()); - SmallVector currentShape(sourceShape.begin(), sourceShape.end()); - - auto reshapeToResultShape = [&](Operation* op) -> LogicalResult { - auto resultType = dyn_cast(op->getResult(0).getType()); - if (!resultType || !resultType.hasStaticShape()) - return failure(); - int64_t flatIndex = computeFlatElementIndex(currentIndices, currentShape); - currentShape.assign(resultType.getShape().begin(), resultType.getShape().end()); - currentIndices = expandFlatElementIndex(flatIndex, currentShape); - return success(); - }; - - for (Operation* op : helperChain) { - if (auto extractSliceOp = dyn_cast(op)) { - auto hasStaticValues = [](ArrayRef values) { - return llvm::all_of(values, [](int64_t value) { return !ShapedType::isDynamic(value); }); - }; - if (!hasStaticValues(extractSliceOp.getStaticOffsets()) || !hasStaticValues(extractSliceOp.getStaticSizes()) - || !hasStaticValues(extractSliceOp.getStaticStrides())) - return failure(); - - SmallVector nextIndices; - nextIndices.reserve(currentIndices.size()); - for (auto [index, offset, size, stride] : llvm::zip_equal(currentIndices, - extractSliceOp.getStaticOffsets(), - extractSliceOp.getStaticSizes(), - extractSliceOp.getStaticStrides())) { - if (stride != 1 || index < offset || index >= offset + size) - return failure(); - nextIndices.push_back(index - offset); - } - - auto resultType = dyn_cast(extractSliceOp.getResult().getType()); - if (!resultType || !resultType.hasStaticShape()) - return failure(); - currentIndices = std::move(nextIndices); - currentShape.assign(resultType.getShape().begin(), resultType.getShape().end()); - continue; - } - - if (auto transposeOp = dyn_cast(op)) { - SmallVector nextIndices(currentIndices.size()); - SmallVector nextShape(currentShape.size()); - for (auto [destIndex, attr] : llvm::enumerate(transposeOp.getPermAttr().getAsRange())) { - int64_t sourceIndex = attr.getInt(); - nextIndices[destIndex] = currentIndices[sourceIndex]; - nextShape[destIndex] = currentShape[sourceIndex]; - } - currentIndices = std::move(nextIndices); - currentShape = std::move(nextShape); - continue; - } - - if (auto transposeOp = dyn_cast(op)) { - SmallVector nextIndices(currentIndices.size()); - SmallVector nextShape(currentShape.size()); - for (auto [destIndex, attr] : llvm::enumerate(transposeOp.getPermutation().getAsRange())) { - int64_t sourceIndex = attr.getInt(); - nextIndices[destIndex] = currentIndices[sourceIndex]; - nextShape[destIndex] = currentShape[sourceIndex]; - } - currentIndices = std::move(nextIndices); - currentShape = std::move(nextShape); - continue; - } - - if (isa(op)) { - if (failed(reshapeToResultShape(op))) - return failure(); - continue; - } - - return failure(); - } - - mappedIndices.assign(currentIndices.begin(), currentIndices.end()); - return success(); -} - -static void -cloneHelperChain(Value sourceValue, ArrayRef helperChain, IRRewriter& rewriter, Value& clonedValue) { - IRMapping mapping; - mapping.map(sourceValue, sourceValue); - clonedValue = sourceValue; - - rewriter.setInsertionPointAfterValue(sourceValue); - for (Operation* op : helperChain) { - cloneMappedHelperOperands(op, mapping, rewriter); - Operation* clonedOp = rewriter.clone(*op, mapping); - for (auto [originalResult, newResult] : llvm::zip(op->getResults(), clonedOp->getResults())) - mapping.map(originalResult, newResult); - clonedValue = clonedOp->getResult(0); - rewriter.setInsertionPointAfter(clonedOp); - } -} - -static Value emitHostCopy(IRRewriter& rewriter, - Location loc, - Value outputTensor, - Value sourceValue, - int32_t hostTargetOffset, - int32_t deviceSourceOffset, - int32_t sizeInBytes) { - return PimMemCopyDevToHostOp::create(rewriter, - loc, - outputTensor.getType(), - outputTensor, - sourceValue, - rewriter.getI32IntegerAttr(hostTargetOffset), - rewriter.getI32IntegerAttr(deviceSourceOffset), - rewriter.getI32IntegerAttr(sizeInBytes)) - .getOutput(); -} - void SpatialToPimPass::runOnOperation() { coreId = 1; ModuleOp moduleOp = getOperation(); @@ -939,9 +298,19 @@ void SpatialToPimPass::runOnOperation() { target.addLegalDialect(); + target.addLegalOp(); { RewritePatternSet patterns(ctx); @@ -955,30 +324,36 @@ void SpatialToPimPass::runOnOperation() { { RewritePatternSet patterns(ctx); - populateGlobalTensorToMemrefPatterns(patterns); + populateGlobalTensorMaterializationPatterns(patterns); walkAndApplyPatterns(moduleOp, std::move(patterns)); } auto returnOp = cast(funcOp.front().getTerminator()); - addResultBuffer(returnOp, rewriter); + addReturnOutputBuffers(returnOp, rewriter, outputTensors); if (failed(allocateAndInitializeCoreLocalVariables(funcOp, rewriter))) { signalPassFailure(); return; } + CoreLoweringState coreLoweringState {coreId, outputTensors, operationsToRemove}; for (auto computeOp : funcOp.getOps()) { markOpToRemove(computeOp); - runOnComputeOp(computeOp, rewriter); + if (failed(lowerComputeOp(computeOp, coreLoweringState, rewriter))) { + signalPassFailure(); + return; + } } for (auto computeBatchOp : funcOp.getOps()) { markOpToRemove(computeBatchOp); - runOnComputeBatchOp(computeBatchOp, rewriter); + if (failed(lowerComputeBatchOp(computeBatchOp, coreLoweringState, rewriter))) { + signalPassFailure(); + return; + } } compactSpatialTensorGroups(funcOp, rewriter); - lowerMapOps(funcOp, rewriter); SmallVector receiveOps; for (auto op : funcOp.getOps()) @@ -997,11 +372,11 @@ void SpatialToPimPass::runOnOperation() { lowerChannelReceive(receiveOp, rewriter); } - SmallVector receiveManyOps; - for (auto op : funcOp.getOps()) - receiveManyOps.push_back(op); - for (auto receiveManyOp : receiveManyOps) - lowerChannelReceiveMany(receiveManyOp, rewriter); + SmallVector receiveTensorOps; + for (auto op : funcOp.getOps()) + receiveTensorOps.push_back(op); + for (auto receiveTensorOp : receiveTensorOps) + lowerChannelReceiveTensor(receiveTensorOp, rewriter); SmallVector sendOps; for (auto op : funcOp.getOps()) @@ -1009,11 +384,11 @@ void SpatialToPimPass::runOnOperation() { for (auto sendOp : sendOps) lowerChannelSend(sendOp, rewriter); - SmallVector sendManyOps; - for (auto op : funcOp.getOps()) - sendManyOps.push_back(op); - for (auto sendManyOp : sendManyOps) - lowerChannelSendMany(sendManyOp, rewriter); + SmallVector sendTensorOps; + for (auto op : funcOp.getOps()) + sendTensorOps.push_back(op); + for (auto sendTensorOp : sendTensorOps) + lowerChannelSendTensor(sendTensorOp, rewriter); SmallVector extractRowsOps; for (auto op : funcOp.getOps()) @@ -1029,7 +404,7 @@ void SpatialToPimPass::runOnOperation() { SmallVector coreOps; funcOp.walk([&](pim::PimCoreOp coreOp) { coreOps.push_back(coreOp); }); for (auto coreOp : coreOps) { - if (failed(applyPartialConversion(coreOp.getOperation(), target, frozenCoreBodyPatterns))) { + if (failed(applyFullConversion(coreOp.getOperation(), target, frozenCoreBodyPatterns))) { signalPassFailure(); return; } @@ -1038,89 +413,52 @@ void SpatialToPimPass::runOnOperation() { SmallVector coreBatchOps; funcOp.walk([&](pim::PimCoreBatchOp coreBatchOp) { coreBatchOps.push_back(coreBatchOp); }); for (auto coreBatchOp : coreBatchOps) { - if (failed(applyPartialConversion(coreBatchOp.getOperation(), target, frozenCoreBodyPatterns))) { + if (failed(applyFullConversion(coreBatchOp.getOperation(), target, frozenCoreBodyPatterns))) { signalPassFailure(); return; } } } - RewritePatternSet channelPatterns(ctx); - populateWithGenerated(channelPatterns); - if (failed(applyPatternsGreedily(funcOp, std::move(channelPatterns)))) { + enlargeVMMOutTensorsToCrossbarSize(funcOp, rewriter); + ReturnPathState returnPathState {outputTensors, operationsToRemove}; + replaceReturnWithOutputBuffers(returnOp, rewriter, returnPathState); + + SmallVector pendingRemovals(operationsToRemove.begin(), operationsToRemove.end()); + if (failed(erasePendingOps(pendingRemovals, rewriter))) { signalPassFailure(); return; } - enlargeVMMOutTensorsToCrossbarSize(funcOp, rewriter); - replaceReturnOpOperands(returnOp, rewriter); - - SmallVector pendingRemovals(operationsToRemove.begin(), operationsToRemove.end()); - while (!pendingRemovals.empty()) { - bool erasedAnyOp = false; - for (auto it = pendingRemovals.begin(); it != pendingRemovals.end();) { - Operation* opToRemove = *it; - if (!opToRemove->use_empty()) { - ++it; - continue; - } - - rewriter.eraseOp(opToRemove); - it = pendingRemovals.erase(it); - erasedAnyOp = true; - } - - if (erasedAnyOp) - continue; - - for (auto opToRemove : pendingRemovals) { - opToRemove->dump(); - for (auto user : opToRemove->getUsers()) - user->dump(); - } - assert(false && "tracked op removal reached a cycle or missed dependency"); - } - compactSpatialTensorGroups(funcOp, rewriter); - SmallVector remainingConcatOps; - funcOp.walk([&](spatial::SpatConcatOp op) { remainingConcatOps.push_back(op); }); - for (auto concatOp : remainingConcatOps) - lowerConcat(concatOp, rewriter); + { + ConversionTarget communicationTarget(*ctx); + communicationTarget.addLegalDialect(); + communicationTarget.addLegalOp(); + communicationTarget.addIllegalOp(); - SmallVector remainingReceiveOps; - funcOp.walk([&](spatial::SpatChannelReceiveOp op) { remainingReceiveOps.push_back(op); }); - for (auto receiveOp : remainingReceiveOps) - lowerChannelReceive(receiveOp, rewriter); + RewritePatternSet communicationPatterns(ctx); + populateChannelLoweringPatterns(communicationPatterns); + if (failed(applyFullConversion(funcOp, communicationTarget, std::move(communicationPatterns)))) { + signalPassFailure(); + return; + } + } - SmallVector remainingReceiveManyOps; - funcOp.walk([&](spatial::SpatChannelReceiveManyOp op) { remainingReceiveManyOps.push_back(op); }); - for (auto receiveManyOp : remainingReceiveManyOps) - lowerChannelReceiveMany(receiveManyOp, rewriter); - - SmallVector remainingSendOps; - funcOp.walk([&](spatial::SpatChannelSendOp op) { remainingSendOps.push_back(op); }); - for (auto sendOp : remainingSendOps) - lowerChannelSend(sendOp, rewriter); - - SmallVector remainingSendManyOps; - funcOp.walk([&](spatial::SpatChannelSendManyOp op) { remainingSendManyOps.push_back(op); }); - for (auto sendManyOp : remainingSendManyOps) - lowerChannelSendMany(sendManyOp, rewriter); - - SmallVector remainingExtractRowsOps; - funcOp.walk([&](spatial::SpatExtractRowsOp op) { remainingExtractRowsOps.push_back(op); }); - for (auto extractRowsOp : remainingExtractRowsOps) - lowerExtractRows(extractRowsOp, rewriter); - - // Dump to file for debug - bool hasSpatialOps = false; - moduleOp.walk([&](Operation* op) { - if (op->getDialect()->getNamespace() == "spat") - hasSpatialOps = true; - }); - if (hasSpatialOps) { - moduleOp.emitError("SpatialToPim left illegal Spatial operations in the module"); + if (failed(verifySpatialToPimBoundary(moduleOp))) { signalPassFailure(); return; } @@ -1129,350 +467,6 @@ void SpatialToPimPass::runOnOperation() { dumpModule(moduleOp, "pim0"); } -void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter& rewriter) { - Location loc = computeOp->getLoc(); - - if (inlineInputlessHelperComputeForWeightLikeUsers(computeOp, rewriter)) - return; - - SmallVector helperChain; - if (succeeded(collectHelperComputeChain(computeOp, helperChain))) - return; - - auto& block = computeOp.getRegion().front(); - auto yieldOp = cast(block.getTerminator()); - - for (auto [argIndex, blockArg] : llvm::enumerate(block.getArguments())) { - auto receiveOp = dyn_cast_or_null(computeOp.getInputs()[argIndex].getDefiningOp()); - if (!receiveOp || blockArg.use_empty()) - continue; - - rewriter.setInsertionPoint(getEarliestUserWithinBlock(blockArg)); - auto outputType = cast(blockArg.getType()); - auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveOp.getLoc(), outputType); - auto sizeAttr = getTensorSizeInBytesAttr(rewriter, blockArg); - auto sourceCoreIdAttr = rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(receiveOp.getSourceCoreId())); - Value received = PimReceiveOp::create( - rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr) - .getOutput(); - blockArg.replaceAllUsesWith(received); - markOpToRemove(receiveOp); - } - - if (computeOp.getNumResults() != yieldOp.getNumOperands()) - llvm_unreachable("ComputeOp must have same number of results as yieldOp operands"); - - for (auto [result, yieldValue] : llvm::zip(computeOp.getResults(), yieldOp.getOperands())) { - if (result.use_empty()) - continue; - - auto yieldType = cast(yieldValue.getType()); - - if (auto returnUse = analyzeReturnUse(result)) { - Value storedValue = yieldValue; - cloneHelperChain(yieldValue, returnUse->helperChain, rewriter, storedValue); - for (Operation* op : returnUse->helperChain) - markOpToRemove(op); - - auto storedType = cast(storedValue.getType()); - size_t elementSize = storedType.getElementTypeBitWidth() / 8; - if (auto storedOp = storedValue.getDefiningOp()) - rewriter.setInsertionPointAfter(storedOp); - Value outputTensor = outputTensors[returnUse->returnIndex](rewriter, loc); - emitHostCopy(rewriter, - loc, - outputTensor, - storedValue, - 0, - 0, - static_cast(storedType.getNumElements() * elementSize)); - continue; - } - - auto resultUses = result.getUses(); - if (rangeLength(resultUses) == 1) { - OpOperand& resultUse = *resultUses.begin(); - Operation* resultUser = resultUse.getOwner(); - - if (isa(resultUser)) { - size_t resultIndexInReturn = resultUse.getOperandNumber(); - size_t elementSize = yieldType.getElementType().getIntOrFloatBitWidth() / 8; - rewriter.setInsertionPointAfterValue(yieldValue); - Value outputTensor = outputTensors[resultIndexInReturn](rewriter, loc); - emitHostCopy(rewriter, - loc, - outputTensor, - yieldValue, - 0, - 0, - static_cast(yieldType.getNumElements() * elementSize)); - continue; - } - - if (isa(resultUser)) - continue; - } - - if (auto concatReturnUse = analyzeConcatReturnUse(result)) { - size_t elementSize = yieldType.getElementTypeBitWidth() / 8; - for (Operation* concatOp : concatReturnUse->concatChain) - markOpToRemove(concatOp); - - if (concatReturnUse->helperChain.empty()) { - rewriter.setInsertionPointAfterValue(yieldValue); - Value outputTensor = outputTensors[concatReturnUse->returnIndex](rewriter, loc); - auto outputType = cast(outputTensor.getType()); - int64_t flatOffset = computeFlatElementIndex(concatReturnUse->sliceOffsets, outputType.getShape()); - emitHostCopy(rewriter, - loc, - outputTensor, - yieldValue, - static_cast(flatOffset * elementSize), - 0, - static_cast(yieldType.getNumElements() * elementSize)); - continue; - } - - auto storedType = dyn_cast(yieldValue.getType()); - if (!storedType) { - computeOp.emitOpError( - "has an unsupported non-ranked concat-return helper yield during Spatial-to-PIM lowering"); - signalPassFailure(); - return; - } - rewriter.setInsertionPointAfterValue(yieldValue); - Value outputTensor = outputTensors[concatReturnUse->returnIndex](rewriter, loc); - auto outputType = cast(outputTensor.getType()); - for (int64_t linearIndex = 0; linearIndex < storedType.getNumElements(); ++linearIndex) { - SmallVector sourceIndices = expandFlatElementIndex(linearIndex, storedType.getShape()); - for (auto [dim, idx] : llvm::enumerate(sourceIndices)) - sourceIndices[dim] = concatReturnUse->sliceOffsets[dim] + idx; - - SmallVector destinationIndices; - if (failed(mapIndicesThroughHelperChain( - sourceIndices, concatReturnUse->concatShape, concatReturnUse->helperChain, destinationIndices))) { - computeOp.emitOpError("has an unsupported concat-return helper chain during Spatial-to-PIM lowering"); - signalPassFailure(); - return; - } - - SmallVector extractOffsets; - SmallVector extractSizes; - SmallVector extractStrides; - extractOffsets.reserve(storedType.getRank()); - extractSizes.reserve(storedType.getRank()); - extractStrides.reserve(storedType.getRank()); - for (int64_t idx : expandFlatElementIndex(linearIndex, storedType.getShape())) { - extractOffsets.push_back(rewriter.getIndexAttr(idx)); - extractSizes.push_back(rewriter.getIndexAttr(1)); - extractStrides.push_back(rewriter.getIndexAttr(1)); - } - - auto scalarTensorType = - RankedTensorType::get(SmallVector(storedType.getRank(), 1), storedType.getElementType()); - auto elementSlice = tensor::ExtractSliceOp::create( - rewriter, loc, scalarTensorType, yieldValue, extractOffsets, extractSizes, extractStrides); - rewriter.setInsertionPointAfter(elementSlice); - - int64_t destinationFlatOffset = computeFlatElementIndex(destinationIndices, outputType.getShape()); - outputTensor = emitHostCopy(rewriter, - loc, - outputTensor, - elementSlice.getResult(), - static_cast(destinationFlatOffset * elementSize), - 0, - static_cast(elementSize)); - } - continue; - } - - computeOp.emitOpError("has an unsupported remaining result use during Spatial-to-PIM lowering"); - signalPassFailure(); - return; - } - - // Use `HaltOp` instead of `YieldOp` - rewriter.setInsertionPoint(yieldOp); - rewriter.replaceOpWithNewOp(yieldOp); - - // Replace `spat.compute` with `pim.core` - SmallVector computeWeights; - if (!computeOp.getWeights().empty()) - computeWeights.append(computeOp.getWeights().begin(), computeOp.getWeights().end()); - rewriter.setInsertionPointAfter(computeOp); - auto coreOp = PimCoreOp::create( - rewriter, loc, ValueRange(computeWeights), rewriter.getI32IntegerAttr(getPimCoreIdForComputeOp(computeOp, coreId))); - auto& coreOpBlocks = coreOp.getBody().getBlocks(); - for (auto [argIndex, blockArg] : llvm::enumerate(block.getArguments())) - if (!blockArg.use_empty()) - blockArg.replaceAllUsesWith(computeOp.getInputs()[argIndex]); - block.eraseArguments(0, block.getNumArguments()); - coreOpBlocks.splice(coreOpBlocks.begin(), computeOp.getBody().getBlocks()); - Block* tempComputeBlock = new Block(); - computeOp.getBody().push_back(tempComputeBlock); - rewriter.setInsertionPointToEnd(tempComputeBlock); - PimHaltOp::create(rewriter, computeOp.getLoc()); -} - -void SpatialToPimPass::runOnComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, IRRewriter& rewriter) { - if (std::getenv("PIM_BATCH_LOWER_DEBUG")) - llvm::errs() << "lowering compute_batch lanes=" << computeBatchOp.getLaneCount() << "\n"; - - if (computeBatchOp.getNumResults() != 0) { - computeBatchOp.emitOpError( - "batched Spatial-to-PIM lowering currently requires channelized compute_batch with no results"); - signalPassFailure(); - return; - } - - Location loc = computeBatchOp.getLoc(); - Block& oldBlock = computeBatchOp.getBody().front(); - auto oldYield = cast(oldBlock.getTerminator()); - if (oldYield.getNumOperands() != 0) { - computeBatchOp.emitOpError("batched Spatial-to-PIM lowering currently requires empty spat.yield"); - signalPassFailure(); - return; - } - - SmallVector coreIds = getPimCoreIdsForBatchOp(computeBatchOp, coreId); - SmallVector batchWeights(computeBatchOp.getWeights().begin(), computeBatchOp.getWeights().end()); - SmallVector batchInputs; - if (!computeBatchOp.getInputs().empty()) - batchInputs.append(computeBatchOp.getInputs().begin(), computeBatchOp.getInputs().end()); - - rewriter.setInsertionPointAfter(computeBatchOp); - auto coreBatchOp = pim::PimCoreBatchOp::create(rewriter, - loc, - rewriter.getI32IntegerAttr(computeBatchOp.getLaneCount()), - ValueRange(batchWeights), - ValueRange(batchInputs)); - coreBatchOp.getProperties().setOperandSegmentSizes( - {static_cast(batchWeights.size()), static_cast(batchInputs.size())}); - coreBatchOp->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds)); - - SmallVector blockArgTypes; - SmallVector blockArgLocs; - for (BlockArgument arg : oldBlock.getArguments()) { - blockArgTypes.push_back(arg.getType()); - blockArgLocs.push_back(arg.getLoc()); - } - Block* newBlock = - rewriter.createBlock(&coreBatchOp.getBody(), coreBatchOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs); - - IRMapping mapper; - rewriter.setInsertionPointToStart(newBlock); - for (auto [oldArg, newArg] : llvm::zip(oldBlock.getArguments(), newBlock->getArguments())) { - auto newArgType = cast(newArg.getType()); - auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, newArgType); - auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter, - loc, - outputBuffer.getType(), - outputBuffer, - newArg, - rewriter.getI32IntegerAttr(0), - rewriter.getI32IntegerAttr(0), - getTensorSizeInBytesAttr(rewriter, newArg)) - .getOutput(); - mapper.map(oldArg, copied); - } - - auto materializeCapturedTensor = [&](Value capturedTensor) -> Value { - if (auto mapped = mapper.lookupOrNull(capturedTensor)) - return mapped; - - auto capturedType = cast(capturedTensor.getType()); - auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, capturedType); - auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter, - loc, - outputBuffer.getType(), - outputBuffer, - capturedTensor, - rewriter.getI32IntegerAttr(0), - rewriter.getI32IntegerAttr(0), - getTensorSizeInBytesAttr(rewriter, capturedTensor)) - .getOutput(); - mapper.map(capturedTensor, copied); - return copied; - }; - - rewriter.setInsertionPointToEnd(newBlock); - for (Operation& op : oldBlock) { - if (isa(op)) - continue; - - if (auto sendBatchOp = dyn_cast(op)) { - pim::PimSendBatchOp::create(rewriter, - loc, - mapper.lookup(sendBatchOp.getInput()), - getTensorSizeInBytesAttr(rewriter, mapper.lookup(sendBatchOp.getInput())), - sendBatchOp.getTargetCoreIdsAttr()); - continue; - } - - if (auto sendManyBatchOp = dyn_cast(op)) { - lowerChannelSendManyBatch(sendManyBatchOp, computeBatchOp.getLaneCount(), mapper, rewriter); - continue; - } - - if (auto receiveBatchOp = dyn_cast(op)) { - auto outputType = cast(receiveBatchOp.getOutput().getType()); - auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, outputType); - auto received = pim::PimReceiveBatchOp::create(rewriter, - loc, - outputBuffer.getType(), - outputBuffer, - getTensorSizeInBytesAttr(rewriter, receiveBatchOp.getOutput()), - receiveBatchOp.getSourceCoreIdsAttr()) - .getOutput(); - mapper.map(receiveBatchOp.getOutput(), received); - continue; - } - - if (auto receiveManyBatchOp = dyn_cast(op)) { - lowerChannelReceiveManyBatch(receiveManyBatchOp, computeBatchOp.getLaneCount(), mapper, rewriter); - continue; - } - - if (auto toTensorOp = dyn_cast(op)) { - if (isa_and_present(toTensorOp.getBuffer().getDefiningOp())) { - Operation* cloned = rewriter.clone(op, mapper); - auto clonedTensor = cloned->getResult(0); - auto clonedType = cast(clonedTensor.getType()); - auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, clonedType); - auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter, - loc, - outputBuffer.getType(), - outputBuffer, - clonedTensor, - rewriter.getI32IntegerAttr(0), - rewriter.getI32IntegerAttr(0), - getTensorSizeInBytesAttr(rewriter, clonedTensor)) - .getOutput(); - mapper.map(toTensorOp.getResult(), copied); - continue; - } - } - - for (Value operand : op.getOperands()) { - if (!isa(operand.getType()) || mapper.contains(operand)) - continue; - - Operation* definingOp = operand.getDefiningOp(); - if (definingOp && definingOp->getBlock() == &oldBlock) - continue; - - materializeCapturedTensor(operand); - } - - Operation* cloned = rewriter.clone(op, mapper); - for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults())) - mapper.map(originalResult, clonedResult); - } - - rewriter.setInsertionPointToEnd(newBlock); - PimHaltOp::create(rewriter, loc); -} - void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) { auto enlargeTiedDpsChain = [&](Value value, RankedTensorType newType, auto& self) -> void { auto* definingOp = value.getDefiningOp(); @@ -1525,41 +519,6 @@ void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, I }); } -void SpatialToPimPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter) { - outputTensors.reserve(returnOp->getNumOperands()); - for (auto [index, returnValue] : llvm::enumerate(returnOp->getOperands())) { - Value currentReturnValue = returnValue; - Operation* returnValueDefiningOp = currentReturnValue.getDefiningOp(); - if (returnValueDefiningOp->hasTrait()) { - assert(!hasWeightAlways(returnValueDefiningOp)); - outputTensors.push_back( - [currentReturnValue](IRRewriter& rewriter, Location loc) -> Value { return currentReturnValue; }); - } - else { - auto outRankedTensorType = llvm::dyn_cast(currentReturnValue.getType()); - auto memRefType = mlir::MemRefType::get(outRankedTensorType.getShape(), outRankedTensorType.getElementType()); - - std::string outputName = "output_" + std::to_string(index); - rewriter.setInsertionPoint(returnOp.getParentOp()); - memref::GlobalOp::create(rewriter, - returnOp.getLoc(), - rewriter.getStringAttr(outputName), - rewriter.getStringAttr("private"), - TypeAttr::get(memRefType), - {}, - {}, - {}); - outputTensors.push_back( - [memRefType, outputName, outRankedTensorType](IRRewriter& rewriter, Location loc) -> Value { - auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, outputName); - auto toTensor = bufferization::ToTensorOp::create( - rewriter, loc, outRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr()); - return toTensor.getResult(); - }); - } - } -} - LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) { Location loc = funcOp.getLoc(); @@ -1605,69 +564,6 @@ void SpatialToPimPass::markOpToRemove(Operation* op) { operationsToRemove.push_back(op); } -void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter) { - auto markOwnedReturnChain = [&](Operation* op, auto&& markOwnedReturnChain) -> void { - if (!op) - return; - - bool isExclusivelyOwnedByReturnChain = op->use_empty(); - if (!isExclusivelyOwnedByReturnChain && op->hasOneUse()) { - Operation* onlyUser = *op->getUsers().begin(); - isExclusivelyOwnedByReturnChain = - isa(onlyUser) - || isChannelUseChainOp(onlyUser); - } - if (!isExclusivelyOwnedByReturnChain) - return; - - if (isChannelUseChainOp(op)) { - Value source = op->getOperand(0); - markOpToRemove(op); - markOwnedReturnChain(source.getDefiningOp(), markOwnedReturnChain); - return; - } - - if (auto computeOp = dyn_cast(op)) { - markOpToRemove(computeOp); - if (!computeOp.getInputs().empty()) - for (Value input : computeOp.getInputs()) - markOwnedReturnChain(input.getDefiningOp(), markOwnedReturnChain); - return; - } - - if (auto concatOp = dyn_cast(op)) { - markOpToRemove(concatOp); - for (Value operand : concatOp.getOperands()) - markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain); - return; - } - - if (auto concatOp = dyn_cast(op)) { - markOpToRemove(concatOp); - for (Value operand : concatOp.getInputs()) - markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain); - return; - } - - if (auto concatOp = dyn_cast(op)) { - markOpToRemove(concatOp); - for (Value operand : concatOp.getInputs()) - markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain); - } - }; - - SmallVector originalOperands(returnOp.getOperands().begin(), returnOp.getOperands().end()); - auto loc = returnOp.getLoc(); - for (auto it : llvm::enumerate(originalOperands)) { - size_t orderWithinReturn = it.index(); - Operation* returnOperand = it.value().getDefiningOp(); - rewriter.setInsertionPoint(returnOp); - Value outputTensor = outputTensors[orderWithinReturn](rewriter, loc); - rewriter.modifyOpInPlace(returnOp, [&] { returnOp.setOperand(orderWithinReturn, outputTensor); }); - markOwnedReturnChain(returnOperand, markOwnedReturnChain); - } -} - std::unique_ptr createSpatialToPimPass() { return std::make_unique(); } } // namespace onnx_mlir diff --git a/src/PIM/Conversion/SpatialToPim/TensorPackingPatterns.cpp b/src/PIM/Conversion/SpatialToPim/TensorPackingPatterns.cpp new file mode 100644 index 0000000..4307cdf --- /dev/null +++ b/src/PIM/Conversion/SpatialToPim/TensorPackingPatterns.cpp @@ -0,0 +1,113 @@ +#include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp" + +using namespace mlir; + +namespace onnx_mlir { +namespace { + +// Replaces concat-of-adjacent-slices with one packed slice to keep batch sends compact. +struct FoldConcatOfContiguousSlices : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::ConcatOp op, PatternRewriter& rewriter) const override { + if (op.getDim() != 0) + return failure(); + + Value packed = createPackedExtractSliceTensor(op.getInputs(), rewriter, op.getLoc()); + if (!packed) + return failure(); + + rewriter.replaceOp(op, packed); + return success(); + } +}; + +} // namespace + +RankedTensorType getPackedTensorType(RankedTensorType elementType, int64_t count) { + SmallVector packedShape(elementType.getShape().begin(), elementType.getShape().end()); + packedShape[0] *= count; + return RankedTensorType::get(packedShape, elementType.getElementType()); +} + +Value createPackedExtractSliceTensor(ValueRange values, OpBuilder& builder, Location loc) { + if (values.empty()) + return {}; + if (values.size() == 1) + return values.front(); + + auto firstSliceOp = values.front().getDefiningOp(); + if (!firstSliceOp) + return {}; + + auto firstType = dyn_cast(firstSliceOp.getResult().getType()); + auto sourceType = dyn_cast(firstSliceOp.getSource().getType()); + if (!firstType || !sourceType || !firstType.hasStaticShape() || !sourceType.hasStaticShape() + || firstType.getRank() == 0) + return {}; + + auto hasStaticValues = [](ArrayRef values) { + return llvm::all_of(values, [](int64_t value) { return !ShapedType::isDynamic(value); }); + }; + if (!hasStaticValues(firstSliceOp.getStaticOffsets()) || !hasStaticValues(firstSliceOp.getStaticSizes()) + || !hasStaticValues(firstSliceOp.getStaticStrides())) + return {}; + + ArrayRef firstOffsets = firstSliceOp.getStaticOffsets(); + ArrayRef firstSizes = firstSliceOp.getStaticSizes(); + ArrayRef firstStrides = firstSliceOp.getStaticStrides(); + int64_t rowsPerValue = firstSizes[0]; + if (ShapedType::isDynamic(rowsPerValue)) + return {}; + + for (size_t index = 1; index < values.size(); ++index) { + auto sliceOp = values[index].getDefiningOp(); + if (!sliceOp || sliceOp.getSource() != firstSliceOp.getSource() + || sliceOp.getResult().getType() != firstSliceOp.getResult().getType() + || !hasStaticValues(sliceOp.getStaticOffsets()) || !hasStaticValues(sliceOp.getStaticSizes()) + || !hasStaticValues(sliceOp.getStaticStrides())) + return {}; + + if (sliceOp.getStaticSizes() != firstSizes || sliceOp.getStaticStrides() != firstStrides) + return {}; + + if (sliceOp.getStaticOffsets()[0] != firstOffsets[0] + static_cast(index) * rowsPerValue) + return {}; + + for (int64_t dim = 1; dim < firstType.getRank(); ++dim) + if (sliceOp.getStaticOffsets()[dim] != firstOffsets[dim]) + return {}; + } + + auto packedType = getPackedTensorType(firstType, static_cast(values.size())); + SmallVector offsets; + SmallVector sizes; + SmallVector strides; + offsets.reserve(firstType.getRank()); + sizes.reserve(firstType.getRank()); + strides.reserve(firstType.getRank()); + + offsets.push_back(builder.getIndexAttr(firstOffsets[0])); + sizes.push_back(builder.getIndexAttr(rowsPerValue * static_cast(values.size()))); + strides.push_back(builder.getIndexAttr(firstStrides[0])); + for (int64_t dim = 1; dim < firstType.getRank(); ++dim) { + offsets.push_back(builder.getIndexAttr(firstOffsets[dim])); + sizes.push_back(builder.getIndexAttr(firstSizes[dim])); + strides.push_back(builder.getIndexAttr(firstStrides[dim])); + } + + bool coversWholeSource = packedType == sourceType; + for (int64_t dim = 0; coversWholeSource && dim < sourceType.getRank(); ++dim) + coversWholeSource = firstOffsets[dim] == 0 && firstStrides[dim] == 1; + if (coversWholeSource) + return firstSliceOp.getSource(); + + return tensor::ExtractSliceOp::create(builder, loc, packedType, firstSliceOp.getSource(), offsets, sizes, strides) + .getResult(); +} + +void populateTensorPackingPatterns(RewritePatternSet& patterns) { + patterns.add(patterns.getContext()); +} + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp b/src/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp new file mode 100644 index 0000000..7b34544 --- /dev/null +++ b/src/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp @@ -0,0 +1,13 @@ +#pragma once + +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/PatternMatch.h" + +namespace onnx_mlir { + +mlir::RankedTensorType getPackedTensorType(mlir::RankedTensorType elementType, int64_t count); +mlir::Value createPackedExtractSliceTensor(mlir::ValueRange values, mlir::OpBuilder& builder, mlir::Location loc); + +void populateTensorPackingPatterns(mlir::RewritePatternSet& patterns); + +} // namespace onnx_mlir diff --git a/src/PIM/Dialect/Pim/Pim.td b/src/PIM/Dialect/Pim/Pim.td index f351708..429b5a1 100644 --- a/src/PIM/Dialect/Pim/Pim.td +++ b/src/PIM/Dialect/Pim/Pim.td @@ -113,6 +113,18 @@ def PimSendBatchOp : PimOp<"send_batch", []> { let hasCustomAssemblyFormat = 1; } +def PimSendTensorBatchOp : PimOp<"send_tensor_batch", []> { + let summary = "Send equal contiguous chunks of one per-lane tensor from a batched core"; + + let arguments = (ins + PimTensor:$input, + DenseI32ArrayAttr:$targetCoreIds + ); + + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; +} + def PimReceiveOp : PimOp<"receive", [DestinationStyleOpInterface]> { let summary = "Receive a tensor from another core"; @@ -181,6 +193,28 @@ def PimReceiveBatchOp : PimOp<"receive_batch", [DestinationStyleOpInterface]> { let hasCustomAssemblyFormat = 1; } +def PimReceiveTensorBatchOp : PimOp<"receive_tensor_batch", [DestinationStyleOpInterface]> { + let summary = "Receive equal contiguous chunks into one per-lane tensor inside a batched core"; + + let arguments = (ins + PimTensor:$outputBuffer, + DenseI32ArrayAttr:$sourceCoreIds + ); + + let results = (outs + PimTensor:$output + ); + + let extraClassDeclaration = [{ + mlir::MutableOperandRange getDpsInitsMutable() { + return getOutputBufferMutable(); + } + }]; + + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; +} + def PimMemCopyHostToDevOp : PimOp<"memcp_hd", [DestinationStyleOpInterface]> { let summary = "Copy a memory region from host memory into device memory"; diff --git a/src/PIM/Dialect/Pim/PimOpsAsm.cpp b/src/PIM/Dialect/Pim/PimOpsAsm.cpp index 8bee1b4..db2f6d0 100644 --- a/src/PIM/Dialect/Pim/PimOpsAsm.cpp +++ b/src/PIM/Dialect/Pim/PimOpsAsm.cpp @@ -174,6 +174,33 @@ ParseResult PimSendBatchOp::parse(OpAsmParser& parser, OperationState& result) { return parser.resolveOperand(input, inputType, result.operands); } +void PimSendTensorBatchOp::print(OpAsmPrinter& printer) { + printer << " "; + printer.printOperand(getInput()); + printCoreIdList(printer, "to", getTargetCoreIds()); + printer.printOptionalAttrDict((*this)->getAttrs(), {getTargetCoreIdsAttrName().getValue()}); + printer << " : "; + printer.printType(getInput().getType()); +} + +ParseResult PimSendTensorBatchOp::parse(OpAsmParser& parser, OperationState& result) { + OpAsmParser::UnresolvedOperand input; + Type inputType; + SmallVector targetCoreIds; + + if (parser.parseOperand(input) || parseOptionalCoreIdList(parser, "to", targetCoreIds) + || parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType)) + return failure(); + + if (!targetCoreIds.empty() && result.attributes.get("targetCoreIds")) + return parser.emitError(parser.getCurrentLocation(), + "targetCoreIds cannot be specified both positionally and in attr-dict"); + if (!targetCoreIds.empty()) + result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds)); + + return parser.resolveOperand(input, inputType, result.operands); +} + void PimSendTensorOp::print(OpAsmPrinter& printer) { printer << " "; printer.printOperand(getInput()); @@ -275,6 +302,43 @@ ParseResult PimReceiveBatchOp::parse(OpAsmParser& parser, OperationState& result return success(); } +void PimReceiveTensorBatchOp::print(OpAsmPrinter& printer) { + printCoreIdList(printer, "from", getSourceCoreIds()); + printer << " into "; + printOpenDelimiter(printer, ListDelimiter::Paren); + printer.printOperand(getOutputBuffer()); + printCloseDelimiter(printer, ListDelimiter::Paren); + printer.printOptionalAttrDict((*this)->getAttrs(), {getSourceCoreIdsAttrName().getValue()}); + printer << " : "; + printer.printType(getOutputBuffer().getType()); + printer << " -> "; + printer.printType(getOutput().getType()); +} + +ParseResult PimReceiveTensorBatchOp::parse(OpAsmParser& parser, OperationState& result) { + OpAsmParser::UnresolvedOperand outputBuffer; + Type outputBufferType; + Type outputType; + SmallVector sourceCoreIds; + + if (parseOptionalCoreIdList(parser, "from", sourceCoreIds) || parser.parseKeyword("into") || parser.parseLParen() + || parser.parseOperand(outputBuffer) || parser.parseRParen() || parser.parseOptionalAttrDict(result.attributes) + || parser.parseColon() || parser.parseType(outputBufferType) || parser.parseArrow() + || parser.parseType(outputType)) + return failure(); + + if (!sourceCoreIds.empty() && result.attributes.get("sourceCoreIds")) + return parser.emitError(parser.getCurrentLocation(), + "sourceCoreIds cannot be specified both positionally and in attr-dict"); + if (!sourceCoreIds.empty()) + result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds)); + + if (parser.resolveOperand(outputBuffer, outputBufferType, result.operands)) + return failure(); + result.addTypes(outputType); + return success(); +} + void PimConcatOp::print(OpAsmPrinter& printer) { printer << " axis " << getAxis() << " "; printCompressedValueSequence(printer, getInputs()); diff --git a/src/PIM/Dialect/Pim/PimOpsVerify.cpp b/src/PIM/Dialect/Pim/PimOpsVerify.cpp index 3d39f4c..05cef60 100644 --- a/src/PIM/Dialect/Pim/PimOpsVerify.cpp +++ b/src/PIM/Dialect/Pim/PimOpsVerify.cpp @@ -46,12 +46,47 @@ static LogicalResult verifyTensorCommunication(Operation* op, Type type, ArrayRe return success(); } +static LogicalResult +verifyTensorBatchCommunication(Operation* op, Type type, ArrayRef coreIds, StringRef kind) { + if (coreIds.empty()) + return op->emitError() << kind << " must carry at least one chunk"; + + auto coreBatchOp = op->getParentOfType(); + if (!coreBatchOp) + return op->emitError() << kind << " must be nested inside pim.core_batch"; + + int32_t laneCount = coreBatchOp.getLaneCount(); + if (laneCount <= 0) + return op->emitError() << kind << " requires a positive parent laneCount"; + if (coreIds.size() % static_cast(laneCount) != 0) + return op->emitError() << kind << " core id count must be divisible by the parent laneCount"; + + auto shapedType = dyn_cast(type); + if (!shapedType || !shapedType.hasStaticShape()) + return op->emitError() << kind << " requires a static shaped tensor or memref"; + + int64_t elementBits = shapedType.getElementTypeBitWidth(); + if (elementBits <= 0 || elementBits % 8 != 0) + return op->emitError() << kind << " requires byte-sized elements"; + + int64_t chunkCount = static_cast(coreIds.size()) / laneCount; + int64_t totalBytes = shapedType.getNumElements() * elementBits / 8; + if (totalBytes % chunkCount != 0) + return op->emitError() << kind << " tensor byte size must be divisible by the chunk count per lane"; + + return success(); +} + } // namespace LogicalResult PimSendTensorOp::verify() { return verifyTensorCommunication(getOperation(), getInput().getType(), getTargetCoreIds(), "send_tensor"); } +LogicalResult PimSendTensorBatchOp::verify() { + return verifyTensorBatchCommunication(getOperation(), getInput().getType(), getTargetCoreIds(), "send_tensor_batch"); +} + LogicalResult PimReceiveTensorOp::verify() { if (failed(verifyCompatibleShapedTypes( getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match"))) @@ -60,6 +95,15 @@ LogicalResult PimReceiveTensorOp::verify() { return verifyTensorCommunication(getOperation(), getOutput().getType(), getSourceCoreIds(), "receive_tensor"); } +LogicalResult PimReceiveTensorBatchOp::verify() { + if (failed(verifyCompatibleShapedTypes( + getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match"))) + return failure(); + + return verifyTensorBatchCommunication( + getOperation(), getOutput().getType(), getSourceCoreIds(), "receive_tensor_batch"); +} + LogicalResult PimConcatOp::verify() { if (getInputs().empty()) return emitError("requires at least one input"); diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.cpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.cpp new file mode 100644 index 0000000..d0c90f7 --- /dev/null +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.cpp @@ -0,0 +1,40 @@ +#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" + +#include "src/Accelerators/PIM/Common/PimCommon.hpp" +#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" +#include "src/Accelerators/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.hpp" + +using namespace mlir; +using namespace bufferization; + +namespace onnx_mlir::pim { + +Value materializeContiguousMemRef(Value memrefValue, Location loc, RewriterBase& rewriter) { + if (succeeded(resolveContiguousAddress(memrefValue))) + return memrefValue; + + auto shapedType = cast(memrefValue.getType()); + auto contiguousType = MemRefType::get(shapedType.getShape(), shapedType.getElementType()); + Value contiguousBuffer = memref::AllocOp::create(rewriter, loc, contiguousType); + auto sizeInBytes = shapedType.getNumElements() * shapedType.getElementTypeBitWidth() / 8; + + return PimMemCopyOp::create(rewriter, + loc, + contiguousType, + contiguousBuffer, + memrefValue, + rewriter.getI32IntegerAttr(0), + rewriter.getI32IntegerAttr(0), + rewriter.getI32IntegerAttr(sizeInBytes)) + .getOutput(); +} + +FailureOr +getBufferOrValue(RewriterBase& rewriter, Value value, const BufferizationOptions& options, BufferizationState& state) { + if (isa(value.getType())) + return value; + return getBuffer(rewriter, value, options, state); +} + +} // namespace onnx_mlir::pim diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.hpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.hpp new file mode 100644 index 0000000..d4c5d48 --- /dev/null +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.hpp @@ -0,0 +1,15 @@ +#pragma once + +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/IR/PatternMatch.h" + +namespace onnx_mlir::pim { + +mlir::Value materializeContiguousMemRef(mlir::Value memrefValue, mlir::Location loc, mlir::RewriterBase& rewriter); + +llvm::FailureOr getBufferOrValue(mlir::RewriterBase& rewriter, + mlir::Value value, + const mlir::bufferization::BufferizationOptions& options, + mlir::bufferization::BufferizationState& state); + +} // namespace onnx_mlir::pim diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/CMakeLists.txt b/src/PIM/Dialect/Pim/Transforms/Bufferization/CMakeLists.txt index cb74785..5011fa9 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/CMakeLists.txt +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/CMakeLists.txt @@ -4,6 +4,8 @@ add_public_tablegen_target(PimBufferizationIncGen) add_pim_library(OMPimBufferization PimBufferizationPass.cpp + BufferizationUtils.hpp + BufferizationUtils.cpp OpBufferizationInterfaces.hpp OpBufferizationInterfaces.cpp Common.hpp diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp index d07bd9d..35d3d94 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp @@ -7,6 +7,7 @@ #include "OpBufferizationInterfaces.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" +#include "src/Accelerators/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.hpp" using namespace mlir; using namespace bufferization; @@ -14,33 +15,6 @@ using namespace bufferization; namespace onnx_mlir { namespace pim { -static Value materializeContiguousMemRef(Value memrefValue, Location loc, RewriterBase& rewriter) { - if (succeeded(resolveContiguousAddress(memrefValue))) - return memrefValue; - - auto shapedType = cast(memrefValue.getType()); - auto contiguousType = MemRefType::get(shapedType.getShape(), shapedType.getElementType()); - Value contiguousBuffer = memref::AllocOp::create(rewriter, loc, contiguousType); - auto sizeInBytes = shapedType.getNumElements() * shapedType.getElementTypeBitWidth() / 8; - - return PimMemCopyOp::create(rewriter, - loc, - contiguousType, - contiguousBuffer, - memrefValue, - rewriter.getI32IntegerAttr(0), - rewriter.getI32IntegerAttr(0), - rewriter.getI32IntegerAttr(sizeInBytes)) - .getOutput(); -} - -static FailureOr -getBufferOrValue(RewriterBase& rewriter, Value value, const BufferizationOptions& options, BufferizationState& state) { - if (isa(value.getType())) - return value; - return getBuffer(rewriter, value, options, state); -} - struct MemCopyHostToDevOpInterface : DstBufferizableOpInterfaceExternalModel { LogicalResult bufferize(Operation* op, @@ -201,6 +175,27 @@ struct ReceiveTensorOpInterface } }; +struct ReceiveTensorBatchOpInterface +: DstBufferizableOpInterfaceExternalModel { + bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { + return !cast(op).isDpsInit(&opOperand); + } + + LogicalResult bufferize(Operation* op, + RewriterBase& rewriter, + const BufferizationOptions& options, + BufferizationState& state) const { + auto receiveOp = cast(op); + auto outputBufferOpt = getBufferOrValue(rewriter, receiveOp.getOutputBuffer(), options, state); + if (failed(outputBufferOpt)) + return failure(); + + replaceOpWithNewBufferizedOp( + rewriter, op, outputBufferOpt->getType(), *outputBufferOpt, receiveOp.getSourceCoreIdsAttr()); + return success(); + } +}; + struct ConcatOpInterface : DstBufferizableOpInterfaceExternalModel { bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return !cast(op).isDpsInit(&opOperand); @@ -308,6 +303,31 @@ struct SendBatchOpInterface : BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; } + + bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return false; } + + AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { + return {}; + } + + LogicalResult bufferize(Operation* op, + RewriterBase& rewriter, + const BufferizationOptions& options, + BufferizationState& state) const { + auto sendOp = cast(op); + auto inputOpt = getBufferOrValue(rewriter, sendOp.getInput(), options, state); + if (failed(inputOpt)) + return failure(); + + replaceOpWithNewBufferizedOp( + rewriter, op, materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter), sendOp.getTargetCoreIdsAttr()); + return success(); + } +}; + struct CoreOpInterface : BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return true; } @@ -623,9 +643,11 @@ void registerOpBufferizationInterfaces(DialectRegistry& registry) { PimCoreBatchOp::attachInterface(*ctx); PimReceiveOp::attachInterface(*ctx); PimReceiveTensorOp::attachInterface(*ctx); + PimReceiveTensorBatchOp::attachInterface(*ctx); PimReceiveBatchOp::attachInterface(*ctx); PimSendOp::attachInterface(*ctx); PimSendBatchOp::attachInterface(*ctx); + PimSendTensorBatchOp::attachInterface(*ctx); PimSendTensorOp::attachInterface(*ctx); PimConcatOp::attachInterface(*ctx); PimMemCopyHostToDevOp::attachInterface(*ctx); diff --git a/src/PIM/Dialect/Spatial/Channels.cpp b/src/PIM/Dialect/Spatial/Channels.cpp index 44a6ff1..59847e6 100644 --- a/src/PIM/Dialect/Spatial/Channels.cpp +++ b/src/PIM/Dialect/Spatial/Channels.cpp @@ -1,8 +1,8 @@ -#include "src/Accelerators/PIM/Dialect/Spatial/Channels.hpp" - #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Diagnostics.h" +#include "src/Accelerators/PIM/Dialect/Spatial/Channels.hpp" + using namespace mlir; namespace onnx_mlir::spatial { diff --git a/src/PIM/Dialect/Spatial/Spatial.td b/src/PIM/Dialect/Spatial/Spatial.td index 43cda62..9414609 100644 --- a/src/PIM/Dialect/Spatial/Spatial.td +++ b/src/PIM/Dialect/Spatial/Spatial.td @@ -102,23 +102,6 @@ def SpatConcatOp : SpatOp<"concat", []> { let hasCustomAssemblyFormat = 1; } -def SpatMapOp : SpatOp<"map", [SingleBlock]> { - let summary = "Apply the same lane-local region to many independent tensors"; - - let arguments = (ins - Variadic:$inputs - ); - - let results = (outs - Variadic:$outputs - ); - - let regions = (region SizedRegion<1>:$body); - - let hasVerifier = 1; - let hasCustomAssemblyFormat = 1; -} - //===----------------------------------------------------------------------===// // Communication //===----------------------------------------------------------------------===// @@ -156,22 +139,25 @@ def SpatChannelReceiveOp : SpatOp<"channel_receive", []> { }]; } -def SpatChannelSendManyOp : SpatOp<"channel_send_many", []> { - let summary = "Send multiple tensors through logical channels"; +def SpatChannelSendTensorOp : SpatOp<"channel_send_tensor", []> { + let summary = "Send equal contiguous chunks of one tensor through logical channels"; let arguments = (ins DenseI64ArrayAttr:$channelIds, DenseI32ArrayAttr:$sourceCoreIds, DenseI32ArrayAttr:$targetCoreIds, - Variadic:$inputs + SpatTensor:$input ); let hasVerifier = 1; - let hasCustomAssemblyFormat = 1; + + let assemblyFormat = [{ + $input attr-dict `:` type($input) + }]; } -def SpatChannelReceiveManyOp : SpatOp<"channel_receive_many", []> { - let summary = "Receive multiple tensors from logical channels"; +def SpatChannelReceiveTensorOp : SpatOp<"channel_receive_tensor", []> { + let summary = "Receive equal contiguous chunks of one tensor from logical channels"; let arguments = (ins DenseI64ArrayAttr:$channelIds, @@ -180,11 +166,14 @@ def SpatChannelReceiveManyOp : SpatOp<"channel_receive_many", []> { ); let results = (outs - Variadic:$outputs + SpatTensor:$output ); let hasVerifier = 1; - let hasCustomAssemblyFormat = 1; + + let assemblyFormat = [{ + attr-dict `:` type($output) + }]; } def SpatChannelSendBatchOp : SpatOp<"channel_send_batch", []> { @@ -201,18 +190,21 @@ def SpatChannelSendBatchOp : SpatOp<"channel_send_batch", []> { let hasCustomAssemblyFormat = 1; } -def SpatChannelSendManyBatchOp : SpatOp<"channel_send_many_batch", []> { - let summary = "Send multiple per-lane tensors through logical channels in a batch body"; +def SpatChannelSendTensorBatchOp : SpatOp<"channel_send_tensor_batch", []> { + let summary = "Send equal contiguous chunks of one per-lane tensor through logical channels in a batch body"; let arguments = (ins DenseI64ArrayAttr:$channelIds, DenseI32ArrayAttr:$sourceCoreIds, DenseI32ArrayAttr:$targetCoreIds, - Variadic:$inputs + SpatTensor:$input ); let hasVerifier = 1; - let hasCustomAssemblyFormat = 1; + + let assemblyFormat = [{ + $input attr-dict `:` type($input) + }]; } def SpatChannelReceiveBatchOp : SpatOp<"channel_receive_batch", []> { @@ -232,8 +224,8 @@ def SpatChannelReceiveBatchOp : SpatOp<"channel_receive_batch", []> { let hasCustomAssemblyFormat = 1; } -def SpatChannelReceiveManyBatchOp : SpatOp<"channel_receive_many_batch", []> { - let summary = "Receive multiple per-lane tensors through logical channels in a batch body"; +def SpatChannelReceiveTensorBatchOp : SpatOp<"channel_receive_tensor_batch", []> { + let summary = "Receive equal contiguous chunks of one per-lane tensor through logical channels in a batch body"; let arguments = (ins DenseI64ArrayAttr:$channelIds, @@ -242,11 +234,14 @@ def SpatChannelReceiveManyBatchOp : SpatOp<"channel_receive_many_batch", []> { ); let results = (outs - Variadic:$outputs + SpatTensor:$output ); let hasVerifier = 1; - let hasCustomAssemblyFormat = 1; + + let assemblyFormat = [{ + attr-dict `:` type($output) + }]; } //===----------------------------------------------------------------------===// diff --git a/src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp b/src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp index 28c0dd9..f86ffa9 100644 --- a/src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp +++ b/src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp @@ -7,8 +7,8 @@ #include -#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp" +#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" using namespace mlir; @@ -129,9 +129,8 @@ ParseResult SpatConcatOp::parse(OpAsmParser& parser, OperationState& result) { if (parser.parseKeyword("axis") || parser.parseInteger(axis)) return failure(); - if (parseCompressedOperandSequence(parser, inputs)) { + if (parseCompressedOperandSequence(parser, inputs)) return failure(); - } if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parseCompressedRepeatedList( @@ -151,46 +150,6 @@ ParseResult SpatConcatOp::parse(OpAsmParser& parser, OperationState& result) { return success(); } -void SpatMapOp::print(OpAsmPrinter& printer) { - printer << " "; - printArgumentBindings(printer, getBody().front(), getInputs()); - printer.printOptionalAttrDict((*this)->getAttrs()); - printer << " : "; - printer.printType(getInputs().front().getType()); - printer << " -> "; - printer.printType(getOutputs().front().getType()); - printer << " "; - printer.printRegion(getBody(), /*printEntryBlockArgs=*/false); -} - -ParseResult SpatMapOp::parse(OpAsmParser& parser, OperationState& result) { - SmallVector regionArgs; - SmallVector inputs; - Type inputType; - Type outputType; - - if (parseArgumentBindings(parser, regionArgs, inputs)) - return failure(); - if (inputs.empty()) - return parser.emitError(parser.getCurrentLocation(), "map requires at least one input"); - - if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType) - || parser.parseArrow() || parser.parseType(outputType)) - return failure(); - - SmallVector inputTypes(inputs.size(), inputType); - SmallVector outputTypes(inputs.size(), outputType); - if (regionArgs.size() != inputs.size()) - return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match"); - if (parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands)) - return failure(); - result.addTypes(outputTypes); - - applyArgumentTypes(inputTypes, regionArgs); - Region* body = result.addRegion(); - return parser.parseRegion(*body, regionArgs); -} - void SpatCompute::print(OpAsmPrinter& printer) { printer << " "; printCompressedValueList(printer, getWeights(), ListDelimiter::Square); @@ -357,97 +316,6 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result) return parser.parseRegion(*body, regionArgs); } -void SpatChannelSendManyOp::print(OpAsmPrinter& printer) { - printer << " "; - printCompressedValueSequence(printer, getInputs()); - printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds()); - printer.printOptionalAttrDict( - (*this)->getAttrs(), - {getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()}); - printer << " : "; - printCompressedTypeSequence(printer, TypeRange(getInputs())); -} - -ParseResult SpatChannelSendManyOp::parse(OpAsmParser& parser, OperationState& result) { - SmallVector inputs; - SmallVector inputTypes; - SmallVector channelIds; - SmallVector sourceCoreIds; - SmallVector targetCoreIds; - - if (parseCompressedOperandSequence(parser, inputs)) - return failure(); - - bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels")); - if (hasMetadata) { - if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from") - || parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to") - || parseCompressedIntegerList(parser, targetCoreIds)) - return failure(); - } - - if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() - || parseCompressedTypeSequence(parser, inputTypes, /*allowEmpty=*/false)) - return failure(); - - if (inputs.size() != inputTypes.size()) - return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match"); - if (hasMetadata - && (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds") - || result.attributes.get("targetCoreIds"))) - return parser.emitError(parser.getCurrentLocation(), - "channel metadata cannot be specified both positionally and in attr-dict"); - if (hasMetadata) { - result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds)); - result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds)); - result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds)); - } - - return parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands); -} - -void SpatChannelReceiveManyOp::print(OpAsmPrinter& printer) { - printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds()); - printer.printOptionalAttrDict( - (*this)->getAttrs(), - {getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()}); - printer << " : "; - printCompressedTypeSequence(printer, getResultTypes()); -} - -ParseResult SpatChannelReceiveManyOp::parse(OpAsmParser& parser, OperationState& result) { - SmallVector outputTypes; - SmallVector channelIds; - SmallVector sourceCoreIds; - SmallVector targetCoreIds; - - bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels")); - if (hasMetadata) { - if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from") - || parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to") - || parseCompressedIntegerList(parser, targetCoreIds)) - return failure(); - } - - if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() - || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/false)) - return failure(); - - if (hasMetadata - && (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds") - || result.attributes.get("targetCoreIds"))) - return parser.emitError(parser.getCurrentLocation(), - "channel metadata cannot be specified both positionally and in attr-dict"); - if (hasMetadata) { - result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds)); - result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds)); - result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds)); - } - - result.addTypes(outputTypes); - return success(); -} - void SpatChannelSendBatchOp::print(OpAsmPrinter& printer) { printer << " "; printer.printOperand(getInput()); @@ -494,55 +362,6 @@ ParseResult SpatChannelSendBatchOp::parse(OpAsmParser& parser, OperationState& r return parser.resolveOperand(input, inputType, result.operands); } -void SpatChannelSendManyBatchOp::print(OpAsmPrinter& printer) { - printer << " "; - printCompressedValueSequence(printer, getInputs()); - printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds()); - printer.printOptionalAttrDict( - (*this)->getAttrs(), - {getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()}); - printer << " : "; - printCompressedTypeSequence(printer, TypeRange(getInputs())); -} - -ParseResult SpatChannelSendManyBatchOp::parse(OpAsmParser& parser, OperationState& result) { - SmallVector inputs; - SmallVector inputTypes; - SmallVector channelIds; - SmallVector sourceCoreIds; - SmallVector targetCoreIds; - - if (parseCompressedOperandSequence(parser, inputs)) - return failure(); - - bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels")); - if (hasMetadata) { - if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from") - || parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to") - || parseCompressedIntegerList(parser, targetCoreIds)) - return failure(); - } - - if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() - || parseCompressedTypeSequence(parser, inputTypes, /*allowEmpty=*/false)) - return failure(); - - if (inputs.size() != inputTypes.size()) - return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match"); - if (hasMetadata - && (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds") - || result.attributes.get("targetCoreIds"))) - return parser.emitError(parser.getCurrentLocation(), - "channel metadata cannot be specified both positionally and in attr-dict"); - if (hasMetadata) { - result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds)); - result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds)); - result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds)); - } - - return parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands); -} - void SpatChannelReceiveBatchOp::print(OpAsmPrinter& printer) { printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds()); printer.printOptionalAttrDict( @@ -584,47 +403,5 @@ ParseResult SpatChannelReceiveBatchOp::parse(OpAsmParser& parser, OperationState return success(); } -void SpatChannelReceiveManyBatchOp::print(OpAsmPrinter& printer) { - printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds()); - printer.printOptionalAttrDict( - (*this)->getAttrs(), - {getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()}); - printer << " : "; - printCompressedTypeSequence(printer, getResultTypes()); -} - -ParseResult SpatChannelReceiveManyBatchOp::parse(OpAsmParser& parser, OperationState& result) { - SmallVector outputTypes; - SmallVector channelIds; - SmallVector sourceCoreIds; - SmallVector targetCoreIds; - - bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels")); - if (hasMetadata) { - if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from") - || parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to") - || parseCompressedIntegerList(parser, targetCoreIds)) - return failure(); - } - - if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() - || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/false)) - return failure(); - - if (hasMetadata - && (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds") - || result.attributes.get("targetCoreIds"))) - return parser.emitError(parser.getCurrentLocation(), - "channel metadata cannot be specified both positionally and in attr-dict"); - if (hasMetadata) { - result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds)); - result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds)); - result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds)); - } - - result.addTypes(outputTypes); - return success(); -} - } // namespace spatial } // namespace onnx_mlir diff --git a/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp b/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp index 528d2e7..e240f17 100644 --- a/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp +++ b/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp @@ -105,26 +105,28 @@ static FailureOr getParentBatchLaneCount(Operation* op) { return batchOp.getLaneCount(); } -static LogicalResult verifyManyChannelSizes(Operation* op, - ArrayRef channelIds, - ArrayRef sourceCoreIds, - ArrayRef targetCoreIds, - size_t valueCount) { +static LogicalResult verifyTensorChannelSizes(Operation* op, + Type type, + ArrayRef channelIds, + ArrayRef sourceCoreIds, + ArrayRef targetCoreIds, + StringRef kind) { if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size()) return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length"); - if (channelIds.size() != valueCount) - return op->emitError("channel metadata length must match the number of values"); - return success(); -} + if (channelIds.empty()) + return op->emitError() << kind << " must carry at least one chunk"; -static LogicalResult verifyManyChannelTypes(Operation* op, TypeRange types, StringRef kind) { - if (types.empty()) - return op->emitError() << kind << " must carry at least one value"; + auto shapedType = dyn_cast(type); + if (!shapedType || !shapedType.hasStaticShape()) + return op->emitError() << kind << " requires a static shaped tensor"; - Type firstType = types.front(); - for (Type type : types.drop_front()) - if (type != firstType) - return op->emitError() << kind << " values must all have the same type"; + int64_t elementBits = shapedType.getElementTypeBitWidth(); + if (elementBits <= 0 || elementBits % 8 != 0) + return op->emitError() << kind << " requires byte-sized elements"; + + int64_t totalBytes = shapedType.getNumElements() * elementBits / 8; + if (totalBytes % static_cast(channelIds.size()) != 0) + return op->emitError() << kind << " tensor byte size must be divisible by the number of channel ids"; return success(); } @@ -144,19 +146,33 @@ static LogicalResult verifyBatchChannelSizes(Operation* op, return success(); } -static LogicalResult verifyManyBatchChannelSizes(Operation* op, - ArrayRef channelIds, - ArrayRef sourceCoreIds, - ArrayRef targetCoreIds, - size_t valueCount) { +static LogicalResult verifyTensorBatchChannelSizes(Operation* op, + Type type, + ArrayRef channelIds, + ArrayRef sourceCoreIds, + ArrayRef targetCoreIds, + StringRef kind) { if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size()) return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length"); auto laneCount = getParentBatchLaneCount(op); if (failed(laneCount)) return op->emitError("must be nested inside spat.compute_batch"); - if (channelIds.size() != valueCount * static_cast(*laneCount)) - return op->emitError("channel metadata length must match the number of values times parent laneCount"); + if (channelIds.empty() || channelIds.size() % static_cast(*laneCount) != 0) + return op->emitError() << kind << " channel metadata length must be a positive multiple of parent laneCount"; + + auto shapedType = dyn_cast(type); + if (!shapedType || !shapedType.hasStaticShape()) + return op->emitError() << kind << " requires a static shaped tensor"; + + int64_t elementBits = shapedType.getElementTypeBitWidth(); + if (elementBits <= 0 || elementBits % 8 != 0) + return op->emitError() << kind << " requires byte-sized elements"; + + int64_t chunkCount = static_cast(channelIds.size()) / *laneCount; + int64_t totalBytes = shapedType.getNumElements() * elementBits / 8; + if (totalBytes % chunkCount != 0) + return op->emitError() << kind << " tensor byte size must be divisible by the chunk count per lane"; return success(); } @@ -323,39 +339,6 @@ LogicalResult SpatConcatOp::verify() { return success(); } -LogicalResult SpatMapOp::verify() { - if (getInputs().empty()) - return emitError("requires at least one input"); - if (getOutputs().size() != getInputs().size()) - return emitError("number of outputs must match number of inputs"); - - Type inputType = getInputs().front().getType(); - for (Value input : getInputs().drop_front()) - if (input.getType() != inputType) - return emitError("all inputs must have the same type"); - - Type outputType = getOutputs().front().getType(); - for (Value output : getOutputs().drop_front()) - if (output.getType() != outputType) - return emitError("all outputs must have the same type"); - - Block& block = getBody().front(); - if (block.getNumArguments() != 1) - return emitError("body must have exactly one block argument"); - if (block.getArgument(0).getType() != inputType) - return emitError("body block argument type must match input type"); - - auto yieldOp = dyn_cast_or_null(block.getTerminator()); - if (!yieldOp) - return emitError("body must terminate with spat.yield"); - if (yieldOp.getNumOperands() != 1) - return emitError("body yield must produce exactly one value"); - if (yieldOp.getOperand(0).getType() != outputType) - return emitError("body yield type must match output type"); - - return success(); -} - LogicalResult SpatCompute::verify() { auto& block = getBody().front(); if (block.mightHaveTerminator()) { @@ -397,40 +380,48 @@ LogicalResult SpatCompute::verify() { return success(); } -LogicalResult SpatChannelSendManyOp::verify() { - if (failed(verifyManyChannelSizes( - getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds(), getInputs().size()))) - return failure(); - return verifyManyChannelTypes(getOperation(), getInputs().getTypes(), "channel_send_many"); +LogicalResult SpatChannelSendTensorOp::verify() { + return verifyTensorChannelSizes(getOperation(), + getInput().getType(), + getChannelIds(), + getSourceCoreIds(), + getTargetCoreIds(), + "channel_send_tensor"); } -LogicalResult SpatChannelReceiveManyOp::verify() { - if (failed(verifyManyChannelSizes( - getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds(), getOutputs().size()))) - return failure(); - return verifyManyChannelTypes(getOperation(), getOperation()->getResultTypes(), "channel_receive_many"); +LogicalResult SpatChannelReceiveTensorOp::verify() { + return verifyTensorChannelSizes(getOperation(), + getOutput().getType(), + getChannelIds(), + getSourceCoreIds(), + getTargetCoreIds(), + "channel_receive_tensor"); } LogicalResult SpatChannelSendBatchOp::verify() { return verifyBatchChannelSizes(getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds()); } -LogicalResult SpatChannelSendManyBatchOp::verify() { - if (failed(verifyManyBatchChannelSizes( - getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds(), getInputs().size()))) - return failure(); - return verifyManyChannelTypes(getOperation(), getInputs().getTypes(), "channel_send_many_batch"); +LogicalResult SpatChannelSendTensorBatchOp::verify() { + return verifyTensorBatchChannelSizes(getOperation(), + getInput().getType(), + getChannelIds(), + getSourceCoreIds(), + getTargetCoreIds(), + "channel_send_tensor_batch"); } LogicalResult SpatChannelReceiveBatchOp::verify() { return verifyBatchChannelSizes(getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds()); } -LogicalResult SpatChannelReceiveManyBatchOp::verify() { - if (failed(verifyManyBatchChannelSizes( - getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds(), getOutputs().size()))) - return failure(); - return verifyManyChannelTypes(getOperation(), getOperation()->getResultTypes(), "channel_receive_many_batch"); +LogicalResult SpatChannelReceiveTensorBatchOp::verify() { + return verifyTensorBatchChannelSizes(getOperation(), + getOutput().getType(), + getChannelIds(), + getSourceCoreIds(), + getTargetCoreIds(), + "channel_receive_tensor_batch"); } LogicalResult SpatComputeBatch::verify() { diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.cpp index 6d1cea3..863b8f8 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.cpp @@ -10,6 +10,7 @@ #include "llvm/Support/raw_ostream.h" #include +#include #include #include #include @@ -31,6 +32,8 @@ namespace { using SpatCompute = onnx_mlir::spatial::SpatCompute; using SpatComputeBatch = onnx_mlir::spatial::SpatComputeBatch; +bool isDcpCoarsenDebugEnabled() { return std::getenv("DCP_COARSEN_DEBUG") != nullptr; } + struct VirtualNode { SmallVector originalComputeIndices; Weight weight = 0; @@ -719,11 +722,12 @@ DCPAnalysisResult DCPAnalysis::run() { VirtualGraph virtualGraph = buildInitialVirtualGraph(computeInstances, edges); size_t iteration = 0; + bool debugCoarsening = isDcpCoarsenDebugEnabled(); auto tryCoarsenSelectedNodes = [&](ArrayRef selectedNodes) { size_t oldNodeCount = virtualGraph.nodes.size(); WindowScheduleResult windowSchedule = scheduleWindow(virtualGraph, selectedNodes, entryOp->getContext()); if (windowSchedule.mergeGroups.empty()) { - if (oldNodeCount >= 200) + if (debugCoarsening && oldNodeCount >= 200) llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} windowCpus={3} " "groups=0 mergedNodes=0 maxGroup=0 new={1} changed=0\n", iteration, @@ -737,7 +741,7 @@ DCPAnalysisResult DCPAnalysis::run() { std::vector oldToNewNode; if (!coarsenGraph(virtualGraph, windowSchedule.mergeGroups, coarsenedGraph, oldToNewNode)) return false; - if (oldNodeCount >= 200 || coarsenedGraph.nodes.size() >= 200) + if (debugCoarsening && (oldNodeCount >= 200 || coarsenedGraph.nodes.size() >= 200)) llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} windowCpus={3} " "groups={4} mergedNodes={5} maxGroup={6} new={7} changed={8}\n", iteration, @@ -755,7 +759,7 @@ DCPAnalysisResult DCPAnalysis::run() { while (virtualGraph.nodes.size() > 1) { if (virtualGraph.nodes.size() <= getSchedulingCpuBudget()) { - if (virtualGraph.nodes.size() >= 200) + if (debugCoarsening && virtualGraph.nodes.size() >= 200) llvm::errs() << llvm::formatv( "[DCP-COARSEN] iter={0} old={1} stop=cpu-budget\n", iteration, virtualGraph.nodes.size()); break; @@ -764,7 +768,7 @@ DCPAnalysisResult DCPAnalysis::run() { iteration++; TimingInfo timing = computeTiming(virtualGraph); if (!timing.valid) { - if (virtualGraph.nodes.size() >= 200) + if (debugCoarsening && virtualGraph.nodes.size() >= 200) llvm::errs() << llvm::formatv( "[DCP-COARSEN] iter={0} old={1} invalid-timing\n", iteration, virtualGraph.nodes.size()); break; @@ -776,7 +780,7 @@ DCPAnalysisResult DCPAnalysis::run() { selectedNodes.append(criticalWindow.begin(), criticalWindow.end()); if (selectedNodes.size() < 2) { - if (virtualGraph.nodes.size() >= 200) + if (debugCoarsening && virtualGraph.nodes.size() >= 200) llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} stop=small-window\n", iteration, virtualGraph.nodes.size(), @@ -786,7 +790,7 @@ DCPAnalysisResult DCPAnalysis::run() { if (tryCoarsenSelectedNodes(selectedNodes)) continue; - if (virtualGraph.nodes.size() >= 200) + if (debugCoarsening && virtualGraph.nodes.size() >= 200) llvm::errs() << llvm::formatv( "[DCP-COARSEN] iter={0} old={1} stop=no-merge\n", iteration, virtualGraph.nodes.size()); break; diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.hpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.hpp index 0a54a6c..defcebf 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.hpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.hpp @@ -59,11 +59,7 @@ struct DenseMapInfo { static ComputeInstance getTombstoneKey() { return {DenseMapInfo::getTombstoneKey(), UINT32_MAX, UINT32_MAX}; } - static unsigned getHashValue(const ComputeInstance& v) { - return llvm::hash_combine(v.op, v.laneStart, v.laneCount); - } - static bool isEqual(const ComputeInstance& a, const ComputeInstance& b) { - return a == b; - } + static unsigned getHashValue(const ComputeInstance& v) { return llvm::hash_combine(v.op, v.laneStart, v.laneCount); } + static bool isEqual(const ComputeInstance& a, const ComputeInstance& b) { return a == b; } }; } // namespace llvm diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/GraphDebug.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/GraphDebug.cpp index 1f907d1..dbaaa82 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/GraphDebug.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/GraphDebug.cpp @@ -38,9 +38,11 @@ void DcpProgressLogger::advanceCompleted(size_t taskCount) { completedTasks += t void DcpProgressLogger::printStart(size_t readyCount, int maxCpuCount, size_t xbarsCapacity) const { if (!logProgress) return; - llvm::errs() << llvm::formatv( - "[DCP] start tasks={0} ready={1} cpus=0/{2} crossbars=0/{3}\n", - totalTasks, readyCount, maxCpuCount, xbarsCapacity); + llvm::errs() << llvm::formatv("[DCP] start tasks={0} ready={1} cpus=0/{2} crossbars=0/{3}\n", + totalTasks, + readyCount, + maxCpuCount, + xbarsCapacity); } void DcpProgressLogger::maybePrintSlowCandidate(size_t nodeIndex, @@ -72,18 +74,17 @@ void DcpProgressLogger::printProgress( double percent = totalTasks == 0 ? 100.0 : (100.0 * static_cast(completedTasks) / totalTasks); bool done = completedTasks == totalTasks; - llvm::errs() << llvm::formatv( - "[DCP] {0}/{1} ({2:F0}%) ready={3} cpus={4}/{5} crossbars={6}/{7} {8}{9}\n", - completedTasks, - totalTasks, - percent, - readyCount, - cpuCount, - maxCpuCount, - xbarsUsed, - xbarsAvailable, - llvm::formatv("elapsed={0}", formatDuration(elapsedSeconds)).str(), - done ? "" : llvm::formatv(" eta={0}", formatDuration(etaSeconds)).str()); + llvm::errs() << llvm::formatv("[DCP] {0}/{1} ({2:F0}%) ready={3} cpus={4}/{5} crossbars={6}/{7} {8}{9}\n", + completedTasks, + totalTasks, + percent, + readyCount, + cpuCount, + maxCpuCount, + xbarsUsed, + xbarsAvailable, + llvm::formatv("elapsed={0}", formatDuration(elapsedSeconds)).str(), + done ? "" : llvm::formatv(" eta={0}", formatDuration(etaSeconds)).str()); lastProgressPrint = now; } @@ -100,9 +101,7 @@ void DcpProgressLogger::printProgress(size_t, CPU, int, size_t, size_t, bool) {} #endif -void dumpGraphDot(const std::vector& nodes, - const std::vector>& cpuTasks, - CPU lastCpu) { +void dumpGraphDot(const std::vector& nodes, const std::vector>& cpuTasks, CPU lastCpu) { static int dumpIndex = 0; std::string outputDir = onnx_mlir::getOutputDir(); if (outputDir.empty()) diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/GraphDebug.hpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/GraphDebug.hpp index 14e3783..8097a0f 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/GraphDebug.hpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/GraphDebug.hpp @@ -9,9 +9,9 @@ #include "Task.hpp" #include "Utils.hpp" -// Uncomment to enable DCP progress logging and per-phase profiling during -// development. When disabled the logger methods are no-ops and the helpers -// compile away. +// Define DCP_DEBUG_ENABLED locally when debugging DCP progress and per-phase +// profiling. In normal builds the logger methods are no-ops and helpers compile +// away. #define DCP_DEBUG_ENABLED #ifdef DCP_DEBUG_ENABLED @@ -33,10 +33,11 @@ public: void printStart(size_t readyCount, int maxCpuCount, size_t xbarsCapacity) const; void maybePrintSlowCandidate(size_t nodeIndex, double elapsedSeconds, size_t readyCount, CPU cpuCount) const; - void printProgress(size_t readyCount, CPU cpuCount, int maxCpuCount, - size_t xbarsUsed, size_t xbarsAvailable, bool force); + void + printProgress(size_t readyCount, CPU cpuCount, int maxCpuCount, size_t xbarsUsed, size_t xbarsAvailable, bool force); #ifdef DCP_DEBUG_ENABLED + private: static std::string formatDuration(double seconds); @@ -51,8 +52,6 @@ private: #endif }; -void dumpGraphDot(const std::vector& nodes, - const std::vector>& cpuTasks, - CPU lastCpu); +void dumpGraphDot(const std::vector& nodes, const std::vector>& cpuTasks, CPU lastCpu); } // namespace dcp_graph diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp index 2d63429..d26799e 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp @@ -149,14 +149,6 @@ static SmallVector getMaterializedBatchCoreIds(size_t startCpu, size_t return coreIds; } -static SmallVector getBatchCoreIds(Operation* op, size_t laneCount) { - if (auto coreIdsAttr = op->getAttrOfType(onnx_mlir::kCoreIdsAttrName)) - return SmallVector(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end()); - if (auto coreIdAttr = op->getAttrOfType(onnx_mlir::kCoreIdAttrName)) - return SmallVector(laneCount, static_cast(coreIdAttr.getInt())); - return {}; -} - static std::optional getComputeCoreId(SpatCompute compute) { if (auto coreIdAttr = compute->getAttrOfType(onnx_mlir::kCoreIdAttrName)) return static_cast(coreIdAttr.getInt()); @@ -245,312 +237,6 @@ static bool areEquivalentForRebatch(SpatCompute lhs, SpatCompute rhs) { return lhsIt == lhsBlock.end() && rhsIt == rhsBlock.end(); } -static void sinkChannelsIntoBatchComputes(func::FuncOp funcOp, - IRRewriter& rewriter, - SmallVectorImpl& opsToErase, - int64_t& nextChannelId) { - SmallVector batches(funcOp.getOps()); - - for (auto batch : batches) { - if (batch.getInputs().empty() && batch.getResults().empty()) - continue; - - if (batch.getInputs().size() != static_cast(batch.getLaneCount())) - continue; - if (batch.getResults().size() != static_cast(batch.getLaneCount())) - continue; - - SmallVector inputReceives; - inputReceives.reserve(batch.getInputs().size()); - bool allInputsAreReceives = true; - for (Value input : batch.getInputs()) { - auto receiveOp = dyn_cast_or_null(input.getDefiningOp()); - if (!receiveOp) { - allInputsAreReceives = false; - break; - } - inputReceives.push_back(receiveOp); - } - - SmallVector resultSends; - resultSends.reserve(batch.getResults().size()); - bool allResultsAreSingleSends = true; - for (Value result : batch.getResults()) { - if (!result.hasOneUse()) { - allResultsAreSingleSends = false; - break; - } - auto sendOp = dyn_cast(*result.getUsers().begin()); - if (!sendOp) { - allResultsAreSingleSends = false; - break; - } - resultSends.push_back(sendOp); - } - - if (!allInputsAreReceives || !allResultsAreSingleSends) - continue; - - Block& oldBlock = batch.getBody().front(); - if (oldBlock.getNumArguments() != 1) - continue; - - SmallVector newWeights(batch.getWeights().begin(), batch.getWeights().end()); - rewriter.setInsertionPointAfter(batch); - auto newBatch = SpatComputeBatch::create(rewriter, - batch.getLoc(), - TypeRange {}, - rewriter.getI32IntegerAttr(batch.getLaneCount()), - ValueRange(newWeights), - ValueRange {}); - newBatch.getProperties().setOperandSegmentSizes({static_cast(newWeights.size()), 0}); - - SmallVector coreIds = getBatchCoreIds(batch, static_cast(batch.getLaneCount())); - if (!coreIds.empty()) - newBatch->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds)); - - auto* newBlock = - rewriter.createBlock(&newBatch.getBody(), newBatch.getBody().end(), TypeRange {}, ArrayRef {}); - - rewriter.setInsertionPointToStart(newBlock); - struct BatchReceiveEntry { - uint64_t channelId = 0; - uint32_t sourceCoreId = 0; - uint32_t targetCoreId = 0; - }; - SmallVector receiveEntries; - receiveEntries.reserve(inputReceives.size()); - for (auto receiveOp : inputReceives) - receiveEntries.push_back({receiveOp.getChannelId(), receiveOp.getSourceCoreId(), receiveOp.getTargetCoreId()}); - llvm::stable_sort(receiveEntries, [](const BatchReceiveEntry& lhs, const BatchReceiveEntry& rhs) { - return std::tuple(lhs.sourceCoreId, lhs.targetCoreId, lhs.channelId) - < std::tuple(rhs.sourceCoreId, rhs.targetCoreId, rhs.channelId); - }); - - SmallVector receiveChannelIds; - SmallVector receiveSourceCoreIds; - SmallVector receiveTargetCoreIds; - receiveChannelIds.reserve(receiveEntries.size()); - receiveSourceCoreIds.reserve(receiveEntries.size()); - receiveTargetCoreIds.reserve(receiveEntries.size()); - for (const BatchReceiveEntry& entry : receiveEntries) { - (void) entry; - receiveChannelIds.push_back(nextChannelId++); - receiveSourceCoreIds.push_back(static_cast(entry.sourceCoreId)); - receiveTargetCoreIds.push_back(static_cast(entry.targetCoreId)); - } - auto batchReceive = spatial::SpatChannelReceiveBatchOp::create(rewriter, - batch.getLoc(), - oldBlock.getArgument(0).getType(), - rewriter.getDenseI64ArrayAttr(receiveChannelIds), - rewriter.getDenseI32ArrayAttr(receiveSourceCoreIds), - rewriter.getDenseI32ArrayAttr(receiveTargetCoreIds)); - - IRMapping mapper; - mapper.map(oldBlock.getArgument(0), batchReceive.getOutput()); - - auto oldYield = cast(oldBlock.getTerminator()); - rewriter.setInsertionPointToEnd(newBlock); - for (Operation& op : oldBlock) { - if (&op == oldYield) - continue; - rewriter.clone(op, mapper); - } - - Value sendInput = mapper.lookup(oldYield.getOperand(0)); - struct BatchSendEntry { - uint64_t channelId = 0; - uint32_t sourceCoreId = 0; - uint32_t targetCoreId = 0; - }; - SmallVector sendEntries; - sendEntries.reserve(resultSends.size()); - for (auto sendOp : resultSends) - sendEntries.push_back({sendOp.getChannelId(), sendOp.getSourceCoreId(), sendOp.getTargetCoreId()}); - llvm::stable_sort(sendEntries, [](const BatchSendEntry& lhs, const BatchSendEntry& rhs) { - return std::tuple(lhs.sourceCoreId, lhs.targetCoreId, lhs.channelId) - < std::tuple(rhs.sourceCoreId, rhs.targetCoreId, rhs.channelId); - }); - - SmallVector sendChannelIds; - SmallVector sendSourceCoreIds; - SmallVector sendTargetCoreIds; - sendChannelIds.reserve(sendEntries.size()); - sendSourceCoreIds.reserve(sendEntries.size()); - sendTargetCoreIds.reserve(sendEntries.size()); - for (const BatchSendEntry& entry : sendEntries) { - (void) entry; - sendChannelIds.push_back(nextChannelId++); - sendSourceCoreIds.push_back(static_cast(entry.sourceCoreId)); - sendTargetCoreIds.push_back(static_cast(entry.targetCoreId)); - } - spatial::SpatChannelSendBatchOp::create(rewriter, - batch.getLoc(), - rewriter.getDenseI64ArrayAttr(sendChannelIds), - rewriter.getDenseI32ArrayAttr(sendSourceCoreIds), - rewriter.getDenseI32ArrayAttr(sendTargetCoreIds), - sendInput); - spatial::SpatYieldOp::create(rewriter, batch.getLoc(), ValueRange {}); - - for (auto receiveOp : inputReceives) - opsToErase.push_back(receiveOp); - for (auto sendOp : resultSends) - opsToErase.push_back(sendOp); - opsToErase.push_back(batch); - } -} - -void sinkChannelsIntoComputes(func::FuncOp funcOp, int64_t& nextChannelId) { - IRRewriter rewriter(funcOp.getContext()); - SmallVector computes(funcOp.getOps()); - SmallVector opsToErase; - - for (auto compute : computes) { - SmallVector keptInputIndices; - SmallVector keptResultIndices; - SmallVector internalizedReceives(compute.getInputs().size()); - SmallVector> resultSendOps(compute.getNumResults()); - - bool needsRewrite = false; - for (auto [inputIndex, input] : llvm::enumerate(compute.getInputs())) { - auto receiveOp = dyn_cast_or_null(input.getDefiningOp()); - if (!receiveOp) { - keptInputIndices.push_back(inputIndex); - continue; - } - - internalizedReceives[inputIndex] = receiveOp; - opsToErase.push_back(receiveOp); - needsRewrite = true; - } - - for (auto [resultIndex, result] : llvm::enumerate(compute.getResults())) { - bool hasNonSendUser = false; - for (Operation* user : result.getUsers()) { - if (auto sendOp = dyn_cast(user)) { - resultSendOps[resultIndex].push_back(sendOp); - opsToErase.push_back(sendOp); - needsRewrite = true; - continue; - } - hasNonSendUser = true; - } - - if (hasNonSendUser || resultSendOps[resultIndex].empty()) - keptResultIndices.push_back(resultIndex); - } - - if (!needsRewrite) - continue; - - SmallVector newOperands; - SmallVector newResultTypes; - SmallVector newBlockArgTypes; - SmallVector newBlockArgLocs; - newOperands.reserve(compute.getNumOperands()); - newResultTypes.reserve(keptResultIndices.size()); - newBlockArgTypes.reserve(keptInputIndices.size()); - newBlockArgLocs.reserve(keptInputIndices.size()); - - for (Value weight : compute.getWeights()) - newOperands.push_back(weight); - for (unsigned inputIndex : keptInputIndices) { - Value input = compute.getInputs()[inputIndex]; - newOperands.push_back(input); - newBlockArgTypes.push_back(input.getType()); - newBlockArgLocs.push_back(compute.getLoc()); - } - for (unsigned resultIndex : keptResultIndices) - newResultTypes.push_back(compute.getResult(resultIndex).getType()); - - rewriter.setInsertionPointAfter(compute); - auto newCompute = - SpatCompute::create(rewriter, compute.getLoc(), TypeRange(newResultTypes), ValueRange(newOperands)); - newCompute.getProperties().setOperandSegmentSizes( - {static_cast(compute.getWeights().size()), static_cast(keptInputIndices.size())}); - if (auto coreIdAttr = compute->getAttrOfType(onnx_mlir::kCoreIdAttrName)) - newCompute->setAttr(onnx_mlir::kCoreIdAttrName, coreIdAttr); - - auto* newBlock = rewriter.createBlock( - &newCompute.getBody(), newCompute.getBody().end(), TypeRange(newBlockArgTypes), newBlockArgLocs); - - IRMapping mapper; - for (auto [mappedIndex, inputIndex] : llvm::enumerate(keptInputIndices)) - mapper.map(compute.getBody().front().getArgument(inputIndex), newBlock->getArgument(mappedIndex)); - - rewriter.setInsertionPointToStart(newBlock); - for (auto [inputIndex, receiveOp] : llvm::enumerate(internalizedReceives)) { - if (!receiveOp) - continue; - - auto internalReceive = spatial::SpatChannelReceiveOp::create(rewriter, - receiveOp.getLoc(), - receiveOp.getResult().getType(), - receiveOp.getChannelIdAttr(), - receiveOp.getSourceCoreIdAttr(), - receiveOp.getTargetCoreIdAttr()); - mapper.map(compute.getBody().front().getArgument(inputIndex), internalReceive.getResult()); - } - - auto oldYieldOp = cast(compute.getBody().front().getTerminator()); - rewriter.setInsertionPointToEnd(newBlock); - for (Operation& op : compute.getBody().front()) { - if (&op == oldYieldOp) - continue; - rewriter.clone(op, mapper); - } - - for (auto [resultIndex, sendOps] : llvm::enumerate(resultSendOps)) { - if (sendOps.empty()) - continue; - - Value yieldedValue = mapper.lookup(oldYieldOp.getOperand(resultIndex)); - for (auto sendOp : sendOps) - spatial::SpatChannelSendOp::create(rewriter, - sendOp.getLoc(), - sendOp.getChannelIdAttr(), - sendOp.getSourceCoreIdAttr(), - sendOp.getTargetCoreIdAttr(), - yieldedValue); - } - - SmallVector keptYieldOperands; - keptYieldOperands.reserve(keptResultIndices.size()); - for (unsigned resultIndex : keptResultIndices) - keptYieldOperands.push_back(mapper.lookup(oldYieldOp.getOperand(resultIndex))); - spatial::SpatYieldOp::create(rewriter, oldYieldOp.getLoc(), ValueRange(keptYieldOperands)); - - for (auto [newResultIndex, oldResultIndex] : llvm::enumerate(keptResultIndices)) - compute.getResult(oldResultIndex).replaceAllUsesWith(newCompute.getResult(newResultIndex)); - - opsToErase.push_back(compute); - } - - sinkChannelsIntoBatchComputes(funcOp, rewriter, opsToErase, nextChannelId); - - SmallVector pendingRemovals(opsToErase.begin(), opsToErase.end()); - while (!pendingRemovals.empty()) { - bool erasedAny = false; - for (auto it = pendingRemovals.begin(); it != pendingRemovals.end();) { - if (!(*it)->use_empty()) { - ++it; - continue; - } - - rewriter.eraseOp(*it); - it = pendingRemovals.erase(it); - erasedAny = true; - } - - if (erasedAny) - continue; - - for (Operation* op : pendingRemovals) - op->emitError("failed to sink channel op into compute"); - llvm_unreachable("channel sinking left cyclic top-level dependencies"); - } -} - void rebatchEquivalentComputes(func::FuncOp funcOp, int64_t& nextChannelId) { IRRewriter rewriter(funcOp.getContext()); SmallVector computes(funcOp.getOps()); @@ -1280,7 +966,8 @@ public: void runOnOperation() override { mergeTriviallyConnectedComputes(getOperation()); - emitMotifProfile(getOperation()); + if (std::getenv("DCP_MOTIF_PROFILE")) + emitMotifProfile(getOperation()); func::FuncOp func = getOperation(); Location loc = func.getLoc(); @@ -1718,17 +1405,12 @@ public: for (Operation* user : result.getUsers()) remainingUsers.push_back(user); if (!remainingUsers.empty()) { - llvm::errs() << "[MergeComputeNodesPass] refusing to erase op with remaining uses: " << op->getName() << "\n"; - llvm::errs() << " erase-set: " << (allOpsToErase.contains(op) ? "yes" : "no") << "\n"; - op->print(llvm::errs(), mlir::OpPrintingFlags().skipRegions()); - llvm::errs() << "\n"; + InFlightDiagnostic diagnostic = op->emitOpError("still has uses during per-cpu merge cleanup") + << "; erase-set=" << (allOpsToErase.contains(op) ? "yes" : "no"); for (Operation* user : remainingUsers) { - llvm::errs() << " user: " << user->getName() - << " erase-set=" << (allOpsToErase.contains(user) ? "yes" : "no") << "\n"; - user->print(llvm::errs(), mlir::OpPrintingFlags().skipRegions()); - llvm::errs() << "\n"; + diagnostic.attachNote(user->getLoc()) + << "remaining user " << user->getName() << "; erase-set=" << (allOpsToErase.contains(user) ? "yes" : "no"); } - op->emitOpError("still has uses during per-cpu merge cleanup"); signalPassFailure(); return; } diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/RegularOpCompaction.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/RegularOpCompaction.cpp index 925ba85..74082fd 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/RegularOpCompaction.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/RegularOpCompaction.cpp @@ -1,6 +1,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Value.h" @@ -40,6 +41,176 @@ struct RegularChunk { Value output; }; +static RankedTensorType getPackedTensorType(RankedTensorType elementType, int64_t count) { + SmallVector packedShape(elementType.getShape().begin(), elementType.getShape().end()); + packedShape[0] *= count; + return RankedTensorType::get(packedShape, elementType.getElementType()); +} + +static Value +extractPackedChunk(Value packedValue, RankedTensorType chunkType, unsigned index, IRRewriter& rewriter, Location loc) { + SmallVector offsets; + SmallVector sizes; + SmallVector strides; + offsets.reserve(chunkType.getRank()); + sizes.reserve(chunkType.getRank()); + strides.reserve(chunkType.getRank()); + + offsets.push_back(rewriter.getIndexAttr(static_cast(index) * chunkType.getDimSize(0))); + sizes.push_back(rewriter.getIndexAttr(chunkType.getDimSize(0))); + strides.push_back(rewriter.getIndexAttr(1)); + for (int64_t dim = 1; dim < chunkType.getRank(); ++dim) { + offsets.push_back(rewriter.getIndexAttr(0)); + sizes.push_back(rewriter.getIndexAttr(chunkType.getDimSize(dim))); + strides.push_back(rewriter.getIndexAttr(1)); + } + + return tensor::ExtractSliceOp::create(rewriter, loc, chunkType, packedValue, offsets, sizes, strides).getResult(); +} + +static Value createPackedExtractRowsSlice( + spatial::SpatExtractRowsOp extractRowsOp, unsigned startIndex, unsigned count, IRRewriter& rewriter, Location loc) { + auto rowType = dyn_cast(extractRowsOp.getOutputs()[startIndex].getType()); + auto inputType = dyn_cast(extractRowsOp.getInput().getType()); + if (!rowType || !inputType || !rowType.hasStaticShape() || !inputType.hasStaticShape() || rowType.getRank() == 0) + return {}; + + int64_t rowsPerValue = rowType.getDimSize(0); + if (ShapedType::isDynamic(rowsPerValue)) + return {}; + + auto packedType = getPackedTensorType(rowType, static_cast(count)); + SmallVector offsets; + SmallVector sizes; + SmallVector strides; + offsets.reserve(inputType.getRank()); + sizes.reserve(inputType.getRank()); + strides.reserve(inputType.getRank()); + + offsets.push_back(rewriter.getIndexAttr(static_cast(startIndex) * rowsPerValue)); + sizes.push_back(rewriter.getIndexAttr(static_cast(count) * rowsPerValue)); + strides.push_back(rewriter.getIndexAttr(1)); + for (int64_t dim = 1; dim < inputType.getRank(); ++dim) { + offsets.push_back(rewriter.getIndexAttr(0)); + sizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(dim))); + strides.push_back(rewriter.getIndexAttr(1)); + } + + return tensor::ExtractSliceOp::create(rewriter, loc, packedType, extractRowsOp.getInput(), offsets, sizes, strides) + .getResult(); +} + +static Value createPackedExtractSliceTensor(ValueRange values, IRRewriter& rewriter, Location loc) { + if (values.empty()) + return {}; + if (values.size() == 1) + return values.front(); + + auto firstSliceOp = values.front().getDefiningOp(); + if (!firstSliceOp) + return {}; + + auto firstType = dyn_cast(firstSliceOp.getResult().getType()); + auto sourceType = dyn_cast(firstSliceOp.getSource().getType()); + if (!firstType || !sourceType || !firstType.hasStaticShape() || !sourceType.hasStaticShape() + || firstType.getRank() == 0) + return {}; + + auto hasStaticValues = [](ArrayRef values) { + return llvm::all_of(values, [](int64_t value) { return !ShapedType::isDynamic(value); }); + }; + if (!hasStaticValues(firstSliceOp.getStaticOffsets()) || !hasStaticValues(firstSliceOp.getStaticSizes()) + || !hasStaticValues(firstSliceOp.getStaticStrides())) + return {}; + + ArrayRef firstOffsets = firstSliceOp.getStaticOffsets(); + ArrayRef firstSizes = firstSliceOp.getStaticSizes(); + ArrayRef firstStrides = firstSliceOp.getStaticStrides(); + int64_t rowsPerValue = firstSizes[0]; + if (ShapedType::isDynamic(rowsPerValue)) + return {}; + + for (size_t index = 1; index < values.size(); ++index) { + auto sliceOp = values[index].getDefiningOp(); + if (!sliceOp || sliceOp.getSource() != firstSliceOp.getSource() + || sliceOp.getResult().getType() != firstSliceOp.getResult().getType() + || !hasStaticValues(sliceOp.getStaticOffsets()) || !hasStaticValues(sliceOp.getStaticSizes()) + || !hasStaticValues(sliceOp.getStaticStrides())) + return {}; + + if (sliceOp.getStaticSizes() != firstSizes || sliceOp.getStaticStrides() != firstStrides) + return {}; + + if (sliceOp.getStaticOffsets()[0] != firstOffsets[0] + static_cast(index) * rowsPerValue) + return {}; + + for (int64_t dim = 1; dim < firstType.getRank(); ++dim) + if (sliceOp.getStaticOffsets()[dim] != firstOffsets[dim]) + return {}; + } + + auto packedType = getPackedTensorType(firstType, static_cast(values.size())); + SmallVector offsets; + SmallVector sizes; + SmallVector strides; + offsets.reserve(firstType.getRank()); + sizes.reserve(firstType.getRank()); + strides.reserve(firstType.getRank()); + + offsets.push_back(rewriter.getIndexAttr(firstOffsets[0])); + sizes.push_back(rewriter.getIndexAttr(rowsPerValue * static_cast(values.size()))); + strides.push_back(rewriter.getIndexAttr(firstStrides[0])); + for (int64_t dim = 1; dim < firstType.getRank(); ++dim) { + offsets.push_back(rewriter.getIndexAttr(firstOffsets[dim])); + sizes.push_back(rewriter.getIndexAttr(firstSizes[dim])); + strides.push_back(rewriter.getIndexAttr(firstStrides[dim])); + } + + return tensor::ExtractSliceOp::create(rewriter, loc, packedType, firstSliceOp.getSource(), offsets, sizes, strides) + .getResult(); +} + +static bool getContiguousOpResults(ValueRange values, Operation*& owner, unsigned& startIndex) { + if (values.empty()) + return false; + + auto firstResult = dyn_cast(values.front()); + if (!firstResult) + return false; + + owner = firstResult.getOwner(); + startIndex = firstResult.getResultNumber(); + for (auto [index, value] : llvm::enumerate(values)) { + auto result = dyn_cast(value); + if (!result || result.getOwner() != owner || result.getResultNumber() != startIndex + index) + return false; + } + + return true; +} + +static Value createPackedTensorForValues(ValueRange values, IRRewriter& rewriter, Location loc) { + if (values.empty()) + return {}; + if (Value packedSlice = createPackedExtractSliceTensor(values, rewriter, loc)) + return packedSlice; + + Operation* owner = nullptr; + unsigned startIndex = 0; + if (getContiguousOpResults(values, owner, startIndex)) + if (auto extractRowsOp = dyn_cast(owner)) + return createPackedExtractRowsSlice( + extractRowsOp, startIndex, static_cast(values.size()), rewriter, loc); + + auto firstType = dyn_cast(values.front().getType()); + if (!firstType || !firstType.hasStaticShape() || firstType.getRank() == 0) + return {}; + if (!llvm::all_of(values.drop_front(), [&](Value value) { return value.getType() == firstType; })) + return {}; + + return tensor::ConcatOp::create(rewriter, loc, /*dim=*/0, values).getResult(); +} + static bool areEquivalentRegularSteps(const RegularStep& lhs, const RegularStep& rhs) { return lhs.kind == rhs.kind && lhs.weightIndex == rhs.weightIndex && lhs.invariantOperand == rhs.invariantOperand && lhs.resultType == rhs.resultType; @@ -89,45 +260,97 @@ static FailureOr analyzeRegularChunk(spatial::SpatVMMOp startOp) { return chunk; } -static void buildRegularMapBody(spatial::SpatMapOp mapOp, const RegularChunk& anchorChunk, IRRewriter& rewriter) { - auto* block = rewriter.createBlock( - &mapOp.getBody(), mapOp.getBody().end(), TypeRange {anchorChunk.input.getType()}, {anchorChunk.startOp->getLoc()}); - rewriter.setInsertionPointToEnd(block); - - IRMapping mapping; - mapping.map(anchorChunk.input, block->getArgument(0)); - - for (Operation* op : anchorChunk.ops) { - Operation* cloned = rewriter.clone(*op, mapping); - for (auto [oldResult, newResult] : llvm::zip(op->getResults(), cloned->getResults())) - mapping.map(oldResult, newResult); - } - - spatial::SpatYieldOp::create( - rewriter, anchorChunk.startOp->getLoc(), ValueRange {mapping.lookup(anchorChunk.output)}); -} - static void compactRegularChunkRun(IRRewriter& rewriter, ArrayRef run) { assert(!run.empty() && "expected a non-empty regular chunk run"); const RegularChunk& anchorChunk = run.front(); SmallVector inputs; - SmallVector outputTypes; inputs.reserve(run.size()); - outputTypes.reserve(run.size()); - for (const RegularChunk& chunk : run) { + for (const RegularChunk& chunk : run) inputs.push_back(chunk.input); - outputTypes.push_back(chunk.output.getType()); - } rewriter.setInsertionPoint(anchorChunk.startOp); - auto mapOp = - spatial::SpatMapOp::create(rewriter, anchorChunk.startOp->getLoc(), TypeRange(outputTypes), ValueRange(inputs)); - buildRegularMapBody(mapOp, anchorChunk, rewriter); + Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, anchorChunk.startOp->getLoc()); + if (!packedInput) + return; + + auto inputType = cast(anchorChunk.input.getType()); + auto outputType = cast(anchorChunk.output.getType()); + auto packedOutputType = getPackedTensorType(outputType, static_cast(run.size())); + auto packedInit = tensor::EmptyOp::create( + rewriter, anchorChunk.startOp->getLoc(), packedOutputType.getShape(), packedOutputType.getElementType()); + auto zero = arith::ConstantIndexOp::create(rewriter, anchorChunk.startOp->getLoc(), 0); + auto upper = arith::ConstantIndexOp::create(rewriter, anchorChunk.startOp->getLoc(), run.size()); + auto step = arith::ConstantIndexOp::create(rewriter, anchorChunk.startOp->getLoc(), 1); + auto loop = + scf::ForOp::create(rewriter, anchorChunk.startOp->getLoc(), zero, upper, step, ValueRange {packedInit.getResult()}); + + { + OpBuilder::InsertionGuard guard(rewriter); + Block* loopBlock = loop.getBody(); + rewriter.setInsertionPointToStart(loopBlock); + Value iv = loopBlock->getArgument(0); + Value acc = loopBlock->getArgument(1); + + Value inputRowOffset = iv; + if (inputType.getDimSize(0) != 1) { + auto rowsPerValue = + arith::ConstantIndexOp::create(rewriter, anchorChunk.startOp->getLoc(), inputType.getDimSize(0)); + inputRowOffset = arith::MulIOp::create(rewriter, anchorChunk.startOp->getLoc(), iv, rowsPerValue); + } + + SmallVector extractOffsets; + SmallVector extractSizes; + SmallVector extractStrides; + extractOffsets.push_back(inputRowOffset); + extractSizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(0))); + extractStrides.push_back(rewriter.getIndexAttr(1)); + for (int64_t dim = 1; dim < inputType.getRank(); ++dim) { + extractOffsets.push_back(rewriter.getIndexAttr(0)); + extractSizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(dim))); + extractStrides.push_back(rewriter.getIndexAttr(1)); + } + auto inputSlice = tensor::ExtractSliceOp::create( + rewriter, anchorChunk.startOp->getLoc(), inputType, packedInput, extractOffsets, extractSizes, extractStrides); + + IRMapping mapping; + mapping.map(anchorChunk.input, inputSlice.getResult()); + for (Operation* op : anchorChunk.ops) { + Operation* cloned = rewriter.clone(*op, mapping); + for (auto [oldResult, newResult] : llvm::zip(op->getResults(), cloned->getResults())) + mapping.map(oldResult, newResult); + } + + Value mappedOutput = mapping.lookup(anchorChunk.output); + Value outputRowOffset = iv; + if (outputType.getDimSize(0) != 1) { + auto rowsPerValue = + arith::ConstantIndexOp::create(rewriter, anchorChunk.startOp->getLoc(), outputType.getDimSize(0)); + outputRowOffset = arith::MulIOp::create(rewriter, anchorChunk.startOp->getLoc(), iv, rowsPerValue); + } + + SmallVector insertOffsets; + SmallVector insertSizes; + SmallVector insertStrides; + insertOffsets.push_back(outputRowOffset); + insertSizes.push_back(rewriter.getIndexAttr(outputType.getDimSize(0))); + insertStrides.push_back(rewriter.getIndexAttr(1)); + for (int64_t dim = 1; dim < outputType.getRank(); ++dim) { + insertOffsets.push_back(rewriter.getIndexAttr(0)); + insertSizes.push_back(rewriter.getIndexAttr(outputType.getDimSize(dim))); + insertStrides.push_back(rewriter.getIndexAttr(1)); + } + + auto inserted = tensor::InsertSliceOp::create( + rewriter, anchorChunk.startOp->getLoc(), mappedOutput, acc, insertOffsets, insertSizes, insertStrides); + scf::YieldOp::create(rewriter, anchorChunk.startOp->getLoc(), inserted.getResult()); + } for (auto [index, chunk] : llvm::enumerate(run)) { + Value replacement = extractPackedChunk( + loop.getResult(0), outputType, static_cast(index), rewriter, chunk.startOp->getLoc()); Value output = chunk.output; - output.replaceAllUsesWith(mapOp.getResult(index)); + output.replaceAllUsesWith(replacement); } SmallVector opsToErase; @@ -178,28 +401,29 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) { SmallVector channelIds; SmallVector sourceCoreIds; SmallVector targetCoreIds; - SmallVector outputTypes; channelIds.reserve(sortedEntries.size()); sourceCoreIds.reserve(sortedEntries.size()); targetCoreIds.reserve(sortedEntries.size()); - outputTypes.reserve(sortedEntries.size()); for (ReceiveEntry& entry : sortedEntries) { (void) entry; channelIds.push_back(nextChannelId++); sourceCoreIds.push_back(static_cast(entry.sourceCoreId)); targetCoreIds.push_back(static_cast(entry.targetCoreId)); - outputTypes.push_back(entry.op.getOutput().getType()); } + auto rowType = cast(run.front().getOutput().getType()); + auto packedType = getPackedTensorType(rowType, static_cast(sortedEntries.size())); rewriter.setInsertionPoint(run.front()); - auto compactReceive = spatial::SpatChannelReceiveManyOp::create(rewriter, - run.front().getLoc(), - TypeRange(outputTypes), - rewriter.getDenseI64ArrayAttr(channelIds), - rewriter.getDenseI32ArrayAttr(sourceCoreIds), - rewriter.getDenseI32ArrayAttr(targetCoreIds)); + auto compactReceive = + spatial::SpatChannelReceiveTensorOp::create(rewriter, + run.front().getLoc(), + packedType, + rewriter.getDenseI64ArrayAttr(channelIds), + rewriter.getDenseI32ArrayAttr(sourceCoreIds), + rewriter.getDenseI32ArrayAttr(targetCoreIds)); for (auto [sortedIndex, entry] : llvm::enumerate(sortedEntries)) - entry.op.getOutput().replaceAllUsesWith(compactReceive.getResult(sortedIndex)); + entry.op.getOutput().replaceAllUsesWith(extractPackedChunk( + compactReceive.getOutput(), rowType, static_cast(sortedIndex), rewriter, entry.op.getLoc())); for (auto op : run) rewriter.eraseOp(op); @@ -255,17 +479,20 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) { } rewriter.setInsertionPoint(run.front()); - spatial::SpatChannelSendManyOp::create(rewriter, - run.front().getLoc(), - rewriter.getDenseI64ArrayAttr(channelIds), - rewriter.getDenseI32ArrayAttr(sourceCoreIds), - rewriter.getDenseI32ArrayAttr(targetCoreIds), - ValueRange(inputs)); - for (auto op : run) - rewriter.eraseOp(op); + Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, run.front().getLoc()); + if (packedInput) { + spatial::SpatChannelSendTensorOp::create(rewriter, + run.front().getLoc(), + rewriter.getDenseI64ArrayAttr(channelIds), + rewriter.getDenseI32ArrayAttr(sourceCoreIds), + rewriter.getDenseI32ArrayAttr(targetCoreIds), + packedInput); + for (auto op : run) + rewriter.eraseOp(op); - it = runIt; - continue; + it = runIt; + continue; + } } } @@ -297,25 +524,25 @@ void compactBatchChannelRuns(func::FuncOp funcOp) { SmallVector channelIds; SmallVector sourceCoreIds; SmallVector targetCoreIds; - SmallVector outputTypes; - outputTypes.reserve(run.size()); for (auto op : run) { llvm::append_range(channelIds, op.getChannelIds()); llvm::append_range(sourceCoreIds, op.getSourceCoreIds()); llvm::append_range(targetCoreIds, op.getTargetCoreIds()); - outputTypes.push_back(op.getOutput().getType()); } + auto rowType = cast(run.front().getOutput().getType()); + auto packedType = getPackedTensorType(rowType, static_cast(run.size())); rewriter.setInsertionPoint(run.front()); auto compactReceive = - spatial::SpatChannelReceiveManyBatchOp::create(rewriter, - run.front().getLoc(), - TypeRange(outputTypes), - rewriter.getDenseI64ArrayAttr(channelIds), - rewriter.getDenseI32ArrayAttr(sourceCoreIds), - rewriter.getDenseI32ArrayAttr(targetCoreIds)); + spatial::SpatChannelReceiveTensorBatchOp::create(rewriter, + run.front().getLoc(), + packedType, + rewriter.getDenseI64ArrayAttr(channelIds), + rewriter.getDenseI32ArrayAttr(sourceCoreIds), + rewriter.getDenseI32ArrayAttr(targetCoreIds)); for (auto [index, op] : llvm::enumerate(run)) - op.getOutput().replaceAllUsesWith(compactReceive.getResult(index)); + op.getOutput().replaceAllUsesWith(extractPackedChunk( + compactReceive.getOutput(), rowType, static_cast(index), rewriter, op.getLoc())); for (auto op : run) rewriter.eraseOp(op); @@ -352,17 +579,20 @@ void compactBatchChannelRuns(func::FuncOp funcOp) { } rewriter.setInsertionPoint(run.front()); - spatial::SpatChannelSendManyBatchOp::create(rewriter, - run.front().getLoc(), - rewriter.getDenseI64ArrayAttr(channelIds), - rewriter.getDenseI32ArrayAttr(sourceCoreIds), - rewriter.getDenseI32ArrayAttr(targetCoreIds), - ValueRange(inputs)); - for (auto op : run) - rewriter.eraseOp(op); + Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, run.front().getLoc()); + if (packedInput) { + spatial::SpatChannelSendTensorBatchOp::create(rewriter, + run.front().getLoc(), + rewriter.getDenseI64ArrayAttr(channelIds), + rewriter.getDenseI32ArrayAttr(sourceCoreIds), + rewriter.getDenseI32ArrayAttr(targetCoreIds), + packedInput); + for (auto op : run) + rewriter.eraseOp(op); - it = runIt; - continue; + it = runIt; + continue; + } } } diff --git a/src/PIM/Pass/PimCodegen/HostConstantFolding/Common.cpp b/src/PIM/Pass/PimCodegen/HostConstantFolding/Common.cpp index 0a8e08c..8cbaeab 100644 --- a/src/PIM/Pass/PimCodegen/HostConstantFolding/Common.cpp +++ b/src/PIM/Pass/PimCodegen/HostConstantFolding/Common.cpp @@ -1,9 +1,9 @@ -#include "Common.hpp" -#include "src/Accelerators/PIM/Common/PimCommon.hpp" - #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/Hashing.h" +#include "Common.hpp" +#include "src/Accelerators/PIM/Common/PimCommon.hpp" + using namespace mlir; namespace onnx_mlir { @@ -31,7 +31,8 @@ struct DenseSubviewKeyInfo { static unsigned getHashValue(const DenseSubviewKey& key) { return static_cast( - llvm::hash_combine(key.source, llvm::hash_combine_range(key.offsets.begin(), key.offsets.end()), + llvm::hash_combine(key.source, + llvm::hash_combine_range(key.offsets.begin(), key.offsets.end()), llvm::hash_combine_range(key.resultShape.begin(), key.resultShape.end()))); } @@ -98,16 +99,16 @@ memref::GlobalOp createFoldedGlobal(ModuleOp moduleOp, alignment); } -FailureOr foldDenseSubview(DenseElementsAttr denseAttr, - ArrayRef staticOffsets, - ArrayRef resultShape) { +FailureOr +foldDenseSubview(DenseElementsAttr denseAttr, ArrayRef staticOffsets, ArrayRef resultShape) { auto sourceType = dyn_cast(denseAttr.getType()); if (!sourceType || !sourceType.hasStaticShape() || sourceType.getRank() != static_cast(staticOffsets.size()) || sourceType.getRank() != static_cast(resultShape.size())) return failure(); static DenseMap cache; - DenseSubviewKey key {denseAttr, SmallVector(staticOffsets.begin(), staticOffsets.end()), + DenseSubviewKey key {denseAttr, + SmallVector(staticOffsets.begin(), staticOffsets.end()), SmallVector(resultShape.begin(), resultShape.end())}; if (auto cached = cache.find(key); cached != cache.end()) return cached->second; @@ -152,6 +153,30 @@ FailureOr getDenseGlobalValue(ModuleOp moduleOp, Value value) return denseAttr; } +FailureOr foldDenseSourceToType(ModuleOp moduleOp, Value source, MemRefType resultType) { + auto srcSubview = getStaticSubviewInfo(source); + Value globalSource = succeeded(srcSubview) ? srcSubview->source : stripMemRefCasts(source); + + auto denseAttr = getDenseGlobalValue(moduleOp, globalSource); + if (failed(denseAttr)) + return failure(); + + if (succeeded(srcSubview)) { + if (llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; })) + return failure(); + auto staticOffsets = getStaticSubviewOffsets(*srcSubview); + if (failed(staticOffsets)) + return failure(); + + return foldDenseSubview(*denseAttr, *staticOffsets, resultType.getShape()); + } + + auto resultTensorType = RankedTensorType::get(resultType.getShape(), resultType.getElementType()); + if (resultTensorType != denseAttr->getType()) + return failure(); + return *denseAttr; +} + FailureOr getStaticSubviewInfo(Value value) { value = stripMemRefViewOps(value); auto subviewOp = value.getDefiningOp(); diff --git a/src/PIM/Pass/PimCodegen/HostConstantFolding/Common.hpp b/src/PIM/Pass/PimCodegen/HostConstantFolding/Common.hpp index c355de9..2498658 100644 --- a/src/PIM/Pass/PimCodegen/HostConstantFolding/Common.hpp +++ b/src/PIM/Pass/PimCodegen/HostConstantFolding/Common.hpp @@ -36,6 +36,9 @@ llvm::FailureOr foldDenseSubview(mlir::DenseElementsAtt llvm::FailureOr getDenseGlobalValue(mlir::ModuleOp moduleOp, mlir::Value value); +llvm::FailureOr +foldDenseSourceToType(mlir::ModuleOp moduleOp, mlir::Value source, mlir::MemRefType resultType); + llvm::FailureOr getStaticSubviewInfo(mlir::Value value); /// Returns the offsets in `info` as int64_t, failing if any offset is dynamic. diff --git a/src/PIM/Pass/PimCodegen/HostConstantFolding/Patterns/Constant.cpp b/src/PIM/Pass/PimCodegen/HostConstantFolding/Patterns/Constant.cpp index a69a8a2..d575c84 100644 --- a/src/PIM/Pass/PimCodegen/HostConstantFolding/Patterns/Constant.cpp +++ b/src/PIM/Pass/PimCodegen/HostConstantFolding/Patterns/Constant.cpp @@ -90,6 +90,7 @@ static FailureOr getConstantMapYield(linalg::MapOp mapOp) { return attr; } +// Folds constant linalg fills inside cores into private globals plus device copies. struct FoldConstantCoreMapPattern final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -249,6 +250,7 @@ static FailureOr foldConstantAlloc(memref::AllocOp allocOp, M return DenseElementsAttr::get(resultTensorType, resultValues); } +// Folds transposes of constant globals so weight-only transposes stay host-side. struct FoldConstantTransposePattern final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -304,11 +306,9 @@ struct FoldConstantTransposePattern final : OpRewritePatterngetUsers().empty() - && llvm::all_of(transposeOp->getUsers(), [](Operation* user) { - return isa(user); - }); + bool isAlwaysWeight = !transposeOp->getUsers().empty() + && llvm::all_of(transposeOp->getUsers(), + [](Operation* user) { return isa(user); }); if (isAlwaysWeight) { markWeightAlways(newGlobal); markWeightAlways(newGetGlobal); @@ -330,6 +330,7 @@ struct FoldConstantTransposePattern final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -367,9 +368,8 @@ struct FoldConstantAllocPattern final : OpRewritePattern { } if (!llvm::all_of(castsToReplace, [](memref::CastOp castOp) { - return llvm::all_of(castOp->getUsers(), [](Operation* user) { - return isa(user); - }); + return llvm::all_of(castOp->getUsers(), + [](Operation* user) { return isa(user); }); })) { allLiveUsersAreCoreOps = false; } @@ -417,6 +417,7 @@ struct FoldConstantAllocPattern final : OpRewritePattern { } }; +// Converts host copies from dense globals into direct folded globals. struct FoldConstantHostCopyPattern final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -431,37 +432,14 @@ struct FoldConstantHostCopyPattern final : OpRewritePattern { if (!allocType || !allocType.hasStaticShape()) return failure(); - auto srcSubview = getStaticSubviewInfo(copyOp.getSource()); - Value globalSource = succeeded(srcSubview) ? srcSubview->source : stripMemRefCasts(copyOp.getSource()); - auto moduleOp = copyOp->getParentOfType(); if (!moduleOp) return failure(); - auto denseAttr = getDenseGlobalValue(moduleOp, globalSource); - if (failed(denseAttr)) + auto foldedAttr = foldDenseSourceToType(moduleOp, copyOp.getSource(), allocType); + if (failed(foldedAttr)) return failure(); - DenseElementsAttr foldedAttr; - if (succeeded(srcSubview)) { - if (llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; })) - return failure(); - auto staticOffsets = getStaticSubviewOffsets(*srcSubview); - if (failed(staticOffsets)) - return failure(); - - auto maybeFoldedAttr = foldDenseSubview(*denseAttr, *staticOffsets, allocType.getShape()); - if (failed(maybeFoldedAttr)) - return failure(); - foldedAttr = *maybeFoldedAttr; - } - else { - auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType()); - if (resultTensorType != denseAttr->getType()) - return failure(); - foldedAttr = *denseAttr; - } - bool allLiveUsersAreCores = true; for (Operation* user : allocOp->getUsers()) { if (user == copyOp) @@ -477,7 +455,7 @@ struct FoldConstantHostCopyPattern final : OpRewritePattern { return failure(); } - auto newGlobal = createFoldedGlobal(moduleOp, allocOp.getLoc(), allocType, foldedAttr, "pim_folded_host_copy"); + auto newGlobal = createFoldedGlobal(moduleOp, allocOp.getLoc(), allocType, *foldedAttr, "pim_folded_host_copy"); if (allLiveUsersAreCores) markWeightAlways(newGlobal); @@ -494,6 +472,7 @@ struct FoldConstantHostCopyPattern final : OpRewritePattern { } }; +// Converts PIM copies from dense globals into direct folded globals before codegen. struct FoldConstantMemCpPattern final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -511,37 +490,14 @@ struct FoldConstantMemCpPattern final : OpRewritePattern { if (copyOp.getTargetOffset() != 0 || copyOp.getSourceOffset() != 0) return failure(); - auto srcSubview = getStaticSubviewInfo(copyOp.getSource()); - Value globalSource = succeeded(srcSubview) ? srcSubview->source : stripMemRefCasts(copyOp.getSource()); - auto moduleOp = copyOp->getParentOfType(); if (!moduleOp) return failure(); - auto denseAttr = getDenseGlobalValue(moduleOp, globalSource); - if (failed(denseAttr)) + auto foldedAttr = foldDenseSourceToType(moduleOp, copyOp.getSource(), allocType); + if (failed(foldedAttr)) return failure(); - DenseElementsAttr foldedAttr; - if (succeeded(srcSubview)) { - if (llvm::any_of(srcSubview->strides, [](int64_t stride) { return stride != 1; })) - return failure(); - auto staticOffsets = getStaticSubviewOffsets(*srcSubview); - if (failed(staticOffsets)) - return failure(); - - auto maybeFoldedAttr = foldDenseSubview(*denseAttr, *staticOffsets, allocType.getShape()); - if (failed(maybeFoldedAttr)) - return failure(); - foldedAttr = *maybeFoldedAttr; - } - else { - auto resultTensorType = RankedTensorType::get(allocType.getShape(), allocType.getElementType()); - if (resultTensorType != denseAttr->getType()) - return failure(); - foldedAttr = *denseAttr; - } - bool allLiveUsersAreCores = true; for (Operation* user : allocOp->getUsers()) { if (user == copyOp) @@ -557,7 +513,7 @@ struct FoldConstantMemCpPattern final : OpRewritePattern { return failure(); } - auto newGlobal = createFoldedGlobal(moduleOp, allocOp.getLoc(), allocType, foldedAttr, "pim_folded_memcp"); + auto newGlobal = createFoldedGlobal(moduleOp, allocOp.getLoc(), allocType, *foldedAttr, "pim_folded_memcp"); if (allLiveUsersAreCores) markWeightAlways(newGlobal); @@ -577,13 +533,11 @@ struct FoldConstantMemCpPattern final : OpRewritePattern { } // namespace void populateConstantFoldingConstantPatterns(RewritePatternSet& patterns) { - patterns - .add( - patterns.getContext()); + patterns.add(patterns.getContext()); } } // namespace onnx_mlir diff --git a/src/PIM/Pass/PimCodegen/HostConstantFolding/Patterns/Subview.cpp b/src/PIM/Pass/PimCodegen/HostConstantFolding/Patterns/Subview.cpp index 586522d..d6bddc9 100644 --- a/src/PIM/Pass/PimCodegen/HostConstantFolding/Patterns/Subview.cpp +++ b/src/PIM/Pass/PimCodegen/HostConstantFolding/Patterns/Subview.cpp @@ -128,6 +128,7 @@ static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp, return success(); } +// Splits core copies through subviews into contiguous copy chunks for codegen. struct RewriteCoreSubviewCopyPattern final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -162,6 +163,7 @@ struct RewriteCoreSubviewCopyPattern final : OpRewritePattern } }; +// Splits host-to-device subview loads into contiguous copy chunks. struct RewriteHostSubviewLoadPattern final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -193,6 +195,7 @@ struct RewriteHostSubviewLoadPattern final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -224,6 +227,7 @@ struct RewriteHostSubviewStorePattern final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; diff --git a/src/PIM/Pass/PimCodegen/MaterializeHostConstantsPass.cpp b/src/PIM/Pass/PimCodegen/MaterializeHostConstantsPass.cpp index 31b72d5..abe483b 100644 --- a/src/PIM/Pass/PimCodegen/MaterializeHostConstantsPass.cpp +++ b/src/PIM/Pass/PimCodegen/MaterializeHostConstantsPass.cpp @@ -1,5 +1,6 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/Builders.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" @@ -9,6 +10,8 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/Support/MathExtras.h" +#include + #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" @@ -21,6 +24,8 @@ namespace { static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) { if (isa(op)) return operandIndex == 1; + if (isa(op)) + return operandIndex == 1; if (isa(op)) return operandIndex == 0; return false; @@ -33,6 +38,91 @@ static int64_t getValueSizeInBytes(Value value) { return type.getNumElements() * type.getElementTypeBitWidth() / 8; } +template +static void materializeHostConstantsInCore(CoreOpTy coreOp, IRRewriter& rewriter, bool& hasFailure) { + DenseMap>> materializedValues; + SmallVector ops; + coreOp.getBody().front().walk([&](Operation* op) { + if (!isa(op)) + ops.push_back(op); + }); + + for (Operation* op : ops) { + for (OpOperand& operand : op->getOpOperands()) { + Value originalValue = operand.get(); + if (!isa(originalValue.getType()) || isExplicitHostOperand(op, operand.getOperandNumber())) + continue; + + auto resolvedAddress = resolveContiguousAddress(originalValue); + if (failed(resolvedAddress)) + continue; + + auto getGlobalOp = dyn_cast_or_null(resolvedAddress->base.getDefiningOp()); + if (!getGlobalOp) + continue; + + auto originalType = dyn_cast(originalValue.getType()); + if (!originalType || !originalType.hasStaticShape()) { + op->emitOpError("host constant materialization requires a static memref operand"); + hasFailure = true; + continue; + } + + auto& cachedByOffset = materializedValues[resolvedAddress->base]; + auto& cachedByType = cachedByOffset[resolvedAddress->byteOffset]; + auto cachedValue = cachedByType.find(originalType); + if (cachedValue != cachedByType.end()) { + operand.set(cachedValue->second); + continue; + } + + int64_t totalBytes = getValueSizeInBytes(originalValue); + if (totalBytes < 0 || !llvm::isInt<32>(totalBytes) || !llvm::isInt<32>(resolvedAddress->byteOffset)) { + op->emitOpError("host constant materialization requires 32-bit copy sizes and offsets"); + hasFailure = true; + continue; + } + + auto contiguousType = MemRefType::get(originalType.getShape(), originalType.getElementType()); + + rewriter.setInsertionPoint(op); + Value localAlloc = memref::AllocOp::create(rewriter, op->getLoc(), contiguousType); + Value deviceDst = localAlloc; + if (contiguousType != originalType) + deviceDst = memref::CastOp::create(rewriter, op->getLoc(), originalType, localAlloc); + + Value copiedValue; + if constexpr (std::is_same_v) { + copiedValue = pim::PimMemCopyHostToDevBatchOp::create( + rewriter, + op->getLoc(), + originalType, + deviceDst, + getGlobalOp.getResult(), + rewriter.getI32IntegerAttr(0), + rewriter.getI32IntegerAttr(static_cast(resolvedAddress->byteOffset)), + rewriter.getI32IntegerAttr(static_cast(totalBytes))) + .getOutput(); + } + else { + copiedValue = pim::PimMemCopyHostToDevOp::create( + rewriter, + op->getLoc(), + originalType, + deviceDst, + getGlobalOp.getResult(), + rewriter.getI32IntegerAttr(0), + rewriter.getI32IntegerAttr(static_cast(resolvedAddress->byteOffset)), + rewriter.getI32IntegerAttr(static_cast(totalBytes))) + .getOutput(); + } + + cachedByType[originalType] = copiedValue; + operand.set(copiedValue); + } + } +} + struct MaterializeHostConstantsPass : PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MaterializeHostConstantsPass) @@ -50,71 +140,11 @@ struct MaterializeHostConstantsPass : PassWrapper()) { - DenseMap>> materializedValues; + for (pim::PimCoreOp coreOp : funcOp.getOps()) + materializeHostConstantsInCore(coreOp, rewriter, hasFailure); - for (Operation& op : llvm::make_early_inc_range(coreOp.getBody().front())) { - if (isa(op)) - continue; - - for (OpOperand& operand : op.getOpOperands()) { - Value originalValue = operand.get(); - if (!isa(originalValue.getType()) || isExplicitHostOperand(&op, operand.getOperandNumber())) - continue; - - auto resolvedAddress = resolveContiguousAddress(originalValue); - if (failed(resolvedAddress)) - continue; - - auto getGlobalOp = dyn_cast_or_null(resolvedAddress->base.getDefiningOp()); - if (!getGlobalOp) - continue; - - auto originalType = dyn_cast(originalValue.getType()); - if (!originalType || !originalType.hasStaticShape()) { - op.emitOpError("host constant materialization requires a static memref operand"); - hasFailure = true; - continue; - } - - auto& cachedByOffset = materializedValues[resolvedAddress->base]; - auto& cachedByType = cachedByOffset[resolvedAddress->byteOffset]; - auto cachedValue = cachedByType.find(originalType); - if (cachedValue != cachedByType.end()) { - operand.set(cachedValue->second); - continue; - } - - int64_t totalBytes = getValueSizeInBytes(originalValue); - if (totalBytes < 0 || !llvm::isInt<32>(totalBytes) || !llvm::isInt<32>(resolvedAddress->byteOffset)) { - op.emitOpError("host constant materialization requires 32-bit copy sizes and offsets"); - hasFailure = true; - continue; - } - - auto contiguousType = MemRefType::get(originalType.getShape(), originalType.getElementType()); - - rewriter.setInsertionPoint(&op); - Value localAlloc = memref::AllocOp::create(rewriter, op.getLoc(), contiguousType); - Value deviceDst = localAlloc; - if (contiguousType != originalType) - deviceDst = memref::CastOp::create(rewriter, op.getLoc(), originalType, localAlloc); - - auto hostToDevCopy = pim::PimMemCopyHostToDevOp::create( - rewriter, - op.getLoc(), - originalType, - deviceDst, - getGlobalOp.getResult(), - rewriter.getI32IntegerAttr(0), - rewriter.getI32IntegerAttr(static_cast(resolvedAddress->byteOffset)), - rewriter.getI32IntegerAttr(static_cast(totalBytes))); - - cachedByType[originalType] = hostToDevCopy.getResult(); - operand.set(hostToDevCopy.getResult()); - } - } - } + for (pim::PimCoreBatchOp coreBatchOp : funcOp.getOps()) + materializeHostConstantsInCore(coreBatchOp, rewriter, hasFailure); SmallVector hostCompactOps; for (Operation& op : funcOp.getBody().front()) diff --git a/src/PIM/Pass/PimCodegen/VerificationPass.cpp b/src/PIM/Pass/PimCodegen/VerificationPass.cpp index 81b9de9..8d858f5 100644 --- a/src/PIM/Pass/PimCodegen/VerificationPass.cpp +++ b/src/PIM/Pass/PimCodegen/VerificationPass.cpp @@ -122,6 +122,14 @@ struct VerificationPass : PassWrapper> ModuleOp moduleOp = getOperation(); bool hasFailure = false; + moduleOp.walk([&](Operation* op) { + if (op->getDialect()->getNamespace() != "spat") + return; + + op->emitError("illegal Spatial operation reached PIM codegen verification"); + hasFailure = true; + }); + for (func::FuncOp funcOp : moduleOp.getOps()) { if (funcOp.isExternal()) continue; diff --git a/validation/raptor.py b/validation/raptor.py index 506cd1f..9d9faa3 100644 --- a/validation/raptor.py +++ b/validation/raptor.py @@ -33,8 +33,6 @@ def _parse_pim_pass_timings(output_text): pass_timings[label] = pass_timings.get(label, 0.0) + duration break - if not pass_timings: - raise RuntimeError("Raptor timing report did not contain any PIM pass timings.") return pass_timings @@ -43,7 +41,7 @@ def _format_command(cmd): def compile_with_raptor(network_path, raptor_onnx_path: Path, output_base: Path, - crossbar_size, crossbar_count, core_count=None, cwd=None, reporter=None): + crossbar_size, crossbar_count, core_count=None, cwd=None, verbose=False, reporter=None): # Define the arguments, with the possibility to set crossbar size and count args = [ network_path, @@ -51,13 +49,13 @@ def compile_with_raptor(network_path, raptor_onnx_path: Path, output_base: Path, output_base, "--maccel=PIM", "--EmitPimCodegen", - # "--use-experimental-conv-impl=true", f"--crossbar-size={crossbar_size}", f"--crossbar-count={crossbar_count}", - "--enable-timing", ] if core_count is not None: args.append(f"--core-count={core_count}") + if verbose: + args.append("--enable-timing") cmd = [str(raptor_onnx_path)] + [str(arg) for arg in args] if reporter is not None: diff --git a/validation/subprocess_utils.py b/validation/subprocess_utils.py index 2c2225c..6513ec7 100644 --- a/validation/subprocess_utils.py +++ b/validation/subprocess_utils.py @@ -47,7 +47,9 @@ def _stream_output(fd, process, reporter, treat_eio_as_eof=False, stream_output= return_code = process.wait() if return_code != 0: error_output = captured_output if not stream_output else recent_output - raise subprocess.CalledProcessError(return_code, process.args, output=bytes(error_output)) + exc = subprocess.CalledProcessError(return_code, process.args, output=bytes(error_output)) + exc.output_already_streamed = stream_output and bool(captured_output) + raise exc return bytes(captured_output) @@ -67,15 +69,15 @@ def run_command_with_reporter(cmd, cwd=None, reporter=None, capture_output=False stream_output = bool(getattr(reporter, "verbose", False)) if not stream_output: - process = subprocess.Popen( + completed = subprocess.run( cmd, cwd=cwd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, ) - assert process.stdout is not None - output = _stream_output(process.stdout.fileno(), process, reporter, stream_output=False) - return output.decode("utf-8", errors="replace") if capture_output else None + if completed.returncode != 0: + raise subprocess.CalledProcessError(completed.returncode, completed.args, output=completed.stdout) + return completed.stdout.decode("utf-8", errors="replace") if capture_output else None try: master_fd, slave_fd = pty.openpty() diff --git a/validation/validate.py b/validation/validate.py index 9121efe..bff3de1 100644 --- a/validation/validate.py +++ b/validation/validate.py @@ -27,7 +27,9 @@ def print_validation_error(reporter, rel, exc): file=sys.stderr, flush=True) if isinstance(exc, subprocess.CalledProcessError): print(format_return_status(exc.returncode), file=sys.stderr, flush=True) - if exc.output: + if getattr(exc, "output_already_streamed", False): + print("Failure log already printed above.", file=sys.stderr, flush=True) + elif exc.output: output_text = exc.output.decode("utf-8", errors="replace") if isinstance(exc.output, bytes) else str(exc.output) if output_text: print(output_text, file=sys.stderr, end="" if output_text.endswith("\n") else "\n", flush=True) @@ -160,12 +162,13 @@ def main(): Fore.RED + plain_status.ljust(status_width) + Style.RESET_ALL print(f"| {rel.ljust(path_width)} | {status} |") print(separator) - print_average_pim_pass_timings( - pass_timing_sums, - pass_timing_counts, - total_timing_sum, - timed_benchmark_count, - ) + if a.verbose: + print_average_pim_pass_timings( + pass_timing_sums, + pass_timing_counts, + total_timing_sum, + timed_benchmark_count, + ) sys.exit(0 if n_passed == n_total else 1) diff --git a/validation/validate_one.py b/validation/validate_one.py index 8232f23..c00a433 100644 --- a/validation/validate_one.py +++ b/validation/validate_one.py @@ -1,6 +1,5 @@ import json import numpy as np -import subprocess import shutil import sys from dataclasses import dataclass, field @@ -11,7 +10,6 @@ from raptor import compile_with_raptor from gen_network_runner import gen_network_runner from subprocess_utils import run_command_with_reporter - STAGE_TITLES = ( "Compile ONNX", "Build Runner", @@ -48,10 +46,12 @@ class ProgressReporter: self.verbose = verbose self.columns = shutil.get_terminal_size((100, 20)).columns self.suspended = False + self.rendered_width = 0 def _clear(self): if self.enabled: - sys.stdout.write("\033[2K\r") + sys.stdout.write("\r" + (" " * self.rendered_width) + "\r") + sys.stdout.flush() def _render(self): if not self.enabled or self.suspended: @@ -70,16 +70,16 @@ class ProgressReporter: prefix = Fore.CYAN + prefix_text + Style.RESET_ALL counts = ( - " " - + Style.BRIGHT - + Fore.GREEN - + f"P:{self.passed_models}" - + Style.RESET_ALL - + " " - + Style.BRIGHT - + Fore.RED - + f"F:{self.failed_models}" - + Style.RESET_ALL + " " + + Style.BRIGHT + + Fore.GREEN + + f"P:{self.passed_models}" + + Style.RESET_ALL + + " " + + Style.BRIGHT + + Fore.RED + + f"F:{self.failed_models}" + + Style.RESET_ALL ) model_counter = "" label = "" @@ -92,9 +92,12 @@ class ProgressReporter: available_label_width = max(0, self.columns - len(prefix_text) - len(model_counter) - len(counts_text) - 3) label = label[:available_label_width] - - sys.stdout.write("\r" + prefix + model_counter + counts + label + Style.RESET_ALL) + plain_line = prefix_text + model_counter + f" P:{self.passed_models} F:{self.failed_models}" + label + rendered_line = prefix + model_counter + counts + label + Style.RESET_ALL + padded_width = max(self.rendered_width, len(plain_line)) + sys.stdout.write("\r" + rendered_line + (" " * max(0, padded_width - len(plain_line)))) sys.stdout.flush() + self.rendered_width = len(plain_line) def log(self, message="", color=None): if not self.verbose: @@ -124,18 +127,19 @@ class ProgressReporter: self._render() def suspend(self): + if self.enabled: + self._clear() self.suspended = True - self._clear() - sys.stdout.flush() def resume(self): self.suspended = False + self._render() def finish(self): if self.enabled: self.suspended = True self._clear() - sys.stdout.flush() + self.rendered_width = 0 def run_command(cmd, cwd=None, reporter=None): @@ -212,7 +216,8 @@ def build_dump_ranges(config_path, outputs_descriptor): def run_pim_simulator(simulator_dir, pim_dir, output_bin_path, dump_ranges, reporter=None): run_command( - ["cargo", "run", "--no-default-features", "--release", "--package", "pim-simulator", "--bin", "pim-simulator", "--", + ["cargo", "run", "--no-default-features", "--release", "--package", "pim-simulator", "--bin", "pim-simulator", + "--", "-f", str(pim_dir), "-o", str(output_bin_path), "-d", dump_ranges], cwd=simulator_dir, reporter=reporter, @@ -293,7 +298,8 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir, reporter.advance() print_stage(reporter, model_index, model_total, network_onnx_path.name, "Build Runner") - gen_network_runner(network_onnx_path, network_so_path, onnx_include_dir, out=runner_dir / "runner.c", verbose=False) + gen_network_runner(network_onnx_path, network_so_path, onnx_include_dir, out=runner_dir / "runner.c", + verbose=False) runner_path = build_onnx_runner(runner_dir, runner_build_dir, reporter=reporter) print_info(reporter, f"Runner built at {runner_path}") reporter.advance() @@ -316,9 +322,8 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir, print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compile PIM") pim_pass_timings = compile_with_raptor( - network_mlir_path, raptor_path, raptor_dir / network_onnx_path.stem, - crossbar_size, crossbar_count, core_count=core_count, - cwd=raptor_dir, reporter=reporter) + network_mlir_path, raptor_path, raptor_dir / network_onnx_path.stem, crossbar_size, crossbar_count, + core_count=core_count, cwd=raptor_dir, verbose=verbose, reporter=reporter) print_info(reporter, f"PIM artifacts saved to {raptor_dir / 'pim'}") reporter.advance()