From b605585b1f6cb7763cabeedcdcaea911e9f7e552 Mon Sep 17 00:00:00 2001 From: NiccoloN Date: Sun, 3 May 2026 14:14:14 +0200 Subject: [PATCH] compact spatial IR through different new operations and dedicated syntax fast spatial node merging with batch operations --- .gitignore | 9 + README.md | 2 +- src/PIM/Common/PimCommon.cpp | 71 +- src/PIM/Common/PimCommon.hpp | 5 +- src/PIM/Compiler/PimCodeGen.cpp | 359 ++-- src/PIM/Conversion/ONNXToSpatial/Common.hpp | 26 + .../ONNXToSpatial/ONNXToSpatialPass.cpp | 47 +- .../ONNXToSpatial/Patterns/Math/Conv.cpp | 368 ++-- .../ONNXToSpatial/Patterns/Math/Gemm.cpp | 169 +- .../ONNXToSpatial/Patterns/Math/MatMul.cpp | 4 +- .../Patterns/Math/ReduceMean.cpp | 3 +- .../ONNXToSpatial/Patterns/NN/Pool.cpp | 4 +- .../ONNXToSpatial/Patterns/NN/Softmax.cpp | 3 +- .../ONNXToSpatial/Patterns/Tensor/Concat.cpp | 3 +- .../ONNXToSpatial/Patterns/Tensor/Gather.cpp | 6 +- .../ONNXToSpatial/Patterns/Tensor/Resize.cpp | 2 +- src/PIM/Conversion/SpatialToPim/Common.cpp | 42 - src/PIM/Conversion/SpatialToPim/Common.hpp | 17 - .../Conversion/SpatialToPim/SpatialToPim.td | 25 - .../SpatialToPim/SpatialToPimPass.cpp | 1053 +++++++---- src/PIM/Dialect/Pim/Pim.td | 80 + .../OpBufferizationInterfaces.cpp | 151 ++ .../Bufferization/PimBufferizationPass.cpp | 9 +- src/PIM/Dialect/Spatial/CMakeLists.txt | 1 + src/PIM/Dialect/Spatial/Channels.cpp | 120 ++ src/PIM/Dialect/Spatial/Channels.hpp | 43 + src/PIM/Dialect/Spatial/Spatial.td | 159 +- src/PIM/Dialect/Spatial/SpatialOps.cpp | 1137 ++++++++++++ .../DCPGraph/DCPAnalysis.cpp | 263 ++- .../DCPGraph/DCPAnalysis.hpp | 39 +- .../MergeComputeNodes/DCPGraph/Graph.cpp | 26 +- .../MergeComputeNodes/DCPGraph/Graph.hpp | 7 +- .../MergeComputeNodesPass.cpp | 1592 ++++++++++++----- src/PIM/Pass/PimCodegen/VerificationPass.cpp | 19 +- 34 files changed, 4419 insertions(+), 1445 deletions(-) create mode 100644 src/PIM/Dialect/Spatial/Channels.cpp create mode 100644 src/PIM/Dialect/Spatial/Channels.hpp diff --git a/.gitignore b/.gitignore index 47c814e..dc780e7 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,14 @@ +.zed .idea **/.vscode + .claude AGENTS.md + +CMakeUserPresets.json + build +cmake-build-debug +cmake-build-release + +**/__pycache__ diff --git a/README.md b/README.md index e973656..5017289 100644 --- a/README.md +++ b/README.md @@ -135,7 +135,7 @@ validate.py \ --raptor-path ../cmake-build-release/Release/bin/onnx-mlir \ --onnx-include-dir ../onnx-mlir/include \ --operations-dir ./networks/yolo11n/depth_04 \ - --crossbar-size 2048 + --crossbar-size 2048 --crossbar-count 256 ``` Available networks under `validation/networks/`: `vgg16`, `yolo11n`. diff --git a/src/PIM/Common/PimCommon.cpp b/src/PIM/Common/PimCommon.cpp index 3ca1839..116eb32 100644 --- a/src/PIM/Common/PimCommon.cpp +++ b/src/PIM/Common/PimCommon.cpp @@ -48,7 +48,9 @@ void dumpModule(ModuleOp moduleOp, const std::string& name) { std::fstream file(dialectsDir + "/" + name + ".mlir", std::ios::out); llvm::raw_os_ostream os(file); - os << *moduleOp; + OpPrintingFlags flags; + flags.elideLargeElementsAttrs(); + moduleOp.print(os, flags); os.flush(); file.close(); } @@ -173,6 +175,13 @@ void walkPimMvmVmmWeightUses(Operation* root, function_ref cal root->walk([&](pim::PimCoreOp coreOp) { walkMvmVmmWeightUses(coreOp, callback); }); + root->walk([&](pim::PimCoreBatchOp coreBatchOp) { + auto weights = coreBatchOp.getWeights(); + for (auto weight : weights) + for (OpOperand& use : weight.getUses()) + if (use.getOwner() == coreBatchOp.getOperation()) + callback(use); + }); } memref::GlobalOp lookupGlobalForGetGlobal(ModuleOp moduleOp, memref::GetGlobalOp getGlobalOp) { @@ -181,66 +190,6 @@ memref::GlobalOp lookupGlobalForGetGlobal(ModuleOp moduleOp, memref::GetGlobalOp return moduleOp.lookupSymbol(getGlobalOp.getName()); } -FailureOr getOtherEndOfChannel(Operation* op, bool opIsReceive, RewriterBase& rewriter) { - - auto channelNewOp = op->getOperand(0).getDefiningOp(); - if (!channelNewOp) { - op->emitError("User of Channel must have the first operand created by ChannelNewOp."); - return failure(); - } - // channelNewOp should have two users: `op` and a - // `ChannelSendOp`/`ChannelReceiveOp` - auto channelUsers = channelNewOp->getUsers(); - auto usersIterator = channelUsers.begin(); - auto firstUser = *usersIterator; - usersIterator++; - if (usersIterator == channelUsers.end()) { - op->emitError("Operand generated by ChannelNewOp must have two users, " - "only one found."); - channelNewOp->dump(); - op->dump(); - channelNewOp->getParentOp()->dump(); - return failure(); - } - auto secondUser = *usersIterator; - usersIterator++; - if (usersIterator != channelUsers.end()) { - op->emitError("Operand generated by ChannelNewOp must have two users, " - "more than two found."); - return failure(); - } - Operation* notOpUser; - if (firstUser == op) { - notOpUser = secondUser; - } - else if (secondUser == op) { - notOpUser = firstUser; - } - else { - op->emitError("Operand generated by ChannelNewOp must have two users, " - "and one of them must be me, but" - "none of them is actually me."); - return failure(); - } - - if (opIsReceive) { - if (!isa(notOpUser)) { - op->emitError("Operand generated by ChannelNewOp has two user, one is " - "me, the other is not a ChannelSendOp."); - return failure(); - } - return notOpUser; - } - else { - if (!isa(notOpUser)) { - op->emitError("Operand generated by ChannelNewOp has two user, one is " - "me, the other is not a ChannelReceiveOp."); - return failure(); - } - return notOpUser; - } -} - SmallVector computeRowMajorStrides(ArrayRef shape) { SmallVector strides(shape.size(), 1); for (int64_t dim = static_cast(shape.size()) - 2; dim >= 0; --dim) diff --git a/src/PIM/Common/PimCommon.hpp b/src/PIM/Common/PimCommon.hpp index d65ad8c..0ad4fd3 100644 --- a/src/PIM/Common/PimCommon.hpp +++ b/src/PIM/Common/PimCommon.hpp @@ -17,6 +17,8 @@ inline constexpr llvm::StringRef PimWeightAlwaysAttrName = "weightAlways"; namespace onnx_mlir { +inline constexpr llvm::StringLiteral kCoreIdAttrName = "core_id"; + struct ResolvedContiguousAddress { mlir::Value base; int64_t byteOffset = 0; @@ -48,9 +50,6 @@ void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref -getOtherEndOfChannel(mlir::Operation* op, bool opIsReceive, mlir::RewriterBase& rewriter); - llvm::SmallVector computeRowMajorStrides(llvm::ArrayRef shape); llvm::SmallVector diff --git a/src/PIM/Compiler/PimCodeGen.cpp b/src/PIM/Compiler/PimCodeGen.cpp index a2deda2..b15074b 100644 --- a/src/PIM/Compiler/PimCodeGen.cpp +++ b/src/PIM/Compiler/PimCodeGen.cpp @@ -1,8 +1,10 @@ +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/IRMapping.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SmallPtrSet.h" @@ -474,19 +476,113 @@ std::string getMemorySizeAsString(size_t size) { return std::to_string(size) + " Bytes"; } -static SmallVector getUsedWeightIndices(pim::PimCoreOp coreOp) { +static SmallVector getUsedWeightIndices(Block& block) { SmallVector indices; auto addIndex = [&](unsigned weightIndex) { if (!llvm::is_contained(indices, weightIndex)) indices.push_back(weightIndex); }; - coreOp.walk([&](pim::PimMVMOp mvmOp) { addIndex(mvmOp.getWeightIndex()); }); - coreOp.walk([&](pim::PimVMMOp vmmOp) { addIndex(vmmOp.getWeightIndex()); }); + block.walk([&](pim::PimMVMOp mvmOp) { addIndex(mvmOp.getWeightIndex()); }); + block.walk([&](pim::PimVMMOp vmmOp) { addIndex(vmmOp.getWeightIndex()); }); llvm::sort(indices); return indices; } +static SmallVector getUsedWeightIndices(pim::PimCoreOp coreOp) { + return getUsedWeightIndices(coreOp.getBody().front()); +} + +static SmallVector getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) { + auto coreIdsAttr = coreBatchOp->getAttrOfType(onnx_mlir::kCoreIdAttrName); + assert(coreIdsAttr && "pim.core_batch requires core_id array attribute"); + return SmallVector(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end()); +} + +static SmallVector collectTopLevelCoreLikeOps(func::FuncOp funcOp) { + SmallVector coreLikeOps; + for (Operation& op : funcOp.getBody().front()) { + if (dyn_cast(&op) || dyn_cast(&op)) + coreLikeOps.push_back(&op); + } + return coreLikeOps; +} + +static pim::PimCoreOp materializeScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp, unsigned lane) { + OpBuilder builder(coreBatchOp); + builder.setInsertionPointAfter(coreBatchOp); + + size_t laneCount = static_cast(coreBatchOp.getLaneCount()); + size_t weightsPerLane = coreBatchOp.getWeights().size() / laneCount; + SmallVector laneWeights; + laneWeights.reserve(weightsPerLane); + for (size_t weightIndex = 0; weightIndex < weightsPerLane; ++weightIndex) + laneWeights.push_back(coreBatchOp.getWeights()[lane * weightsPerLane + weightIndex]); + + auto coreIds = getBatchCoreIds(coreBatchOp); + auto scalarCore = pim::PimCoreOp::create(builder, + coreBatchOp.getLoc(), + ValueRange(laneWeights), + builder.getI32IntegerAttr(coreIds[lane])); + Block* block = builder.createBlock(&scalarCore.getBody(), scalarCore.getBody().end()); + IRMapping mapper; + if (coreBatchOp.getBody().front().getNumArguments() == 1) + mapper.map(coreBatchOp.getBody().front().getArgument(0), coreBatchOp.getInputs()[lane]); + + builder.setInsertionPointToEnd(block); + for (Operation& op : coreBatchOp.getBody().front()) { + if (isa(op)) { + pim::PimHaltOp::create(builder, op.getLoc()); + continue; + } + + if (auto sendBatchOp = dyn_cast(op)) { + pim::PimSendOp::create(builder, + sendBatchOp.getLoc(), + mapper.lookup(sendBatchOp.getInput()), + sendBatchOp.getSizeAttr(), + builder.getI32IntegerAttr(sendBatchOp.getTargetCoreIds()[lane])); + continue; + } + + if (auto receiveBatchOp = dyn_cast(op)) { + auto scalarReceive = pim::PimReceiveOp::create(builder, + receiveBatchOp.getLoc(), + receiveBatchOp.getOutput().getType(), + mapper.lookup(receiveBatchOp.getOutputBuffer()), + receiveBatchOp.getSizeAttr(), + builder.getI32IntegerAttr(receiveBatchOp.getSourceCoreIds()[lane])); + mapper.map(receiveBatchOp.getOutput(), scalarReceive.getOutput()); + continue; + } + + if (auto memcpBatchOp = dyn_cast(op)) { + mlir::Value hostSource = mapper.lookupOrNull(memcpBatchOp.getHostSource()); + if (!hostSource) + hostSource = memcpBatchOp.getHostSource(); + + auto scalarCopy = pim::PimMemCopyHostToDevOp::create(builder, + memcpBatchOp.getLoc(), + memcpBatchOp.getOutput().getType(), + mapper.lookup(memcpBatchOp.getDeviceTarget()), + hostSource, + memcpBatchOp.getDeviceTargetOffsetAttr(), + memcpBatchOp.getHostSourceOffsetAttr(), + memcpBatchOp.getSizeAttr()); + mapper.map(memcpBatchOp.getOutput(), scalarCopy.getOutput()); + continue; + } + + Operation* cloned = builder.clone(op, mapper); + for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults())) + mapper.map(originalResult, clonedResult); + } + + if (block->empty() || !isa(block->back())) + pim::PimHaltOp::create(builder, coreBatchOp.getLoc()); + return scalarCore; +} + /// Write global constant data into a binary memory image at their allocated addresses. static OnnxMlirCompilerErrorCodes writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory& memory, StringRef outputDirPath) { @@ -670,7 +766,7 @@ static OnnxMlirCompilerErrorCodes writeCrossbarWeights(ModuleOp moduleOp, return CompilerSuccess; } -llvm::DenseMap> +llvm::DenseMap> createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) { ModuleOp moduleOp = funcOp->getParentOfType(); auto coreWeightsDirPath = outputDirPath + "/weights"; @@ -679,85 +775,104 @@ createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) { size_t indexFileName = 0; int64_t xbarSize = crossbarSize.getValue(); - llvm::DenseMap> mapCoreWeightToFileName; + llvm::DenseMap> mapCoreWeightToFileName; llvm::DenseMap mapGlobalOpToFileName; - for (pim::PimCoreOp coreOp : funcOp.getOps()) { - for (unsigned index : getUsedWeightIndices(coreOp)) { - if (index >= coreOp.getWeights().size()) { - coreOp.emitWarning("Weight index " + std::to_string(index) + " is out of range"); - assert(index < coreOp.getWeights().size() && "Weight index is out of range"); - } - mlir::Value weight = coreOp.getWeights()[index]; + SmallVector coreLikeOps = collectTopLevelCoreLikeOps(funcOp); - auto getGlobalOp = weight.getDefiningOp(); - if (!getGlobalOp) { - coreOp.emitWarning("Weight is not from a memref.get_global at index " + std::to_string(index)); - assert(!getGlobalOp && "Weight is not from a memref.get_global"); - } + for (Operation* op : coreLikeOps) { + SmallVector scalarCores; + if (auto coreOp = dyn_cast(op)) { + scalarCores.push_back(coreOp); + } + else { + auto coreBatchOp = cast(op); + for (unsigned lane = 0; lane < static_cast(coreBatchOp.getLaneCount()); ++lane) + scalarCores.push_back(materializeScalarCoreFromBatchLane(coreBatchOp, lane)); + } - auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp); - if (!globalOp) { - coreOp.emitWarning("Could not find memref.global for weight at index " + std::to_string(index)); - assert(!globalOp && "Could not find memref.global"); - } + for (pim::PimCoreOp coreOp : scalarCores) { + size_t coreId = static_cast(coreOp.getCoreId()); + for (unsigned index : getUsedWeightIndices(coreOp)) { + if (index >= coreOp.getWeights().size()) { + coreOp.emitWarning("Weight index " + std::to_string(index) + " is out of range"); + assert(index < coreOp.getWeights().size() && "Weight index is out of range"); + } + mlir::Value weight = coreOp.getWeights()[index]; - auto initialValue = globalOp.getInitialValue(); - if (!initialValue) { - coreOp.emitWarning("memref.global has no initial value at index " + std::to_string(index)); - assert(!initialValue && "memref.global has no initial value"); - } + auto getGlobalOp = weight.getDefiningOp(); + if (!getGlobalOp) { + coreOp.emitWarning("Weight is not from a memref.get_global at index " + std::to_string(index)); + assert(!getGlobalOp && "Weight is not from a memref.get_global"); + } - auto denseAttr = dyn_cast(*initialValue); - if (!denseAttr) { - coreOp.emitWarning("memref.global initial value is not dense at index " + std::to_string(index)); - assert(!denseAttr && "memref.global initial value is not dense"); - } + auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp); + if (!globalOp) { + coreOp.emitWarning("Could not find memref.global for weight at index " + std::to_string(index)); + assert(!globalOp && "Could not find memref.global"); + } - if (mapGlobalOpToFileName.contains(globalOp)) { - auto& fileName = mapGlobalOpToFileName[globalOp]; - std::pair weightToFile = {weight, fileName}; - mapCoreWeightToFileName[coreOp].insert(weightToFile); - continue; - } + auto initialValue = globalOp.getInitialValue(); + if (!initialValue) { + coreOp.emitWarning("memref.global has no initial value at index " + std::to_string(index)); + assert(!initialValue && "memref.global has no initial value"); + } - auto type = denseAttr.getType(); - auto shape = type.getShape(); - assert(isMatrixShape(shape) && "Weight matrix must be 2-dimensional"); - int64_t numRows = shape[0]; - int64_t numCols = shape[1]; - assert(numRows <= xbarSize && numCols <= xbarSize && "Weight dimensions must not exceed crossbar size"); + auto denseAttr = dyn_cast(*initialValue); + if (!denseAttr) { + coreOp.emitWarning("memref.global initial value is not dense at index " + std::to_string(index)); + assert(!denseAttr && "memref.global initial value is not dense"); + } - size_t elementByteWidth = type.getElementType().getIntOrFloatBitWidth() / 8; + if (mapGlobalOpToFileName.contains(globalOp)) { + auto& fileName = mapGlobalOpToFileName[globalOp]; + std::pair weightToFile = {weight, fileName}; + mapCoreWeightToFileName[coreId].insert(weightToFile); + continue; + } - std::string newFileName = "crossbar_" + std::to_string(indexFileName++) + ".bin"; - auto weightFilePath = (coreWeightsDirPath + "/" + newFileName).str(); - std::error_code errorCode; - raw_fd_ostream weightFileStream(weightFilePath, errorCode, sys::fs::OF_None); - if (errorCode) { - errs() << "Error while opening weight file `" << weightFilePath << "`: " << errorCode.message() << '\n'; - assert(errorCode); - } + auto type = denseAttr.getType(); + auto shape = type.getShape(); + assert(isMatrixShape(shape) && "Weight matrix must be 2-dimensional"); + int64_t numRows = shape[0]; + int64_t numCols = shape[1]; + assert(numRows <= xbarSize && numCols <= xbarSize && "Weight dimensions must not exceed crossbar size"); - uint64_t zero = 0; - for (int64_t row = 0; row < xbarSize; row++) { - for (int64_t col = 0; col < xbarSize; col++) { - if (row < numRows && col < numCols) { - int64_t index = row * numCols + col; - APInt bits = denseAttr.getValues()[index].bitcastToAPInt(); - uint64_t word = bits.getZExtValue(); - weightFileStream.write(reinterpret_cast(&word), elementByteWidth); - } - else { - weightFileStream.write(reinterpret_cast(&zero), elementByteWidth); + size_t elementByteWidth = type.getElementType().getIntOrFloatBitWidth() / 8; + + std::string newFileName = "crossbar_" + std::to_string(indexFileName++) + ".bin"; + auto weightFilePath = (coreWeightsDirPath + "/" + newFileName).str(); + std::error_code errorCode; + raw_fd_ostream weightFileStream(weightFilePath, errorCode, sys::fs::OF_None); + if (errorCode) { + errs() << "Error while opening weight file `" << weightFilePath << "`: " << errorCode.message() << '\n'; + assert(errorCode); + } + + uint64_t zero = 0; + for (int64_t row = 0; row < xbarSize; row++) { + for (int64_t col = 0; col < xbarSize; col++) { + if (row < numRows && col < numCols) { + int64_t index = row * numCols + col; + APInt bits = denseAttr.getValues()[index].bitcastToAPInt(); + uint64_t word = bits.getZExtValue(); + weightFileStream.write(reinterpret_cast(&word), elementByteWidth); + } + else { + weightFileStream.write(reinterpret_cast(&zero), elementByteWidth); + } } } - } - weightFileStream.close(); - mapGlobalOpToFileName.insert({globalOp, newFileName}); - mapCoreWeightToFileName[coreOp].insert({weight, newFileName}); + weightFileStream.close(); + mapGlobalOpToFileName.insert({globalOp, newFileName}); + mapCoreWeightToFileName[coreId].insert({weight, newFileName}); + } } + + for (pim::PimCoreOp coreOp : scalarCores) + if (coreOp.getOperation() != op) + coreOp.erase(); } return mapCoreWeightToFileName; } @@ -850,60 +965,76 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std:: // Create Weight Folder auto mapCoreWeightToFileName = createAndPopulateWeightFolder(funcOp, outputDirPath); - for (auto coreOp : funcOp.getOps()) { - auto coreId = coreOp.getCoreId(); - coreCount++; + SmallVector coreLikeOps = collectTopLevelCoreLikeOps(funcOp); - std::error_code errorCode; - auto outputCorePath = outputDirPath + "/core_" + std::to_string(coreId) + ".json"; - raw_fd_ostream coreFileStream(outputCorePath, errorCode); - if (errorCode) { - errs() << "Error while opening core file `" << outputCorePath << "`: " << errorCode.message() << '\n'; - return InvalidOutputFileAccess; + for (Operation* op : coreLikeOps) { + SmallVector scalarCores; + if (auto coreOp = dyn_cast(op)) { + scalarCores.push_back(coreOp); } - coreFileStream << '['; - - PimCodeGen coreCodeGen(memory, coreFileStream); - memory.getOrCreateDeviceMem(coreId).allocateCore(coreOp); - - int64_t processedOperations = codeGenCoreOps(coreOp.getBody().front(), coreCodeGen); - if (processedOperations < 0) - return CompilerFailure; - assert(processedOperations > 0); - - // Remove trailing comma, close JSON array - coreFileStream.seek(coreFileStream.tell() - 1); - coreFileStream << ']'; - coreFileStream.close(); - - // Write crossbar weights for this core - auto coreWeightsDirPath = outputDirPath + "/core_" + std::to_string(coreId); - if (auto error = sys::fs::create_directory(coreWeightsDirPath)) { - errs() << "Error creating core directory: " << coreWeightsDirPath << ": " << error.message() << '\n'; - return InvalidOutputFileAccess; + else { + auto coreBatchOp = cast(op); + for (unsigned lane = 0; lane < static_cast(coreBatchOp.getLaneCount()); ++lane) + scalarCores.push_back(materializeScalarCoreFromBatchLane(coreBatchOp, lane)); } - auto& mapWeightToFile = mapCoreWeightToFileName[coreOp]; - json::Array xbarsPerGroup; - for (unsigned index : getUsedWeightIndices(coreOp)) { - if (index >= coreOp.getWeights().size()) { - coreOp.emitWarning("Weight index " + std::to_string(index) + " is out of range"); - assert(index < coreOp.getWeights().size() && "Weight index is out of range"); - } - mlir::Value weight = coreOp.getWeights()[index]; - xbarsPerGroup.push_back(index); - assert(mapWeightToFile.contains(weight) && "Weight was not materialized into a file!!"); - auto& fileName = mapWeightToFile[weight]; - if (auto error = sys::fs::create_link(outputDirPath + "/weights/" + fileName, - coreWeightsDirPath + "/crossbar_" + std::to_string(index) + ".bin")) { - errs() << "Error creating link file: " << (outputDirPath + "/weights/" + fileName) << " to " - << (coreWeightsDirPath + "/crossbar_" + std::to_string(index) + ".bin") << "\nError:" << error.message() - << '\n'; + for (pim::PimCoreOp coreOp : scalarCores) { + auto coreId = coreOp.getCoreId(); + coreCount++; + + std::error_code errorCode; + auto outputCorePath = outputDirPath + "/core_" + std::to_string(coreId) + ".json"; + raw_fd_ostream coreFileStream(outputCorePath, errorCode); + if (errorCode) { + errs() << "Error while opening core file `" << outputCorePath << "`: " << errorCode.message() << '\n'; return InvalidOutputFileAccess; } + coreFileStream << '['; + + PimCodeGen coreCodeGen(memory, coreFileStream); + memory.getOrCreateDeviceMem(coreId).allocateCore(coreOp); + + int64_t processedOperations = codeGenCoreOps(coreOp.getBody().front(), coreCodeGen); + if (processedOperations < 0) + return CompilerFailure; + assert(processedOperations > 0); + + coreFileStream.seek(coreFileStream.tell() - 1); + coreFileStream << ']'; + coreFileStream.close(); + + auto coreWeightsDirPath = outputDirPath + "/core_" + std::to_string(coreId); + if (auto error = sys::fs::create_directory(coreWeightsDirPath)) { + errs() << "Error creating core directory: " << coreWeightsDirPath << ": " << error.message() << '\n'; + return InvalidOutputFileAccess; + } + + auto& mapWeightToFile = mapCoreWeightToFileName[static_cast(coreId)]; + json::Array xbarsPerGroup; + for (unsigned index : getUsedWeightIndices(coreOp)) { + if (index >= coreOp.getWeights().size()) { + coreOp.emitWarning("Weight index " + std::to_string(index) + " is out of range"); + assert(index < coreOp.getWeights().size() && "Weight index is out of range"); + } + mlir::Value weight = coreOp.getWeights()[index]; + xbarsPerGroup.push_back(index); + assert(mapWeightToFile.contains(weight) && "Weight was not materialized into a file!!"); + auto& fileName = mapWeightToFile[weight]; + if (auto error = sys::fs::create_link(outputDirPath + "/weights/" + fileName, + coreWeightsDirPath + "/crossbar_" + std::to_string(index) + ".bin")) { + errs() << "Error creating link file: " << (outputDirPath + "/weights/" + fileName) << " to " + << (coreWeightsDirPath + "/crossbar_" + std::to_string(index) + ".bin") << "\nError:" + << error.message() << '\n'; + return InvalidOutputFileAccess; + } + } + + xbarsPerArrayGroup["core" + std::to_string(coreId)] = std::move(xbarsPerGroup); } - xbarsPerArrayGroup["core" + std::to_string(coreId)] = std::move(xbarsPerGroup); + for (pim::PimCoreOp coreOp : scalarCores) + if (coreOp.getOperation() != op) + coreOp.erase(); } return writeConfigJson(funcOp, memory, coreCount, std::move(xbarsPerArrayGroup), outputDirPath); diff --git a/src/PIM/Conversion/ONNXToSpatial/Common.hpp b/src/PIM/Conversion/ONNXToSpatial/Common.hpp index 5cde963..3c49f34 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Common.hpp +++ b/src/PIM/Conversion/ONNXToSpatial/Common.hpp @@ -12,6 +12,7 @@ #include #include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/STLExtras.h" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" @@ -174,6 +175,31 @@ using InvokeWithValueRangeResultT = std::invoke_result_t; } // namespace detail +template +inline mlir::Value createSpatConcat(RewriterT& rewriter, mlir::Location loc, int64_t axis, mlir::ValueRange inputs) { + assert(!inputs.empty() && "spat.concat requires at least one input"); + if (inputs.size() == 1) + return inputs.front(); + + auto firstType = mlir::cast(inputs.front().getType()); + auto outputShape = llvm::to_vector(firstType.getShape()); + int64_t concatDimSize = 0; + bool concatDimDynamic = false; + + for (mlir::Value input : inputs) { + auto inputType = mlir::cast(input.getType()); + assert(inputType.getRank() == firstType.getRank() && "spat.concat expects same-rank inputs"); + if (mlir::ShapedType::isDynamic(inputType.getDimSize(axis))) + concatDimDynamic = true; + else + concatDimSize += inputType.getDimSize(axis); + } + + outputShape[axis] = concatDimDynamic ? mlir::ShapedType::kDynamic : concatDimSize; + auto outputType = mlir::RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding()); + return spatial::SpatConcatOp::create(rewriter, loc, outputType, rewriter.getI64IntegerAttr(axis), inputs).getOutput(); +} + template auto createSpatCompute(RewriterT& rewriter, mlir::Location loc, diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp index 6fc2240..90d91fd 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp @@ -54,6 +54,43 @@ private: } // namespace +static void foldSingleLaneComputeBatches(func::FuncOp funcOp) { + IRRewriter rewriter(funcOp.getContext()); + SmallVector batchOps; + funcOp.walk([&](spatial::SpatComputeBatch batchOp) { batchOps.push_back(batchOp); }); + + for (auto batchOp : batchOps) { + if (batchOp.getLaneCount() != 1) + continue; + + auto loc = batchOp.getLoc(); + rewriter.setInsertionPoint(batchOp); + auto computeOp = spatial::SpatCompute::create(rewriter, loc, batchOp.getResultTypes(), batchOp.getWeights(), batchOp.getInputs()); + computeOp.getProperties().setOperandSegmentSizes( + {static_cast(batchOp.getWeights().size()), static_cast(batchOp.getInputs().size())}); + + Block& templateBlock = batchOp.getBody().front(); + SmallVector blockArgTypes; + SmallVector blockArgLocs; + for (BlockArgument arg : templateBlock.getArguments()) { + blockArgTypes.push_back(arg.getType()); + blockArgLocs.push_back(loc); + } + auto* newBlock = rewriter.createBlock( + &computeOp.getBody(), computeOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs); + + IRMapping mapper; + for (auto [oldArg, newArg] : llvm::zip(templateBlock.getArguments(), newBlock->getArguments())) + mapper.map(oldArg, newArg); + rewriter.setInsertionPointToEnd(newBlock); + for (Operation& op : templateBlock) + rewriter.clone(op, mapper); + + batchOp.replaceAllUsesWith(computeOp.getResults()); + rewriter.eraseOp(batchOp); + } +} + void ONNXToSpatialPass::runOnOperation() { ModuleOp moduleOp = getOperation(); MLIRContext* ctx = &getContext(); @@ -124,6 +161,8 @@ void ONNXToSpatialPass::runOnOperation() { return; } + foldSingleLaneComputeBatches(*entryFunc); + // Count the number of compute ops and check they do not exceed the core count if (coresCount != -1) { int computeOpsCount = 0; @@ -196,8 +235,12 @@ bool encapsulateConcat(IRRewriter& rewriter, Location loc, Operation* inst) { IRMapping mapper; for (auto [source, bbArg] : llvm::zip(sources, BB->getArguments())) mapper.map(source, bbArg); - auto newConcat = rewriter.clone(*inst, mapper); - spatial::SpatYieldOp::create(rewriter, loc, newConcat->getResults()); + auto newConcat = spatial::SpatConcatOp::create(rewriter, + loc, + toRemoveOp.getType(), + rewriter.getI64IntegerAttr(toRemoveOp.getDim()), + ValueRange(BB->getArguments())); + spatial::SpatYieldOp::create(rewriter, loc, newConcat.getOutput()); inst->replaceAllUsesWith(newCompute->getResults()); inst->erase(); return true; diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp index bffdc72..70bf871 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp @@ -147,161 +147,148 @@ static Value buildPackedBias(bool hasBias, return arith::ConstantOp::create(rewriter, loc, packedBiasType, packedBiasAttr).getResult(); } -static SmallVector createIm2colRowComputes(Value x, - RankedTensorType xType, - RankedTensorType im2colType, - RankedTensorType im2colRowType, - RankedTensorType gemmInputRowType, - int64_t batchSize, - int64_t numChannelsIn, - int64_t xHeight, - int64_t xWidth, - int64_t wHeight, - int64_t wWidth, - int64_t padHeightBegin, - int64_t padHeightEnd, - int64_t padWidthBegin, - int64_t padWidthEnd, - int64_t strideHeight, - int64_t strideWidth, - int64_t dilationHeight, - int64_t dilationWidth, - int64_t outWidth, - int64_t patchSize, - int64_t numPatches, - int64_t numPatchesPerBatch, - int64_t packFactor, - ConversionPatternRewriter& rewriter, - Location loc) { +static Value createIm2colRowComputes(Value x, + RankedTensorType xType, + RankedTensorType im2colType, + RankedTensorType im2colRowType, + RankedTensorType gemmInputRowsType, + int64_t batchSize, + int64_t numChannelsIn, + int64_t xHeight, + int64_t xWidth, + int64_t wHeight, + int64_t wWidth, + int64_t padHeightBegin, + int64_t padHeightEnd, + int64_t padWidthBegin, + int64_t padWidthEnd, + int64_t strideHeight, + int64_t strideWidth, + int64_t dilationHeight, + int64_t dilationWidth, + int64_t outWidth, + int64_t patchSize, + int64_t numPatches, + int64_t numPatchesPerBatch, + int64_t packFactor, + ConversionPatternRewriter& rewriter, + Location loc) { auto elemType = xType.getElementType(); constexpr size_t numInputs = 1; const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor); - SmallVector resultTypes(packedNumRows, gemmInputRowType); - auto im2colComputeOp = createSpatCompute(rewriter, loc, resultTypes, {}, x, [&](Value xArg) { - Value paddedInput = xArg; + auto im2colComputeOp = + createSpatCompute(rewriter, loc, TypeRange {gemmInputRowsType}, {}, x, [&](Value xArg) { + Value paddedInput = xArg; - // Pad input with zeros if needed: - // [1, numChannelsIn, xHeight, xWidth] -> [1, numChannelsIn, xHeight+padHeight, xWidth+padWidth] - if (padHeightBegin || padHeightEnd || padWidthBegin || padWidthEnd) { - const int64_t paddedHeight = xHeight + padHeightBegin + padHeightEnd; - const int64_t paddedWidth = xWidth + padWidthBegin + padWidthEnd; - auto paddedType = RankedTensorType::get({batchSize, numChannelsIn, paddedHeight, paddedWidth}, elemType); - SmallVector lowPads = {rewriter.getIndexAttr(0), - rewriter.getIndexAttr(0), - rewriter.getIndexAttr(padHeightBegin), - rewriter.getIndexAttr(padWidthBegin)}; - SmallVector highPads = {rewriter.getIndexAttr(0), - rewriter.getIndexAttr(0), - rewriter.getIndexAttr(padHeightEnd), - rewriter.getIndexAttr(padWidthEnd)}; - auto padOp = tensor::PadOp::create(rewriter, loc, paddedType, paddedInput, lowPads, highPads); - auto* padBlock = new Block(); - for (int i = 0; i < 4; i++) - padBlock->addArgument(rewriter.getIndexType(), loc); - padOp.getRegion().push_back(padBlock); - rewriter.setInsertionPointToStart(padBlock); - auto zero = arith::ConstantOp::create(rewriter, loc, elemType, rewriter.getFloatAttr(elemType, 0.0)); - tensor::YieldOp::create(rewriter, loc, zero.getResult()); - rewriter.setInsertionPointAfter(padOp); - paddedInput = padOp.getResult(); - } + // Pad input with zeros if needed: + // [1, numChannelsIn, xHeight, xWidth] -> [1, numChannelsIn, xHeight+padHeight, xWidth+padWidth] + if (padHeightBegin || padHeightEnd || padWidthBegin || padWidthEnd) { + const int64_t paddedHeight = xHeight + padHeightBegin + padHeightEnd; + const int64_t paddedWidth = xWidth + padWidthBegin + padWidthEnd; + auto paddedType = RankedTensorType::get({batchSize, numChannelsIn, paddedHeight, paddedWidth}, elemType); + SmallVector lowPads = {rewriter.getIndexAttr(0), + rewriter.getIndexAttr(0), + rewriter.getIndexAttr(padHeightBegin), + rewriter.getIndexAttr(padWidthBegin)}; + SmallVector highPads = {rewriter.getIndexAttr(0), + rewriter.getIndexAttr(0), + rewriter.getIndexAttr(padHeightEnd), + rewriter.getIndexAttr(padWidthEnd)}; + auto padOp = tensor::PadOp::create(rewriter, loc, paddedType, paddedInput, lowPads, highPads); + auto* padBlock = new Block(); + for (int i = 0; i < 4; i++) + padBlock->addArgument(rewriter.getIndexType(), loc); + padOp.getRegion().push_back(padBlock); + rewriter.setInsertionPointToStart(padBlock); + auto zero = arith::ConstantOp::create(rewriter, loc, elemType, rewriter.getFloatAttr(elemType, 0.0)); + tensor::YieldOp::create(rewriter, loc, zero.getResult()); + rewriter.setInsertionPointAfter(padOp); + paddedInput = padOp.getResult(); + } - // Build im2col [numPatches, patchSize] incrementally to keep the IR small - // until the late PIM unrolling step. - Value im2colInit = tensor::EmptyOp::create(rewriter, loc, im2colType.getShape(), elemType); - auto c0 = arith::ConstantIndexOp::create(rewriter, loc, 0); - auto c1 = arith::ConstantIndexOp::create(rewriter, loc, 1); - auto cNumPatches = arith::ConstantIndexOp::create(rewriter, loc, numPatches); - auto cNumPatchesPerBatch = arith::ConstantIndexOp::create(rewriter, loc, numPatchesPerBatch); - auto cOutWidth = arith::ConstantIndexOp::create(rewriter, loc, outWidth); - auto cStrideHeight = arith::ConstantIndexOp::create(rewriter, loc, strideHeight); - auto cStrideWidth = arith::ConstantIndexOp::create(rewriter, loc, strideWidth); + // Build im2col [numPatches, patchSize] incrementally to keep the IR small + // until the late PIM unrolling step. + Value im2colInit = tensor::EmptyOp::create(rewriter, loc, im2colType.getShape(), elemType); + auto c0 = arith::ConstantIndexOp::create(rewriter, loc, 0); + auto c1 = arith::ConstantIndexOp::create(rewriter, loc, 1); + auto cNumPatches = arith::ConstantIndexOp::create(rewriter, loc, numPatches); + auto cNumPatchesPerBatch = arith::ConstantIndexOp::create(rewriter, loc, numPatchesPerBatch); + auto cOutWidth = arith::ConstantIndexOp::create(rewriter, loc, outWidth); + auto cStrideHeight = arith::ConstantIndexOp::create(rewriter, loc, strideHeight); + auto cStrideWidth = arith::ConstantIndexOp::create(rewriter, loc, strideWidth); - auto im2colLoop = scf::ForOp::create(rewriter, loc, c0, cNumPatches, c1, ValueRange {im2colInit}); - rewriter.setInsertionPointToStart(im2colLoop.getBody()); + auto im2colLoop = scf::ForOp::create(rewriter, loc, c0, cNumPatches, c1, ValueRange {im2colInit}); + rewriter.setInsertionPointToStart(im2colLoop.getBody()); - Value patchIndex = im2colLoop.getInductionVar(); - Value im2colAcc = im2colLoop.getRegionIterArgs().front(); + Value patchIndex = im2colLoop.getInductionVar(); + Value im2colAcc = im2colLoop.getRegionIterArgs().front(); - Value batchIndex = arith::DivUIOp::create(rewriter, loc, patchIndex, cNumPatchesPerBatch); - Value batchPatchIndex = arith::RemUIOp::create(rewriter, loc, patchIndex, cNumPatchesPerBatch); - Value outHeightIndex = arith::DivUIOp::create(rewriter, loc, batchPatchIndex, cOutWidth); - Value outWidthIndex = arith::RemUIOp::create(rewriter, loc, batchPatchIndex, cOutWidth); - Value inputHeightOffset = arith::MulIOp::create(rewriter, loc, outHeightIndex, cStrideHeight); - Value inputWidthOffset = arith::MulIOp::create(rewriter, loc, outWidthIndex, cStrideWidth); + Value batchIndex = arith::DivUIOp::create(rewriter, loc, patchIndex, cNumPatchesPerBatch); + Value batchPatchIndex = arith::RemUIOp::create(rewriter, loc, patchIndex, cNumPatchesPerBatch); + Value outHeightIndex = arith::DivUIOp::create(rewriter, loc, batchPatchIndex, cOutWidth); + Value outWidthIndex = arith::RemUIOp::create(rewriter, loc, batchPatchIndex, cOutWidth); + Value inputHeightOffset = arith::MulIOp::create(rewriter, loc, outHeightIndex, cStrideHeight); + Value inputWidthOffset = arith::MulIOp::create(rewriter, loc, outWidthIndex, cStrideWidth); - SmallVector offsets = {batchIndex, rewriter.getIndexAttr(0), inputHeightOffset, inputWidthOffset}; - SmallVector sizes = {rewriter.getIndexAttr(1), - rewriter.getIndexAttr(numChannelsIn), - rewriter.getIndexAttr(wHeight), - rewriter.getIndexAttr(wWidth)}; - SmallVector strides = {rewriter.getIndexAttr(1), - rewriter.getIndexAttr(1), - rewriter.getIndexAttr(dilationHeight), - rewriter.getIndexAttr(dilationWidth)}; - auto patchType = RankedTensorType::get({1, numChannelsIn, wHeight, wWidth}, elemType); - Value patch = tensor::ExtractSliceOp::create(rewriter, loc, patchType, paddedInput, offsets, sizes, strides); + SmallVector offsets = {batchIndex, rewriter.getIndexAttr(0), inputHeightOffset, inputWidthOffset}; + SmallVector sizes = {rewriter.getIndexAttr(1), + rewriter.getIndexAttr(numChannelsIn), + rewriter.getIndexAttr(wHeight), + rewriter.getIndexAttr(wWidth)}; + SmallVector strides = {rewriter.getIndexAttr(1), + rewriter.getIndexAttr(1), + rewriter.getIndexAttr(dilationHeight), + rewriter.getIndexAttr(dilationWidth)}; + auto patchType = RankedTensorType::get({1, numChannelsIn, wHeight, wWidth}, elemType); + Value patch = tensor::ExtractSliceOp::create(rewriter, loc, patchType, paddedInput, offsets, sizes, strides); - Value row = tensor::CollapseShapeOp::create(rewriter, - loc, - im2colRowType, - patch, - SmallVector { - {0}, - {1, 2, 3} + Value row = tensor::CollapseShapeOp::create(rewriter, + loc, + im2colRowType, + patch, + SmallVector { + {0}, + {1, 2, 3} + }); + + SmallVector rowOffsets = {patchIndex, rewriter.getIndexAttr(0)}; + SmallVector rowSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(patchSize)}; + SmallVector rowStrides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; + Value updatedIm2col = + tensor::InsertSliceOp::create(rewriter, loc, row, im2colAcc, rowOffsets, rowSizes, rowStrides); + scf::YieldOp::create(rewriter, loc, updatedIm2col); + + rewriter.setInsertionPointAfter(im2colLoop); + Value im2col = im2colLoop.getResult(0); + + Value gemmInputRows = im2col; + if (packFactor != 1) { + const int64_t paddedNumPatches = packedNumRows * packFactor; + auto groupedType = RankedTensorType::get({packedNumRows, packFactor, patchSize}, elemType); + auto packedType = RankedTensorType::get({packedNumRows, packFactor * patchSize}, elemType); + Value paddedIm2col = createPaddedRows(im2col, im2colType, paddedNumPatches, rewriter, loc); + Value groupedIm2col = tensor::ExpandShapeOp::create(rewriter, + loc, + groupedType, + paddedIm2col, + SmallVector { + {0, 1}, + {2} + }); + gemmInputRows = tensor::CollapseShapeOp::create(rewriter, + loc, + packedType, + groupedIm2col, + SmallVector { + {0}, + {1, 2} + }); + } + + spatial::SpatYieldOp::create(rewriter, loc, gemmInputRows); }); - SmallVector rowOffsets = {patchIndex, rewriter.getIndexAttr(0)}; - SmallVector rowSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(patchSize)}; - SmallVector rowStrides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; - Value updatedIm2col = - tensor::InsertSliceOp::create(rewriter, loc, row, im2colAcc, rowOffsets, rowSizes, rowStrides); - scf::YieldOp::create(rewriter, loc, updatedIm2col); - - rewriter.setInsertionPointAfter(im2colLoop); - Value im2col = im2colLoop.getResult(0); - - Value gemmInputRows = im2col; - if (packFactor != 1) { - const int64_t paddedNumPatches = packedNumRows * packFactor; - auto groupedType = RankedTensorType::get({packedNumRows, packFactor, patchSize}, elemType); - auto packedType = RankedTensorType::get({packedNumRows, packFactor * patchSize}, elemType); - Value paddedIm2col = createPaddedRows(im2col, im2colType, paddedNumPatches, rewriter, loc); - Value groupedIm2col = tensor::ExpandShapeOp::create(rewriter, - loc, - groupedType, - paddedIm2col, - SmallVector { - {0, 1}, - {2} - }); - gemmInputRows = tensor::CollapseShapeOp::create(rewriter, - loc, - packedType, - groupedIm2col, - SmallVector { - {0}, - {1, 2} - }); - } - - SmallVector rowResults; - rowResults.reserve(packedNumRows); - for (int64_t rowIdx = 0; rowIdx < packedNumRows; rowIdx++) { - SmallVector offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)}; - SmallVector sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(packFactor * patchSize)}; - SmallVector strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; - rowResults.push_back( - tensor::ExtractSliceOp::create(rewriter, loc, gemmInputRowType, gemmInputRows, offsets, sizes, strides)); - } - spatial::SpatYieldOp::create(rewriter, loc, rowResults); - }); - - SmallVector rows; - rows.reserve(im2colComputeOp.getNumResults()); - for (Value result : im2colComputeOp.getResults()) - rows.push_back(result); - return rows; + return im2colComputeOp.getResult(0); } static Value createCollectedConvOutput(ValueRange gemmRows, @@ -319,15 +306,12 @@ static Value createCollectedConvOutput(ValueRange gemmRows, auto collectComputeOp = createSpatCompute(rewriter, loc, convType, {}, gemmRows, [&](ValueRange gemmRowArgs) { Value gemmOut; if (packFactor == 1) { - gemmOut = gemmRowArgs.size() == 1 ? gemmRowArgs.front() - : tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, gemmRowArgs).getResult(); + gemmOut = createSpatConcat(rewriter, loc, /*axis=*/0, gemmRowArgs); } else { auto expandedType = RankedTensorType::get({packedNumRows, packFactor, numChannelsOut}, outType.getElementType()); auto paddedType = RankedTensorType::get({paddedNumPatches, numChannelsOut}, outType.getElementType()); - Value packedOutput = gemmRowArgs.size() == 1 - ? gemmRowArgs.front() - : tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, gemmRowArgs).getResult(); + Value packedOutput = createSpatConcat(rewriter, loc, /*axis=*/0, gemmRowArgs); Value expandedOutput = tensor::ExpandShapeOp::create(rewriter, loc, expandedType, @@ -509,35 +493,36 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp, // A_packed: [ceil(numPatches / N), N * patchSize] // B_packed: [N * patchSize, N * cOut] // Y_packed: [ceil(numPatches / N), N * cOut] - auto gemmInputRowType = RankedTensorType::get({1, effectiveMaxParallelPixels * patchSize}, elemType); - auto gemmOutputRowType = - RankedTensorType::get({1, effectiveMaxParallelPixels * numChannelsOut}, outType.getElementType()); - SmallVector gemmInputRows = createIm2colRowComputes(x, - xType, - im2colType, - rowType, - gemmInputRowType, - batchSize, - numChannelsIn, - xHeight, - xWidth, - wHeight, - wWidth, - padHeightBegin, - padHeightEnd, - padWidthBegin, - padWidthEnd, - strideHeight, - strideWidth, - dilationHeight, - dilationWidth, - outWidth, - patchSize, - numPatches, - numPatchesPerBatch, - effectiveMaxParallelPixels, - rewriter, - loc); + const int64_t packedNumRows = ceilIntegerDivide(numPatches, effectiveMaxParallelPixels); + auto gemmInputRowsType = RankedTensorType::get({packedNumRows, effectiveMaxParallelPixels * patchSize}, elemType); + auto gemmOutputRowsType = + RankedTensorType::get({packedNumRows, effectiveMaxParallelPixels * numChannelsOut}, outType.getElementType()); + Value gemmInputRows = createIm2colRowComputes(x, + xType, + im2colType, + rowType, + gemmInputRowsType, + batchSize, + numChannelsIn, + xHeight, + xWidth, + wHeight, + wWidth, + padHeightBegin, + padHeightEnd, + padWidthBegin, + padWidthEnd, + strideHeight, + strideWidth, + dilationHeight, + dilationWidth, + outWidth, + patchSize, + numPatches, + numPatchesPerBatch, + effectiveMaxParallelPixels, + rewriter, + loc); Value gemmB = buildPackedWeight(wDenseAttr, wTrans, @@ -553,25 +538,20 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp, Value gemmC = buildPackedBias( hasB, gemmBias, biasMatrix, biasDenseAttr, outType, numChannelsOut, effectiveMaxParallelPixels, rewriter, loc); - SmallVector gemmRows; - gemmRows.reserve(gemmInputRows.size()); - for (Value gemmInputRow : gemmInputRows) { - Value gemmRow = ONNXGemmOp::create(rewriter, - loc, - gemmOutputRowType, - gemmInputRow, - gemmB, - gemmC, - rewriter.getF32FloatAttr(1.0f), - rewriter.getF32FloatAttr(1.0f), - rewriter.getBoolAttr(false), - rewriter.getBoolAttr(false)) - .getY(); - gemmRows.push_back(gemmRow); - } + Value gemmRows = ONNXGemmOp::create(rewriter, + loc, + gemmOutputRowsType, + gemmInputRows, + gemmB, + gemmC, + rewriter.getF32FloatAttr(1.0f), + rewriter.getF32FloatAttr(1.0f), + rewriter.getBoolAttr(false), + rewriter.getBoolAttr(false)) + .getY(); rewriter.replaceOp(convOp, - createCollectedConvOutput(gemmRows, + createCollectedConvOutput(ValueRange {gemmRows}, convOp.getType(), gemmOutType, nhwcType, diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp index fe0d1cb..55ff2c5 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp @@ -1,6 +1,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/IRMapping.h" #include "mlir/IR/Location.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" @@ -65,6 +66,66 @@ struct GemvToSpatialCompute : OpConversionPattern { ConversionPatternRewriter& rewriter) const override; }; +struct GemmToSpatialComputeBatch : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(ONNXGemmOp gemmOp, + ONNXGemmOpAdaptor gemmOpAdaptor, + ConversionPatternRewriter& rewriter) const override; +}; + +static SmallVector materializeBatchRowSlices(Value matrix, + RankedTensorType matrixType, + ConversionPatternRewriter& rewriter, + Location loc) { + const int64_t numRows = matrixType.getDimSize(0); + auto rowType = RankedTensorType::get({1, matrixType.getDimSize(1)}, matrixType.getElementType()); + SmallVector resultTypes(static_cast(numRows), rowType); + + auto buildRowSlices = [&](Value matrixArg) { + auto extractRowsOp = spatial::SpatExtractRowsOp::create(rewriter, loc, TypeRange(resultTypes), matrixArg); + return SmallVector(extractRowsOp->result_begin(), extractRowsOp->result_end()); + }; + + auto cloneBatchInputChainIntoSliceCompute = + [&](Value rootInput, SmallVector chainOps, Value rootValue) -> SmallVector { + auto sliceCompute = + createSpatCompute<1>(rewriter, loc, TypeRange(resultTypes), {}, ValueRange {rootInput}, [&](Value input) { + Value transformedMatrix = input; + if (!chainOps.empty()) { + IRMapping mapper; + mapper.map(rootValue, input); + for (Operation* chainOp : chainOps) + rewriter.clone(*chainOp, mapper); + transformedMatrix = cast(mapper.lookup(matrix)); + } + spatial::SpatYieldOp::create(rewriter, loc, buildRowSlices(transformedMatrix)); + }); + SmallVector rowSlices(sliceCompute->result_begin(), sliceCompute->result_end()); + return rowSlices; + }; + + SmallVector chainOps; + Value rootValue = matrix; + while (Operation* definingOp = rootValue.getDefiningOp()) { + if (auto rootCompute = dyn_cast(definingOp)) { + SmallVector reversedChainOps(chainOps.rbegin(), chainOps.rend()); + return cloneBatchInputChainIntoSliceCompute( + rootCompute.getResult(cast(rootValue).getResultNumber()), reversedChainOps, rootValue); + } + + if (definingOp->getNumOperands() != 1) + break; + if (!isa(definingOp)) + break; + + chainOps.push_back(definingOp); + rootValue = definingOp->getOperand(0); + } + + return buildRowSlices(matrix); +} + } // namespace LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp, @@ -156,8 +217,7 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp, } auto concatComputeOp = createSpatCompute(rewriter, loc, gemmOp.getType(), {}, gemvOps, [&](ValueRange gemvOpsArgs) { - auto concatOp = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, gemvOpsArgs); - spatial::SpatYieldOp::create(rewriter, loc, concatOp.getResult()); + spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, /*axis=*/0, gemvOpsArgs)); }); rewriter.replaceOp(gemmOp, concatComputeOp); @@ -313,15 +373,116 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp, auto concatComputeOp = createSpatCompute(rewriter, gemmLoc, gemmOp.getType(), {}, outHSlices, [&](ValueRange blockArgs) { - auto concatOp = tensor::ConcatOp::create(rewriter, gemmLoc, /*axis=*/1, blockArgs); - spatial::SpatYieldOp::create(rewriter, gemmLoc, concatOp.getResult()); + spatial::SpatYieldOp::create(rewriter, gemmLoc, createSpatConcat(rewriter, gemmLoc, /*axis=*/1, blockArgs)); }); rewriter.replaceOp(gemmOp, concatComputeOp); return success(); } +LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp, + ONNXGemmOpAdaptor gemmOpAdaptor, + ConversionPatternRewriter& rewriter) const { + Location loc = gemmOp.getLoc(); + Value a = gemmOpAdaptor.getA(); + Value b = gemmOpAdaptor.getB(); + Value c = gemmOpAdaptor.getC(); + + assert("A should have been transposed already" && !gemmOpAdaptor.getTransA()); + + bool hasC = !isa(c.getDefiningOp()); + + auto aType = cast(a.getType()); + auto bType = cast(b.getType()); + auto outType = cast(gemmOp.getY().getType()); + assert("Only support static shapes" && aType.hasStaticShape() && bType.hasStaticShape() && outType.hasStaticShape()); + + const int64_t numOutRows = aType.getDimSize(0); + if (numOutRows <= 1) + return failure(); + + // Only handle the single-tile case: K <= crossbarSize and N <= crossbarSize + if (aType.getDimSize(1) > static_cast(crossbarSize.getValue()) + || outType.getDimSize(1) > static_cast(crossbarSize.getValue())) + return failure(); + + auto scaledB = materializeScaledConstantTensor(b, gemmOpAdaptor.getAlpha().convertToFloat(), rewriter, loc); + if (failed(scaledB)) + return failure(); + b = *scaledB; + bType = cast(b.getType()); + + if (gemmOpAdaptor.getTransB()) { + auto bShape = bType.getShape(); + auto transposedType = bType.cloneWith(ArrayRef({bShape[1], bShape[0]}), bType.getElementType()); + b = ONNXTransposeOp::create(rewriter, loc, transposedType, b, rewriter.getI64ArrayAttr({1, 0})); + bType = cast(b.getType()); + } + (void) bType; + + Value sharedBias; + if (hasC) { + auto scaledC = materializeScaledConstantTensor(c, gemmOpAdaptor.getBeta().convertToFloat(), rewriter, loc); + if (failed(scaledC)) + return failure(); + c = *scaledC; + auto cType = cast(c.getType()); + if (cType.getRank() == 1) { + auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType()); + c = tensor::ExpandShapeOp::create(rewriter, + loc, + expandedType, + c, + SmallVector { + {0, 1} + }); + cType = cast(c.getType()); + } + assert("Only support rank 2 tensor for C" && cType.getRank() == 2); + // Row-specific bias can't share a single template body; fall through to GemmToManyGemv + if (cType.getDimSize(0) == numOutRows && numOutRows > 1) + return failure(); + if (cType.getDimSize(0) == 1 && cType.getDimSize(1) == 1) + c = broadcastToVector(c, outType.getDimSize(1), rewriter, loc); + sharedBias = c; + } + + SmallVector aSlices = materializeBatchRowSlices(a, aType, rewriter, loc); + auto aSliceType = cast(aSlices.front().getType()); + + auto outRowType = RankedTensorType::get({1, outType.getDimSize(1)}, outType.getElementType()); + SmallVector resultTypes(static_cast(numOutRows), outRowType); + SmallVector weights(static_cast(numOutRows), b); + + auto batchOp = spatial::SpatComputeBatch::create(rewriter, + loc, + TypeRange(resultTypes), + rewriter.getI32IntegerAttr(static_cast(numOutRows)), + ValueRange(weights), + ValueRange(aSlices)); + + Block* body = rewriter.createBlock( + &batchOp.getBody(), batchOp.getBody().end(), TypeRange {aSliceType}, SmallVector(1, loc)); + rewriter.setInsertionPointToEnd(body); + + Value vmmResult = spatial::SpatWeightedVMMOp::create(rewriter, loc, outRowType, 0, body->getArgument(0)).getResult(); + Value laneResult = vmmResult; + if (sharedBias) + laneResult = spatial::SpatVAddOp::create(rewriter, loc, outRowType, vmmResult, sharedBias).getResult(); + spatial::SpatYieldOp::create(rewriter, loc, laneResult); + + rewriter.setInsertionPointAfter(batchOp); + SmallVector laneResults(batchOp->result_begin(), batchOp->result_end()); + auto concatComputeOp = createSpatCompute(rewriter, loc, gemmOp.getType(), {}, laneResults, [&](ValueRange args) { + spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, /*axis=*/0, args)); + }); + + rewriter.replaceOp(gemmOp, concatComputeOp); + return success(); +} + void populateGemmPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { + patterns.insert(ctx, PatternBenefit(2)); patterns.insert(ctx); patterns.insert(ctx); } diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp index 5bf5801..0cbe033 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp @@ -232,9 +232,7 @@ struct MatMulToGemm : OpRewritePattern { })); } - Value result = batchResults.size() == 1 - ? batchResults.front() - : tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, batchResults).getResult(); + Value result = createSpatConcat(rewriter, loc, /*axis=*/0, batchResults); rewriter.replaceOp(matmulOp, result); return success(); } diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/ReduceMean.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/ReduceMean.cpp index a90f107..d236593 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/ReduceMean.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/ReduceMean.cpp @@ -100,8 +100,7 @@ static Value buildReduceMeanKeepdims(Value input, for (Value slice : slices) reducedSlices.push_back(buildReduceMeanKeepdims(slice, reducedAxes, axis + 1, leafType, rewriter, loc)); - return reducedSlices.size() == 1 ? reducedSlices.front() - : tensor::ConcatOp::create(rewriter, loc, axis, reducedSlices).getResult(); + return createSpatConcat(rewriter, loc, axis, reducedSlices); } static Value squeezeReducedAxes(Value keepdimsValue, diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Pool.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Pool.cpp index 8d6000a..56df1f9 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Pool.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Pool.cpp @@ -33,9 +33,7 @@ static int64_t getOptionalI64(std::optional arrayAttr, size_t index, static Value concatAlongAxis(ConversionPatternRewriter& rewriter, Location loc, int64_t axis, ArrayRef values) { assert(!values.empty() && "Expected at least one value to concatenate."); - if (values.size() == 1) - return values.front(); - return tensor::ConcatOp::create(rewriter, loc, axis, values); + return createSpatConcat(rewriter, loc, axis, values); } static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Location loc, Value tile) { diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Softmax.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Softmax.cpp index e141615..0f43e99 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Softmax.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Softmax.cpp @@ -47,8 +47,7 @@ buildSoftmax(Value input, int64_t softmaxAxis, int64_t axis, ConversionPatternRe for (Value slice : slices) rebuiltSlices.push_back(buildSoftmax(slice, softmaxAxis, axis + 1, rewriter, loc)); - return rebuiltSlices.size() == 1 ? rebuiltSlices.front() - : tensor::ConcatOp::create(rewriter, loc, axis, rebuiltSlices).getResult(); + return createSpatConcat(rewriter, loc, axis, rebuiltSlices); } struct SoftmaxToSpatialCompute : OpConversionPattern { diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Concat.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Concat.cpp index a32c551..87f8a21 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Concat.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Concat.cpp @@ -2,6 +2,7 @@ #include "mlir/IR/PatternMatch.h" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" using namespace mlir; @@ -17,7 +18,7 @@ struct Concat : public OpConversionPattern { auto inputs = adaptor.getInputs(); int64_t axis = adaptor.getAxis(); - rewriter.replaceOpWithNewOp(maxpoolOp, axis, inputs); + rewriter.replaceOp(maxpoolOp, createSpatConcat(rewriter, maxpoolOp.getLoc(), axis, inputs)); return success(); } diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Gather.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Gather.cpp index 6605dc1..2e59cf1 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Gather.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Gather.cpp @@ -49,7 +49,7 @@ static Value concatGatherSlices(Value data, } if (slices.empty()) return {}; - return slices.size() == 1 ? slices.front() : tensor::ConcatOp::create(rewriter, loc, axis, slices).getResult(); + return createSpatConcat(rewriter, loc, axis, slices); } static Value addLeadingGatherDim(Value value, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) { @@ -130,9 +130,7 @@ struct Gather : OpConversionPattern { return failure(); rows.push_back(addLeadingGatherDim(gatheredRow, axis, rewriter, loc)); } - result = rows.size() == 1 - ? rows.front() - : tensor::ConcatOp::create(rewriter, loc, /*axis=*/axis, rows).getResult(); + result = createSpatConcat(rewriter, loc, /*axis=*/axis, rows); } else { return failure(); diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Resize.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Resize.cpp index d5e340e..2e6dbcf 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Resize.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Resize.cpp @@ -50,7 +50,7 @@ static Value buildNearestResize(Value input, slices.push_back(buildNearestResize(slice, inputShape, outputShape, axis + 1, rewriter, loc)); } - return slices.size() == 1 ? slices.front() : tensor::ConcatOp::create(rewriter, loc, axis, slices).getResult(); + return createSpatConcat(rewriter, loc, axis, slices); } struct Resize : OpConversionPattern { diff --git a/src/PIM/Conversion/SpatialToPim/Common.cpp b/src/PIM/Conversion/SpatialToPim/Common.cpp index 0f6859f..6788872 100644 --- a/src/PIM/Conversion/SpatialToPim/Common.cpp +++ b/src/PIM/Conversion/SpatialToPim/Common.cpp @@ -7,23 +7,12 @@ #include #include "Common.hpp" -#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" using namespace llvm; using namespace mlir; namespace onnx_mlir { -namespace { - -IntegerAttr getRequiredI32Attr(Builder& builder, Operation* op, llvm::StringRef attrName) { - auto attr = op->getAttrOfType(attrName); - assert(attr && "required precomputed channel attr is missing"); - return IntegerAttr::get(builder.getI32Type(), attr.getInt()); -} - -} // namespace - size_t getSliceActualOffset(tensor::ExtractSliceOp& sliceOp, ShapedType& inputShape) { /* EXAMPLE RUN: @@ -74,37 +63,6 @@ IntegerAttr getTensorSizeInBytesAttr(Builder& builder, mlir::Value value) { return builder.getI32IntegerAttr(static_cast(getShapedTypeSizeInBytes(cast(value.getType())))); } -IntegerAttr getSpatialChannelSourceCoreIdAttr(Builder& builder, mlir::Value channel) { - auto channelNewOp = channel.getDefiningOp(); - assert(channelNewOp && "spatial channel value must come from spat.channel_new"); - return getRequiredI32Attr(builder, channelNewOp, kChannelSourceCoreIdAttrName); -} - -IntegerAttr getSpatialChannelTargetCoreIdAttr(Builder& builder, mlir::Value channel) { - auto channelNewOp = channel.getDefiningOp(); - assert(channelNewOp && "spatial channel value must come from spat.channel_new"); - return getRequiredI32Attr(builder, channelNewOp, kChannelTargetCoreIdAttrName); -} - -bool hasSpatialChannelSourceCoreIdAttr(mlir::Value channel) { - auto channelNewOp = channel.getDefiningOp(); - return channelNewOp && channelNewOp->hasAttr(kChannelSourceCoreIdAttrName); -} - -bool hasSpatialChannelTargetCoreIdAttr(mlir::Value channel) { - auto channelNewOp = channel.getDefiningOp(); - return channelNewOp && channelNewOp->hasAttr(kChannelTargetCoreIdAttrName); -} - -mlir::Value -createPimReceiveFromSpatialChannel(PatternRewriter& rewriter, Location loc, mlir::Value output, mlir::Value channel) { - mlir::Value outputBuffer = getBestOutputTensorFromOperandsOrAllocate(rewriter, output.getDefiningOp()); - auto sizeAttr = getTensorSizeInBytesAttr(rewriter, output); - auto sourceCoreIdAttr = getSpatialChannelSourceCoreIdAttr(rewriter, channel); - return pim::PimReceiveOp::create(rewriter, loc, outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr) - .getOutput(); -} - Operation* getEarliestUserWithinBlock(mlir::Value value) { auto users = value.getUsers(); diff --git a/src/PIM/Conversion/SpatialToPim/Common.hpp b/src/PIM/Conversion/SpatialToPim/Common.hpp index d8e9d9d..e006189 100644 --- a/src/PIM/Conversion/SpatialToPim/Common.hpp +++ b/src/PIM/Conversion/SpatialToPim/Common.hpp @@ -2,16 +2,10 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "llvm/ADT/StringRef.h" - #include "src/Accelerators/PIM/Common/PimCommon.hpp" -#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" namespace onnx_mlir { -inline constexpr llvm::StringLiteral kChannelSourceCoreIdAttrName = "precomp_source_core_id"; -inline constexpr llvm::StringLiteral kChannelTargetCoreIdAttrName = "precomp_target_core_id"; - /** * \brief Get the offset of the ExtractSliceOp based on its static offsets and * its static tensor input. @@ -30,17 +24,6 @@ size_t getShapedTypeSizeInBytes(mlir::ShapedType shapedType); mlir::IntegerAttr getTensorSizeInBytesAttr(mlir::Builder& builder, mlir::Value value); -mlir::IntegerAttr getSpatialChannelSourceCoreIdAttr(mlir::Builder& builder, mlir::Value channel); - -mlir::IntegerAttr getSpatialChannelTargetCoreIdAttr(mlir::Builder& builder, mlir::Value channel); - -bool hasSpatialChannelSourceCoreIdAttr(mlir::Value channel); - -bool hasSpatialChannelTargetCoreIdAttr(mlir::Value channel); - -mlir::Value createPimReceiveFromSpatialChannel( - mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value output, mlir::Value channel); - template size_t rangeLength(const mlir::iterator_range range) { return std::distance(range.begin(), range.end()); diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPim.td b/src/PIM/Conversion/SpatialToPim/SpatialToPim.td index 0e323cd..a0fbce5 100644 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPim.td +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPim.td @@ -9,17 +9,6 @@ include "src/Accelerators/PIM/Dialect/Spatial/Spatial.td" include "src/Accelerators/PIM/Dialect/Pim/Pim.td" #endif // OP_BASE -def HasSpatialChannelSourceCoreIdAttr: Constraint< - CPred<"onnx_mlir::hasSpatialChannelSourceCoreIdAttr($0)">, - "spatial channel has precomputed source core id">; - -def HasSpatialChannelTargetCoreIdAttr: Constraint< - CPred<"onnx_mlir::hasSpatialChannelTargetCoreIdAttr($0)">, - "spatial channel has precomputed target core id">; - -def createPimReceiveFromSpatialChannelValue: NativeCodeCall< - "onnx_mlir::createPimReceiveFromSpatialChannel($_builder, $_loc, $0, $1)">; - def onnxToPimTranspose : Pat< (ONNXTransposeOp:$srcOpRes $data, $perms), (PimTransposeOp $data, $perms, @@ -80,18 +69,4 @@ def spatToPimVSoftmax : Pat< (NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes)) >; -def spatChannelSendToPimSend : Pat< - (SpatChannelSendOp $channel, $input), - (PimSendOp $input, - (NativeCodeCall<"onnx_mlir::getTensorSizeInBytesAttr($_builder, $0)"> $input), - (NativeCodeCall<"onnx_mlir::getSpatialChannelTargetCoreIdAttr($_builder, $0)"> $channel)), - [(HasSpatialChannelTargetCoreIdAttr $channel)] ->; - -def spatChannelReceiveToPimReceive : Pat< - (SpatChannelReceiveOp:$srcOpRes $channel), - (createPimReceiveFromSpatialChannelValue $srcOpRes, $channel), - [(HasSpatialChannelSourceCoreIdAttr $channel)] ->; - #endif // SPATIAL_TO_PIM diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp index ebdb429..ed3454b 100644 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp @@ -19,6 +19,7 @@ #include #include +#include #include #include @@ -26,6 +27,7 @@ #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/Channels.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Pass/PIMPasses.h" #include "src/Compiler/CompilerOptions.hpp" @@ -59,21 +61,10 @@ private: LogicalResult allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter); - void runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter); - void - addReceiveOps(Value channelSourceOp, spatial::SpatChannelNewOp& channel, bool useBroadcastOp, IRRewriter& rewriter); - void replaceBlockArgumentWithRecvOp(spatial::SpatCompute& computeOp, - unsigned int argIndex, - Value channelSourceOp, - Value consumerValue, - spatial::SpatChannelNewOp& channel, - bool useBroadcastOp, - IRRewriter& rewriter); void markOpToRemove(Operation* op); - void annotateChannelCoreIds(func::FuncOp funcOp); - void lowerBroadcastChannelOps(func::FuncOp funcOp, IRRewriter& rewriter); void runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter& rewriter); + void runOnComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, IRRewriter& rewriter); void enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter); @@ -88,6 +79,7 @@ static bool isChannelUseChainOp(Operation* op) { tensor::ExpandShapeOp, tensor::CastOp, tosa::ReshapeOp, + ONNXTransposeOp, pim::PimTransposeOp>(op); } @@ -110,27 +102,338 @@ static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewri } } -static size_t countComputeLeafUsers(Value value) { - size_t leafUserCount = 0; +static int32_t translateSpatialCoreIdToPimCoreId(size_t spatialCoreId) { return static_cast(spatialCoreId); } - auto walkUses = [&](Value currentValue, auto& self) -> void { - for (OpOperand& use : currentValue.getUses()) { - Operation* owner = use.getOwner(); - if (isa(owner)) { - leafUserCount++; - continue; - } +static int32_t getPimCoreIdForComputeOp(spatial::SpatCompute computeOp, size_t& fallbackCoreId) { + if (auto spatialCoreIdAttr = computeOp->getAttrOfType(onnx_mlir::kCoreIdAttrName)) + return static_cast(spatialCoreIdAttr.getInt()); + return static_cast(fallbackCoreId++); +} - if (!isChannelUseChainOp(owner)) - llvm_unreachable("Channel use chain contains unsupported op"); +static SmallVector getPimCoreIdsForBatchOp(spatial::SpatComputeBatch computeBatchOp, size_t& fallbackCoreId) { + if (auto coreIdsAttr = computeBatchOp->getAttrOfType(onnx_mlir::kCoreIdAttrName)) + return SmallVector(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end()); - assert(owner->getNumResults() == 1 && "Channel use chain op must have a single result"); - self(owner->getResult(0), self); - } + SmallVector coreIds; + coreIds.reserve(static_cast(computeBatchOp.getLaneCount())); + for (int32_t lane = 0; lane < computeBatchOp.getLaneCount(); ++lane) + coreIds.push_back(static_cast(fallbackCoreId++)); + return coreIds; +} + +static void lowerChannelSend(spatial::SpatChannelSendOp sendOp, IRRewriter& rewriter) { + auto sizeAttr = getTensorSizeInBytesAttr(rewriter, sendOp.getInput()); + auto targetCoreIdAttr = rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(sendOp.getTargetCoreId())); + + rewriter.setInsertionPoint(sendOp); + PimSendOp::create(rewriter, sendOp.getLoc(), sendOp.getInput(), sizeAttr, targetCoreIdAttr); + rewriter.eraseOp(sendOp); +} + +static void lowerChannelReceive(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter) { + if (receiveOp->use_empty()) { + rewriter.eraseOp(receiveOp); + return; + } + + auto outputType = cast(receiveOp.getResult().getType()); + rewriter.setInsertionPoint(receiveOp); + auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveOp.getLoc(), outputType); + auto sizeAttr = getTensorSizeInBytesAttr(rewriter, receiveOp.getResult()); + auto sourceCoreIdAttr = rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(receiveOp.getSourceCoreId())); + + Value received = + PimReceiveOp::create(rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr) + .getOutput(); + rewriter.replaceOp(receiveOp, received); +} + +static void lowerChannelSendMany(spatial::SpatChannelSendManyOp sendManyOp, IRRewriter& rewriter) { + rewriter.setInsertionPoint(sendManyOp); + for (auto [input, targetCoreId] : llvm::zip(sendManyOp.getInputs(), sendManyOp.getTargetCoreIds())) { + auto sizeAttr = getTensorSizeInBytesAttr(rewriter, input); + auto targetCoreIdAttr = rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(targetCoreId)); + PimSendOp::create(rewriter, sendManyOp.getLoc(), input, sizeAttr, targetCoreIdAttr); + } + rewriter.eraseOp(sendManyOp); +} + +static void lowerChannelReceiveMany(spatial::SpatChannelReceiveManyOp receiveManyOp, IRRewriter& rewriter) { + SmallVector replacements; + replacements.reserve(receiveManyOp.getNumResults()); + + rewriter.setInsertionPoint(receiveManyOp); + for (auto [output, sourceCoreId] : llvm::zip(receiveManyOp.getOutputs(), receiveManyOp.getSourceCoreIds())) { + auto outputType = cast(output.getType()); + auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveManyOp.getLoc(), outputType); + auto sizeAttr = getTensorSizeInBytesAttr(rewriter, output); + auto sourceCoreIdAttr = rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(sourceCoreId)); + Value received = + PimReceiveOp::create( + rewriter, receiveManyOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr) + .getOutput(); + replacements.push_back(received); + } + + rewriter.replaceOp(receiveManyOp, ValueRange(replacements)); +} + +static void lowerExtractRows(spatial::SpatExtractRowsOp extractRowsOp, IRRewriter& rewriter) { + auto inputType = cast(extractRowsOp.getInput().getType()); + int64_t numCols = inputType.getDimSize(1); + + SmallVector replacements; + replacements.reserve(extractRowsOp.getNumResults()); + + rewriter.setInsertionPoint(extractRowsOp); + for (auto [rowIndex, output] : llvm::enumerate(extractRowsOp.getOutputs())) { + SmallVector offsets = {rewriter.getIndexAttr(static_cast(rowIndex)), rewriter.getIndexAttr(0)}; + SmallVector sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(numCols)}; + SmallVector strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; + auto rowSlice = tensor::ExtractSliceOp::create( + rewriter, extractRowsOp.getLoc(), cast(output.getType()), extractRowsOp.getInput(), offsets, sizes, strides); + replacements.push_back(rowSlice.getResult()); + } + + rewriter.replaceOp(extractRowsOp, ValueRange(replacements)); +} + +static void lowerConcat(spatial::SpatConcatOp concatOp, IRRewriter& rewriter) { + rewriter.setInsertionPoint(concatOp); + Value concatenated = + tensor::ConcatOp::create(rewriter, concatOp.getLoc(), concatOp.getAxis(), concatOp.getInputs()).getResult(); + rewriter.replaceOp(concatOp, concatenated); +} + +struct ReturnUseInfo { + size_t returnIndex; + SmallVector helperChain; +}; + +struct ConcatReturnUseInfo { + size_t returnIndex; + SmallVector sliceOffsets; + SmallVector concatShape; + SmallVector helperChain; +}; + +static int64_t computeFlatElementIndex(ArrayRef indices, ArrayRef shape) { + int64_t flatIndex = 0; + for (size_t i = 0; i < shape.size(); ++i) { + flatIndex *= shape[i]; + flatIndex += indices[i]; + } + return flatIndex; +} + +static SmallVector expandFlatElementIndex(int64_t flatIndex, ArrayRef shape) { + SmallVector indices(shape.size(), 0); + for (int64_t dim = static_cast(shape.size()) - 1; dim >= 0; --dim) { + indices[dim] = flatIndex % shape[dim]; + flatIndex /= shape[dim]; + } + return indices; +} + +static std::optional analyzeReturnUse(Value value) { + auto uses = value.getUses(); + if (rangeLength(uses) != 1) + return std::nullopt; + + SmallVector helperChain; + Value currentValue = value; + Operation* currentUser = uses.begin()->getOwner(); + + while (isChannelUseChainOp(currentUser)) { + helperChain.push_back(currentUser); + auto currentUses = currentUser->getResult(0).getUses(); + if (rangeLength(currentUses) != 1) + return std::nullopt; + currentValue = currentUser->getResult(0); + currentUser = currentUses.begin()->getOwner(); + } + + if (!isa(currentUser)) + return std::nullopt; + + return ReturnUseInfo { + currentValue.getUses().begin()->getOperandNumber(), + std::move(helperChain), + }; +} + +static std::optional analyzeConcatReturnUse(Value value) { + auto uses = value.getUses(); + if (rangeLength(uses) != 1 || !isa(uses.begin()->getOwner())) + return std::nullopt; + + auto valueType = dyn_cast(value.getType()); + if (!valueType || !valueType.hasStaticShape()) + return std::nullopt; + + SmallVector sliceOffsets(valueType.getRank(), 0); + SmallVector concatShape(valueType.getShape().begin(), valueType.getShape().end()); + Value currentValue = value; + Operation* currentUser = uses.begin()->getOwner(); + + while (auto concatOp = dyn_cast(currentUser)) { + size_t operandIndex = currentValue.getUses().begin()->getOperandNumber(); + int64_t axis = concatOp.getDim(); + for (Value operand : concatOp.getOperands().take_front(operandIndex)) + sliceOffsets[axis] += cast(operand.getType()).getShape()[axis]; + + auto concatType = dyn_cast(concatOp.getResult().getType()); + if (!concatType || !concatType.hasStaticShape()) + return std::nullopt; + concatShape.assign(concatType.getShape().begin(), concatType.getShape().end()); + + currentValue = concatOp.getResult(); + auto currentUses = currentValue.getUses(); + if (rangeLength(currentUses) != 1) + return std::nullopt; + currentUser = currentUses.begin()->getOwner(); + } + + SmallVector helperChain; + while (isChannelUseChainOp(currentUser)) { + helperChain.push_back(currentUser); + auto currentUses = currentUser->getResult(0).getUses(); + if (rangeLength(currentUses) != 1) + return std::nullopt; + currentValue = currentUser->getResult(0); + currentUser = currentUses.begin()->getOwner(); + } + + if (!isa(currentUser)) + return std::nullopt; + + return ConcatReturnUseInfo { + currentValue.getUses().begin()->getOperandNumber(), + std::move(sliceOffsets), + std::move(concatShape), + std::move(helperChain), + }; +} + +static LogicalResult mapIndicesThroughHelperChain(ArrayRef sourceIndices, + ArrayRef sourceShape, + ArrayRef helperChain, + SmallVectorImpl& mappedIndices) { + SmallVector currentIndices(sourceIndices.begin(), sourceIndices.end()); + SmallVector currentShape(sourceShape.begin(), sourceShape.end()); + + auto reshapeToResultShape = [&](Operation* op) -> LogicalResult { + auto resultType = dyn_cast(op->getResult(0).getType()); + if (!resultType || !resultType.hasStaticShape()) + return failure(); + int64_t flatIndex = computeFlatElementIndex(currentIndices, currentShape); + currentShape.assign(resultType.getShape().begin(), resultType.getShape().end()); + currentIndices = expandFlatElementIndex(flatIndex, currentShape); + return success(); }; - walkUses(value, walkUses); - return leafUserCount; + for (Operation* op : helperChain) { + if (auto extractSliceOp = dyn_cast(op)) { + auto hasStaticValues = [](ArrayRef values) { + return llvm::all_of(values, [](int64_t value) { return !ShapedType::isDynamic(value); }); + }; + if (!hasStaticValues(extractSliceOp.getStaticOffsets()) + || !hasStaticValues(extractSliceOp.getStaticSizes()) + || !hasStaticValues(extractSliceOp.getStaticStrides())) + return failure(); + + SmallVector nextIndices; + nextIndices.reserve(currentIndices.size()); + for (auto [index, offset, size, stride] : llvm::zip_equal(currentIndices, + extractSliceOp.getStaticOffsets(), + extractSliceOp.getStaticSizes(), + extractSliceOp.getStaticStrides())) { + if (stride != 1 || index < offset || index >= offset + size) + return failure(); + nextIndices.push_back(index - offset); + } + + auto resultType = dyn_cast(extractSliceOp.getResult().getType()); + if (!resultType || !resultType.hasStaticShape()) + return failure(); + currentIndices = std::move(nextIndices); + currentShape.assign(resultType.getShape().begin(), resultType.getShape().end()); + continue; + } + + if (auto transposeOp = dyn_cast(op)) { + SmallVector nextIndices(currentIndices.size()); + SmallVector nextShape(currentShape.size()); + for (auto [destIndex, attr] : llvm::enumerate(transposeOp.getPermAttr().getAsRange())) { + int64_t sourceIndex = attr.getInt(); + nextIndices[destIndex] = currentIndices[sourceIndex]; + nextShape[destIndex] = currentShape[sourceIndex]; + } + currentIndices = std::move(nextIndices); + currentShape = std::move(nextShape); + continue; + } + + if (auto transposeOp = dyn_cast(op)) { + SmallVector nextIndices(currentIndices.size()); + SmallVector nextShape(currentShape.size()); + for (auto [destIndex, attr] : llvm::enumerate(transposeOp.getPermutation().getAsRange())) { + int64_t sourceIndex = attr.getInt(); + nextIndices[destIndex] = currentIndices[sourceIndex]; + nextShape[destIndex] = currentShape[sourceIndex]; + } + currentIndices = std::move(nextIndices); + currentShape = std::move(nextShape); + continue; + } + + if (isa(op)) { + if (failed(reshapeToResultShape(op))) + return failure(); + continue; + } + + return failure(); + } + + mappedIndices.assign(currentIndices.begin(), currentIndices.end()); + return success(); +} + +static void cloneHelperChain(Value sourceValue, + ArrayRef helperChain, + IRRewriter& rewriter, + Value& clonedValue) { + IRMapping mapping; + mapping.map(sourceValue, sourceValue); + clonedValue = sourceValue; + + rewriter.setInsertionPointAfterValue(sourceValue); + for (Operation* op : helperChain) { + cloneMappedHelperOperands(op, mapping, rewriter); + Operation* clonedOp = rewriter.clone(*op, mapping); + for (auto [originalResult, newResult] : llvm::zip(op->getResults(), clonedOp->getResults())) + mapping.map(originalResult, newResult); + clonedValue = clonedOp->getResult(0); + rewriter.setInsertionPointAfter(clonedOp); + } +} + +static void emitHostCopy(IRRewriter& rewriter, + Location loc, + Value outputTensor, + Value sourceValue, + int32_t hostTargetOffset, + int32_t deviceSourceOffset, + int32_t sizeInBytes) { + PimMemCopyDevToHostOp::create(rewriter, + loc, + outputTensor.getType(), + outputTensor, + sourceValue, + rewriter.getI32IntegerAttr(hostTargetOffset), + rewriter.getI32IntegerAttr(deviceSourceOffset), + rewriter.getI32IntegerAttr(sizeInBytes)); } void SpatialToPimPass::runOnOperation() { @@ -138,6 +441,15 @@ void SpatialToPimPass::runOnOperation() { ModuleOp moduleOp = getOperation(); MLIRContext* ctx = moduleOp.getContext(); + auto entryFunc = getPimEntryFunc(moduleOp); + if (failed(entryFunc)) { + signalPassFailure(); + return; + } + func::FuncOp funcOp = *entryFunc; + + IRRewriter rewriter(&getContext()); + ConversionTarget target(*ctx); target.addLegalDialect(funcOp.front().getTerminator()); addResultBuffer(returnOp, rewriter); @@ -170,17 +473,56 @@ void SpatialToPimPass::runOnOperation() { return; } - for (auto receiveOp : funcOp.getOps()) { - markOpToRemove(receiveOp); - runOnReceiveOp(receiveOp, rewriter); - } + SmallVector concatOps; + funcOp.walk([&](spatial::SpatConcatOp op) { concatOps.push_back(op); }); + for (auto concatOp : concatOps) + lowerConcat(concatOp, rewriter); + for (auto computeOp : funcOp.getOps()) { markOpToRemove(computeOp); runOnComputeOp(computeOp, rewriter); } - annotateChannelCoreIds(funcOp); - lowerBroadcastChannelOps(funcOp, rewriter); + for (auto computeBatchOp : funcOp.getOps()) { + markOpToRemove(computeBatchOp); + runOnComputeBatchOp(computeBatchOp, rewriter); + } + + SmallVector receiveOps; + funcOp.walk([&](spatial::SpatChannelReceiveOp op) { receiveOps.push_back(op); }); + for (auto receiveOp : receiveOps) { + bool onlyPendingRemovalUsers = llvm::all_of( + receiveOp->getUsers(), [&](Operation* user) { return llvm::is_contained(operationsToRemove, user); }); + if (onlyPendingRemovalUsers) { + markOpToRemove(receiveOp); + continue; + } + if (receiveOp->use_empty()) { + rewriter.eraseOp(receiveOp); + continue; + } + lowerChannelReceive(receiveOp, rewriter); + } + + SmallVector receiveManyOps; + funcOp.walk([&](spatial::SpatChannelReceiveManyOp op) { receiveManyOps.push_back(op); }); + for (auto receiveManyOp : receiveManyOps) + lowerChannelReceiveMany(receiveManyOp, rewriter); + + SmallVector sendOps; + funcOp.walk([&](spatial::SpatChannelSendOp op) { sendOps.push_back(op); }); + for (auto sendOp : sendOps) + lowerChannelSend(sendOp, rewriter); + + SmallVector sendManyOps; + funcOp.walk([&](spatial::SpatChannelSendManyOp op) { sendManyOps.push_back(op); }); + for (auto sendManyOp : sendManyOps) + lowerChannelSendMany(sendManyOp, rewriter); + + SmallVector extractRowsOps; + funcOp.walk([&](spatial::SpatExtractRowsOp op) { extractRowsOps.push_back(op); }); + for (auto extractRowsOp : extractRowsOps) + lowerExtractRows(extractRowsOp, rewriter); RewritePatternSet channelPatterns(ctx); populateWithGenerated(channelPatterns); @@ -218,6 +560,18 @@ void SpatialToPimPass::runOnOperation() { assert(false && "tracked op removal reached a cycle or missed dependency"); } + // Dump to file for debug + bool hasSpatialOps = false; + moduleOp.walk([&](Operation* op) { + if (op->getDialect()->getNamespace() == "spat") + hasSpatialOps = true; + }); + if (hasSpatialOps) { + moduleOp.emitError("SpatialToPim left illegal Spatial operations in the module"); + signalPassFailure(); + return; + } + // Dump to file for debug dumpModule(moduleOp, "pim0"); } @@ -228,6 +582,23 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter auto& block = computeOp.getRegion().front(); auto yieldOp = cast(block.getTerminator()); + for (auto [argIndex, blockArg] : llvm::enumerate(block.getArguments())) { + auto receiveOp = dyn_cast_or_null(computeOp.getInputs()[argIndex].getDefiningOp()); + if (!receiveOp || blockArg.use_empty()) + continue; + + rewriter.setInsertionPoint(getEarliestUserWithinBlock(blockArg)); + auto outputType = cast(blockArg.getType()); + auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveOp.getLoc(), outputType); + auto sizeAttr = getTensorSizeInBytesAttr(rewriter, blockArg); + auto sourceCoreIdAttr = rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(receiveOp.getSourceCoreId())); + Value received = PimReceiveOp::create( + rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr) + .getOutput(); + blockArg.replaceAllUsesWith(received); + markOpToRemove(receiveOp); + } + if (computeOp.getNumResults() != yieldOp.getNumOperands()) llvm_unreachable("ComputeOp must have same number of results as yieldOp operands"); @@ -237,145 +608,119 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter auto yieldType = cast(yieldValue.getType()); + if (auto returnUse = analyzeReturnUse(result)) { + Value storedValue = yieldValue; + cloneHelperChain(yieldValue, returnUse->helperChain, rewriter, storedValue); + for (Operation* op : returnUse->helperChain) + markOpToRemove(op); + + auto storedType = cast(storedValue.getType()); + size_t elementSize = storedType.getElementTypeBitWidth() / 8; + Value outputTensor = outputTensors[returnUse->returnIndex]; + if (auto storedOp = storedValue.getDefiningOp()) + rewriter.setInsertionPointAfter(storedOp); + emitHostCopy(rewriter, + loc, + outputTensor, + storedValue, + 0, + 0, + static_cast(storedType.getNumElements() * elementSize)); + continue; + } + auto resultUses = result.getUses(); - auto numResultUses = rangeLength(resultUses); - if (numResultUses == 1) { + if (rangeLength(resultUses) == 1) { OpOperand& resultUse = *resultUses.begin(); Operation* resultUser = resultUse.getOwner(); - if (isChannelUseChainOp(resultUser)) { - SmallVector returnChain; - Value chainedValue = result; - Operation* chainUser = resultUser; - - while (isChannelUseChainOp(chainUser)) { - returnChain.push_back(chainUser); - auto chainUses = chainUser->getResult(0).getUses(); - if (rangeLength(chainUses) != 1) - break; - chainedValue = chainUser->getResult(0); - chainUser = chainUses.begin()->getOwner(); - } - - if (isa(chainUser)) { - size_t resultIndexInReturn = chainedValue.getUses().begin()->getOperandNumber(); - - rewriter.setInsertionPoint(yieldOp); - IRMapping mapping; - mapping.map(result, yieldValue); - - Value storedValue = yieldValue; - for (Operation* op : returnChain) { - cloneMappedHelperOperands(op, mapping, rewriter); - Operation* clonedOp = rewriter.clone(*op, mapping); - for (auto [originalResult, newResult] : llvm::zip(op->getResults(), clonedOp->getResults())) - mapping.map(originalResult, newResult); - storedValue = clonedOp->getResult(0); - rewriter.setInsertionPointAfter(clonedOp); - markOpToRemove(op); - } - - auto storedType = cast(storedValue.getType()); - size_t elementSize = storedType.getElementTypeBitWidth() / 8; - - Value outputTensor = outputTensors[resultIndexInReturn]; - if (auto storedOp = storedValue.getDefiningOp()) - rewriter.setInsertionPointAfter(storedOp); - PimMemCopyDevToHostOp::create(rewriter, - loc, - outputTensor.getType(), - outputTensor, - storedValue, - rewriter.getI32IntegerAttr(0), - rewriter.getI32IntegerAttr(0), - rewriter.getI32IntegerAttr(storedType.getNumElements() * elementSize)); - continue; - } - } - if (isa(resultUser)) { size_t resultIndexInReturn = resultUse.getOperandNumber(); - size_t offset = 0; - size_t numElements = yieldType.getNumElements(); size_t elementSize = yieldType.getElementType().getIntOrFloatBitWidth() / 8; - - // Store to global memory Value outputTensor = outputTensors[resultIndexInReturn]; rewriter.setInsertionPointAfterValue(yieldValue); - PimMemCopyDevToHostOp::create(rewriter, - loc, - outputTensor.getType(), - outputTensor, - yieldValue, - rewriter.getI32IntegerAttr(offset), - rewriter.getI32IntegerAttr(0), - rewriter.getI32IntegerAttr(numElements * elementSize)); + emitHostCopy(rewriter, + loc, + outputTensor, + yieldValue, + 0, + 0, + static_cast(yieldType.getNumElements() * elementSize)); continue; } - if (isa(resultUser)) { - auto concatOp = resultUser; - auto concatValue = concatOp->getResult(0); - auto concatUses = concatValue.getUses(); - auto numConcatUses = rangeLength(concatUses); - if (numConcatUses == 1) { - Value chainedValue = concatValue; - Operation* concatUser = concatUses.begin()->getOwner(); - - while (isChannelUseChainOp(concatUser)) { - auto chainUses = concatUser->getResult(0).getUses(); - if (rangeLength(chainUses) != 1) - break; - chainedValue = concatUser->getResult(0); - concatUser = chainUses.begin()->getOwner(); - } - - if (isa(concatUser)) { - size_t concatIndexInReturn = chainedValue.getUses().begin()->getOperandNumber(); - size_t resultIndexInConcat = resultUses.begin()->getOperandNumber(); - size_t offset = 0; - for (auto operand : concatOp->getOperands().take_front(resultIndexInConcat)) - offset += cast(operand.getType()).getNumElements() - * cast(operand.getType()).getElementTypeBitWidth() / 8; - - size_t elementSize = yieldType.getElementTypeBitWidth() / 8; - - // Store to global memory - Value outputTensor = outputTensors[concatIndexInReturn]; - rewriter.setInsertionPointAfterValue(yieldValue); - PimMemCopyDevToHostOp::create(rewriter, - loc, - outputTensor.getType(), - outputTensor, - yieldValue, - rewriter.getI32IntegerAttr(offset), - rewriter.getI32IntegerAttr(0), - rewriter.getI32IntegerAttr(yieldType.getNumElements() * elementSize)); - continue; - } - } - } + if (isa(resultUser)) + continue; } - // If this pattern was not found, then create a channel and send the value + if (auto concatReturnUse = analyzeConcatReturnUse(result)) { + Value outputTensor = outputTensors[concatReturnUse->returnIndex]; + auto outputType = cast(outputTensor.getType()); + size_t elementSize = yieldType.getElementTypeBitWidth() / 8; - // 1. Create a new ChannelOp - rewriter.setInsertionPoint(computeOp); - auto channelType = spatial::SpatChannelType::get(computeOp.getContext()); - auto channelOp = spatial::SpatChannelNewOp::create(rewriter, loc, channelType); + if (concatReturnUse->helperChain.empty()) { + int64_t flatOffset = computeFlatElementIndex(concatReturnUse->sliceOffsets, outputType.getShape()); + rewriter.setInsertionPointAfterValue(yieldValue); + emitHostCopy(rewriter, + loc, + outputTensor, + yieldValue, + static_cast(flatOffset * elementSize), + 0, + static_cast(yieldType.getNumElements() * elementSize)); + continue; + } - // 2. Receive value through the channel. Broadcast is needed whenever the - // value eventually reaches more than one compute consumer, even through a - // chain of view-like ops. - bool useBroadcastOp = countComputeLeafUsers(result) > 1; - addReceiveOps(result, channelOp, useBroadcastOp, rewriter); + auto storedType = cast(yieldValue.getType()); + for (int64_t linearIndex = 0; linearIndex < storedType.getNumElements(); ++linearIndex) { + SmallVector sourceIndices = expandFlatElementIndex(linearIndex, storedType.getShape()); + for (auto [dim, idx] : llvm::enumerate(sourceIndices)) + sourceIndices[dim] = concatReturnUse->sliceOffsets[dim] + idx; - // 3. Send the value through the channel - rewriter.setInsertionPointAfterValue(yieldValue); - if (useBroadcastOp) - spatial::SpatChannelBroadcastSendOp::create(rewriter, loc, channelOp, yieldValue); - else - spatial::SpatChannelSendOp::create(rewriter, loc, channelOp, yieldValue); + SmallVector destinationIndices; + if (failed(mapIndicesThroughHelperChain(sourceIndices, + concatReturnUse->concatShape, + concatReturnUse->helperChain, + destinationIndices))) { + computeOp.emitOpError("has an unsupported concat-return helper chain during Spatial-to-PIM lowering"); + signalPassFailure(); + return; + } + + SmallVector extractOffsets; + SmallVector extractSizes; + SmallVector extractStrides; + extractOffsets.reserve(storedType.getRank()); + extractSizes.reserve(storedType.getRank()); + extractStrides.reserve(storedType.getRank()); + for (int64_t idx : expandFlatElementIndex(linearIndex, storedType.getShape())) { + extractOffsets.push_back(rewriter.getIndexAttr(idx)); + extractSizes.push_back(rewriter.getIndexAttr(1)); + extractStrides.push_back(rewriter.getIndexAttr(1)); + } + + auto scalarTensorType = + RankedTensorType::get(SmallVector(storedType.getRank(), 1), storedType.getElementType()); + rewriter.setInsertionPointAfterValue(yieldValue); + auto elementSlice = tensor::ExtractSliceOp::create( + rewriter, loc, scalarTensorType, yieldValue, extractOffsets, extractSizes, extractStrides); + rewriter.setInsertionPointAfter(elementSlice); + + int64_t destinationFlatOffset = computeFlatElementIndex(destinationIndices, outputType.getShape()); + emitHostCopy(rewriter, + loc, + outputTensor, + elementSlice.getResult(), + static_cast(destinationFlatOffset * elementSize), + 0, + static_cast(elementSize)); + } + continue; + } + + computeOp.emitOpError("has an unsupported remaining result use during Spatial-to-PIM lowering"); + signalPassFailure(); + return; } // Use `HaltOp` instead of `YieldOp` @@ -384,8 +729,12 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter // Replace `spat.compute` with `pim.core` rewriter.setInsertionPointAfter(computeOp); - auto coreOp = PimCoreOp::create(rewriter, loc, computeOp.getWeights(), rewriter.getI32IntegerAttr(coreId++)); + auto coreOp = PimCoreOp::create( + rewriter, loc, computeOp.getWeights(), rewriter.getI32IntegerAttr(getPimCoreIdForComputeOp(computeOp, coreId))); auto& coreOpBlocks = coreOp.getBody().getBlocks(); + for (auto [argIndex, blockArg] : llvm::enumerate(block.getArguments())) + if (!blockArg.use_empty()) + blockArg.replaceAllUsesWith(computeOp.getInputs()[argIndex]); block.eraseArguments(0, block.getNumArguments()); coreOpBlocks.splice(coreOpBlocks.begin(), computeOp.getBody().getBlocks()); Block* tempComputeBlock = new Block(); @@ -394,6 +743,131 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter PimHaltOp::create(rewriter, computeOp.getLoc()); } +void SpatialToPimPass::runOnComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, IRRewriter& rewriter) { + if (std::getenv("PIM_BATCH_LOWER_DEBUG")) + llvm::errs() << "lowering compute_batch lanes=" << computeBatchOp.getLaneCount() << "\n"; + + if (computeBatchOp.getNumResults() != 0) { + computeBatchOp.emitOpError( + "batched Spatial-to-PIM lowering currently requires channelized compute_batch with no results"); + signalPassFailure(); + return; + } + + Location loc = computeBatchOp.getLoc(); + Block& oldBlock = computeBatchOp.getBody().front(); + auto oldYield = cast(oldBlock.getTerminator()); + if (oldYield.getNumOperands() != 0) { + computeBatchOp.emitOpError("batched Spatial-to-PIM lowering currently requires empty spat.yield"); + signalPassFailure(); + return; + } + + SmallVector coreIds = getPimCoreIdsForBatchOp(computeBatchOp, coreId); + + rewriter.setInsertionPointAfter(computeBatchOp); + auto coreBatchOp = pim::PimCoreBatchOp::create(rewriter, + loc, + rewriter.getI32IntegerAttr(computeBatchOp.getLaneCount()), + computeBatchOp.getWeights(), + computeBatchOp.getInputs()); + coreBatchOp.getProperties().setOperandSegmentSizes( + {static_cast(computeBatchOp.getWeights().size()), static_cast(computeBatchOp.getInputs().size())}); + coreBatchOp->setAttr(onnx_mlir::kCoreIdAttrName, rewriter.getDenseI32ArrayAttr(coreIds)); + + SmallVector blockArgTypes; + SmallVector blockArgLocs; + for (BlockArgument arg : oldBlock.getArguments()) { + blockArgTypes.push_back(arg.getType()); + blockArgLocs.push_back(arg.getLoc()); + } + Block* newBlock = + rewriter.createBlock(&coreBatchOp.getBody(), coreBatchOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs); + + IRMapping mapper; + rewriter.setInsertionPointToStart(newBlock); + for (auto [oldArg, newArg] : llvm::zip(oldBlock.getArguments(), newBlock->getArguments())) { + auto newArgType = cast(newArg.getType()); + auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, newArgType); + auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter, + loc, + outputBuffer.getType(), + outputBuffer, + newArg, + rewriter.getI32IntegerAttr(0), + rewriter.getI32IntegerAttr(0), + getTensorSizeInBytesAttr(rewriter, newArg)) + .getOutput(); + mapper.map(oldArg, copied); + } + + auto materializeCapturedTensor = [&](Value capturedTensor) -> Value { + if (auto mapped = mapper.lookupOrNull(capturedTensor)) + return mapped; + + auto capturedType = cast(capturedTensor.getType()); + auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, capturedType); + auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter, + loc, + outputBuffer.getType(), + outputBuffer, + capturedTensor, + rewriter.getI32IntegerAttr(0), + rewriter.getI32IntegerAttr(0), + getTensorSizeInBytesAttr(rewriter, capturedTensor)) + .getOutput(); + mapper.map(capturedTensor, copied); + return copied; + }; + + rewriter.setInsertionPointToEnd(newBlock); + for (Operation& op : oldBlock) { + if (isa(op)) + continue; + + if (auto sendBatchOp = dyn_cast(op)) { + pim::PimSendBatchOp::create(rewriter, + loc, + mapper.lookup(sendBatchOp.getInput()), + getTensorSizeInBytesAttr(rewriter, mapper.lookup(sendBatchOp.getInput())), + sendBatchOp.getTargetCoreIdsAttr()); + continue; + } + + if (auto receiveBatchOp = dyn_cast(op)) { + auto outputType = cast(receiveBatchOp.getOutput().getType()); + auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, outputType); + auto received = pim::PimReceiveBatchOp::create(rewriter, + loc, + outputBuffer.getType(), + outputBuffer, + getTensorSizeInBytesAttr(rewriter, receiveBatchOp.getOutput()), + receiveBatchOp.getSourceCoreIdsAttr()) + .getOutput(); + mapper.map(receiveBatchOp.getOutput(), received); + continue; + } + + for (Value operand : op.getOperands()) { + if (!isa(operand.getType()) || mapper.contains(operand)) + continue; + + Operation* definingOp = operand.getDefiningOp(); + if (definingOp && definingOp->getBlock() == &oldBlock) + continue; + + materializeCapturedTensor(operand); + } + + Operation* cloned = rewriter.clone(op, mapper); + for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults())) + mapper.map(originalResult, clonedResult); + } + + rewriter.setInsertionPointToEnd(newBlock); + PimHaltOp::create(rewriter, loc); +} + void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) { auto enlargeTiedDpsChain = [&](Value value, RankedTensorType newType, auto& self) -> void { auto* definingOp = value.getDefiningOp(); @@ -545,8 +1019,9 @@ LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::Fu else tensorSource = cast>(computeOpInput); - // Compute results must be transferred through channels via send/receive - if (isa(tensorSource.getDefiningOp())) + // Values already produced inside the device-side graph must not be + // copied back through a host-to-device staging step here. + if (isa(tensorSource.getDefiningOp())) continue; BlockArgument computeBlockArgToReplace = computeOp.getBody().front().getArgument(computeInputIdx); @@ -561,179 +1036,39 @@ LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::Fu return success(); } -void SpatialToPimPass::replaceBlockArgumentWithRecvOp(spatial::SpatCompute& computeOp, - unsigned int argIndex, - Value channelSourceOp, - Value consumerValue, - spatial::SpatChannelNewOp& channel, - bool useBroadcastOp, - IRRewriter& rewriter) { - auto& computeBlock = computeOp.getRegion().front(); - //(remember that WeightedCompute have weights as first operands, however these - // weights are not included in the block arguments. Thus, when indexing the - // block argument we need to remove the weights count) - auto computeWeightsCount = computeOp.getWeights().size(); - auto blockArg = computeBlock.getArgument(argIndex - computeWeightsCount); - // Receive the tensor just before the first use of the value - rewriter.setInsertionPoint(getEarliestUserWithinBlock(blockArg)); - Value receivedValue; - if (useBroadcastOp) - receivedValue = - spatial::SpatChannelBroadcastReceiveOp::create(rewriter, computeOp.getLoc(), channelSourceOp.getType(), channel); - else - receivedValue = - spatial::SpatChannelReceiveOp::create(rewriter, computeOp.getLoc(), channelSourceOp.getType(), channel); - - Value replacementValue = receivedValue; - if (consumerValue != channelSourceOp) { - SmallVector clonedChain; - Value currentValue = consumerValue; - while (currentValue != channelSourceOp) { - Operation* definingOp = currentValue.getDefiningOp(); - if (!definingOp || !isChannelUseChainOp(definingOp)) - llvm_unreachable("Unsupported channel use chain while replaying value into consumer compute"); - - clonedChain.push_back(definingOp); - currentValue = definingOp->getOperand(0); - } - - IRMapping mapping; - mapping.map(channelSourceOp, receivedValue); - for (Operation* op : llvm::reverse(clonedChain)) { - cloneMappedHelperOperands(op, mapping, rewriter); - Operation* clonedOp = rewriter.clone(*op, mapping); - for (auto [originalResult, newResult] : llvm::zip(op->getResults(), clonedOp->getResults())) - mapping.map(originalResult, newResult); - markOpToRemove(op); - } - - replacementValue = cast(mapping.lookup(consumerValue)); - } - - assert(replacementValue.getType() == blockArg.getType() - && "Replayed channel use chain must match block argument type"); - blockArg.replaceAllUsesWith(replacementValue); -} - -void SpatialToPimPass::addReceiveOps(Value channelSourceOp, - spatial::SpatChannelNewOp& channel, - bool useBroadcastOp, - IRRewriter& rewriter) { - auto replayUsesIntoConsumers = [&](Value currentValue, auto& self) -> void { - for (OpOperand& use : currentValue.getUses()) { - Operation* owner = use.getOwner(); - if (auto computeUser = dyn_cast(owner)) { - replaceBlockArgumentWithRecvOp( - computeUser, use.getOperandNumber(), channelSourceOp, currentValue, channel, useBroadcastOp, rewriter); - continue; - } - - if (!isChannelUseChainOp(owner)) - llvm_unreachable("User of channel-carried value is not a compute nor a supported view-like op"); - - markOpToRemove(owner); - assert(owner->getNumResults() == 1 && "Channel use chain op must have a single result"); - self(owner->getResult(0), self); - } - }; - - replayUsesIntoConsumers(channelSourceOp, replayUsesIntoConsumers); -} - void SpatialToPimPass::markOpToRemove(Operation* op) { if (!llvm::is_contained(operationsToRemove, op)) operationsToRemove.push_back(op); } -void SpatialToPimPass::annotateChannelCoreIds(func::FuncOp funcOp) { - funcOp.walk([&](spatial::SpatChannelNewOp channelNewOp) { - markOpToRemove(channelNewOp); - - if (channelNewOp->use_empty()) - return; - - spatial::SpatChannelSendOp sendOp; - spatial::SpatChannelReceiveOp receiveOp; - spatial::SpatChannelBroadcastSendOp broadcastSendOp; - - for (Operation* user : channelNewOp->getUsers()) { - if (auto op = dyn_cast(user)) { - sendOp = op; - continue; - } - if (auto op = dyn_cast(user)) { - receiveOp = op; - continue; - } - if (auto op = dyn_cast(user)) { - broadcastSendOp = op; - continue; - } - if (auto op = dyn_cast(user)) - continue; - llvm_unreachable("Unexpected user of spat.channel_new during Spatial-to-PIM lowering"); - } - - if (broadcastSendOp) { - auto sourceCoreIdAttr = cast(broadcastSendOp->getParentOp()).getCoreIdAttr(); - channelNewOp->setAttr(kChannelSourceCoreIdAttrName, sourceCoreIdAttr); - return; - } - - if (!sendOp || !receiveOp) - llvm_unreachable("spat.channel_new must connect exactly one send and one receive"); - - auto sourceCoreIdAttr = cast(sendOp->getParentOp()).getCoreIdAttr(); - auto targetCoreIdAttr = cast(receiveOp->getParentOp()).getCoreIdAttr(); - channelNewOp->setAttr(kChannelSourceCoreIdAttrName, sourceCoreIdAttr); - channelNewOp->setAttr(kChannelTargetCoreIdAttrName, targetCoreIdAttr); - }); -} - -void SpatialToPimPass::lowerBroadcastChannelOps(func::FuncOp funcOp, IRRewriter& rewriter) { - SmallVector broadcastSendOps; - funcOp.walk([&](spatial::SpatChannelBroadcastSendOp op) { broadcastSendOps.push_back(op); }); - - for (auto sendOp : broadcastSendOps) { - auto channelNewOp = cast(sendOp.getChannel().getDefiningOp()); - auto sizeAttr = getTensorSizeInBytesAttr(rewriter, sendOp.getInput()); - - rewriter.setInsertionPoint(sendOp); - bool foundReceiver = false; - for (Operation* user : channelNewOp->getUsers()) { - auto receiveOp = dyn_cast(user); - if (!receiveOp) - continue; - - foundReceiver = true; - auto targetCoreIdAttr = cast(receiveOp->getParentOp()).getCoreIdAttr(); - PimSendOp::create(rewriter, sendOp.getLoc(), sendOp.getInput(), sizeAttr, targetCoreIdAttr); - } - - if (!foundReceiver) - llvm_unreachable("spat.channel_broadcast_send has no matching broadcast receive"); - - rewriter.eraseOp(sendOp); - } - - SmallVector broadcastReceiveOps; - funcOp.walk([&](spatial::SpatChannelBroadcastReceiveOp op) { broadcastReceiveOps.push_back(op); }); - - for (auto receiveOp : broadcastReceiveOps) { - rewriter.setInsertionPoint(receiveOp); - auto outputType = cast(receiveOp.getResult().getType()); - Value outputBuffer = createEmptyTensorFromShaped(rewriter, receiveOp.getLoc(), outputType); - auto sizeAttr = getTensorSizeInBytesAttr(rewriter, receiveOp.getResult()); - auto sourceCoreIdAttr = getSpatialChannelSourceCoreIdAttr(rewriter, receiveOp.getChannel()); - Value receivedValue = - PimReceiveOp::create( - rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr) - .getOutput(); - rewriter.replaceOp(receiveOp, receivedValue); - } -} - void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter) { + auto markOwnedReturnChain = [&](Operation* op, auto&& markOwnedReturnChain) -> void { + if (!op) + return; + + bool isExclusivelyOwnedByReturnChain = op->use_empty(); + if (!isExclusivelyOwnedByReturnChain && op->hasOneUse()) { + Operation* onlyUser = *op->getUsers().begin(); + isExclusivelyOwnedByReturnChain = + isa(onlyUser) || isChannelUseChainOp(onlyUser); + } + if (!isExclusivelyOwnedByReturnChain) + return; + + if (isChannelUseChainOp(op)) { + Value source = op->getOperand(0); + markOpToRemove(op); + markOwnedReturnChain(source.getDefiningOp(), markOwnedReturnChain); + return; + } + + if (auto concatOp = dyn_cast(op)) { + markOpToRemove(concatOp); + for (Value operand : concatOp.getOperands()) + markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain); + } + }; + SmallVector originalOperands(returnOp.getOperands().begin(), returnOp.getOperands().end()); for (auto it : llvm::enumerate(originalOperands)) { size_t orderWithinReturn = it.index(); @@ -741,53 +1076,7 @@ void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewri rewriter.modifyOpInPlace(returnOp, [&] { returnOp.setOperand(orderWithinReturn, outputTensors[orderWithinReturn]); }); - - Operation* opToErase = returnOperand; - while (opToErase) { - bool isExclusivelyOwnedByReturnChain = opToErase->use_empty(); - if (!isExclusivelyOwnedByReturnChain && opToErase->hasOneUse()) { - Operation* onlyUser = *opToErase->getUsers().begin(); - isExclusivelyOwnedByReturnChain = - isa(onlyUser) || isChannelUseChainOp(onlyUser); - } - if (!isExclusivelyOwnedByReturnChain) - break; - - if (isChannelUseChainOp(opToErase)) { - Value source = opToErase->getOperand(0); - markOpToRemove(opToErase); - opToErase = source.getDefiningOp(); - continue; - } - - if (isa(opToErase)) - markOpToRemove(opToErase); - break; - } - } -} - -void SpatialToPimPass::runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter) { - - auto channel = cast(receiveOp.getChannel().getDefiningOp()); - - auto sendOpOpt = getOtherEndOfChannel(receiveOp, true, rewriter); - if (failed(sendOpOpt)) - llvm_unreachable("ChannelReceiveOp has no matching SendOp"); - - auto sendOp = cast(*sendOpOpt); - - Value receiveRes = receiveOp.getResult(); - - bool useBroadcastOp = countComputeLeafUsers(receiveRes) > 1; - addReceiveOps(receiveRes, channel, useBroadcastOp, rewriter); - - if (useBroadcastOp) { - // When receiving, we actually noticed that the value has more than one - // user. This means that we need to get the replace the original SendOp with - // a BroadcastSendOp - rewriter.setInsertionPoint(sendOp); - rewriter.replaceOpWithNewOp(sendOp, sendOp.getChannel(), sendOp.getInput()); + markOwnedReturnChain(returnOperand, markOwnedReturnChain); } } diff --git a/src/PIM/Dialect/Pim/Pim.td b/src/PIM/Dialect/Pim/Pim.td index 2a6d8e9..f203f5f 100644 --- a/src/PIM/Dialect/Pim/Pim.td +++ b/src/PIM/Dialect/Pim/Pim.td @@ -39,6 +39,22 @@ def PimCoreOp : PimOp<"core", [SingleBlock]> { }]; } +def PimCoreBatchOp : PimOp<"core_batch", [SingleBlock, AttrSizedOperandSegments]> { + let summary = "Execute equivalent batched core bodies"; + + let regions = (region SizedRegion<1>:$body); + + let arguments = (ins + I32Attr:$laneCount, + Variadic:$weights, + Variadic:$inputs + ); + + let assemblyFormat = [{ + `lanes` $laneCount `(` $weights `)` `[` $inputs `]` attr-dict regions `:` type($weights) `[` type($inputs) `]` `->` `(` `)` + }]; +} + def PimHaltOp : PimOp<"halt", [Terminator]> { let summary = "Halt execution of the core"; @@ -65,6 +81,20 @@ def PimSendOp : PimOp<"send", []> { }]; } +def PimSendBatchOp : PimOp<"send_batch", []> { + let summary = "Send a per-lane tensor to target cores from a batched core"; + + let arguments = (ins + PimTensor:$input, + I32Attr:$size, + DenseI32ArrayAttr:$targetCoreIds + ); + + let assemblyFormat = [{ + `(` $input `)` attr-dict `:` type($input) `->` `(` `)` + }]; +} + def PimReceiveOp : PimOp<"receive", [DestinationStyleOpInterface]> { let summary = "Receive a tensor from another core"; @@ -89,6 +119,30 @@ def PimReceiveOp : PimOp<"receive", [DestinationStyleOpInterface]> { }]; } +def PimReceiveBatchOp : PimOp<"receive_batch", [DestinationStyleOpInterface]> { + let summary = "Receive per-lane tensors from source cores into a batched core"; + + let arguments = (ins + PimTensor:$outputBuffer, + I32Attr:$size, + DenseI32ArrayAttr:$sourceCoreIds + ); + + let results = (outs + PimTensor:$output + ); + + let extraClassDeclaration = [{ + mlir::MutableOperandRange getDpsInitsMutable() { + return getOutputBufferMutable(); + } + }]; + + let assemblyFormat = [{ + `(` $outputBuffer `)` attr-dict `:` type($outputBuffer) `->` type($output) + }]; +} + def PimMemCopyHostToDevOp : PimOp<"memcp_hd", [DestinationStyleOpInterface]> { let summary = "Copy a memory region from host memory into device memory"; @@ -115,6 +169,32 @@ def PimMemCopyHostToDevOp : PimOp<"memcp_hd", [DestinationStyleOpInterface]> { }]; } +def PimMemCopyHostToDevBatchOp : PimOp<"memcp_hd_batch", [DestinationStyleOpInterface]> { + let summary = "Copy a per-lane tensor from host memory into device memory inside a batched core"; + + let arguments = (ins + PimTensor:$deviceTarget, + PimTensor:$hostSource, + I32Attr:$deviceTargetOffset, + I32Attr:$hostSourceOffset, + I32Attr:$size + ); + + let results = (outs + PimTensor:$output + ); + + let extraClassDeclaration = [{ + mlir::MutableOperandRange getDpsInitsMutable() { + return getDeviceTargetMutable(); + } + }]; + + let assemblyFormat = [{ + `(` $deviceTarget `,` $hostSource `)` attr-dict `:` `(` type($deviceTarget) `,` type($hostSource) `)` `->` type($output) + }]; +} + def PimMemCopyDevToHostOp : PimOp<"memcp_dh", [DestinationStyleOpInterface]> { let summary = "Copy a memory region from device memory into host memory"; diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp index e3445f8..c9a6769 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp @@ -1,6 +1,7 @@ #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "OpBufferizationInterfaces.hpp" @@ -65,6 +66,32 @@ struct MemCopyHostToDevOpInterface } }; +struct MemCopyHostToDevBatchOpInterface +: DstBufferizableOpInterfaceExternalModel { + LogicalResult bufferize(Operation* op, + RewriterBase& rewriter, + const BufferizationOptions& options, + BufferizationState& state) const { + auto memCopyHostToDevOp = cast(op); + auto deviceTargetOpt = getBuffer(rewriter, memCopyHostToDevOp.getDeviceTarget(), options, state); + if (failed(deviceTargetOpt)) + return failure(); + auto hostSourceOpt = getBuffer(rewriter, memCopyHostToDevOp.getHostSource(), options, state); + if (failed(hostSourceOpt)) + return failure(); + + replaceOpWithNewBufferizedOp(rewriter, + memCopyHostToDevOp, + deviceTargetOpt->getType(), + *deviceTargetOpt, + *hostSourceOpt, + memCopyHostToDevOp.getDeviceTargetOffsetAttr(), + memCopyHostToDevOp.getHostSourceOffsetAttr(), + memCopyHostToDevOp.getSizeAttr()); + return success(); + } +}; + struct MemCopyDevToHostOpInterface : DstBufferizableOpInterfaceExternalModel { LogicalResult bufferize(Operation* op, @@ -122,6 +149,127 @@ struct ReceiveOpInterface : DstBufferizableOpInterfaceExternalModel { + bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { + return !cast(op).isDpsInit(&opOperand); + } + + LogicalResult bufferize(Operation* op, + RewriterBase& rewriter, + const BufferizationOptions& options, + BufferizationState& state) const { + auto receiveOp = cast(op); + auto outputBufferOpt = getBuffer(rewriter, receiveOp.getOutputBuffer(), options, state); + if (failed(outputBufferOpt)) + return failure(); + + replaceOpWithNewBufferizedOp(rewriter, + op, + outputBufferOpt->getType(), + *outputBufferOpt, + receiveOp.getSizeAttr(), + receiveOp.getSourceCoreIdsAttr()); + return success(); + } +}; + +struct CoreBatchOpInterface : BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { + return true; + } + + bool bufferizesToMemoryWrite(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { + return false; + } + + AliasingValueList getAliasingValues(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { + return {}; + } + + AliasingOpOperandList getAliasingOpOperands(Operation* op, Value value, const AnalysisState& state) const { + auto coreBatchOp = cast(op); + auto bbArg = dyn_cast(value); + if (!bbArg || bbArg.getOwner() != &coreBatchOp.getBody().front()) + return {}; + + unsigned inputOperandIndex = coreBatchOp.getWeights().size() + bbArg.getArgNumber(); + return {{&coreBatchOp->getOpOperand(inputOperandIndex), BufferRelation::Equivalent}}; + } + + bool isWritable(Operation* op, Value value, const AnalysisState& state) const { + return false; + } + + FailureOr + getBufferType(Operation* op, + Value value, + const BufferizationOptions& options, + const BufferizationState& state, + SmallVector& invocationStack) const { + auto coreBatchOp = cast(op); + auto bbArg = dyn_cast(value); + if (!bbArg || bbArg.getOwner() != &coreBatchOp.getBody().front()) + return failure(); + + Value tiedInput = coreBatchOp.getInputs()[bbArg.getArgNumber()]; + if (auto memRefType = dyn_cast(tiedInput.getType())) + return memRefType; + + return bufferization::getBufferType(tiedInput, options, state, invocationStack); + } + + LogicalResult bufferize(Operation* op, + RewriterBase& rewriter, + const BufferizationOptions& options, + BufferizationState& state) const { + auto coreBatchOp = cast(op); + + SmallVector weights; + SmallVector inputs; + weights.reserve(coreBatchOp.getWeights().size()); + inputs.reserve(coreBatchOp.getInputs().size()); + + for (Value weight : coreBatchOp.getWeights()) { + if (isa(weight.getType())) { + auto weightOpt = getBuffer(rewriter, weight, options, state); + if (failed(weightOpt)) + return failure(); + weights.push_back(*weightOpt); + } + else { + weights.push_back(weight); + } + } + + for (Value input : coreBatchOp.getInputs()) { + if (isa(input.getType())) { + auto inputOpt = getBuffer(rewriter, input, options, state); + if (failed(inputOpt)) + return failure(); + inputs.push_back(*inputOpt); + } + else { + inputs.push_back(input); + } + } + + rewriter.setInsertionPoint(coreBatchOp); + auto newOp = PimCoreBatchOp::create( + rewriter, coreBatchOp.getLoc(), coreBatchOp.getLaneCountAttr(), ValueRange(weights), ValueRange(inputs)); + newOp.getProperties().setOperandSegmentSizes({static_cast(weights.size()), static_cast(inputs.size())}); + if (auto coreIdsAttr = coreBatchOp->getAttr(onnx_mlir::kCoreIdAttrName)) + newOp->setAttr(onnx_mlir::kCoreIdAttrName, coreIdsAttr); + + rewriter.inlineRegionBefore(coreBatchOp.getBody(), newOp.getBody(), newOp.getBody().begin()); + for (Block& block : newOp.getBody()) + if (failed(bufferization::bufferizeBlockSignature(&block, rewriter, options, state))) + return failure(); + + rewriter.eraseOp(coreBatchOp); + return success(); + } +}; + struct TransposeOpInterface : DstBufferizableOpInterfaceExternalModel { bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { return !cast(op).isDpsInit(&opOperand); @@ -287,8 +435,11 @@ struct UnaryDstOpInterface : DstBufferizableOpInterfaceExternalModel(*ctx); PimReceiveOp::attachInterface(*ctx); + PimReceiveBatchOp::attachInterface(*ctx); PimMemCopyHostToDevOp::attachInterface(*ctx); + PimMemCopyHostToDevBatchOp::attachInterface(*ctx); PimMemCopyDevToHostOp::attachInterface(*ctx); PimTransposeOp::attachInterface(*ctx); PimVMMOp::attachInterface(*ctx); diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp index 0cd8482..cb5b9b2 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp @@ -93,8 +93,8 @@ void PimBufferizationPass::runOnOperation() { } void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const { - funcOp.walk([&](PimCoreOp coreOp) { - walkPimMvmVmmWeightUses(coreOp, [&](OpOperand& weightUse) { + auto markWeights = [&](Operation* op) { + walkPimMvmVmmWeightUses(op, [&](OpOperand& weightUse) { Value weight = weightUse.get(); auto getGlobalOp = weight.getDefiningOp(); if (!getGlobalOp) @@ -104,7 +104,10 @@ void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncO markWeightAlways(getGlobalOp); markWeightAlways(globalMemrefOp); }); - }); + }; + + funcOp.walk([&](PimCoreOp coreOp) { markWeights(coreOp); }); + funcOp.walk([&](PimCoreBatchOp coreBatchOp) { markWeights(coreBatchOp); }); } std::unique_ptr createPimBufferizationPass() { return std::make_unique(); } diff --git a/src/PIM/Dialect/Spatial/CMakeLists.txt b/src/PIM/Dialect/Spatial/CMakeLists.txt index 641f011..473e8ec 100644 --- a/src/PIM/Dialect/Spatial/CMakeLists.txt +++ b/src/PIM/Dialect/Spatial/CMakeLists.txt @@ -2,6 +2,7 @@ add_onnx_mlir_dialect(Spatial spat) add_onnx_mlir_dialect_doc(spat Spatial.td) add_pim_library(SpatialOps + Channels.cpp SpatialOps.cpp Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp Transforms/MergeComputeNodes/DCPGraph/Graph.cpp diff --git a/src/PIM/Dialect/Spatial/Channels.cpp b/src/PIM/Dialect/Spatial/Channels.cpp new file mode 100644 index 0000000..44a6ff1 --- /dev/null +++ b/src/PIM/Dialect/Spatial/Channels.cpp @@ -0,0 +1,120 @@ +#include "src/Accelerators/PIM/Dialect/Spatial/Channels.hpp" + +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Diagnostics.h" + +using namespace mlir; + +namespace onnx_mlir::spatial { + +namespace { + +static Channels::ChannelId getChannelId(SpatChannelSendOp sendOp) { return sendOp.getChannelId(); } + +static Channels::ChannelId getChannelId(SpatChannelReceiveOp receiveOp) { return receiveOp.getChannelId(); } + +static LogicalResult verifyEndpointPair(ChannelEndpoints endpoints) { + if (!endpoints.send || !endpoints.receive) + return failure(); + + if (endpoints.send.getSourceCoreId() != endpoints.receive.getSourceCoreId()) { + endpoints.send.emitOpError("sourceCoreId does not match paired spat.channel_receive"); + return failure(); + } + if (endpoints.send.getTargetCoreId() != endpoints.receive.getTargetCoreId()) { + endpoints.send.emitOpError("targetCoreId does not match paired spat.channel_receive"); + return failure(); + } + if (endpoints.send.getInput().getType() != endpoints.receive.getOutput().getType()) { + endpoints.send.emitOpError("input type does not match paired spat.channel_receive result type"); + return failure(); + } + + return success(); +} + +} // namespace + +Channels::Channels(func::FuncOp funcOp) { + if (!funcOp) + return; + + funcOp.walk([&](SpatChannelSendOp sendOp) { insertSend(sendOp); }); + funcOp.walk([&](SpatChannelReceiveOp receiveOp) { insertReceive(receiveOp); }); +} + +Channels::ChannelId Channels::allocate() { return nextChannelId++; } + +void Channels::insertSend(SpatChannelSendOp sendOp) { + ChannelId channelId = getChannelId(sendOp); + nextChannelId = std::max(nextChannelId, channelId + 1); + endpoints[channelId].send = sendOp; +} + +void Channels::insertReceive(SpatChannelReceiveOp receiveOp) { + ChannelId channelId = getChannelId(receiveOp); + nextChannelId = std::max(nextChannelId, channelId + 1); + endpoints[channelId].receive = receiveOp; +} + +void Channels::eraseSend(SpatChannelSendOp sendOp) { + ChannelId channelId = getChannelId(sendOp); + auto it = endpoints.find(channelId); + if (it == endpoints.end()) + return; + it->second.send = {}; + if (!it->second.receive) + endpoints.erase(it); +} + +void Channels::eraseReceive(SpatChannelReceiveOp receiveOp) { + ChannelId channelId = getChannelId(receiveOp); + auto it = endpoints.find(channelId); + if (it == endpoints.end()) + return; + it->second.receive = {}; + if (!it->second.send) + endpoints.erase(it); +} + +FailureOr Channels::lookup(ChannelId id) const { + auto it = endpoints.find(id); + if (it == endpoints.end()) + return failure(); + return it->second; +} + +FailureOr Channels::getReceiveFor(SpatChannelSendOp sendOp) const { + auto endpointsOr = lookup(getChannelId(sendOp)); + if (failed(endpointsOr) || !endpointsOr->receive) + return failure(); + return endpointsOr->receive; +} + +FailureOr Channels::getSendFor(SpatChannelReceiveOp receiveOp) const { + auto endpointsOr = lookup(getChannelId(receiveOp)); + if (failed(endpointsOr) || !endpointsOr->send) + return failure(); + return endpointsOr->send; +} + +LogicalResult Channels::verify() const { + for (const auto& [channelId, pair] : endpoints) { + if (!pair.send || !pair.receive) { + if (pair.send) { + auto sendOp = pair.send; + sendOp.emitOpError("channel_id ") << channelId << " is missing a paired spat.channel_receive"; + } + else if (pair.receive) { + auto receiveOp = pair.receive; + receiveOp.emitOpError("channel_id ") << channelId << " is missing a paired spat.channel_send"; + } + return failure(); + } + if (failed(verifyEndpointPair(pair))) + return failure(); + } + return success(); +} + +} // namespace onnx_mlir::spatial diff --git a/src/PIM/Dialect/Spatial/Channels.hpp b/src/PIM/Dialect/Spatial/Channels.hpp new file mode 100644 index 0000000..5f99569 --- /dev/null +++ b/src/PIM/Dialect/Spatial/Channels.hpp @@ -0,0 +1,43 @@ +#pragma once + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Support/LogicalResult.h" + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/StringRef.h" + +#include "src/Accelerators/PIM/Common/PimCommon.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" + +namespace onnx_mlir::spatial { + +struct ChannelEndpoints { + SpatChannelSendOp send; + SpatChannelReceiveOp receive; +}; + +class Channels { +public: + using ChannelId = int64_t; + + explicit Channels(mlir::func::FuncOp funcOp); + + ChannelId allocate(); + + void insertSend(SpatChannelSendOp sendOp); + void insertReceive(SpatChannelReceiveOp receiveOp); + void eraseSend(SpatChannelSendOp sendOp); + void eraseReceive(SpatChannelReceiveOp receiveOp); + + llvm::FailureOr lookup(ChannelId id) const; + llvm::FailureOr getReceiveFor(SpatChannelSendOp sendOp) const; + llvm::FailureOr getSendFor(SpatChannelReceiveOp receiveOp) const; + + mlir::LogicalResult verify() const; + +private: + ChannelId nextChannelId = 0; + llvm::DenseMap endpoints; +}; + +} // namespace onnx_mlir::spatial diff --git a/src/PIM/Dialect/Spatial/Spatial.td b/src/PIM/Dialect/Spatial/Spatial.td index 594d159..238472f 100644 --- a/src/PIM/Dialect/Spatial/Spatial.td +++ b/src/PIM/Dialect/Spatial/Spatial.td @@ -9,7 +9,6 @@ def SpatialDialect : Dialect { let name = "spat"; let summary = "Dialect designed for deep learning computation in a spatial architecture"; let cppNamespace = "::onnx_mlir::spatial"; - let useDefaultTypePrinterParser = 1; } class SpatOp traits = []> : @@ -19,15 +18,6 @@ class SpatOp traits = []> : def SpatTensor : AnyTypeOf<[AnyMemRef, AnyRankedTensor], "", "::mlir::ShapedType">; -class SpatType traits = []> - : TypeDef { - let mnemonic = typeMnemonic; -} - -def SpatChannelType : SpatType<"SpatChannel", "ch"> { - let summary = "Virtual channel type"; -} - //===----------------------------------------------------------------------===// // Execution //===----------------------------------------------------------------------===// @@ -48,10 +38,27 @@ def SpatCompute : SpatOp<"compute", [SingleBlock, AttrSizedOperandSegments]> { let hasVerifier = 1; let hasFolder = 1; + let hasCustomAssemblyFormat = 1; +} - let assemblyFormat = [{ - `[` $weights `]` `(` $inputs `)` attr-dict `:` `[` type($weights) `]` `(` type($inputs) `)` `->` type($outputs) $body - }]; +def SpatComputeBatch : SpatOp<"compute_batch", + [SingleBlock, AttrSizedOperandSegments]> { + let summary = "Compressed batch of independent equivalent compute lanes"; + + let arguments = (ins + I32Attr:$laneCount, + Variadic:$weights, + Variadic:$inputs + ); + + let results = (outs + Variadic:$outputs + ); + + let regions = (region SizedRegion<1>:$body); + + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; } def SpatYieldOp : SpatOp<"yield", [Terminator]> { @@ -61,51 +68,66 @@ def SpatYieldOp : SpatOp<"yield", [Terminator]> { Variadic:$outputs ); - let assemblyFormat = [{ - $outputs attr-dict `:` type($outputs) - }]; + let hasCustomAssemblyFormat = 1; +} + +def SpatExtractRowsOp : SpatOp<"extract_rows", []> { + let summary = "Extract every row of a rank-2 tensor as separate rank-2 row tensors"; + + let arguments = (ins + SpatTensor:$input + ); + + let results = (outs + Variadic:$outputs + ); + + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; +} + +def SpatConcatOp : SpatOp<"concat", []> { + let summary = "Concatenate tensors with compact Spatial operand syntax"; + + let arguments = (ins + I64Attr:$axis, + Variadic:$inputs + ); + + let results = (outs + SpatTensor:$output + ); + + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; } //===----------------------------------------------------------------------===// // Communication //===----------------------------------------------------------------------===// -def SpatChannelNewOp : SpatOp<"channel_new", []> { - let summary = "Create a new virtual channel"; - - let results = (outs - SpatChannelType:$channel - ); - - let builders = [ - OpBuilder<(ins ), [{ - $_state.addTypes(SpatChannelType()); - }]> - ]; - - let assemblyFormat = [{ - attr-dict - }]; -} - def SpatChannelSendOp : SpatOp<"channel_send", []> { - let summary = "Send a tensor through a channel"; + let summary = "Send a tensor through a logical channel"; let arguments = (ins - SpatChannelType:$channel, + I64Attr:$channelId, + I32Attr:$sourceCoreId, + I32Attr:$targetCoreId, SpatTensor:$input ); let assemblyFormat = [{ - $input `to` $channel attr-dict `:` `(` type($input) `->` type($channel) `)` + $input attr-dict `:` type($input) }]; } def SpatChannelReceiveOp : SpatOp<"channel_receive", []> { - let summary = "Receive a tensor from a channel"; + let summary = "Receive a tensor from a logical channel"; let arguments = (ins - SpatChannelType:$channel + I64Attr:$channelId, + I32Attr:$sourceCoreId, + I32Attr:$targetCoreId ); let results = (outs @@ -113,37 +135,70 @@ def SpatChannelReceiveOp : SpatOp<"channel_receive", []> { ); let assemblyFormat = [{ - $channel attr-dict `:` `(` type($channel) `->` type($output) `)` + attr-dict `:` type($output) }]; } -def SpatChannelBroadcastSendOp : SpatOp<"channel_broadcast_send", []> { - let summary = "Broadcast a tensor through a shared channel buffer"; +def SpatChannelSendManyOp : SpatOp<"channel_send_many", []> { + let summary = "Send multiple tensors through logical channels"; let arguments = (ins - SpatChannelType:$channel, + DenseI64ArrayAttr:$channelIds, + DenseI32ArrayAttr:$sourceCoreIds, + DenseI32ArrayAttr:$targetCoreIds, + Variadic:$inputs + ); + + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; +} + +def SpatChannelReceiveManyOp : SpatOp<"channel_receive_many", []> { + let summary = "Receive multiple tensors from logical channels"; + + let arguments = (ins + DenseI64ArrayAttr:$channelIds, + DenseI32ArrayAttr:$sourceCoreIds, + DenseI32ArrayAttr:$targetCoreIds + ); + + let results = (outs + Variadic:$outputs + ); + + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; +} + +def SpatChannelSendBatchOp : SpatOp<"channel_send_batch", []> { + let summary = "Send per-lane tensors through logical channels in a batch body"; + + let arguments = (ins + DenseI64ArrayAttr:$channelIds, + DenseI32ArrayAttr:$sourceCoreIds, + DenseI32ArrayAttr:$targetCoreIds, SpatTensor:$input ); - let assemblyFormat = [{ - $input `to` $channel attr-dict `:` `(` type($input) `->` type($channel) `)` - }]; + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; } -def SpatChannelBroadcastReceiveOp : SpatOp<"channel_broadcast_receive", []> { - let summary = "Receive a tensor from a shared channel buffer"; +def SpatChannelReceiveBatchOp : SpatOp<"channel_receive_batch", []> { + let summary = "Receive a per-lane tensor through logical channels in a batch body"; let arguments = (ins - SpatChannelType:$channel + DenseI64ArrayAttr:$channelIds, + DenseI32ArrayAttr:$sourceCoreIds, + DenseI32ArrayAttr:$targetCoreIds ); let results = (outs SpatTensor:$output ); - let assemblyFormat = [{ - $channel attr-dict `:` `(` type($channel) `->` type($output) `)` - }]; + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; } //===----------------------------------------------------------------------===// diff --git a/src/PIM/Dialect/Spatial/SpatialOps.cpp b/src/PIM/Dialect/Spatial/SpatialOps.cpp index 83289b3..66d2fef 100644 --- a/src/PIM/Dialect/Spatial/SpatialOps.cpp +++ b/src/PIM/Dialect/Spatial/SpatialOps.cpp @@ -16,8 +16,10 @@ #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/LogicalResult.h" +#include "llvm/Support/raw_ostream.h" #include +#include #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" @@ -42,6 +44,448 @@ void SpatialDialect::initialize() { >(); } +namespace { + +enum class ListDelimiter { + Square, + Paren +}; + +static ParseResult parseOpenDelimiter(OpAsmParser& parser, ListDelimiter delimiter) { + if (delimiter == ListDelimiter::Square) + return parser.parseLSquare(); + return parser.parseLParen(); +} + +static ParseResult parseOptionalCloseDelimiter(OpAsmParser& parser, ListDelimiter delimiter) { + if (delimiter == ListDelimiter::Square) + return parser.parseOptionalRSquare(); + return parser.parseOptionalRParen(); +} + +static void printOpenDelimiter(OpAsmPrinter& printer, ListDelimiter delimiter) { + printer << (delimiter == ListDelimiter::Square ? "[" : "("); +} + +static void printCloseDelimiter(OpAsmPrinter& printer, ListDelimiter delimiter) { + printer << (delimiter == ListDelimiter::Square ? "]" : ")"); +} + +template +static ParseResult parseCompressedRepeatedList(OpAsmParser& parser, + ListDelimiter delimiter, + SmallVectorImpl& entries, + ParseEntryFn parseEntry) { + if (parseOpenDelimiter(parser, delimiter)) + return failure(); + if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) + return success(); + + while (true) { + EntryT entry; + if (parseEntry(entry)) + return failure(); + + int64_t repeatCount = 1; + if (succeeded(parser.parseOptionalKeyword("x"))) { + if (parser.parseInteger(repeatCount) || repeatCount <= 0) + return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); + } + for (int64_t index = 0; index < repeatCount; ++index) + entries.push_back(entry); + + if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) + break; + if (parser.parseComma()) + return failure(); + } + + return success(); +} + +template +static ParseResult parseCompressedIntegerList(OpAsmParser& parser, SmallVectorImpl& values) { + if (parser.parseLSquare()) + return failure(); + if (succeeded(parser.parseOptionalRSquare())) + return success(); + + while (true) { + int64_t first = 0; + if (parser.parseInteger(first)) + return failure(); + + if (succeeded(parser.parseOptionalKeyword("to"))) { + int64_t last = 0; + if (parser.parseInteger(last) || last < first) + return parser.emitError(parser.getCurrentLocation(), "invalid ascending range"); + + int64_t step = 1; + if (succeeded(parser.parseOptionalKeyword("by"))) { + if (parser.parseInteger(step) || step <= 0) + return parser.emitError(parser.getCurrentLocation(), "step after 'by' must be positive"); + } + int64_t repeatCount = 1; + if (succeeded(parser.parseOptionalKeyword("x"))) { + if (parser.parseInteger(repeatCount) || repeatCount <= 0) + return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); + } + if ((last - first) % step != 0) + return parser.emitError(parser.getCurrentLocation(), "range end must be reachable from start using the given step"); + + for (int64_t value = first; value <= last; value += step) + for (int64_t index = 0; index < repeatCount; ++index) + values.push_back(static_cast(value)); + } + else { + int64_t repeatCount = 1; + if (succeeded(parser.parseOptionalKeyword("x"))) { + if (parser.parseInteger(repeatCount) || repeatCount <= 0) + return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); + } + for (int64_t index = 0; index < repeatCount; ++index) + values.push_back(static_cast(first)); + } + + if (succeeded(parser.parseOptionalRSquare())) + break; + if (parser.parseComma()) + return failure(); + } + + return success(); +} + +template +static void printCompressedEqualRuns(OpAsmPrinter& printer, RangeT entries, PrintEntryFn printEntry) { + for (size_t index = 0; index < entries.size();) { + size_t runEnd = index + 1; + while (runEnd < entries.size() && entries[runEnd] == entries[index]) + ++runEnd; + + if (index != 0) + printer << ", "; + printEntry(entries[index]); + size_t runLength = runEnd - index; + if (runLength > 1) + printer << " x" << runLength; + index = runEnd; + } +} + +template +static void printCompressedIntegerList(OpAsmPrinter& printer, ArrayRef values) { + printer << "["; + for (size_t index = 0; index < values.size();) { + if (index != 0) + printer << ", "; + + auto findEqualRunEnd = [&](size_t start) { + size_t end = start + 1; + while (end < values.size() && values[end] == values[start]) + ++end; + return end; + }; + + size_t firstRunEnd = findEqualRunEnd(index); + size_t repeatCount = firstRunEnd - index; + size_t progressionEnd = firstRunEnd; + int64_t step = 0; + IntT lastValue = values[index]; + + if (firstRunEnd < values.size()) { + size_t secondRunEnd = findEqualRunEnd(firstRunEnd); + step = static_cast(values[firstRunEnd]) - static_cast(values[index]); + if (step > 0 && secondRunEnd - firstRunEnd == repeatCount) { + progressionEnd = secondRunEnd; + lastValue = values[firstRunEnd]; + size_t currentRunStart = secondRunEnd; + while (currentRunStart < values.size()) { + size_t currentRunEnd = findEqualRunEnd(currentRunStart); + if (currentRunEnd - currentRunStart != repeatCount) + break; + if (static_cast(values[currentRunStart]) != static_cast(lastValue) + step) + break; + lastValue = values[currentRunStart]; + progressionEnd = currentRunEnd; + currentRunStart = currentRunEnd; + } + } + else { + step = 0; + } + } + + size_t progressionValueCount = repeatCount == 0 ? 0 : (progressionEnd - index) / repeatCount; + if (progressionEnd > firstRunEnd && progressionValueCount >= 3) { + printer << values[index] << " to " << lastValue; + if (step != 1) + printer << " by " << step; + if (repeatCount > 1) + printer << " x" << repeatCount; + index = progressionEnd; + continue; + } + + if (repeatCount > 1) { + printer << values[index] << " x" << repeatCount; + index = firstRunEnd; + continue; + } + + printer << values[index]; + index = firstRunEnd; + } + printer << "]"; +} + +static void printCompressedValueList(OpAsmPrinter& printer, ValueRange values, ListDelimiter delimiter) { + printOpenDelimiter(printer, delimiter); + for (size_t index = 0; index < values.size();) { + size_t equalRunEnd = index + 1; + while (equalRunEnd < values.size() && values[equalRunEnd] == values[index]) + ++equalRunEnd; + + if (index != 0) + printer << ", "; + if (equalRunEnd - index > 1) { + printer.printOperand(values[index]); + printer << " x" << (equalRunEnd - index); + index = equalRunEnd; + continue; + } + + size_t rangeEnd = index + 1; + if (auto firstResult = dyn_cast(values[index])) { + while (rangeEnd < values.size()) { + auto nextResult = dyn_cast(values[rangeEnd]); + if (!nextResult || nextResult.getOwner() != firstResult.getOwner() + || nextResult.getResultNumber() != firstResult.getResultNumber() + (rangeEnd - index)) + break; + ++rangeEnd; + } + } + else if (auto firstArg = dyn_cast(values[index])) { + while (rangeEnd < values.size()) { + auto nextArg = dyn_cast(values[rangeEnd]); + if (!nextArg || nextArg.getOwner() != firstArg.getOwner() + || nextArg.getArgNumber() != firstArg.getArgNumber() + (rangeEnd - index)) + break; + ++rangeEnd; + } + } + + printer.printOperand(values[index]); + if (rangeEnd - index >= 3) { + printer << " to "; + printer.printOperand(values[rangeEnd - 1]); + } + else if (rangeEnd - index == 2) { + printer << ", "; + printer.printOperand(values[index + 1]); + } + index = rangeEnd; + } + printCloseDelimiter(printer, delimiter); +} + +static void printCompressedTypeList(OpAsmPrinter& printer, TypeRange types, ListDelimiter delimiter) { + printOpenDelimiter(printer, delimiter); + printCompressedEqualRuns(printer, types, [&](Type type) { printer.printType(type); }); + printCloseDelimiter(printer, delimiter); +} + +static ParseResult parseCompressedOperandEntryWithFirst( + OpAsmParser& parser, + OpAsmParser::UnresolvedOperand firstOperand, + SmallVectorImpl& operands) { + if (succeeded(parser.parseOptionalKeyword("to"))) { + OpAsmParser::UnresolvedOperand lastOperand; + if (parser.parseOperand(lastOperand)) + return failure(); + if (firstOperand.name != lastOperand.name || firstOperand.number > lastOperand.number) + return parser.emitError(parser.getCurrentLocation(), "invalid operand range"); + for (unsigned number = firstOperand.number; number <= lastOperand.number; ++number) + operands.push_back({firstOperand.location, firstOperand.name, number}); + } + else { + int64_t repeatCount = 1; + if (succeeded(parser.parseOptionalKeyword("x"))) { + if (parser.parseInteger(repeatCount) || repeatCount <= 0) + return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); + } + for (int64_t index = 0; index < repeatCount; ++index) + operands.push_back(firstOperand); + } + + return success(); +} + +static ParseResult parseOneCompressedOperandEntry(OpAsmParser& parser, + SmallVectorImpl& operands) { + OpAsmParser::UnresolvedOperand firstOperand; + if (parser.parseOperand(firstOperand)) + return failure(); + return parseCompressedOperandEntryWithFirst(parser, firstOperand, operands); +} + +static ParseResult parseCompressedOperandList(OpAsmParser& parser, + ListDelimiter delimiter, + SmallVectorImpl& operands) { + if (parseOpenDelimiter(parser, delimiter)) + return failure(); + if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) + return success(); + + while (true) { + if (parseOneCompressedOperandEntry(parser, operands)) + return failure(); + if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) + break; + if (parser.parseComma()) + return failure(); + } + + return success(); +} + +static ParseResult parseCompressedOperandSequence(OpAsmParser& parser, + SmallVectorImpl& operands) { + if (parseOneCompressedOperandEntry(parser, operands)) + return failure(); + while (succeeded(parser.parseOptionalComma())) + if (parseOneCompressedOperandEntry(parser, operands)) + return failure(); + return success(); +} + +static void printCompressedValueSequence(OpAsmPrinter& printer, ValueRange values) { + for (size_t index = 0; index < values.size();) { + size_t equalRunEnd = index + 1; + while (equalRunEnd < values.size() && values[equalRunEnd] == values[index]) + ++equalRunEnd; + + if (index != 0) + printer << ", "; + if (equalRunEnd - index > 1) { + printer.printOperand(values[index]); + printer << " x" << (equalRunEnd - index); + index = equalRunEnd; + continue; + } + + size_t rangeEnd = index + 1; + if (auto firstResult = dyn_cast(values[index])) { + while (rangeEnd < values.size()) { + auto nextResult = dyn_cast(values[rangeEnd]); + if (!nextResult || nextResult.getOwner() != firstResult.getOwner() + || nextResult.getResultNumber() != firstResult.getResultNumber() + (rangeEnd - index)) + break; + ++rangeEnd; + } + } + else if (auto firstArg = dyn_cast(values[index])) { + while (rangeEnd < values.size()) { + auto nextArg = dyn_cast(values[rangeEnd]); + if (!nextArg || nextArg.getOwner() != firstArg.getOwner() + || nextArg.getArgNumber() != firstArg.getArgNumber() + (rangeEnd - index)) + break; + ++rangeEnd; + } + } + + printer.printOperand(values[index]); + if (rangeEnd - index >= 3) { + printer << " to "; + printer.printOperand(values[rangeEnd - 1]); + } + else if (rangeEnd - index == 2) { + printer << ", "; + printer.printOperand(values[index + 1]); + } + index = rangeEnd; + } +} + +static void printCompressedTypeSequence(OpAsmPrinter& printer, TypeRange types) { + printCompressedEqualRuns(printer, types, [&](Type type) { printer.printType(type); }); +} + +static ParseResult parseCompressedTypeSequence(OpAsmParser& parser, SmallVectorImpl& types, bool allowEmpty) { + Type firstType; + OptionalParseResult firstTypeResult = parser.parseOptionalType(firstType); + if (!firstTypeResult.has_value()) { + if (allowEmpty) + return success(); + return parser.emitError(parser.getCurrentLocation(), "expected type"); + } + if (failed(*firstTypeResult)) + return failure(); + + auto appendType = [&](Type type) -> ParseResult { + int64_t repeatCount = 1; + if (succeeded(parser.parseOptionalKeyword("x"))) { + if (parser.parseInteger(repeatCount) || repeatCount <= 0) + return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); + } + for (int64_t index = 0; index < repeatCount; ++index) + types.push_back(type); + return success(); + }; + + if (appendType(firstType)) + return failure(); + + while (succeeded(parser.parseOptionalComma())) { + Type nextType; + if (parser.parseType(nextType) || appendType(nextType)) + return failure(); + } + + return success(); +} + +static void printChannelMetadata(OpAsmPrinter& printer, + ArrayRef channelIds, + ArrayRef sourceCoreIds, + ArrayRef targetCoreIds) { + printer << " channels "; + printCompressedIntegerList(printer, channelIds); + printer << " from "; + printCompressedIntegerList(printer, sourceCoreIds); + printer << " to "; + printCompressedIntegerList(printer, targetCoreIds); +} + +static DenseI64ArrayAttr getDenseI64ArrayAttr(OpAsmParser& parser, ArrayRef values) { + return parser.getBuilder().getDenseI64ArrayAttr(values); +} + +static DenseI32ArrayAttr getDenseI32ArrayAttr(OpAsmParser& parser, ArrayRef values) { + return parser.getBuilder().getDenseI32ArrayAttr(values); +} + +static IntegerAttr getI32Attr(OpAsmParser& parser, int32_t value) { + return parser.getBuilder().getI32IntegerAttr(value); +} + +static void buildImplicitRegionArgs(OpAsmParser& parser, + ArrayRef inputTypes, + SmallVectorImpl& generatedNames, + SmallVectorImpl& arguments) { + generatedNames.reserve(inputTypes.size()); + arguments.reserve(inputTypes.size()); + for (auto [index, inputType] : llvm::enumerate(inputTypes)) { + generatedNames.push_back("arg" + std::to_string(index + 1)); + OpAsmParser::Argument arg; + arg.ssaName = {parser.getCurrentLocation(), generatedNames.back(), 0}; + arg.type = inputType; + arguments.push_back(arg); + } +} + +} // namespace + inline LogicalResult mvmOpVerifySize2(SpatWeightedMVMOp* emitter, ArrayRef& matrixShape, ArrayRef& vectorShape, @@ -121,6 +565,14 @@ llvm::FailureOr> getWeightShapeForWeightedOp(Operation* weigth if (auto coreOp = dyn_cast(weigthedOp->getParentOp())) return cast(coreOp.getWeights()[weightIndex].getType()).getShape(); + // In compute_batch bodies, weightIndex refers to the lane-local template + // weight index, so lane 0's weight slice is representative for type checks. + if (auto batchOp = dyn_cast(weigthedOp->getParentOp())) { + if (batchOp.getWeights().empty() || weightIndex >= batchOp.getWeights().size()) + return failure(); + return cast(batchOp.getWeights()[weightIndex].getType()).getShape(); + } + return failure(); } @@ -193,6 +645,190 @@ LogicalResult SpatVMaxOp::verify() { return OpTrait::impl::verifySameOperandsAndResultType(*this); } +void SpatYieldOp::print(OpAsmPrinter& printer) { + printer << " "; + printCompressedValueSequence(printer, getOutputs()); + printer.printOptionalAttrDict((*this)->getAttrs()); + printer << " : "; + printCompressedTypeSequence(printer, getOutputs().getTypes()); +} + +ParseResult SpatYieldOp::parse(OpAsmParser& parser, OperationState& result) { + SmallVector outputs; + SmallVector outputTypes; + + OpAsmParser::UnresolvedOperand firstOutput; + OptionalParseResult firstOutputResult = parser.parseOptionalOperand(firstOutput); + if (firstOutputResult.has_value()) { + if (failed(*firstOutputResult)) + return failure(); + if (parseCompressedOperandEntryWithFirst(parser, firstOutput, outputs)) + return failure(); + while (succeeded(parser.parseOptionalComma())) + if (parseOneCompressedOperandEntry(parser, outputs)) + return failure(); + } + + if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() + || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true)) + return failure(); + + if (outputs.size() != outputTypes.size()) + return parser.emitError(parser.getCurrentLocation(), "number of outputs and output types must match"); + + return parser.resolveOperands(outputs, outputTypes, parser.getCurrentLocation(), result.operands); +} + +LogicalResult SpatExtractRowsOp::verify() { + auto inputType = dyn_cast(getInput().getType()); + if (!inputType || !inputType.hasRank() || inputType.getRank() != 2) + return emitError("input must be a rank-2 shaped type"); + + int64_t numRows = inputType.getShape()[0]; + int64_t numCols = inputType.getShape()[1]; + Type elementType = inputType.getElementType(); + + if (numRows >= 0 && static_cast(getNumResults()) != numRows) + return emitError("number of outputs must match the number of input rows"); + + for (Type output : getResultTypes()) { + auto outputType = dyn_cast(output); + if (!outputType || !outputType.hasRank() || outputType.getRank() != 2) + return emitError("outputs must all be rank-2 shaped types"); + if (outputType.getElementType() != elementType) + return emitError("output element types must match input element type"); + auto outputShape = outputType.getShape(); + if (outputShape[0] != 1) + return emitError("each output must have exactly one row"); + if (numCols >= 0 && outputShape[1] != numCols) + return emitError("output column count must match input column count"); + } + + return success(); +} + +void SpatExtractRowsOp::print(OpAsmPrinter& printer) { + printer << " "; + printer.printOperand(getInput()); + printer.printOptionalAttrDict((*this)->getAttrs()); + printer << " : "; + printer.printType(getInput().getType()); + printer << " -> "; + printCompressedTypeSequence(printer, getResultTypes()); +} + +ParseResult SpatExtractRowsOp::parse(OpAsmParser& parser, OperationState& result) { + OpAsmParser::UnresolvedOperand input; + Type inputType; + SmallVector outputTypes; + + if (parser.parseOperand(input) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() + || parser.parseType(inputType) || parser.parseArrow() + || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/false)) + return failure(); + + if (parser.resolveOperand(input, inputType, result.operands)) + return failure(); + result.addTypes(outputTypes); + return success(); +} + +LogicalResult SpatConcatOp::verify() { + if (getInputs().empty()) + return emitError("requires at least one input"); + + auto outputType = dyn_cast(getOutput().getType()); + if (!outputType || !outputType.hasRank()) + return emitError("output must be a ranked shaped type"); + + int64_t axis = getAxis(); + int64_t rank = outputType.getRank(); + if (axis < 0 || axis >= rank) + return emitError("axis must be within the output rank"); + + int64_t concatenatedDimSize = 0; + bool concatenatedDimDynamic = false; + Type outputElementType = outputType.getElementType(); + + for (Value input : getInputs()) { + auto inputType = dyn_cast(input.getType()); + if (!inputType || !inputType.hasRank()) + return emitError("inputs must be ranked shaped types"); + if (inputType.getRank() != rank) + return emitError("all inputs must have the same rank as the output"); + if (inputType.getElementType() != outputElementType) + return emitError("all inputs must have the same element type as the output"); + + for (int64_t dim = 0; dim < rank; ++dim) { + if (dim == axis) + continue; + int64_t inputDim = inputType.getDimSize(dim); + int64_t outputDim = outputType.getDimSize(dim); + if (!ShapedType::isDynamic(inputDim) && !ShapedType::isDynamic(outputDim) && inputDim != outputDim) + return emitError("non-concatenated dimensions must match the output shape"); + } + + int64_t inputConcatDim = inputType.getDimSize(axis); + if (ShapedType::isDynamic(inputConcatDim)) { + concatenatedDimDynamic = true; + continue; + } + concatenatedDimSize += inputConcatDim; + } + + int64_t outputConcatDim = outputType.getDimSize(axis); + if (!concatenatedDimDynamic && !ShapedType::isDynamic(outputConcatDim) && concatenatedDimSize != outputConcatDim) + return emitError("output concatenated dimension must equal the sum of input sizes"); + + return success(); +} + +void SpatConcatOp::print(OpAsmPrinter& printer) { + printer << " axis " << getAxis(); + printer << " args = "; + printCompressedValueList(printer, getInputs(), ListDelimiter::Paren); + printer.printOptionalAttrDict((*this)->getAttrs(), {getAxisAttrName().getValue()}); + printer << " : "; + printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren); + printer << " -> "; + printer.printType(getOutput().getType()); +} + +ParseResult SpatConcatOp::parse(OpAsmParser& parser, OperationState& result) { + int64_t axis = 0; + SmallVector inputs; + SmallVector inputTypes; + Type outputType; + + if (parser.parseKeyword("axis") || parser.parseInteger(axis)) + return failure(); + + if (succeeded(parser.parseOptionalKeyword("args"))) { + if (parser.parseEqual() || parseCompressedOperandList(parser, ListDelimiter::Paren, inputs)) + return failure(); + } + else if (parseCompressedOperandList(parser, ListDelimiter::Paren, inputs)) { + return failure(); + } + + if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() + || parseCompressedRepeatedList( + parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); }) + || parser.parseArrow() || parser.parseType(outputType)) + return failure(); + + if (inputs.size() != inputTypes.size()) + return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match"); + if (result.attributes.get("axis")) + return parser.emitError(parser.getCurrentLocation(), "axis cannot be specified both positionally and in attr-dict"); + + result.addAttribute("axis", parser.getBuilder().getI64IntegerAttr(axis)); + if (parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands)) + return failure(); + result.addTypes(outputType); + return success(); +} + LogicalResult SpatCompute::verify() { // Check that the terminator yields the same number and types as the compute results. auto& block = getBody().front(); @@ -270,6 +906,507 @@ LogicalResult SpatCompute::fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::m return success(); } +void SpatCompute::print(OpAsmPrinter& printer) { + printer << " "; + printCompressedValueList(printer, getWeights(), ListDelimiter::Square); + printer << " args = "; + printCompressedValueList(printer, getInputs(), ListDelimiter::Paren); + + if (auto coreIdAttr = (*this)->getAttrOfType(onnx_mlir::kCoreIdAttrName)) { + printer << " core_id " << coreIdAttr.getInt(); + } + + printer.printOptionalAttrDict( + (*this)->getAttrs(), + {getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdAttrName}); + + printer << " : "; + printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square); + printer << " "; + printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren); + printer << " -> "; + printCompressedTypeSequence(printer, getResultTypes()); + printer << " "; + printer.printRegion(getBody(), /*printEntryBlockArgs=*/false); +} + +ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) { + SmallVector regionArgs; + SmallVector generatedArgNames; + SmallVector weights; + SmallVector inputs; + SmallVector weightTypes; + SmallVector inputTypes; + SmallVector outputTypes; + int32_t coreId = 0; + + if (parseCompressedOperandList(parser, ListDelimiter::Square, weights)) + return failure(); + + if (succeeded(parser.parseOptionalKeyword("args"))) { + if (parser.parseEqual() || parseCompressedOperandList(parser, ListDelimiter::Paren, inputs)) + return failure(); + } + else if (parseCompressedOperandList(parser, ListDelimiter::Paren, inputs)) { + return failure(); + } + + bool hasCoreId = succeeded(parser.parseOptionalKeyword("core_id")); + if (hasCoreId && parser.parseInteger(coreId)) + return failure(); + + if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() + || parseCompressedRepeatedList( + parser, ListDelimiter::Square, weightTypes, [&](Type& type) { return parser.parseType(type); }) + || parseCompressedRepeatedList( + parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); }) + || parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true)) + return failure(); + + if (weights.size() != weightTypes.size()) + return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match"); + if (inputs.size() != inputTypes.size()) + return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match"); + if (hasCoreId && result.attributes.get(onnx_mlir::kCoreIdAttrName)) + return parser.emitError(parser.getCurrentLocation(), "core_id cannot be specified both positionally and in attr-dict"); + + auto& builder = parser.getBuilder(); + result.addAttribute( + "operandSegmentSizes", + builder.getDenseI32ArrayAttr({static_cast(weights.size()), static_cast(inputs.size())})); + if (hasCoreId) + result.addAttribute(onnx_mlir::kCoreIdAttrName, getI32Attr(parser, coreId)); + + if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands) + || parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands)) + return failure(); + result.addTypes(outputTypes); + + Region* body = result.addRegion(); + buildImplicitRegionArgs(parser, inputTypes, generatedArgNames, regionArgs); + return parser.parseRegion(*body, regionArgs); +} + +static FailureOr getParentBatchLaneCount(Operation* op) { + auto batchOp = op->getParentOfType(); + if (!batchOp) + return failure(); + return batchOp.getLaneCount(); +} + +static LogicalResult verifyManyChannelSizes(Operation* op, + ArrayRef channelIds, + ArrayRef sourceCoreIds, + ArrayRef targetCoreIds, + size_t valueCount) { + if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size()) + return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length"); + if (channelIds.size() != valueCount) + return op->emitError("channel metadata length must match the number of values"); + return success(); +} + +static LogicalResult verifyManyChannelTypes(Operation* op, TypeRange types, StringRef kind) { + if (types.empty()) + return op->emitError() << kind << " must carry at least one value"; + + Type firstType = types.front(); + for (Type type : types.drop_front()) + if (type != firstType) + return op->emitError() << kind << " values must all have the same type"; + return success(); +} + +static LogicalResult verifyBatchChannelSizes(Operation* op, + ArrayRef channelIds, + ArrayRef sourceCoreIds, + ArrayRef targetCoreIds) { + if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size()) + return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length"); + + auto laneCount = getParentBatchLaneCount(op); + if (failed(laneCount)) + return op->emitError("must be nested inside spat.compute_batch"); + if (channelIds.size() != static_cast(*laneCount)) + return op->emitError("channel metadata length must match parent laneCount"); + + return success(); +} + +LogicalResult SpatChannelSendManyOp::verify() { + if (failed(verifyManyChannelSizes( + getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds(), getInputs().size()))) + return failure(); + return verifyManyChannelTypes(getOperation(), getInputs().getTypes(), "channel_send_many"); +} + +LogicalResult SpatChannelReceiveManyOp::verify() { + if (failed(verifyManyChannelSizes( + getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds(), getOutputs().size()))) + return failure(); + return verifyManyChannelTypes(getOperation(), getOperation()->getResultTypes(), "channel_receive_many"); +} + +LogicalResult SpatChannelSendBatchOp::verify() { + return verifyBatchChannelSizes(getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds()); +} + +LogicalResult SpatChannelReceiveBatchOp::verify() { + return verifyBatchChannelSizes(getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds()); +} + +static LogicalResult verifyBatchBody(Operation* op, Block& block, TypeRange outputTypes, size_t weightsPerLane) { + auto yieldOp = dyn_cast_or_null(block.getTerminator()); + if (!yieldOp) + return op->emitError("body must terminate with spat.yield"); + if (outputTypes.empty()) { + if (yieldOp.getNumOperands() != 0) + return op->emitError("body yield must be empty when compute_batch has no results"); + } + else { + if (yieldOp.getNumOperands() != 1) + return op->emitError("body yield must produce exactly one value"); + if (yieldOp.getOperand(0).getType() != outputTypes[0]) + return op->emitError("body yield type must match output type"); + } + + for (auto& bodyOp : block) { + if (auto wvmm = dyn_cast(&bodyOp)) + if (wvmm.getWeightIndex() < 0 || static_cast(wvmm.getWeightIndex()) >= weightsPerLane) + return op->emitError("compute_batch body Wvmm weightIndex is out of range for one lane"); + if (auto wmvm = dyn_cast(&bodyOp)) + if (wmvm.getWeightIndex() < 0 || static_cast(wmvm.getWeightIndex()) >= weightsPerLane) + return op->emitError("compute_batch body Wmvm weightIndex is out of range for one lane"); + } + return success(); +} + +LogicalResult SpatComputeBatch::verify() { + int32_t count = getLaneCount(); + if (count <= 0) + return emitError("laneCount must be positive"); + + auto laneCountSz = static_cast(count); + if (getWeights().size() % laneCountSz != 0) + return emitError("number of weights must be a multiple of laneCount"); + + if (!getInputs().empty() && getInputs().size() != laneCountSz) + return emitError("number of inputs must be either 0 or laneCount"); + if (!getOutputs().empty() && getOutputs().size() != laneCountSz) + return emitError("number of outputs must be either 0 or laneCount"); + + size_t weightsPerLane = getWeights().size() / laneCountSz; + for (size_t weightIndex = 0; weightIndex < weightsPerLane; ++weightIndex) { + Type weightType = getWeights()[weightIndex].getType(); + for (size_t lane = 1; lane < laneCountSz; ++lane) + if (getWeights()[lane * weightsPerLane + weightIndex].getType() != weightType) + return emitError("corresponding weights across lanes must have the same type"); + } + + if (!getInputs().empty()) { + Type inputType = getInputs()[0].getType(); + for (Value in : getInputs().drop_front()) + if (in.getType() != inputType) + return emitError("all inputs must have the same type"); + } + + if (!getOutputs().empty()) { + Type outputType = getOutputs()[0].getType(); + for (Value out : getOutputs().drop_front()) + if (out.getType() != outputType) + return emitError("all outputs must have the same type"); + } + + if (auto coreIdAttr = (*this)->getAttr(onnx_mlir::kCoreIdAttrName)) { + auto coreIdsAttr = dyn_cast(coreIdAttr); + if (!coreIdsAttr) + return emitError("compute_batch core_id attribute must be a dense i32 array"); + if (coreIdsAttr.size() != laneCountSz) + return emitError("compute_batch core_id array length must match laneCount"); + if (llvm::any_of(coreIdsAttr.asArrayRef(), [](int32_t coreId) { return coreId <= 0; })) + return emitError("compute_batch core_id values must be positive"); + } + + Block& block = getBody().front(); + if (getInputs().empty()) { + if (block.getNumArguments() != 0) + return emitError("compute_batch body must have no block arguments when there are no inputs"); + } + else { + if (block.getNumArguments() != 1) + return emitError("compute_batch body must have exactly one block argument"); + if (block.getArgument(0).getType() != getInputs()[0].getType()) + return emitError("body block argument type must match input type"); + } + + return verifyBatchBody(getOperation(), block, getResultTypes(), weightsPerLane); +} + +void SpatComputeBatch::print(OpAsmPrinter& printer) { + printer << " lanes " << getLaneCount() << " "; + printCompressedValueList(printer, getWeights(), ListDelimiter::Square); + printer << " args = "; + printCompressedValueList(printer, getInputs(), ListDelimiter::Paren); + + if (auto coreIdsAttr = (*this)->getAttrOfType(onnx_mlir::kCoreIdAttrName)) { + printer << " core_ids "; + printCompressedIntegerList(printer, coreIdsAttr.asArrayRef()); + } + + printer.printOptionalAttrDict( + (*this)->getAttrs(), + {getLaneCountAttrName().getValue(), getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdAttrName}); + + printer << " : "; + printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square); + printer << " "; + printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren); + printer << " -> "; + printCompressedTypeSequence(printer, getResultTypes()); + printer << " "; + printer.printRegion(getBody(), /*printEntryBlockArgs=*/false); +} + +ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result) { + int32_t laneCount = 0; + SmallVector regionArgs; + SmallVector generatedArgNames; + SmallVector weights; + SmallVector inputs; + SmallVector weightTypes; + SmallVector inputTypes; + SmallVector outputTypes; + SmallVector coreIds; + + if (parser.parseKeyword("lanes") || parser.parseInteger(laneCount)) + return failure(); + + if (parseCompressedOperandList(parser, ListDelimiter::Square, weights)) + return failure(); + + if (succeeded(parser.parseOptionalKeyword("args"))) { + if (parser.parseEqual() || parseCompressedOperandList(parser, ListDelimiter::Paren, inputs)) + return failure(); + } + else if (parseCompressedOperandList(parser, ListDelimiter::Paren, inputs)) { + return failure(); + } + + bool hasCoreIds = succeeded(parser.parseOptionalKeyword("core_ids")); + if (hasCoreIds && parseCompressedIntegerList(parser, coreIds)) + return failure(); + + if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() + || parseCompressedRepeatedList( + parser, ListDelimiter::Square, weightTypes, [&](Type& type) { return parser.parseType(type); }) + || parseCompressedRepeatedList( + parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); }) + || parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true)) + return failure(); + + if (weights.size() != weightTypes.size()) + return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match"); + if (inputs.size() != inputTypes.size()) + return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match"); + if (hasCoreIds && result.attributes.get(onnx_mlir::kCoreIdAttrName)) + return parser.emitError(parser.getCurrentLocation(), "core_id cannot be specified both in core_ids and attr-dict"); + + auto& builder = parser.getBuilder(); + result.addAttribute("laneCount", builder.getI32IntegerAttr(laneCount)); + result.addAttribute( + "operandSegmentSizes", + builder.getDenseI32ArrayAttr({static_cast(weights.size()), static_cast(inputs.size())})); + if (hasCoreIds) + result.addAttribute(onnx_mlir::kCoreIdAttrName, getDenseI32ArrayAttr(parser, coreIds)); + + if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands) + || parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands)) + return failure(); + result.addTypes(outputTypes); + + Region* body = result.addRegion(); + buildImplicitRegionArgs(parser, inputTypes, generatedArgNames, regionArgs); + return parser.parseRegion(*body, regionArgs); +} + +void SpatChannelSendManyOp::print(OpAsmPrinter& printer) { + printer << " "; + printCompressedValueSequence(printer, getInputs()); + printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds()); + printer.printOptionalAttrDict( + (*this)->getAttrs(), + {getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()}); + printer << " : "; + printCompressedTypeSequence(printer, TypeRange(getInputs())); +} + +ParseResult SpatChannelSendManyOp::parse(OpAsmParser& parser, OperationState& result) { + SmallVector inputs; + SmallVector inputTypes; + SmallVector channelIds; + SmallVector sourceCoreIds; + SmallVector targetCoreIds; + + if (parseCompressedOperandSequence(parser, inputs)) + return failure(); + + bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels")); + if (hasMetadata) { + if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from") + || parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to") + || parseCompressedIntegerList(parser, targetCoreIds)) + return failure(); + } + + if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() + || parseCompressedTypeSequence(parser, inputTypes, /*allowEmpty=*/false)) + return failure(); + + if (inputs.size() != inputTypes.size()) + return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match"); + if (hasMetadata + && (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds") + || result.attributes.get("targetCoreIds"))) + return parser.emitError(parser.getCurrentLocation(), + "channel metadata cannot be specified both positionally and in attr-dict"); + if (hasMetadata) { + result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds)); + result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds)); + result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds)); + } + + return parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands); +} + +void SpatChannelReceiveManyOp::print(OpAsmPrinter& printer) { + printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds()); + printer.printOptionalAttrDict( + (*this)->getAttrs(), + {getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()}); + printer << " : "; + printCompressedTypeSequence(printer, getResultTypes()); +} + +ParseResult SpatChannelReceiveManyOp::parse(OpAsmParser& parser, OperationState& result) { + SmallVector outputTypes; + SmallVector channelIds; + SmallVector sourceCoreIds; + SmallVector targetCoreIds; + + bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels")); + if (hasMetadata) { + if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from") + || parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to") + || parseCompressedIntegerList(parser, targetCoreIds)) + return failure(); + } + + if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() + || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/false)) + return failure(); + + if (hasMetadata + && (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds") + || result.attributes.get("targetCoreIds"))) + return parser.emitError(parser.getCurrentLocation(), + "channel metadata cannot be specified both positionally and in attr-dict"); + if (hasMetadata) { + result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds)); + result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds)); + result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds)); + } + + result.addTypes(outputTypes); + return success(); +} + +void SpatChannelSendBatchOp::print(OpAsmPrinter& printer) { + printer << " "; + printer.printOperand(getInput()); + printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds()); + printer.printOptionalAttrDict( + (*this)->getAttrs(), + {getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()}); + printer << " : "; + printer.printType(getInput().getType()); +} + +ParseResult SpatChannelSendBatchOp::parse(OpAsmParser& parser, OperationState& result) { + OpAsmParser::UnresolvedOperand input; + Type inputType; + SmallVector channelIds; + SmallVector sourceCoreIds; + SmallVector targetCoreIds; + + if (parser.parseOperand(input)) + return failure(); + + bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels")); + if (hasMetadata) { + if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from") + || parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to") + || parseCompressedIntegerList(parser, targetCoreIds)) + return failure(); + } + + if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType)) + return failure(); + + if (hasMetadata + && (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds") + || result.attributes.get("targetCoreIds"))) + return parser.emitError(parser.getCurrentLocation(), + "channel metadata cannot be specified both positionally and in attr-dict"); + if (hasMetadata) { + result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds)); + result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds)); + result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds)); + } + + return parser.resolveOperand(input, inputType, result.operands); +} + +void SpatChannelReceiveBatchOp::print(OpAsmPrinter& printer) { + printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds()); + printer.printOptionalAttrDict( + (*this)->getAttrs(), + {getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()}); + printer << " : "; + printer.printType(getOutput().getType()); +} + +ParseResult SpatChannelReceiveBatchOp::parse(OpAsmParser& parser, OperationState& result) { + Type outputType; + SmallVector channelIds; + SmallVector sourceCoreIds; + SmallVector targetCoreIds; + + bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels")); + if (hasMetadata) { + if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from") + || parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to") + || parseCompressedIntegerList(parser, targetCoreIds)) + return failure(); + } + + if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(outputType)) + return failure(); + + if (hasMetadata + && (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds") + || result.attributes.get("targetCoreIds"))) + return parser.emitError(parser.getCurrentLocation(), + "channel metadata cannot be specified both positionally and in attr-dict"); + if (hasMetadata) { + result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds)); + result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds)); + result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds)); + } + + result.addTypes(outputType); + return success(); +} + } // namespace spatial } // namespace onnx_mlir diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.cpp index a9e5e6c..3d47ea7 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.cpp @@ -28,6 +28,8 @@ namespace spatial { using namespace mlir; namespace { +using SpatCompute = onnx_mlir::spatial::SpatCompute; +using SpatComputeBatch = onnx_mlir::spatial::SpatComputeBatch; struct VirtualNode { SmallVector originalComputeIndices; @@ -54,6 +56,45 @@ struct WindowScheduleResult { size_t maxMergeGroupSize = 0; }; +constexpr CPU kDefaultMaxCpuCount = 1000; + +size_t getSchedulingCpuBudget() { + if (coresCount.getValue() > 0) + return static_cast(coresCount.getValue()); + return static_cast(kDefaultMaxCpuCount); +} + +size_t getBatchChunkTargetCount(int32_t laneCount) { + assert(laneCount > 0 && "laneCount must be positive"); + return std::min(static_cast(laneCount), std::max(1, getSchedulingCpuBudget())); +} + +ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex) { + size_t totalLanes = static_cast(batch.getLaneCount()); + size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount()); + size_t baseChunkSize = totalLanes / chunkCount; + size_t largeChunkCount = totalLanes % chunkCount; + + size_t laneStart = chunkIndex * baseChunkSize + std::min(chunkIndex, largeChunkCount); + size_t laneCount = baseChunkSize + (chunkIndex < largeChunkCount ? 1 : 0); + return {batch.getOperation(), static_cast(laneStart), static_cast(laneCount)}; +} + +ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane) { + size_t totalLanes = static_cast(batch.getLaneCount()); + size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount()); + size_t baseChunkSize = totalLanes / chunkCount; + size_t largeChunkCount = totalLanes % chunkCount; + size_t largeChunkSpan = largeChunkCount * (baseChunkSize + 1); + + size_t chunkIndex = 0; + if (static_cast(lane) < largeChunkSpan) + chunkIndex = static_cast(lane) / (baseChunkSize + 1); + else + chunkIndex = largeChunkCount + (static_cast(lane) - largeChunkSpan) / baseChunkSize; + return getBatchChunkForIndex(batch, chunkIndex); +} + std::vector aggregateEdges(ArrayRef edges) { llvm::DenseMap, Weight> edgeWeights; for (auto [start, end, weight] : edges) { @@ -81,14 +122,96 @@ std::vector aggregateEdges(ArrayRef edges) { return aggregatedEdges; } -VirtualGraph buildInitialVirtualGraph(ArrayRef spatComputes, ArrayRef edges) { +Weight getComputeBodyWeight(Region& body) { + constexpr Weight kOperationWeight = 100; + Weight numOperations = 0; + for (auto& block : body) + for ([[maybe_unused]] auto& op : block) + numOperations = checkedAdd(numOperations, static_cast(1)); + return checkedMultiply(numOperations, kOperationWeight); +} + +CrossbarUsage getComputeBodyCrossbarUsage(Region& body) { + CrossbarUsage crossbarUsage = 0; + for (auto& block : body) + for (auto& op : block) + if (isa(op)) + crossbarUsage = checkedAdd(crossbarUsage, static_cast(1)); + return crossbarUsage; +} + +Weight getComputeInstanceWeight(const ComputeInstance& instance) { + if (auto spatCompute = dyn_cast(instance.op)) + return getSpatComputeWeight(spatCompute); + auto batch = cast(instance.op); + return checkedMultiply(getComputeBodyWeight(batch.getBody()), static_cast(instance.laneCount)); +} + +CrossbarUsage getComputeInstanceCrossbarUsage(const ComputeInstance& instance) { + if (auto spatCompute = dyn_cast(instance.op)) + return getSpatComputeCrossbarUsage(spatCompute); + auto batch = cast(instance.op); + return checkedMultiply(getComputeBodyCrossbarUsage(batch.getBody()), static_cast(instance.laneCount)); +} + +SmallVector getComputeInstanceInputs(const ComputeInstance& instance) { + if (auto spatCompute = dyn_cast(instance.op)) + return SmallVector(spatCompute.getInputs().begin(), spatCompute.getInputs().end()); + auto batch = cast(instance.op); + SmallVector inputs; + inputs.reserve(instance.laneCount); + for (uint32_t lane = instance.laneStart; lane < instance.laneStart + instance.laneCount; ++lane) + inputs.push_back(batch.getInputs()[lane]); + return inputs; +} + +std::optional getOriginalComputeInstance(Value value) { + Operation* op = value.getDefiningOp(); + if (!op) + return std::nullopt; + + while (auto extract = dyn_cast(op)) { + value = extract.getSource(); + op = value.getDefiningOp(); + if (!op) + return std::nullopt; + } + + if (auto spatCompute = dyn_cast(op)) + return ComputeInstance {spatCompute.getOperation(), 0, 1}; + if (auto batch = dyn_cast(op)) + return getBatchChunkForLane(batch, static_cast(cast(value).getResultNumber())); + return std::nullopt; +} + +SmallVector collectComputeInstances(Operation* entryOp) { + SmallVector instances; + for (Region& region : entryOp->getRegions()) { + for (Block& block : region) { + for (Operation& op : block) { + if (auto spatCompute = dyn_cast(&op)) { + instances.push_back({spatCompute.getOperation(), 0, 1}); + continue; + } + if (auto batch = dyn_cast(&op)) { + size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount()); + for (size_t chunkIndex = 0; chunkIndex < chunkCount; ++chunkIndex) + instances.push_back(getBatchChunkForIndex(batch, chunkIndex)); + } + } + } + } + return instances; +} + +VirtualGraph buildInitialVirtualGraph(ArrayRef computeInstances, ArrayRef edges) { VirtualGraph graph; - graph.nodes.reserve(spatComputes.size()); - for (auto [index, spatCompute] : llvm::enumerate(spatComputes)) { + graph.nodes.reserve(computeInstances.size()); + for (auto [index, computeInstance] : llvm::enumerate(computeInstances)) { VirtualNode node; node.originalComputeIndices.push_back(index); - node.weight = getSpatComputeWeight(spatCompute); - node.crossbarUsage = getSpatComputeCrossbarUsage(spatCompute); + node.weight = getComputeInstanceWeight(computeInstance); + node.crossbarUsage = getComputeInstanceCrossbarUsage(computeInstance); graph.nodes.push_back(std::move(node)); } graph.edges = aggregateEdges(edges); @@ -116,22 +239,34 @@ TimingInfo computeTiming(const VirtualGraph& graph) { incomingEdgeCount[endIndex]++; } - std::vector readyNodes; - readyNodes.reserve(nodeCount); + auto getVirtualNodeOrderKey = [&](size_t nodeIndex) { + const VirtualNode& node = graph.nodes[nodeIndex]; + if (!node.originalComputeIndices.empty()) + return node.originalComputeIndices.front(); + return nodeIndex; + }; + auto readyNodeGreater = [&](size_t lhs, size_t rhs) { + size_t lhsKey = getVirtualNodeOrderKey(lhs); + size_t rhsKey = getVirtualNodeOrderKey(rhs); + if (lhsKey != rhsKey) + return lhsKey > rhsKey; + return lhs > rhs; + }; + std::priority_queue, decltype(readyNodeGreater)> readyNodes(readyNodeGreater); for (size_t i = 0; i < nodeCount; ++i) if (incomingEdgeCount[i] == 0) - readyNodes.push_back(i); + readyNodes.push(i); - size_t readyIndex = 0; - while (readyIndex != readyNodes.size()) { - size_t current = readyNodes[readyIndex++]; + while (!readyNodes.empty()) { + size_t current = readyNodes.top(); + readyNodes.pop(); timing.topologicalOrder.push_back(current); for (auto [child, weight] : children[current]) { (void) weight; assert(incomingEdgeCount[child] > 0 && "incoming edge count underflow"); incomingEdgeCount[child]--; if (incomingEdgeCount[child] == 0) - readyNodes.push_back(child); + readyNodes.push(child); } } @@ -287,17 +422,21 @@ std::vector buildWindowEdges(const VirtualGraph& graph, const std:: WindowScheduleResult scheduleWindow(const VirtualGraph& graph, ArrayRef selectedNodes, MLIRContext* context) { std::vector windowWeights; std::vector windowCrossbarUsage; + std::vector windowNodeOrderKeys; std::vector nodeToWindowIndex(graph.nodes.size(), -1); windowWeights.reserve(selectedNodes.size()); windowCrossbarUsage.reserve(selectedNodes.size()); + windowNodeOrderKeys.reserve(selectedNodes.size()); for (auto [windowIndex, nodeIndex] : llvm::enumerate(selectedNodes)) { nodeToWindowIndex[nodeIndex] = static_cast(windowIndex); windowWeights.push_back(graph.nodes[nodeIndex].weight); windowCrossbarUsage.push_back(graph.nodes[nodeIndex].crossbarUsage); + windowNodeOrderKeys.push_back(static_cast(nodeIndex)); } - GraphDCP windowGraph(windowWeights, buildWindowEdges(graph, nodeToWindowIndex), windowCrossbarUsage); + GraphDCP windowGraph( + windowWeights, buildWindowEdges(graph, nodeToWindowIndex), windowNodeOrderKeys, windowCrossbarUsage); if (coresCount.getValue() > 0) windowGraph.setMaxCpuCount(static_cast(coresCount.getValue())); windowGraph.setContext(context); @@ -414,13 +553,7 @@ bool coarsenGraph(const VirtualGraph& graph, return true; } -constexpr CPU kDefaultMaxCpuCount = 1000; - -CPU getVirtualGraphMaxCpuCount() { - if (coresCount.getValue() > 0) - return static_cast(coresCount.getValue()); - return kDefaultMaxCpuCount; -} +CPU getVirtualGraphMaxCpuCount() { return static_cast(getSchedulingCpuBudget()); } size_t getDcpCoarseningWindowSize(size_t nodeCount) { size_t windowSize = std::min(dcpCriticalWindowSize.getValue(), nodeCount); @@ -430,7 +563,7 @@ size_t getDcpCoarseningWindowSize(size_t nodeCount) { return windowSize; } -DCPAnalysisResult buildResultFromVirtualGraph(const VirtualGraph& graph, ArrayRef spatComputes) { +DCPAnalysisResult buildResultFromVirtualGraph(const VirtualGraph& graph, ArrayRef computeInstances) { DCPAnalysisResult result; TimingInfo timing = computeTiming(graph); @@ -443,19 +576,19 @@ DCPAnalysisResult buildResultFromVirtualGraph(const VirtualGraph& graph, ArrayRe std::iota(virtualNodeOrder.begin(), virtualNodeOrder.end(), 0); } - std::vector originalComputeToCpu(spatComputes.size(), 0); + std::vector originalComputeToCpu(computeInstances.size(), 0); for (auto [cpu, virtualNodeIndex] : llvm::enumerate(virtualNodeOrder)) { const VirtualNode& virtualNode = graph.nodes[virtualNodeIndex]; for (size_t originalIndex : virtualNode.originalComputeIndices) originalComputeToCpu[originalIndex] = cpu; } - result.dominanceOrderCompute.reserve(spatComputes.size()); - for (auto [originalIndex, spatCompute] : llvm::enumerate(spatComputes)) { + result.dominanceOrderCompute.reserve(computeInstances.size()); + for (auto [originalIndex, computeInstance] : llvm::enumerate(computeInstances)) { size_t cpu = originalComputeToCpu[originalIndex]; - result.dominanceOrderCompute.push_back(spatCompute); - result.computeToCpuMap[spatCompute] = cpu; - result.cpuToLastComputeMap[cpu] = spatCompute; + result.dominanceOrderCompute.push_back(computeInstance); + result.computeToCpuMap[computeInstance] = cpu; + result.cpuToLastComputeMap[cpu] = computeInstance; } for (const auto& [cpu, lastCompute] : result.cpuToLastComputeMap) result.isLastComputeOfCpu.insert(lastCompute); @@ -463,13 +596,44 @@ DCPAnalysisResult buildResultFromVirtualGraph(const VirtualGraph& graph, ArrayRe return result; } -DCPAnalysisResult runLegacyDcp(ArrayRef spatComputes, ArrayRef edges, MLIRContext* context) { - GraphDCP graphDCP(spatComputes, edges); +DCPAnalysisResult buildResultFromScheduledGraph(GraphDCP& graphDCP, ArrayRef computeInstances) { + DCPAnalysisResult result; + result.dominanceOrderCompute.assign(computeInstances.begin(), computeInstances.end()); + + for (CPU cpu = 0; cpu < graphDCP.cpuCount(); ++cpu) { + auto scheduledTasks = graphDCP.getScheduledTasks(cpu); + if (scheduledTasks.empty()) + continue; + + for (const auto& task : scheduledTasks) + result.computeToCpuMap[computeInstances[task.nodeIndex]] = cpu; + result.cpuToLastComputeMap[cpu] = computeInstances[scheduledTasks.back().nodeIndex]; + result.isLastComputeOfCpu.insert(computeInstances[scheduledTasks.back().nodeIndex]); + } + + return result; +} + +DCPAnalysisResult +runLegacyDcp(ArrayRef computeInstances, ArrayRef edges, MLIRContext* context) { + SmallVector nodeWeights; + SmallVector nodeCrossbarUsage; + SmallVector nodeOrderKeys; + nodeWeights.reserve(computeInstances.size()); + nodeCrossbarUsage.reserve(computeInstances.size()); + nodeOrderKeys.reserve(computeInstances.size()); + for (auto [index, instance] : llvm::enumerate(computeInstances)) { + nodeWeights.push_back(getComputeInstanceWeight(instance)); + nodeCrossbarUsage.push_back(getComputeInstanceCrossbarUsage(instance)); + nodeOrderKeys.push_back(static_cast(index)); + } + + GraphDCP graphDCP(nodeWeights, edges, nodeOrderKeys, nodeCrossbarUsage); if (coresCount.getValue() > 0) graphDCP.setMaxCpuCount(static_cast(coresCount.getValue())); graphDCP.setContext(context); graphDCP.runDcp(); - return graphDCP.getResult(); + return buildResultFromScheduledGraph(graphDCP, computeInstances); } } // namespace @@ -488,27 +652,31 @@ SpatCompute getOriginalSpatCompute(Operation* op) { } DCPAnalysisResult DCPAnalysis::run() { - SmallVector spatComputes; + SmallVector computeInstances = collectComputeInstances(entryOp); SmallVector edges; - for (auto& region : entryOp->getRegions()) - for (SpatCompute spatCompute : region.getOps()) - spatComputes.push_back(spatCompute); - for (auto [indexEndEdge, spatCompute] : llvm::enumerate(spatComputes)) { - for (Value input : spatCompute.getInputs()) { - if (auto producerCompute = getOriginalSpatCompute(input.getDefiningOp())) { - auto producerIt = llvm::find(spatComputes, producerCompute); - assert(producerIt != spatComputes.end()); - auto indexStartEdge = std::distance(spatComputes.begin(), producerIt); - edges.push_back({indexStartEdge, indexEndEdge, getSizeInBytes(cast(input.getType()))}); + llvm::DenseMap instanceToIndex; + instanceToIndex.reserve(computeInstances.size()); + for (auto [index, instance] : llvm::enumerate(computeInstances)) + instanceToIndex[instance] = index; + + for (auto [indexEndEdge, computeInstance] : llvm::enumerate(computeInstances)) { + for (Value input : getComputeInstanceInputs(computeInstance)) { + if (auto producerInstance = getOriginalComputeInstance(input)) { + auto producerIt = instanceToIndex.find(*producerInstance); + assert(producerIt != instanceToIndex.end()); + auto indexStartEdge = producerIt->second; + edges.push_back({static_cast(indexStartEdge), + static_cast(indexEndEdge), + static_cast(getSizeInBytes(cast(input.getType())))}); } } } if (dcpCriticalWindowSize.getValue() == 0) - return runLegacyDcp(spatComputes, edges, entryOp->getContext()); + return runLegacyDcp(computeInstances, edges, entryOp->getContext()); - VirtualGraph virtualGraph = buildInitialVirtualGraph(spatComputes, edges); + VirtualGraph virtualGraph = buildInitialVirtualGraph(computeInstances, edges); size_t iteration = 0; auto tryCoarsenSelectedNodes = [&](ArrayRef selectedNodes) { size_t oldNodeCount = virtualGraph.nodes.size(); @@ -545,6 +713,13 @@ DCPAnalysisResult DCPAnalysis::run() { }; while (virtualGraph.nodes.size() > 1) { + if (virtualGraph.nodes.size() <= getSchedulingCpuBudget()) { + if (virtualGraph.nodes.size() >= 200) + llvm::errs() << llvm::formatv( + "[DCP-COARSEN] iter={0} old={1} stop=cpu-budget\n", iteration, virtualGraph.nodes.size()); + break; + } + iteration++; TimingInfo timing = computeTiming(virtualGraph); if (!timing.valid) { @@ -576,7 +751,7 @@ DCPAnalysisResult DCPAnalysis::run() { break; } - return buildResultFromVirtualGraph(virtualGraph, spatComputes); + return buildResultFromVirtualGraph(virtualGraph, computeInstances); } } // namespace spatial diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.hpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.hpp index 8b1e1d3..dd8adfa 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.hpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.hpp @@ -5,15 +5,28 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" +#include #include #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" +// A scheduling identity that covers both spat.compute and scheduled shards of +// spat.compute_batch. +struct ComputeInstance { + mlir::Operation* op = nullptr; + uint32_t laneStart = 0; + uint32_t laneCount = 1; + + bool operator==(const ComputeInstance& other) const { + return op == other.op && laneStart == other.laneStart && laneCount == other.laneCount; + } +}; + struct DCPAnalysisResult { - std::vector dominanceOrderCompute; - llvm::DenseMap computeToCpuMap; - llvm::DenseSet isLastComputeOfCpu; - llvm::DenseMap cpuToLastComputeMap; + std::vector dominanceOrderCompute; + llvm::DenseMap computeToCpuMap; + llvm::DenseSet isLastComputeOfCpu; + llvm::DenseMap cpuToLastComputeMap; }; namespace onnx_mlir { @@ -34,3 +47,21 @@ public: } // namespace spatial } // namespace onnx_mlir + +namespace llvm { +template <> +struct DenseMapInfo { + static ComputeInstance getEmptyKey() { + return {DenseMapInfo::getEmptyKey(), UINT32_MAX, UINT32_MAX}; + } + static ComputeInstance getTombstoneKey() { + return {DenseMapInfo::getTombstoneKey(), UINT32_MAX, UINT32_MAX}; + } + static unsigned getHashValue(const ComputeInstance& v) { + return llvm::hash_combine(v.op, v.laneStart, v.laneCount); + } + static bool isEqual(const ComputeInstance& a, const ComputeInstance& b) { + return a == b; + } +}; +} // namespace llvm diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/Graph.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/Graph.cpp index 6c48b1d..3bca16c 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/Graph.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/Graph.cpp @@ -1491,18 +1491,21 @@ void GraphDCP::runDcp() { struct ReadyEntry { Time slack; Time aest; + int64_t orderKey; TaskDCP* task; bool operator>(const ReadyEntry& other) const { if (slack != other.slack) return slack > other.slack; - return aest > other.aest; + if (aest != other.aest) + return aest > other.aest; + return orderKey > other.orderKey; } }; std::priority_queue, std::greater> readyQueue; size_t readyCount = 0; auto pushReady = [&](TaskDCP* node) { - readyQueue.push({slackOrZero(node->getAest(), node->getAlst()), node->getAest(), node}); + readyQueue.push({slackOrZero(node->getAest(), node->getAlst()), node->getAest(), node->Id(), node}); }; for (auto& node : nodes) { @@ -1528,7 +1531,7 @@ void GraphDCP::runDcp() { candidate = entry.task; break; } - readyQueue.push({curSlack, curAest, entry.task}); + readyQueue.push({curSlack, curAest, entry.orderKey, entry.task}); } assert(candidate != nullptr && "readyCount > 0 but heap exhausted"); --readyCount; @@ -1579,8 +1582,11 @@ DCPAnalysisResult GraphDCP::getResult() { auto dominanceOrder = dcp_graph::collectDominanceOrder(getRoots(), nodes.size()); ret.dominanceOrderCompute.reserve(dominanceOrder.size()); - for (auto elem : dominanceOrder) - ret.dominanceOrderCompute.push_back(elem->getSpatCompute()); + for (auto elem : dominanceOrder) { + auto spatCompute = elem->getSpatCompute(); + if (spatCompute) + ret.dominanceOrderCompute.push_back({spatCompute.getOperation(), 0}); + } for (CPU cpu = 0; cpu < getLastCpu(); ++cpu) { const CpuTaskList* tasks = findCpuTasks(cpu); @@ -1588,10 +1594,14 @@ DCPAnalysisResult GraphDCP::getResult() { continue; size_t i = 0; for (auto node : *tasks) { - ret.computeToCpuMap[node->getSpatCompute()] = cpu; + auto spatCompute = node->getSpatCompute(); + if (!spatCompute) + continue; + ComputeInstance instance {spatCompute.getOperation(), 0}; + ret.computeToCpuMap[instance] = cpu; if (i++ == tasks->size() - 1) { - ret.isLastComputeOfCpu.insert(node->getSpatCompute()); - ret.cpuToLastComputeMap[cpu] = node->getSpatCompute(); + ret.isLastComputeOfCpu.insert(instance); + ret.cpuToLastComputeMap[cpu] = instance; } } } diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/Graph.hpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/Graph.hpp index e8303d6..32d1976 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/Graph.hpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/Graph.hpp @@ -138,13 +138,18 @@ public: GraphDCP(llvm::ArrayRef nodeWeights, llvm::ArrayRef edges, + llvm::ArrayRef nodeOrderKeys = {}, llvm::ArrayRef nodeCrossbarUsage = {}) : nodes(), cpuTasks(), cpuCrossbarUsage() { assert((nodeCrossbarUsage.empty() || nodeCrossbarUsage.size() == nodeWeights.size()) && "synthetic crossbar usage must match synthetic node weights"); + assert((nodeOrderKeys.empty() || nodeOrderKeys.size() == nodeWeights.size()) + && "synthetic node order keys must match synthetic node weights"); nodes.reserve(nodeWeights.size()); for (auto [index, weight] : llvm::enumerate(nodeWeights)) - nodes.emplace_back(index, weight, nodeCrossbarUsage.empty() ? 0 : nodeCrossbarUsage[index]); + nodes.emplace_back(nodeOrderKeys.empty() ? static_cast(index) : nodeOrderKeys[index], + weight, + nodeCrossbarUsage.empty() ? 0 : nodeCrossbarUsage[index]); for (auto [start, end, weight] : edges) makeEdge(start, end, weight); } diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp index 89acae3..1309c3b 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp @@ -10,6 +10,7 @@ #include "mlir/Support/LLVM.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" @@ -35,12 +36,66 @@ #include "DCPGraph/DCPAnalysis.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/Channels.hpp" using namespace mlir; namespace onnx_mlir { namespace { using SpatCompute = spatial::SpatCompute; +using SpatComputeBatch = spatial::SpatComputeBatch; + +SpatCompute getOriginalSpatCompute(Operation* op); +struct ProducerValueRef { + ComputeInstance instance; + size_t resultIndex = 0; +}; + +std::optional getProducerValueRef(Value value); + +static size_t getFastPathCpuBudget() { + constexpr size_t kDefaultMaxCpuCount = 1000; + if (coresCount.getValue() > 0) + return static_cast(coresCount.getValue()); + return kDefaultMaxCpuCount; +} + +static size_t getBatchChunkTargetCount(int32_t laneCount) { + assert(laneCount > 0 && "laneCount must be positive"); + return std::min(static_cast(laneCount), std::max(1, getFastPathCpuBudget())); +} + +static ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex) { + size_t totalLanes = static_cast(batch.getLaneCount()); + size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount()); + size_t baseChunkSize = totalLanes / chunkCount; + size_t largeChunkCount = totalLanes % chunkCount; + + size_t laneStart = chunkIndex * baseChunkSize + std::min(chunkIndex, largeChunkCount); + size_t laneCount = baseChunkSize + (chunkIndex < largeChunkCount ? 1 : 0); + return {batch.getOperation(), static_cast(laneStart), static_cast(laneCount)}; +} + +static ProducerValueRef getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane) { + size_t totalLanes = static_cast(batch.getLaneCount()); + size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount()); + size_t baseChunkSize = totalLanes / chunkCount; + size_t largeChunkCount = totalLanes % chunkCount; + size_t largeChunkSpan = largeChunkCount * (baseChunkSize + 1); + + size_t chunkIndex = 0; + size_t resultIndex = 0; + if (static_cast(lane) < largeChunkSpan) { + chunkIndex = static_cast(lane) / (baseChunkSize + 1); + resultIndex = static_cast(lane) % (baseChunkSize + 1); + } + else { + size_t smallLane = static_cast(lane) - largeChunkSpan; + chunkIndex = largeChunkCount + smallLane / baseChunkSize; + resultIndex = smallLane % baseChunkSize; + } + return {getBatchChunkForIndex(batch, chunkIndex), resultIndex}; +} SpatCompute getOriginalSpatCompute(Operation* op) { while (auto extract = dyn_cast_if_present(op)) @@ -48,6 +103,742 @@ SpatCompute getOriginalSpatCompute(Operation* op) { return dyn_cast_if_present(op); } +std::optional getProducerValueRef(Value value) { + Operation* op = value.getDefiningOp(); + if (!op) + return std::nullopt; + + while (auto extract = dyn_cast(op)) { + value = extract.getSource(); + op = value.getDefiningOp(); + if (!op) + return std::nullopt; + } + + if (auto compute = dyn_cast(op)) { + return ProducerValueRef { + ComputeInstance {compute.getOperation(), 0, 1}, + static_cast(cast(value).getResultNumber()) + }; + } + + if (auto batch = dyn_cast(op)) + return getBatchChunkForLane(batch, static_cast(cast(value).getResultNumber())); + + return std::nullopt; +} + +static int32_t getPhysicalCoreId(size_t schedulerCpu) { return static_cast(schedulerCpu + 1); } + +static SmallVector getBatchCoreIds(Operation* op, size_t laneCount) { + if (auto coreIdsAttr = op->getAttrOfType(onnx_mlir::kCoreIdAttrName)) + return SmallVector(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end()); + if (auto coreIdAttr = op->getAttrOfType(onnx_mlir::kCoreIdAttrName)) + return SmallVector(laneCount, static_cast(coreIdAttr.getInt())); + return {}; +} + +static std::optional getComputeCoreId(SpatCompute compute) { + if (auto coreIdAttr = compute->getAttrOfType(onnx_mlir::kCoreIdAttrName)) + return static_cast(coreIdAttr.getInt()); + return std::nullopt; +} + +static bool areEquivalentForRebatch(SpatCompute lhs, SpatCompute rhs) { + if (!lhs || !rhs) + return false; + if (lhs.getInputs().size() != rhs.getInputs().size()) + return false; + if (lhs.getResultTypes() != rhs.getResultTypes()) + return false; + if (lhs.getWeights().size() != rhs.getWeights().size()) + return false; + if (!llvm::equal(lhs.getWeights(), rhs.getWeights())) + return false; + + auto& lhsBlock = lhs.getBody().front(); + auto& rhsBlock = rhs.getBody().front(); + if (lhsBlock.getNumArguments() != rhsBlock.getNumArguments()) + return false; + + DenseMap mappedValues; + for (auto [lhsArg, rhsArg] : llvm::zip(lhsBlock.getArguments(), rhsBlock.getArguments())) { + if (lhsArg.getType() != rhsArg.getType()) + return false; + mappedValues[lhsArg] = rhsArg; + } + auto lhsIt = lhsBlock.begin(); + auto rhsIt = rhsBlock.begin(); + for (; lhsIt != lhsBlock.end() && rhsIt != rhsBlock.end(); ++lhsIt, ++rhsIt) { + Operation& lhsOp = *lhsIt; + Operation& rhsOp = *rhsIt; + + if (lhsOp.getName() != rhsOp.getName()) + return false; + if (lhsOp.getNumOperands() != rhsOp.getNumOperands()) + return false; + if (lhsOp.getNumResults() != rhsOp.getNumResults()) + return false; + if (lhsOp.getNumRegions() != 0 || rhsOp.getNumRegions() != 0) + return false; + + for (auto [lhsOperand, rhsOperand] : llvm::zip(lhsOp.getOperands(), rhsOp.getOperands())) { + auto mapped = mappedValues.find(lhsOperand); + if (mapped != mappedValues.end()) { + if (mapped->second != rhsOperand) + return false; + continue; + } + if (lhsOperand != rhsOperand) + return false; + } + + if (auto lhsReceive = dyn_cast(lhsOp)) { + auto rhsReceive = cast(rhsOp); + if (lhsReceive.getOutput().getType() != rhsReceive.getOutput().getType()) + return false; + } + else if (auto lhsSend = dyn_cast(lhsOp)) { + auto rhsSend = cast(rhsOp); + if (lhsSend.getInput().getType() != rhsSend.getInput().getType()) + return false; + } + else if (lhsOp.getAttrs() != rhsOp.getAttrs()) { + return false; + } + + if (lhsOp.getResultTypes() != rhsOp.getResultTypes()) + return false; + for (auto [lhsResult, rhsResult] : llvm::zip(lhsOp.getResults(), rhsOp.getResults())) + mappedValues[lhsResult] = rhsResult; + } + + return lhsIt == lhsBlock.end() && rhsIt == rhsBlock.end(); +} + +static void sinkChannelsIntoBatchComputes(func::FuncOp funcOp, + IRRewriter& rewriter, + SmallVectorImpl& opsToErase, + int64_t& nextChannelId) { + SmallVector batches(funcOp.getOps()); + + for (auto batch : batches) { + if (batch.getInputs().empty() && batch.getResults().empty()) + continue; + + if (batch.getInputs().size() != static_cast(batch.getLaneCount())) + continue; + if (batch.getResults().size() != static_cast(batch.getLaneCount())) + continue; + + SmallVector inputReceives; + inputReceives.reserve(batch.getInputs().size()); + bool allInputsAreReceives = true; + for (Value input : batch.getInputs()) { + auto receiveOp = dyn_cast_or_null(input.getDefiningOp()); + if (!receiveOp) { + allInputsAreReceives = false; + break; + } + inputReceives.push_back(receiveOp); + } + + SmallVector resultSends; + resultSends.reserve(batch.getResults().size()); + bool allResultsAreSingleSends = true; + for (Value result : batch.getResults()) { + if (!result.hasOneUse()) { + allResultsAreSingleSends = false; + break; + } + auto sendOp = dyn_cast(*result.getUsers().begin()); + if (!sendOp) { + allResultsAreSingleSends = false; + break; + } + resultSends.push_back(sendOp); + } + + if (!allInputsAreReceives || !allResultsAreSingleSends) + continue; + + Block& oldBlock = batch.getBody().front(); + if (oldBlock.getNumArguments() != 1) + continue; + + SmallVector newWeights(batch.getWeights().begin(), batch.getWeights().end()); + rewriter.setInsertionPointAfter(batch); + auto newBatch = SpatComputeBatch::create(rewriter, + batch.getLoc(), + TypeRange {}, + rewriter.getI32IntegerAttr(batch.getLaneCount()), + ValueRange(newWeights), + ValueRange {}); + newBatch.getProperties().setOperandSegmentSizes({static_cast(newWeights.size()), 0}); + + SmallVector coreIds = getBatchCoreIds(batch, static_cast(batch.getLaneCount())); + if (!coreIds.empty()) + newBatch->setAttr(onnx_mlir::kCoreIdAttrName, rewriter.getDenseI32ArrayAttr(coreIds)); + + auto* newBlock = + rewriter.createBlock(&newBatch.getBody(), newBatch.getBody().end(), TypeRange {}, ArrayRef {}); + + rewriter.setInsertionPointToStart(newBlock); + struct BatchReceiveEntry { + uint64_t channelId = 0; + uint32_t sourceCoreId = 0; + uint32_t targetCoreId = 0; + }; + SmallVector receiveEntries; + receiveEntries.reserve(inputReceives.size()); + for (auto receiveOp : inputReceives) + receiveEntries.push_back({receiveOp.getChannelId(), receiveOp.getSourceCoreId(), receiveOp.getTargetCoreId()}); + llvm::stable_sort(receiveEntries, [](const BatchReceiveEntry& lhs, const BatchReceiveEntry& rhs) { + return std::tuple(lhs.sourceCoreId, lhs.targetCoreId, lhs.channelId) + < std::tuple(rhs.sourceCoreId, rhs.targetCoreId, rhs.channelId); + }); + + SmallVector receiveChannelIds; + SmallVector receiveSourceCoreIds; + SmallVector receiveTargetCoreIds; + receiveChannelIds.reserve(receiveEntries.size()); + receiveSourceCoreIds.reserve(receiveEntries.size()); + receiveTargetCoreIds.reserve(receiveEntries.size()); + for (const BatchReceiveEntry& entry : receiveEntries) { + (void) entry; + receiveChannelIds.push_back(nextChannelId++); + receiveSourceCoreIds.push_back(static_cast(entry.sourceCoreId)); + receiveTargetCoreIds.push_back(static_cast(entry.targetCoreId)); + } + auto batchReceive = spatial::SpatChannelReceiveBatchOp::create(rewriter, + batch.getLoc(), + oldBlock.getArgument(0).getType(), + rewriter.getDenseI64ArrayAttr(receiveChannelIds), + rewriter.getDenseI32ArrayAttr(receiveSourceCoreIds), + rewriter.getDenseI32ArrayAttr(receiveTargetCoreIds)); + + IRMapping mapper; + mapper.map(oldBlock.getArgument(0), batchReceive.getOutput()); + + auto oldYield = cast(oldBlock.getTerminator()); + rewriter.setInsertionPointToEnd(newBlock); + for (Operation& op : oldBlock) { + if (&op == oldYield) + continue; + rewriter.clone(op, mapper); + } + + Value sendInput = mapper.lookup(oldYield.getOperand(0)); + struct BatchSendEntry { + uint64_t channelId = 0; + uint32_t sourceCoreId = 0; + uint32_t targetCoreId = 0; + }; + SmallVector sendEntries; + sendEntries.reserve(resultSends.size()); + for (auto sendOp : resultSends) + sendEntries.push_back({sendOp.getChannelId(), sendOp.getSourceCoreId(), sendOp.getTargetCoreId()}); + llvm::stable_sort(sendEntries, [](const BatchSendEntry& lhs, const BatchSendEntry& rhs) { + return std::tuple(lhs.sourceCoreId, lhs.targetCoreId, lhs.channelId) + < std::tuple(rhs.sourceCoreId, rhs.targetCoreId, rhs.channelId); + }); + + SmallVector sendChannelIds; + SmallVector sendSourceCoreIds; + SmallVector sendTargetCoreIds; + sendChannelIds.reserve(sendEntries.size()); + sendSourceCoreIds.reserve(sendEntries.size()); + sendTargetCoreIds.reserve(sendEntries.size()); + for (const BatchSendEntry& entry : sendEntries) { + (void) entry; + sendChannelIds.push_back(nextChannelId++); + sendSourceCoreIds.push_back(static_cast(entry.sourceCoreId)); + sendTargetCoreIds.push_back(static_cast(entry.targetCoreId)); + } + spatial::SpatChannelSendBatchOp::create(rewriter, + batch.getLoc(), + rewriter.getDenseI64ArrayAttr(sendChannelIds), + rewriter.getDenseI32ArrayAttr(sendSourceCoreIds), + rewriter.getDenseI32ArrayAttr(sendTargetCoreIds), + sendInput); + spatial::SpatYieldOp::create(rewriter, batch.getLoc(), ValueRange {}); + + for (auto receiveOp : inputReceives) + opsToErase.push_back(receiveOp); + for (auto sendOp : resultSends) + opsToErase.push_back(sendOp); + opsToErase.push_back(batch); + } +} + +void sinkChannelsIntoComputes(func::FuncOp funcOp, int64_t& nextChannelId) { + IRRewriter rewriter(funcOp.getContext()); + SmallVector computes(funcOp.getOps()); + SmallVector opsToErase; + + for (auto compute : computes) { + SmallVector keptInputIndices; + SmallVector keptResultIndices; + SmallVector internalizedReceives(compute.getInputs().size()); + SmallVector> resultSendOps(compute.getNumResults()); + + bool needsRewrite = false; + for (auto [inputIndex, input] : llvm::enumerate(compute.getInputs())) { + auto receiveOp = dyn_cast_or_null(input.getDefiningOp()); + if (!receiveOp) { + keptInputIndices.push_back(inputIndex); + continue; + } + + internalizedReceives[inputIndex] = receiveOp; + opsToErase.push_back(receiveOp); + needsRewrite = true; + } + + for (auto [resultIndex, result] : llvm::enumerate(compute.getResults())) { + bool hasNonSendUser = false; + for (Operation* user : result.getUsers()) { + if (auto sendOp = dyn_cast(user)) { + resultSendOps[resultIndex].push_back(sendOp); + opsToErase.push_back(sendOp); + needsRewrite = true; + continue; + } + hasNonSendUser = true; + } + + if (hasNonSendUser || resultSendOps[resultIndex].empty()) + keptResultIndices.push_back(resultIndex); + } + + if (!needsRewrite) + continue; + + SmallVector newOperands; + SmallVector newResultTypes; + SmallVector newBlockArgTypes; + SmallVector newBlockArgLocs; + newOperands.reserve(compute.getNumOperands()); + newResultTypes.reserve(keptResultIndices.size()); + newBlockArgTypes.reserve(keptInputIndices.size()); + newBlockArgLocs.reserve(keptInputIndices.size()); + + for (Value weight : compute.getWeights()) + newOperands.push_back(weight); + for (unsigned inputIndex : keptInputIndices) { + Value input = compute.getInputs()[inputIndex]; + newOperands.push_back(input); + newBlockArgTypes.push_back(input.getType()); + newBlockArgLocs.push_back(compute.getLoc()); + } + for (unsigned resultIndex : keptResultIndices) + newResultTypes.push_back(compute.getResult(resultIndex).getType()); + + rewriter.setInsertionPointAfter(compute); + auto newCompute = + SpatCompute::create(rewriter, compute.getLoc(), TypeRange(newResultTypes), ValueRange(newOperands)); + newCompute.getProperties().setOperandSegmentSizes( + {static_cast(compute.getWeights().size()), static_cast(keptInputIndices.size())}); + if (auto coreIdAttr = compute->getAttrOfType(onnx_mlir::kCoreIdAttrName)) + newCompute->setAttr(onnx_mlir::kCoreIdAttrName, coreIdAttr); + + auto* newBlock = rewriter.createBlock( + &newCompute.getBody(), newCompute.getBody().end(), TypeRange(newBlockArgTypes), newBlockArgLocs); + + IRMapping mapper; + for (auto [mappedIndex, inputIndex] : llvm::enumerate(keptInputIndices)) + mapper.map(compute.getBody().front().getArgument(inputIndex), newBlock->getArgument(mappedIndex)); + + rewriter.setInsertionPointToStart(newBlock); + for (auto [inputIndex, receiveOp] : llvm::enumerate(internalizedReceives)) { + if (!receiveOp) + continue; + + auto internalReceive = spatial::SpatChannelReceiveOp::create(rewriter, + receiveOp.getLoc(), + receiveOp.getResult().getType(), + receiveOp.getChannelIdAttr(), + receiveOp.getSourceCoreIdAttr(), + receiveOp.getTargetCoreIdAttr()); + mapper.map(compute.getBody().front().getArgument(inputIndex), internalReceive.getResult()); + } + + auto oldYieldOp = cast(compute.getBody().front().getTerminator()); + rewriter.setInsertionPointToEnd(newBlock); + for (Operation& op : compute.getBody().front()) { + if (&op == oldYieldOp) + continue; + rewriter.clone(op, mapper); + } + + for (auto [resultIndex, sendOps] : llvm::enumerate(resultSendOps)) { + if (sendOps.empty()) + continue; + + Value yieldedValue = mapper.lookup(oldYieldOp.getOperand(resultIndex)); + for (auto sendOp : sendOps) + spatial::SpatChannelSendOp::create(rewriter, + sendOp.getLoc(), + sendOp.getChannelIdAttr(), + sendOp.getSourceCoreIdAttr(), + sendOp.getTargetCoreIdAttr(), + yieldedValue); + } + + SmallVector keptYieldOperands; + keptYieldOperands.reserve(keptResultIndices.size()); + for (unsigned resultIndex : keptResultIndices) + keptYieldOperands.push_back(mapper.lookup(oldYieldOp.getOperand(resultIndex))); + spatial::SpatYieldOp::create(rewriter, oldYieldOp.getLoc(), ValueRange(keptYieldOperands)); + + for (auto [newResultIndex, oldResultIndex] : llvm::enumerate(keptResultIndices)) + compute.getResult(oldResultIndex).replaceAllUsesWith(newCompute.getResult(newResultIndex)); + + opsToErase.push_back(compute); + } + + sinkChannelsIntoBatchComputes(funcOp, rewriter, opsToErase, nextChannelId); + + SmallVector pendingRemovals(opsToErase.begin(), opsToErase.end()); + while (!pendingRemovals.empty()) { + bool erasedAny = false; + for (auto it = pendingRemovals.begin(); it != pendingRemovals.end();) { + if (!(*it)->use_empty()) { + ++it; + continue; + } + + rewriter.eraseOp(*it); + it = pendingRemovals.erase(it); + erasedAny = true; + } + + if (erasedAny) + continue; + + for (Operation* op : pendingRemovals) + op->emitError("failed to sink channel op into compute"); + llvm_unreachable("channel sinking left cyclic top-level dependencies"); + } +} + +static void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) { + IRRewriter rewriter(funcOp.getContext()); + + for (auto compute : funcOp.getOps()) { + Block& block = compute.getBody().front(); + for (auto it = block.begin(); it != block.end();) { + auto receiveOp = dyn_cast(&*it); + if (receiveOp) { + SmallVector run; + Type outputType = receiveOp.getOutput().getType(); + auto runIt = it; + while (runIt != block.end()) { + auto current = dyn_cast(&*runIt); + if (!current || current.getOutput().getType() != outputType) + break; + run.push_back(current); + ++runIt; + } + + if (run.size() > 1) { + struct ReceiveEntry { + spatial::SpatChannelReceiveOp op; + size_t originalIndex = 0; + uint32_t sourceCoreId = 0; + uint32_t targetCoreId = 0; + uint64_t channelId = 0; + }; + SmallVector sortedEntries; + sortedEntries.reserve(run.size()); + for (auto [originalIndex, op] : llvm::enumerate(run)) + sortedEntries.push_back({op, originalIndex, op.getSourceCoreId(), op.getTargetCoreId(), op.getChannelId()}); + llvm::stable_sort(sortedEntries, [](const ReceiveEntry& lhs, const ReceiveEntry& rhs) { + return std::tuple(lhs.sourceCoreId, lhs.targetCoreId, lhs.channelId) + < std::tuple(rhs.sourceCoreId, rhs.targetCoreId, rhs.channelId); + }); + + SmallVector channelIds; + SmallVector sourceCoreIds; + SmallVector targetCoreIds; + SmallVector outputTypes; + channelIds.reserve(sortedEntries.size()); + sourceCoreIds.reserve(sortedEntries.size()); + targetCoreIds.reserve(sortedEntries.size()); + outputTypes.reserve(sortedEntries.size()); + for (ReceiveEntry& entry : sortedEntries) { + (void) entry; + channelIds.push_back(nextChannelId++); + sourceCoreIds.push_back(static_cast(entry.sourceCoreId)); + targetCoreIds.push_back(static_cast(entry.targetCoreId)); + outputTypes.push_back(entry.op.getOutput().getType()); + } + + rewriter.setInsertionPoint(run.front()); + auto compactReceive = spatial::SpatChannelReceiveManyOp::create(rewriter, + run.front().getLoc(), + TypeRange(outputTypes), + rewriter.getDenseI64ArrayAttr(channelIds), + rewriter.getDenseI32ArrayAttr(sourceCoreIds), + rewriter.getDenseI32ArrayAttr(targetCoreIds)); + for (auto [sortedIndex, entry] : llvm::enumerate(sortedEntries)) + entry.op.getOutput().replaceAllUsesWith(compactReceive.getResult(sortedIndex)); + for (auto op : run) + rewriter.eraseOp(op); + + it = compactReceive->getIterator(); + ++it; + continue; + } + } + + auto sendOp = dyn_cast(&*it); + if (sendOp) { + SmallVector run; + Type inputType = sendOp.getInput().getType(); + auto runIt = it; + while (runIt != block.end()) { + auto current = dyn_cast(&*runIt); + if (!current || current.getInput().getType() != inputType) + break; + run.push_back(current); + ++runIt; + } + + if (run.size() > 1) { + struct SendEntry { + spatial::SpatChannelSendOp op; + uint32_t sourceCoreId = 0; + uint32_t targetCoreId = 0; + uint64_t channelId = 0; + }; + SmallVector sortedEntries; + sortedEntries.reserve(run.size()); + for (auto op : run) + sortedEntries.push_back({op, op.getSourceCoreId(), op.getTargetCoreId(), op.getChannelId()}); + llvm::stable_sort(sortedEntries, [](const SendEntry& lhs, const SendEntry& rhs) { + return std::tuple(lhs.sourceCoreId, lhs.targetCoreId, lhs.channelId) + < std::tuple(rhs.sourceCoreId, rhs.targetCoreId, rhs.channelId); + }); + + SmallVector channelIds; + SmallVector sourceCoreIds; + SmallVector targetCoreIds; + SmallVector inputs; + channelIds.reserve(sortedEntries.size()); + sourceCoreIds.reserve(sortedEntries.size()); + targetCoreIds.reserve(sortedEntries.size()); + inputs.reserve(sortedEntries.size()); + for (SendEntry& entry : sortedEntries) { + (void) entry; + channelIds.push_back(nextChannelId++); + sourceCoreIds.push_back(static_cast(entry.sourceCoreId)); + targetCoreIds.push_back(static_cast(entry.targetCoreId)); + inputs.push_back(entry.op.getInput()); + } + + rewriter.setInsertionPoint(run.front()); + spatial::SpatChannelSendManyOp::create(rewriter, + run.front().getLoc(), + rewriter.getDenseI64ArrayAttr(channelIds), + rewriter.getDenseI32ArrayAttr(sourceCoreIds), + rewriter.getDenseI32ArrayAttr(targetCoreIds), + ValueRange(inputs)); + for (auto op : run) + rewriter.eraseOp(op); + + it = runIt; + continue; + } + } + + ++it; + } + } +} + +void rebatchEquivalentComputes(func::FuncOp funcOp, int64_t& nextChannelId) { + IRRewriter rewriter(funcOp.getContext()); + SmallVector computes(funcOp.getOps()); + DenseSet consumed; + + for (size_t index = 0; index < computes.size(); ++index) { + auto anchor = computes[index]; + if (consumed.contains(anchor)) + continue; + if (anchor.getInputs().size() > 1) + continue; + + SmallVector group {anchor}; + if (!anchor.getResults().empty()) + continue; + for (size_t candidateIndex = index + 1; candidateIndex < computes.size(); ++candidateIndex) { + auto candidate = computes[candidateIndex]; + if (consumed.contains(candidate)) + continue; + if (candidate.getInputs().size() > 1) + continue; + if (!candidate.getResults().empty()) + continue; + if (areEquivalentForRebatch(anchor, candidate)) + group.push_back(candidate); + } + + if (group.size() <= 1) + continue; + + auto insertionAnchor = group.front(); + if (llvm::all_of(group, [](SpatCompute compute) { return getComputeCoreId(compute).has_value(); })) { + llvm::stable_sort( + group, [](SpatCompute lhs, SpatCompute rhs) { return *getComputeCoreId(lhs) < *getComputeCoreId(rhs); }); + } + + SmallVector weights; + weights.reserve(group.size() * anchor.getWeights().size()); + SmallVector inputs; + inputs.reserve(group.size() * anchor.getInputs().size()); + SmallVector coreIds; + coreIds.reserve(group.size()); + bool haveAllCoreIds = true; + for (auto compute : group) { + llvm::append_range(weights, compute.getWeights()); + llvm::append_range(inputs, compute.getInputs()); + auto coreIdAttr = compute->getAttrOfType(onnx_mlir::kCoreIdAttrName); + if (!coreIdAttr) + haveAllCoreIds = false; + else if (haveAllCoreIds) + coreIds.push_back(static_cast(coreIdAttr.getInt())); + } + + rewriter.setInsertionPoint(insertionAnchor); + auto rebatched = SpatComputeBatch::create(rewriter, + insertionAnchor.getLoc(), + TypeRange {}, + rewriter.getI32IntegerAttr(static_cast(group.size())), + ValueRange(weights), + ValueRange(inputs)); + rebatched.getProperties().setOperandSegmentSizes( + {static_cast(weights.size()), static_cast(inputs.size())}); + if (haveAllCoreIds) + rebatched->setAttr(onnx_mlir::kCoreIdAttrName, rewriter.getDenseI32ArrayAttr(coreIds)); + + SmallVector blockArgTypes; + SmallVector blockArgLocs; + for (BlockArgument arg : anchor.getBody().front().getArguments()) { + blockArgTypes.push_back(arg.getType()); + blockArgLocs.push_back(arg.getLoc()); + } + auto* newBlock = + rewriter.createBlock(&rebatched.getBody(), rebatched.getBody().end(), TypeRange(blockArgTypes), blockArgLocs); + rewriter.setInsertionPointToEnd(newBlock); + + IRMapping mapper; + auto& anchorBlock = anchor.getBody().front(); + for (auto [oldArg, newArg] : llvm::zip(anchorBlock.getArguments(), newBlock->getArguments())) + mapper.map(oldArg, newArg); + auto opIts = llvm::map_to_vector(group, [](SpatCompute compute) { return compute.getBody().front().begin(); }); + for (Operation& anchorOp : anchorBlock) { + if (auto receiveOp = dyn_cast(&anchorOp)) { + struct BatchReceiveEntry { + uint64_t channelId = 0; + uint32_t sourceCoreId = 0; + uint32_t targetCoreId = 0; + }; + SmallVector entries; + entries.reserve(group.size()); + for (auto [groupIndex, compute] : llvm::enumerate(group)) { + auto groupReceive = cast(&*opIts[groupIndex]); + entries.push_back( + {groupReceive.getChannelId(), groupReceive.getSourceCoreId(), groupReceive.getTargetCoreId()}); + ++opIts[groupIndex]; + } + llvm::stable_sort(entries, [](const BatchReceiveEntry& lhs, const BatchReceiveEntry& rhs) { + return std::tuple(lhs.sourceCoreId, lhs.targetCoreId, lhs.channelId) + < std::tuple(rhs.sourceCoreId, rhs.targetCoreId, rhs.channelId); + }); + + SmallVector channelIds; + SmallVector sourceCoreIds; + SmallVector targetCoreIds; + channelIds.reserve(group.size()); + sourceCoreIds.reserve(group.size()); + targetCoreIds.reserve(group.size()); + for (const BatchReceiveEntry& entry : entries) { + (void) entry; + channelIds.push_back(nextChannelId++); + sourceCoreIds.push_back(static_cast(entry.sourceCoreId)); + targetCoreIds.push_back(static_cast(entry.targetCoreId)); + } + auto batchReceive = spatial::SpatChannelReceiveBatchOp::create(rewriter, + receiveOp.getLoc(), + receiveOp.getOutput().getType(), + rewriter.getDenseI64ArrayAttr(channelIds), + rewriter.getDenseI32ArrayAttr(sourceCoreIds), + rewriter.getDenseI32ArrayAttr(targetCoreIds)); + mapper.map(receiveOp.getOutput(), batchReceive.getOutput()); + continue; + } + + if (auto sendOp = dyn_cast(&anchorOp)) { + struct BatchSendEntry { + uint64_t channelId = 0; + uint32_t sourceCoreId = 0; + uint32_t targetCoreId = 0; + }; + SmallVector entries; + entries.reserve(group.size()); + for (auto [groupIndex, compute] : llvm::enumerate(group)) { + auto groupSend = cast(&*opIts[groupIndex]); + entries.push_back({groupSend.getChannelId(), groupSend.getSourceCoreId(), groupSend.getTargetCoreId()}); + ++opIts[groupIndex]; + } + llvm::stable_sort(entries, [](const BatchSendEntry& lhs, const BatchSendEntry& rhs) { + return std::tuple(lhs.sourceCoreId, lhs.targetCoreId, lhs.channelId) + < std::tuple(rhs.sourceCoreId, rhs.targetCoreId, rhs.channelId); + }); + + SmallVector channelIds; + SmallVector sourceCoreIds; + SmallVector targetCoreIds; + channelIds.reserve(group.size()); + sourceCoreIds.reserve(group.size()); + targetCoreIds.reserve(group.size()); + for (const BatchSendEntry& entry : entries) { + (void) entry; + channelIds.push_back(nextChannelId++); + sourceCoreIds.push_back(static_cast(entry.sourceCoreId)); + targetCoreIds.push_back(static_cast(entry.targetCoreId)); + } + spatial::SpatChannelSendBatchOp::create(rewriter, + sendOp.getLoc(), + rewriter.getDenseI64ArrayAttr(channelIds), + rewriter.getDenseI32ArrayAttr(sourceCoreIds), + rewriter.getDenseI32ArrayAttr(targetCoreIds), + mapper.lookup(sendOp.getInput())); + continue; + } + + if (isa(anchorOp)) { + for (auto& opIt : opIts) + ++opIt; + spatial::SpatYieldOp::create(rewriter, anchorOp.getLoc(), ValueRange {}); + continue; + } + + Operation* cloned = rewriter.clone(anchorOp, mapper); + for (auto [originalResult, clonedResult] : llvm::zip(anchorOp.getResults(), cloned->getResults())) + mapper.map(originalResult, clonedResult); + for (auto& opIt : opIts) + ++opIt; + } + + for (auto compute : group) { + consumed.insert(compute); + rewriter.eraseOp(compute); + } + } +} + struct ComputeMotifInfo { uint64_t instructionCount = 0; uint64_t weightedMvmCount = 0; @@ -164,206 +955,6 @@ void mergeTriviallyConnectedComputes(func::FuncOp funcOp) { } } -struct WeightedVmmBandCandidate { - Operation* parent; - SpatCompute compute; -}; - -bool isSingleWeightedVmmCompute(SpatCompute compute) { - if (compute.getNumResults() != 1 || compute.getWeights().size() != 1 || compute.getInputs().size() != 1) - return false; - - uint64_t weightedVmmCount = 0; - for (Operation& op : compute.getBody().front()) { - if (isa(&op)) - return false; - if (isa(&op)) - weightedVmmCount++; - } - return weightedVmmCount == 1; -} - -std::optional getWeightedVmmBandCandidate(SpatCompute compute) { - if (!isSingleWeightedVmmCompute(compute) || !compute->hasOneUse()) - return std::nullopt; - - auto& use = *compute->getUses().begin(); - auto child = dyn_cast(use.getOwner()); - if (!child || use.getOperandNumber() < child.getWeights().size()) - return std::nullopt; - - auto parent = getOriginalSpatCompute(compute.getInputs().front().getDefiningOp()); - if (!parent || parent == child) - return std::nullopt; - - return WeightedVmmBandCandidate {parent.getOperation(), compute}; -} - -bool haveSameWeightedVmmBandShape(SpatCompute lhs, SpatCompute rhs) { - return lhs.getWeights().front().getType() == rhs.getWeights().front().getType() - && lhs.getInputs().front().getType() == rhs.getInputs().front().getType() - && lhs.getResult(0).getType() == rhs.getResult(0).getType(); -} - -SpatCompute packWeightedVmmComputes(func::FuncOp funcOp, ArrayRef computes) { - assert(!computes.empty() && "expected at least one compute to pack"); - - IRRewriter rewriter(funcOp->getContext()); - SpatCompute child = cast((*computes.front()->getUses().begin()).getOwner()); - rewriter.setInsertionPoint(child); - - SmallVector operands; - SmallVector inputTypes; - SmallVector inputLocs; - SmallVector resultTypes; - operands.reserve(computes.size() * 2); - inputTypes.reserve(computes.size()); - inputLocs.reserve(computes.size()); - resultTypes.reserve(computes.size()); - - for (SpatCompute compute : computes) - for (Value weight : compute.getWeights()) - operands.push_back(weight); - for (SpatCompute compute : computes) { - for (Value input : compute.getInputs()) { - operands.push_back(input); - inputTypes.push_back(input.getType()); - inputLocs.push_back(input.getLoc()); - } - for (Type resultType : compute.getResultTypes()) - resultTypes.push_back(resultType); - } - - auto packedCompute = SpatCompute::create(rewriter, funcOp.getLoc(), resultTypes, operands); - packedCompute.getProperties().setOperandSegmentSizes( - {static_cast(computes.size()), static_cast(inputTypes.size())}); - auto* block = rewriter.createBlock(&packedCompute.getBody(), packedCompute.getBody().end(), inputTypes, inputLocs); - rewriter.setInsertionPointToEnd(block); - - SmallVector yieldValues; - yieldValues.reserve(resultTypes.size()); - size_t inputBaseIndex = 0; - size_t weightBaseIndex = 0; - for (SpatCompute compute : computes) { - IRMapping mapper; - for (auto [weightIndex, weight] : llvm::enumerate(compute.getWeights())) - mapper.map(weight, *std::next(packedCompute.getWeights().begin(), weightBaseIndex + weightIndex)); - for (auto [inputIndex, bbArg] : llvm::enumerate(compute.getBody().front().getArguments())) - mapper.map(bbArg, block->getArgument(inputBaseIndex + inputIndex)); - - auto remapWeightIndex = [&](auto weightedOp) { - weightedOp.setWeightIndex(weightBaseIndex + weightedOp.getWeightIndex()); - }; - for (Operation& op : compute.getBody().front()) { - if (auto yield = dyn_cast(&op)) { - for (Value yieldOperand : yield.getOperands()) - yieldValues.push_back(mapper.lookup(yieldOperand)); - continue; - } - - Operation* cloned = rewriter.clone(op, mapper); - if (auto weightedMvmOp = dyn_cast(cloned)) - remapWeightIndex(weightedMvmOp); - if (auto weightedVmmOp = dyn_cast(cloned)) - remapWeightIndex(weightedVmmOp); - } - - inputBaseIndex += compute.getInputs().size(); - weightBaseIndex += compute.getWeights().size(); - } - spatial::SpatYieldOp::create(rewriter, funcOp.getLoc(), yieldValues); - - size_t resultIndex = 0; - for (SpatCompute compute : computes) - for (OpResult result : compute->getResults()) - result.replaceAllUsesWith(packedCompute.getResult(resultIndex++)); - for (SpatCompute compute : llvm::reverse(computes)) - compute.erase(); - - return packedCompute; -} - -size_t getFastPathCpuBudget() { - constexpr size_t kDefaultMaxCpuCount = 1000; - if (coresCount.getValue() > 0) - return static_cast(coresCount.getValue()); - return kDefaultMaxCpuCount; -} - -size_t packWideWeightedVmmBands(func::FuncOp funcOp) { - constexpr size_t kMinBandSize = 64; - size_t cpuBudget = std::max(1, getFastPathCpuBudget()); - SmallVector candidates; - for (SpatCompute compute : funcOp.getOps()) - if (auto candidate = getWeightedVmmBandCandidate(compute)) - candidates.push_back(*candidate); - - llvm::stable_sort(candidates, [](const WeightedVmmBandCandidate& lhs, const WeightedVmmBandCandidate& rhs) { - if (lhs.parent != rhs.parent) - return lhs.parent < rhs.parent; - return lhs.compute->isBeforeInBlock(rhs.compute); - }); - - size_t packedBandCount = 0; - size_t packedNodeCount = 0; - size_t createdNodeCount = 0; - size_t minChunkSize = std::numeric_limits::max(); - size_t maxChunkSize = 0; - SmallVector> shapeGroups; - for (size_t parentBegin = 0; parentBegin < candidates.size();) { - size_t parentEnd = parentBegin + 1; - while (parentEnd < candidates.size() && candidates[parentEnd].parent == candidates[parentBegin].parent) - parentEnd++; - - shapeGroups.clear(); - for (size_t index = parentBegin; index < parentEnd; ++index) { - SpatCompute compute = candidates[index].compute; - auto groupIt = llvm::find_if(shapeGroups, [&](const SmallVector& group) { - return haveSameWeightedVmmBandShape(group.front(), compute); - }); - if (groupIt == shapeGroups.end()) - shapeGroups.push_back({compute}); - else - groupIt->push_back(compute); - } - - for (ArrayRef band : shapeGroups) { - size_t bandSize = band.size(); - if (bandSize < kMinBandSize || bandSize <= cpuBudget) - continue; - - size_t chunkSize = (bandSize + cpuBudget - 1) / cpuBudget; - packedBandCount++; - SmallVector chunk; - chunk.reserve(std::min(chunkSize, bandSize)); - for (auto [index, compute] : llvm::enumerate(band)) { - chunk.push_back(compute); - if (chunk.size() == chunkSize || index + 1 == band.size()) { - minChunkSize = std::min(minChunkSize, chunk.size()); - maxChunkSize = std::max(maxChunkSize, chunk.size()); - packWeightedVmmComputes(funcOp, chunk); - packedNodeCount += chunk.size(); - createdNodeCount++; - chunk.clear(); - } - } - } - parentBegin = parentEnd; - } - - if (packedNodeCount != 0) - llvm::errs() << llvm::formatv("[DCP-FASTPATH] wvmmBands={0} packedNodes={1} createdNodes={2} changed={3} " - "cpuBudget={4} chunkSizeRange={5}-{6}\n", - packedBandCount, - packedNodeCount, - createdNodeCount, - packedNodeCount - createdNodeCount, - cpuBudget, - minChunkSize, - maxChunkSize); - return packedNodeCount - createdNodeCount; -} - void emitMotifProfile(func::FuncOp funcOp) { if (!std::getenv("DCP_MOTIF_PROFILE")) return; @@ -577,7 +1168,7 @@ void emitMotifProfile(func::FuncOp funcOp) { } } -void generateReport(func::FuncOp funcOp, const std::string& name) { +void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpuCount = 0) { std::string outputDir = getOutputDir(); if (outputDir.empty()) return; @@ -588,58 +1179,95 @@ void generateReport(func::FuncOp funcOp, const std::string& name) { std::fstream file(dialectsDir + "/" + name + ".txt", std::ios::out); llvm::raw_os_ostream os(file); - uint64_t numSpatCompute = 0; - std::vector> collectedData; + struct ReportRow { + uint64_t opId = 0; + uint64_t logicalComputeCount = 0; + uint64_t weightCount = 0; + uint64_t instructionCount = 0; + bool isRebatched = false; + }; - for (auto spatCompute : funcOp.getOps()) { - uint64_t numInst = 0; - for (auto& _ : spatCompute.getRegion().front()) - numInst++; - collectedData.push_back({numSpatCompute++, spatCompute.getWeights().size(), numInst}); + uint64_t totalComputeOps = 0; + uint64_t totalLogicalComputes = 0; + uint64_t totalBatchComputeOps = 0; + uint64_t totalMultiLaneBatchComputeOps = 0; + std::vector collectedData; + + for (Operation& op : funcOp.getBody().front()) { + if (auto spatCompute = dyn_cast(&op)) { + uint64_t numInst = 0; + for (auto& _ : spatCompute.getRegion().front()) + numInst++; + collectedData.push_back({totalComputeOps++, 1, spatCompute.getWeights().size(), numInst, false}); + totalLogicalComputes += 1; + continue; + } + if (auto batch = dyn_cast(&op)) { + uint64_t numInst = 0; + for (auto& _ : batch.getRegion().front()) + numInst++; + uint64_t logicalCount = static_cast(batch.getLaneCount()); + collectedData.push_back({totalComputeOps++, logicalCount, batch.getWeights().size(), numInst, true}); + totalLogicalComputes += logicalCount; + totalBatchComputeOps += 1; + if (batch.getLaneCount() > 1) + totalMultiLaneBatchComputeOps += 1; + } } - std::stable_sort(collectedData.begin(), - collectedData.end(), - [](std::tuple lft, std::tuple rgt) { - auto [iLft, weightLft, numInstLft] = lft; - auto [iRgt, weightRgt, numInstRgt] = rgt; + os << "Used CPUs: " << usedCpuCount << "\n"; + os << "Number of top-level compute ops: " << totalComputeOps << "\n"; + os << "Number of logical computes: " << totalLogicalComputes << "\n"; + os << "Number of top-level batch compute ops: " << totalBatchComputeOps << "\n"; + os << "Number of top-level multi-lane batch compute ops: " << totalMultiLaneBatchComputeOps << "\n\n"; - if (numInstLft < numInstRgt) - return false; - else if (numInstRgt < numInstLft) - return true; + std::stable_sort(collectedData.begin(), collectedData.end(), [](const ReportRow& lft, const ReportRow& rgt) { + if (lft.isRebatched != rgt.isRebatched) + return lft.isRebatched > rgt.isRebatched; - if (weightLft < weightRgt) - return false; - else if (weightRgt < weightLft) - return true; + if (lft.instructionCount < rgt.instructionCount) + return false; + else if (rgt.instructionCount < lft.instructionCount) + return true; - if (iLft < iRgt) - return true; - else if (iRgt < iLft) - return false; + if (lft.weightCount < rgt.weightCount) + return false; + else if (rgt.weightCount < lft.weightCount) + return true; - return true; - }); + if (lft.logicalComputeCount < rgt.logicalComputeCount) + return false; + else if (rgt.logicalComputeCount < lft.logicalComputeCount) + return true; - for (uint64_t cI = 0; cI < numSpatCompute; ++cI) { + if (lft.opId < rgt.opId) + return true; + else if (rgt.opId < lft.opId) + return false; + + return true; + }); + + for (uint64_t cI = 0; cI < totalComputeOps; ++cI) { uint64_t lastIndex = cI; - auto [currentComputeId, currentWeight, currentNumInst] = collectedData[cI]; + ReportRow current = collectedData[cI]; - for (uint64_t nI = cI + 1; nI < numSpatCompute; ++nI) { - auto [nextComputeId, nextWeight, nextNumInst] = collectedData[nI]; - if (currentWeight == nextWeight && currentNumInst == nextNumInst) + for (uint64_t nI = cI + 1; nI < totalComputeOps; ++nI) { + ReportRow next = collectedData[nI]; + if (current.isRebatched == next.isRebatched && current.weightCount == next.weightCount + && current.instructionCount == next.instructionCount + && current.logicalComputeCount == next.logicalComputeCount) lastIndex = nI; else break; } - os << "Compute " << currentComputeId; - auto expectedPrintedValue = currentComputeId + 1; + os << (current.isRebatched ? "Batch " : "Compute ") << current.opId; + auto expectedPrintedValue = current.opId + 1; bool rangePrinted = false; cI++; for (; cI <= lastIndex; ++cI) { - auto candidateToPrint = std::get<0>(collectedData[cI]); + auto candidateToPrint = collectedData[cI].opId; if (candidateToPrint == expectedPrintedValue) { expectedPrintedValue = candidateToPrint + 1; rangePrinted = true; @@ -652,12 +1280,13 @@ void generateReport(func::FuncOp funcOp, const std::string& name) { expectedPrintedValue = candidateToPrint + 1; } } - if (rangePrinted && currentComputeId != expectedPrintedValue - 1) + if (rangePrinted && current.opId != expectedPrintedValue - 1) os << " - " << expectedPrintedValue - 1; os << " :\n"; - os << "\tNumber of instructions " << currentNumInst << "\n"; - os << "\tNumber of used crossbars " << currentWeight << "\n"; + os << "\tNumber of logical computes " << current.logicalComputeCount << "\n"; + os << "\tNumber of instructions " << current.instructionCount << "\n"; + os << "\tNumber of used crossbars " << current.weightCount << "\n"; cI = lastIndex; } @@ -667,62 +1296,64 @@ void generateReport(func::FuncOp funcOp, const std::string& name) { struct ComputeValueResults { SmallVector innerValues; + SmallVector outerValues; - Value get(size_t resultIndex) const { + Value getInner(size_t resultIndex) const { assert(resultIndex < innerValues.size() && "compute result index out of range"); return innerValues[resultIndex]; } + + Value getOuter(size_t resultIndex) const { + assert(resultIndex < outerValues.size() && "compute result index out of range"); + return outerValues[resultIndex]; + } }; class LazyInsertComputeResult { using InsertPoint = mlir::IRRewriter::InsertPoint; - ComputeValueResults computeResults; - bool onlyChannel; - std::function>(size_t)> channelNewInserter; public: - LazyInsertComputeResult(ComputeValueResults computeValueResults, - std::function>(size_t)> channelNewInserter, - bool isOnlyChannel) - : computeResults(computeValueResults), onlyChannel(isOnlyChannel), channelNewInserter(channelNewInserter) {} + struct ChannelInfo { + int64_t channelId = -1; + int32_t sourceCoreId = -1; + int32_t targetCoreId = -1; + }; + + LazyInsertComputeResult( + ComputeValueResults computeValueResults, + std::function>(size_t, size_t)> channelInserter) + : computeResults(computeValueResults), channelInserter(channelInserter) {} struct ChannelOrLocalOp { Value data; bool isChannel; + ChannelInfo channelInfo; }; - bool onlyChanneled() const { return onlyChannel; } - - ChannelOrLocalOp getAsChannelValueAndInsertSender(SpatCompute currentCompute, size_t resultIndex) { - Value innerValue = computeResults.get(resultIndex); - - auto [channelValue, channelSendInserter] = channelNewInserter(resultIndex); + ChannelOrLocalOp getAsChannelValueAndInsertSender(size_t resultIndex, size_t targetCpu) { + Value innerValue = computeResults.getInner(resultIndex); + auto [channelInfo, channelSendInserter] = channelInserter(resultIndex, targetCpu); InsertPoint sendInsertPoint; auto* block = innerValue.getParentBlock(); if (!block->empty() && isa(block->back())) sendInsertPoint = InsertPoint(block, --block->end()); else sendInsertPoint = InsertPoint(block, block->end()); - if (currentCompute) { - for (auto& block : currentCompute.getBody()) - if (&block == sendInsertPoint.getBlock()) - return {innerValue, false}; - } channelSendInserter(sendInsertPoint); - return {channelValue, true}; + return {innerValue, true, channelInfo}; } - ChannelOrLocalOp getAsChannelValueAndInsertSender(size_t resultIndex) { - return getAsChannelValueAndInsertSender({}, resultIndex); - } +private: + ComputeValueResults computeResults; + std::function>(size_t, size_t)> channelInserter; }; struct MergeComputeNodesPass : PassWrapper> { private: - DenseMap newComputeNodeResults; - DenseMap oldToNewComputeMap; - DenseMap cpuToNewComputeMap; + DenseMap newComputeNodeResults; + DenseMap oldToNewOpMap; + int64_t nextChannelId = 0; public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MergeComputeNodesPass) @@ -737,262 +1368,289 @@ public: void runOnOperation() override { mergeTriviallyConnectedComputes(getOperation()); - packWideWeightedVmmBands(getOperation()); emitMotifProfile(getOperation()); DCPAnalysisResult& analysisResult = getAnalysis().getResult(); - auto& lastComputeOfCpu = analysisResult.isLastComputeOfCpu; - auto& cpuToLastComputeMap = analysisResult.cpuToLastComputeMap; + DenseSet materializedInstances; + for (size_t index = 0; index < analysisResult.dominanceOrderCompute.size(); ++index) { + ComputeInstance currentInstance = analysisResult.dominanceOrderCompute[index]; + if (!materializedInstances.insert(currentInstance).second) + continue; - for (auto currentComputeNode : analysisResult.dominanceOrderCompute) { - size_t cpu = analysisResult.computeToCpuMap.at(currentComputeNode); - if (!cpuToNewComputeMap.contains(cpu)) { - ValueTypeRange newComputeType = cpuToLastComputeMap.at(cpu).getResultTypes(); - auto [newCompute, computeValueResult] = - createNewComputeNode(currentComputeNode, newComputeType, lastComputeOfCpu.contains(currentComputeNode)); - cpuToNewComputeMap[cpu] = newCompute; - newComputeNodeResults.insert(std::make_pair( - currentComputeNode, - createLazyComputeResult(newCompute, computeValueResult, lastComputeOfCpu.contains(currentComputeNode)))); - } - else { - auto [newCompute, computeValueResult] = mergeIntoComputeNode( - cpuToNewComputeMap[cpu], currentComputeNode, lastComputeOfCpu.contains(currentComputeNode)); - newComputeNodeResults.insert(std::make_pair( - currentComputeNode, - createLazyComputeResult(newCompute, computeValueResult, lastComputeOfCpu.contains(currentComputeNode)))); + size_t cpu = analysisResult.computeToCpuMap.at(currentInstance); + if (auto batch = dyn_cast(currentInstance.op)) { + createNewBatchCompute(batch, currentInstance.laneStart, currentInstance.laneCount, cpu, analysisResult); + continue; } + + auto scalarCompute = cast(currentInstance.op); + auto [newCompute, computeValueResults] = createNewComputeNode(scalarCompute, cpu, analysisResult); + newComputeNodeResults.insert({currentInstance, createLazyComputeResult(newCompute, computeValueResults, cpu)}); } - for (auto computeNodeToRemove : llvm::make_early_inc_range(llvm::reverse(analysisResult.dominanceOrderCompute))) { - for (auto users : computeNodeToRemove->getUsers()) - users->dump(); - computeNodeToRemove.erase(); + DenseSet toEraseSet; + for (ComputeInstance instance : analysisResult.dominanceOrderCompute) + toEraseSet.insert(instance.op); + + DenseSet externalUsersToMove; + auto collectExternalUsers = [&](Operation* op, auto&& collectExternalUsers) -> void { + if (!externalUsersToMove.insert(op).second) + return; + for (Value result : op->getResults()) { + for (Operation* user : result.getUsers()) { + if (toEraseSet.contains(user) || isa(user)) + continue; + collectExternalUsers(user, collectExternalUsers); + } + } + }; + + DenseSet erasedOps; + for (ComputeInstance instance : llvm::reverse(analysisResult.dominanceOrderCompute)) { + if (!erasedOps.insert(instance.op).second) + continue; + Operation* oldOp = instance.op; + if (Operation* newOp = oldToNewOpMap.lookup(oldOp)) { + for (unsigned i = 0; i < oldOp->getNumResults(); ++i) { + for (auto& use : llvm::make_early_inc_range(oldOp->getResult(i).getUses())) { + Operation* useOwner = use.getOwner(); + if (!toEraseSet.contains(useOwner)) { + use.assign(newOp->getResult(i)); + if (!isa(useOwner) && useOwner->isBeforeInBlock(newOp)) + collectExternalUsers(useOwner, collectExternalUsers); + } + } + } + } + oldOp->erase(); } + func::FuncOp func = getOperation(); + auto returnOp = cast(func.getBody().front().getTerminator()); + SmallVector orderedUsersToMove; + for (Operation& op : func.getBody().front()) { + if (&op == returnOp.getOperation()) + break; + if (externalUsersToMove.contains(&op)) + orderedUsersToMove.push_back(&op); + } + for (Operation* op : orderedUsersToMove) + op->moveBefore(returnOp); + + sinkChannelsIntoComputes(func, nextChannelId); + rebatchEquivalentComputes(func, nextChannelId); + compactScalarChannelRuns(func, nextChannelId); dumpModule(cast(func->getParentOp()), "spatial1_dcp_merged"); - generateReport(func, "spatial1_dcp_merged_report"); + generateReport(func, "spatial1_dcp_merged_report", analysisResult.cpuToLastComputeMap.size()); } private: std::pair - createNewComputeNode(SpatCompute oldCompute, ValueTypeRange newComputeType, bool lastCompute) { + createNewComputeNode(SpatCompute oldCompute, size_t currentCpu, const DCPAnalysisResult& analysisResult) { func::FuncOp func = getOperation(); auto loc = func.getLoc(); IRRewriter rewriter(&getContext()); rewriter.setInsertionPoint(&*std::prev(func.getBody().front().end(), 1)); ComputeValueResults computeValueResults; - IRMapping mapper; - llvm::SmallVector newComputeOperand; - llvm::SmallVector newBBOperandType; - llvm::SmallVector newBBLocations; - for (auto arg : oldCompute.getWeights()) - newComputeOperand.push_back(arg); + SmallVector newComputeOperands; + SmallVector newBBOperandTypes; + SmallVector newBBLocations; + newComputeOperands.reserve(oldCompute.getNumOperands()); + newBBOperandTypes.reserve(oldCompute.getInputs().size()); + newBBLocations.reserve(oldCompute.getInputs().size()); - for (auto arg : oldCompute.getInputs()) - if (!llvm::isa_and_present(arg.getDefiningOp())) { - newComputeOperand.push_back(arg); - newBBOperandType.push_back(arg.getType()); - newBBLocations.push_back(loc); + for (Value weight : oldCompute.getWeights()) + newComputeOperands.push_back(weight); + + for (Value input : oldCompute.getInputs()) { + Value resolvedInput = input; + if (auto producerRef = getProducerValueRef(input)) { + LazyInsertComputeResult& producer = newComputeNodeResults.at(producerRef->instance); + auto [channelVal, isChannel, channelInfo] = producer.getAsChannelValueAndInsertSender(producerRef->resultIndex, currentCpu); + (void) isChannel; + (void) channelVal; + resolvedInput = spatial::SpatChannelReceiveOp::create(rewriter, + loc, + input.getType(), + rewriter.getI64IntegerAttr(channelInfo.channelId), + rewriter.getI32IntegerAttr(channelInfo.sourceCoreId), + rewriter.getI32IntegerAttr(channelInfo.targetCoreId)) + .getResult(); } - auto newCompute = SpatCompute::create(rewriter, loc, newComputeType, newComputeOperand); - - rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), newBBOperandType, newBBLocations); - newCompute.getProperties().setOperandSegmentSizes( - {(int) oldCompute.getWeights().size(), (int) newBBOperandType.size()}); - - auto& newBB = newCompute.getBody().front(); - auto& oldBB = oldCompute.getBody().front(); - rewriter.setInsertionPointToEnd(&newBB); - - int indexNew = 0; - size_t indexOld = oldCompute.getWeights().size(); - size_t indexOldStart = oldCompute.getWeights().size(); - for (; indexOld < oldCompute.getNumOperands(); ++indexOld) { - if (!llvm::isa_and_present(oldCompute.getOperand(indexOld).getDefiningOp())) { - mapper.map(oldBB.getArgument(indexOld - indexOldStart), newBB.getArgument(indexNew++)); - } - else { - auto argWeightCompute = llvm::dyn_cast_if_present(oldCompute.getOperand(indexOld).getDefiningOp()); - auto argResultIndex = cast(oldCompute.getOperand(indexOld)).getResultNumber(); - - LazyInsertComputeResult& lazyArgWeight = newComputeNodeResults.at(argWeightCompute); - auto [channelVal, isChannel] = lazyArgWeight.getAsChannelValueAndInsertSender(argResultIndex); - assert(isChannel == true); - spatial::SpatChannelReceiveOp receiveOp = - spatial::SpatChannelReceiveOp::create(rewriter, loc, oldCompute.getOperand(indexOld).getType(), channelVal); - mapper.map(oldBB.getArgument(indexOld - indexOldStart), receiveOp); - } + newComputeOperands.push_back(resolvedInput); + newBBOperandTypes.push_back(resolvedInput.getType()); + newBBLocations.push_back(loc); } - for (auto& op : oldCompute.getOps()) { + auto newCompute = SpatCompute::create(rewriter, loc, oldCompute.getResultTypes(), newComputeOperands); + newCompute.getProperties().setOperandSegmentSizes( + {static_cast(oldCompute.getWeights().size()), static_cast(oldCompute.getInputs().size())}); + newCompute->setAttr(onnx_mlir::kCoreIdAttrName, rewriter.getI32IntegerAttr(getPhysicalCoreId(currentCpu))); + auto* newBB = + rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), newBBOperandTypes, newBBLocations); + + auto& oldBB = oldCompute.getBody().front(); + for (auto [argIndex, oldArg] : llvm::enumerate(oldBB.getArguments())) + mapper.map(oldArg, newBB->getArgument(argIndex)); + + rewriter.setInsertionPointToEnd(newBB); + for (Operation& op : oldBB) { if (auto yield = dyn_cast(&op)) { - computeValueResults.innerValues.reserve(yield.getNumOperands()); for (Value yieldOperand : yield.getOperands()) computeValueResults.innerValues.push_back(mapper.lookup(yieldOperand)); - if (lastCompute) - rewriter.clone(op, mapper); - } - else rewriter.clone(op, mapper); + continue; + } + rewriter.clone(op, mapper); } - for (auto& use : llvm::make_early_inc_range(oldCompute->getUses())) - if (isa(use.getOwner())) { - auto resultIndex = cast(use.get()).getResultNumber(); - use.assign(newCompute.getResult(resultIndex)); - } + computeValueResults.outerValues.assign(newCompute->result_begin(), newCompute->result_end()); + if (computeValueResults.innerValues.empty()) + computeValueResults.innerValues = computeValueResults.outerValues; - oldToNewComputeMap.insert({oldCompute, newCompute}); - return {cast(newCompute), computeValueResults}; + for (auto& use : llvm::make_early_inc_range(oldCompute->getUses())) + if (isa(use.getOwner())) + use.assign(newCompute.getResult(cast(use.get()).getResultNumber())); + + oldToNewOpMap[oldCompute.getOperation()] = newCompute.getOperation(); + return {newCompute, computeValueResults}; } - std::pair - mergeIntoComputeNode(SpatCompute toCompute, SpatCompute fromCompute, bool lastCompute) { + void createNewBatchCompute(SpatComputeBatch batch, + uint32_t firstLane, + uint32_t laneCount, + size_t currentCpu, + const DCPAnalysisResult& analysisResult) { func::FuncOp func = getOperation(); auto loc = func.getLoc(); IRRewriter rewriter(&getContext()); + rewriter.setInsertionPoint(&*std::prev(func.getBody().front().end(), 1)); + + SmallVector weights; + SmallVector inputs; + SmallVector resultTypes; + weights.reserve(laneCount); + inputs.reserve(laneCount); + resultTypes.reserve(laneCount); + + for (uint32_t lane = firstLane; lane < firstLane + laneCount; ++lane) { + weights.push_back(batch.getWeights()[lane]); + resultTypes.push_back(batch.getOutputs()[lane].getType()); + + Value input = batch.getInputs()[lane]; + Value resolvedInput = input; + if (auto producerRef = getProducerValueRef(input)) { + LazyInsertComputeResult& producer = newComputeNodeResults.at(producerRef->instance); + auto [channelVal, isChannel, channelInfo] = producer.getAsChannelValueAndInsertSender(producerRef->resultIndex, currentCpu); + (void) isChannel; + (void) channelVal; + resolvedInput = spatial::SpatChannelReceiveOp::create(rewriter, + loc, + input.getType(), + rewriter.getI64IntegerAttr(channelInfo.channelId), + rewriter.getI32IntegerAttr(channelInfo.sourceCoreId), + rewriter.getI32IntegerAttr(channelInfo.targetCoreId)) + .getResult(); + } + inputs.push_back(resolvedInput); + } + + Block& templateBlock = batch.getBody().front(); + if (laneCount == 1) { + auto compute = + SpatCompute::create(rewriter, loc, TypeRange(resultTypes), ValueRange(weights), ValueRange(inputs)); + compute.getProperties().setOperandSegmentSizes( + {static_cast(weights.size()), static_cast(inputs.size())}); + compute->setAttr(onnx_mlir::kCoreIdAttrName, rewriter.getI32IntegerAttr(getPhysicalCoreId(currentCpu))); + + auto* newBlock = rewriter.createBlock( + &compute.getBody(), compute.getBody().end(), TypeRange {templateBlock.getArgument(0).getType()}, {loc}); + IRMapping mapper; + mapper.map(templateBlock.getArgument(0), newBlock->getArgument(0)); + rewriter.setInsertionPointToEnd(newBlock); + for (Operation& op : templateBlock) + rewriter.clone(op, mapper); + + ComputeValueResults results; + results.outerValues.assign(compute->result_begin(), compute->result_end()); + results.innerValues = results.outerValues; + newComputeNodeResults.insert({ + ComputeInstance {batch.getOperation(), firstLane, laneCount}, + createLazyComputeResult(compute, results, currentCpu) + }); + return; + } + + auto rebatched = SpatComputeBatch::create(rewriter, + loc, + TypeRange(resultTypes), + rewriter.getI32IntegerAttr(static_cast(laneCount)), + ValueRange(weights), + ValueRange(inputs)); + rebatched->setAttr(onnx_mlir::kCoreIdAttrName, + rewriter.getDenseI32ArrayAttr(SmallVector(laneCount, getPhysicalCoreId(currentCpu)))); + + auto* newBlock = rewriter.createBlock(&rebatched.getBody(), + rebatched.getBody().end(), + TypeRange {templateBlock.getArgument(0).getType()}, + SmallVector(1, loc)); IRMapping mapper; - - DenseMap> toWeightIndices; - for (auto [weightIndex, weight] : llvm::enumerate(toCompute.getWeights())) - toWeightIndices[weight].push_back(weightIndex); - DenseMap usedFromWeightOccurrences; - SmallVector fromWeightToNewIndex; - fromWeightToNewIndex.reserve(fromCompute.getWeights().size()); - - auto weightMutableIter = toCompute.getWeightsMutable(); - for (auto weight : fromCompute.getWeights()) { - size_t occurrence = usedFromWeightOccurrences[weight]++; - auto& matchingIndices = toWeightIndices[weight]; - if (occurrence >= matchingIndices.size()) { - size_t sizeW = toCompute.getWeights().size(); - size_t sizeI = toCompute.getInputs().size(); - weightMutableIter.append(weight); - auto last = weightMutableIter.end(); - last = std::prev(last, 1); - mapper.map(weight, last->get()); - matchingIndices.push_back(sizeW); - fromWeightToNewIndex.push_back(sizeW); - assert(sizeW + 1 == toCompute.getWeights().size()); - assert(sizeI == toCompute.getInputs().size()); - assert(sizeW + sizeI + 1 == toCompute.getOperands().size()); - } - else { - size_t newIndex = matchingIndices[occurrence]; - mapper.map(weight, *std::next(toCompute.getWeights().begin(), newIndex)); - fromWeightToNewIndex.push_back(newIndex); - } - } - - auto& toBB = toCompute.getBody().front(); - auto& fromBB = fromCompute.getBody().front(); - auto inputArgMutable = toCompute.getInputsMutable(); - // Insert receiveOp - rewriter.setInsertionPointToEnd(&toBB); - for (auto [bbIndex, arg] : llvm::enumerate(fromCompute.getInputs())) { - if (auto argWeightCompute = llvm::dyn_cast_if_present(arg.getDefiningOp())) { - LazyInsertComputeResult& lazyArgWeight = newComputeNodeResults.at(argWeightCompute); - auto argResultIndex = cast(arg).getResultNumber(); - - LazyInsertComputeResult::ChannelOrLocalOp channelOrLocal = - lazyArgWeight.getAsChannelValueAndInsertSender(toCompute, argResultIndex); - if (channelOrLocal.isChannel) { - spatial::SpatChannelReceiveOp receiveOp = - spatial::SpatChannelReceiveOp::create(rewriter, loc, arg.getType(), channelOrLocal.data); - mapper.map(fromBB.getArgument(bbIndex), receiveOp.getResult()); - } - else { - mapper.map(fromBB.getArgument(bbIndex), channelOrLocal.data); - } - } - else { - - auto founded = llvm::find(toCompute.getInputs(), arg); - if (founded == toCompute.getInputs().end()) { - size_t sizeW = toCompute.getWeights().size(); - size_t sizeI = toCompute.getInputs().size(); - inputArgMutable.append(arg); - assert(sizeW == toCompute.getWeights().size()); - assert(sizeI + 1 == toCompute.getInputs().size()); - assert(sizeW + sizeI + 1 == toCompute.getOperands().size()); - - toBB.addArgument(fromBB.getArgument(bbIndex).getType(), loc); - mapper.map(fromBB.getArgument(bbIndex), toBB.getArguments().back()); - } - else { - auto distance = std::distance(toCompute.getInputs().begin(), founded); - mapper.map(fromBB.getArgument(bbIndex), toBB.getArgument(distance)); - } - } - } - - for (auto oldBBarg : fromBB.getArguments()) - assert(mapper.contains(oldBBarg)); - - ComputeValueResults computeValueResults; - auto remapWeightIndex = [&](auto weightedOp) { - auto oldIndex = weightedOp.getWeightIndex(); - assert(static_cast(oldIndex) < fromWeightToNewIndex.size() && "weight index out of range"); - weightedOp.setWeightIndex(fromWeightToNewIndex[oldIndex]); - }; - for (auto& op : fromCompute.getOps()) { + mapper.map(templateBlock.getArgument(0), newBlock->getArgument(0)); + rewriter.setInsertionPointToEnd(newBlock); + for (Operation& op : templateBlock) { if (auto yield = dyn_cast(&op)) { - computeValueResults.innerValues.reserve(yield.getNumOperands()); - for (Value yieldOperand : yield.getOperands()) - computeValueResults.innerValues.push_back(mapper.lookup(yieldOperand)); - if (lastCompute) - rewriter.clone(op, mapper); - } - else { - auto newInst = rewriter.clone(op, mapper); - if (auto weightedMvmOp = llvm::dyn_cast(newInst)) - remapWeightIndex(weightedMvmOp); - if (auto weightedVmmOp = llvm::dyn_cast(newInst)) - remapWeightIndex(weightedVmmOp); + rewriter.clone(op, mapper); + continue; } + rewriter.clone(op, mapper); } - for (auto& use : llvm::make_early_inc_range(fromCompute->getUses())) - if (isa(use.getOwner())) { - auto resultIndex = cast(use.get()).getResultNumber(); - use.assign(toCompute.getResult(resultIndex)); - } - - oldToNewComputeMap.insert({fromCompute, toCompute}); - return {cast(toCompute), computeValueResults}; + auto yieldOp = cast(newBlock->getTerminator()); + ComputeValueResults results; + results.outerValues.assign(rebatched->result_begin(), rebatched->result_end()); + results.innerValues = results.outerValues; + if (results.innerValues.empty()) + results.innerValues.push_back(yieldOp.getOperand(0)); + newComputeNodeResults.insert({ + ComputeInstance {batch.getOperation(), firstLane, laneCount}, + createLazyComputeResult(rebatched, results, currentCpu) + }); } LazyInsertComputeResult - createLazyComputeResult(SpatCompute compute, ComputeValueResults computeValueResults, bool lastCompute) { - func::FuncOp funcOp = cast(compute->getParentOp()); + createLazyComputeResult(Operation* producerOp, ComputeValueResults computeValueResults, size_t producerCpu) { + func::FuncOp funcOp = cast(producerOp->getParentOp()); auto* context = &getContext(); auto loc = funcOp.getLoc(); IRRewriter rewriter(context); - rewriter.setInsertionPointToStart(&funcOp.front()); - auto savedChannelInsertPoint = rewriter.saveInsertionPoint(); - auto insertNew = [savedChannelInsertPoint, context, loc, computeValueResults](size_t resultIndex) { - IRRewriter rewriter(context); - rewriter.restoreInsertionPoint(savedChannelInsertPoint); - auto channelOp = spatial::SpatChannelNewOp::create(rewriter, loc, spatial::SpatChannelType::get(context)); - auto channelVal = channelOp.getResult(); - auto insertVal = - [&context, loc, computeValueResults, channelVal, resultIndex](mlir::IRRewriter::InsertPoint sendInsertPoint) { - IRRewriter rewriter(context); - rewriter.restoreInsertionPoint(sendInsertPoint); - auto spatSend = - spatial::SpatChannelSendOp::create(rewriter, loc, channelVal, computeValueResults.get(resultIndex)); - return spatSend; - }; - std::pair> ret {channelVal, insertVal}; + rewriter.setInsertionPointAfter(producerOp); + auto savedSendInsertPoint = rewriter.saveInsertionPoint(); + auto insertNew = [this, savedSendInsertPoint, context, loc, computeValueResults, producerCpu](size_t resultIndex, + size_t targetCpu) { + auto channelId = nextChannelId++; + LazyInsertComputeResult::ChannelInfo channelInfo { + channelId, getPhysicalCoreId(producerCpu), getPhysicalCoreId(targetCpu)}; + auto insertVal = [&context, loc, computeValueResults, channelInfo, resultIndex, savedSendInsertPoint]( + mlir::IRRewriter::InsertPoint) { + IRRewriter rewriter(context); + rewriter.restoreInsertionPoint(savedSendInsertPoint); + spatial::SpatChannelSendOp::create(rewriter, + loc, + rewriter.getI64IntegerAttr(channelInfo.channelId), + rewriter.getI32IntegerAttr(channelInfo.sourceCoreId), + rewriter.getI32IntegerAttr(channelInfo.targetCoreId), + computeValueResults.getOuter(resultIndex)); + }; + std::pair> ret { + channelInfo, insertVal}; return ret; }; - return LazyInsertComputeResult(computeValueResults, insertNew, false); + return LazyInsertComputeResult(computeValueResults, insertNew); } }; diff --git a/src/PIM/Pass/PimCodegen/VerificationPass.cpp b/src/PIM/Pass/PimCodegen/VerificationPass.cpp index 0e67cc2..3f53abb 100644 --- a/src/PIM/Pass/PimCodegen/VerificationPass.cpp +++ b/src/PIM/Pass/PimCodegen/VerificationPass.cpp @@ -23,8 +23,7 @@ static bool isAddressOnlyHostOp(Operation* op) { memref::SubViewOp, memref::CastOp, memref::CollapseShapeOp, - memref::ExpandShapeOp, - spatial::SpatChannelNewOp>(op); + memref::ExpandShapeOp>(op); } static bool isCodegenAddressableValue(Value value) { @@ -38,6 +37,8 @@ static bool isCodegenAddressableValue(Value value) { static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) { if (isa(op)) return operandIndex == 1; + if (isa(op)) + return operandIndex == 1; if (isa(op)) return operandIndex == 0; return false; @@ -69,6 +70,12 @@ struct VerificationPass : PassWrapper> continue; } + if (auto coreBatchOp = dyn_cast(&op)) { + if (failed(verifyCoreWeights(moduleOp, coreBatchOp)) || failed(verifyCoreOperands(coreBatchOp))) + hasFailure = true; + continue; + } + if (auto returnOp = dyn_cast(&op)) { if (failed(verifyReturnOp(returnOp))) hasFailure = true; @@ -92,10 +99,11 @@ struct VerificationPass : PassWrapper> } private: - static LogicalResult verifyCoreWeights(ModuleOp moduleOp, pim::PimCoreOp coreOp) { + template + static LogicalResult verifyCoreWeights(ModuleOp moduleOp, CoreOpTy coreOp) { bool hasFailure = false; for (auto [weightIndex, weight] : llvm::enumerate(coreOp.getWeights())) { - auto getGlobalOp = weight.getDefiningOp(); + auto getGlobalOp = weight.template getDefiningOp(); if (!getGlobalOp) { coreOp.emitOpError() << "weight #" << weightIndex << " must be materialized as memref.get_global before JSON codegen"; @@ -131,7 +139,8 @@ private: return success(!hasFailure); } - static LogicalResult verifyCoreOperands(pim::PimCoreOp coreOp) { + template + static LogicalResult verifyCoreOperands(CoreOpTy coreOp) { return walkPimCoreBlock( coreOp.getBody().front(), StaticValueKnowledge {}, [](Operation& op, const StaticValueKnowledge& knowledge) { bool hasFailure = false;