From 628dc630a4ba9986c3db1f7b11eaf8f02008df5c Mon Sep 17 00:00:00 2001 From: NiccoloN Date: Tue, 12 May 2026 13:35:25 +0200 Subject: [PATCH] compact syntax for spatial tensor ops better IR compaction after dcp merge remove pim.mvm op better memory report --- src/PIM/Common/IR/WeightUtils.cpp | 9 +- src/PIM/Compiler/PimCodeGen.cpp | 87 +++++-- src/PIM/Compiler/PimCodeGen.hpp | 17 +- src/PIM/Compiler/PimWeightEmitter.cpp | 1 - .../BatchCoreLoweringPatterns.cpp | 14 +- .../Conversion/SpatialToPim/SpatialToPim.td | 6 - .../SpatialToPim/SpatialToPimPass.cpp | 46 +--- .../SpatialToPim/TensorPackingPatterns.cpp | 86 ++++-- .../SpatialToPim/TensorPackingPatterns.hpp | 14 +- src/PIM/Dialect/Pim/Pim.td | 24 -- .../OpBufferizationInterfaces.cpp | 28 -- src/PIM/Dialect/Spatial/Spatial.td | 20 +- src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp | 113 ++++++++ .../MergeComputeNodesPass.cpp | 15 ++ .../MergeComputeNodes/RegularOpCompaction.cpp | 244 +++++++++--------- 15 files changed, 419 insertions(+), 305 deletions(-) diff --git a/src/PIM/Common/IR/WeightUtils.cpp b/src/PIM/Common/IR/WeightUtils.cpp index 63ebcdf..d7079b9 100644 --- a/src/PIM/Common/IR/WeightUtils.cpp +++ b/src/PIM/Common/IR/WeightUtils.cpp @@ -88,7 +88,14 @@ bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value) { void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref callback) { assert(root && "expected valid root op"); - root->walk([&](pim::PimCoreOp coreOp) { walkMvmVmmWeightUses(coreOp, callback); }); + root->walk([&](pim::PimCoreOp coreOp) { + coreOp.walk([&](pim::PimVMMOp vmmOp) { + auto weights = coreOp.getWeights(); + unsigned weightIndex = vmmOp.getWeightIndex(); + if (weightIndex < weights.size()) + callback(coreOp->getOpOperand(weightIndex)); + }); + }); root->walk([&](pim::PimCoreBatchOp coreBatchOp) { auto weights = coreBatchOp.getWeights(); for (auto weight : weights) diff --git a/src/PIM/Compiler/PimCodeGen.cpp b/src/PIM/Compiler/PimCodeGen.cpp index 6894558..d911d4b 100644 --- a/src/PIM/Compiler/PimCodeGen.cpp +++ b/src/PIM/Compiler/PimCodeGen.cpp @@ -134,6 +134,15 @@ static void printMemoryReportRow(raw_ostream& os, const MemoryReportRow& row) { os << "\tGlobal memory: " << formatMemory(row.sizeGlobal) << "\n"; } +static MemoryReportRow addMemoryReportRows(const MemoryReportRow& lhs, const MemoryReportRow& rhs) { + MemoryReportRow result = lhs; + result.numAlloca += rhs.numAlloca; + result.sizeAlloca += rhs.sizeAlloca; + result.numGlobal += rhs.numGlobal; + result.sizeGlobal += rhs.sizeGlobal; + return result; +} + MemoryReportRow PimMemory::getReportRow() const { MemoryReportRow row; for (auto& [val, memEntry] : globalMemEntriesMap) { @@ -201,8 +210,17 @@ void PimAcceleratorMemory::reportHost() { hostReportRow = hostMem.getReportRow(); } -void PimAcceleratorMemory::reportCore(size_t coreId) { - coreReportRows.push_back({coreId, deviceMem.at(coreId).getReportRow()}); +void PimAcceleratorMemory::recordCoreReport(size_t coreId, const MemoryReportRow& row) { + reportEntries.push_back({MemoryReportEntry::Kind::Core, coreId, {static_cast(coreId)}, row}); +} + +void PimAcceleratorMemory::recordBatchReport(uint64_t batchId, ArrayRef coreIds, const MemoryReportRow& row) { + MemoryReportEntry entry; + entry.kind = MemoryReportEntry::Kind::Batch; + entry.id = batchId; + llvm::append_range(entry.coreIds, coreIds); + entry.row = row; + reportEntries.push_back(std::move(entry)); } void PimAcceleratorMemory::flushReport() { @@ -215,13 +233,16 @@ void PimAcceleratorMemory::flushReport() { printMemoryReportRow(os, *hostReportRow); } - if (!coreReportRows.empty()) { + if (!reportEntries.empty()) { if (hostReportRow.has_value()) os << "\n"; - llvm::stable_sort(coreReportRows, [](const auto& lhs, const auto& rhs) { - const MemoryReportRow& lhsRow = lhs.second; - const MemoryReportRow& rhsRow = rhs.second; + llvm::stable_sort(reportEntries, [](const MemoryReportEntry& lhs, const MemoryReportEntry& rhs) { + if (lhs.kind != rhs.kind) + return lhs.kind == MemoryReportEntry::Kind::Batch; + + const MemoryReportRow& lhsRow = lhs.row; + const MemoryReportRow& rhsRow = rhs.row; if (lhsRow.sizeAlloca != rhsRow.sizeAlloca) return lhsRow.sizeAlloca > rhsRow.sizeAlloca; if (lhsRow.numAlloca != rhsRow.numAlloca) @@ -230,24 +251,36 @@ void PimAcceleratorMemory::flushReport() { return lhsRow.sizeGlobal > rhsRow.sizeGlobal; if (lhsRow.numGlobal != rhsRow.numGlobal) return lhsRow.numGlobal > rhsRow.numGlobal; - return lhs.first < rhs.first; + return lhs.id < rhs.id; }); - for (size_t index = 0; index < coreReportRows.size();) { + for (size_t index = 0; index < reportEntries.size();) { size_t runEnd = index + 1; - while (runEnd < coreReportRows.size() && coreReportRows[runEnd].second == coreReportRows[index].second) + while (runEnd < reportEntries.size() && reportEntries[runEnd].kind == reportEntries[index].kind + && reportEntries[runEnd].row == reportEntries[index].row) { ++runEnd; + } - llvm::SmallVector coreIds; - coreIds.reserve(runEnd - index); - for (size_t coreIndex = index; coreIndex < runEnd; ++coreIndex) - coreIds.push_back(coreReportRows[coreIndex].first); - - os << "Core "; - printCompressedIntegerEntries(os, ArrayRef(coreIds)); + if (reportEntries[index].kind == MemoryReportEntry::Kind::Batch) { + os << "Batch "; + for (size_t batchIndex = index; batchIndex < runEnd; ++batchIndex) { + if (batchIndex != index) + os << ",\n "; + os << reportEntries[batchIndex].id << " (cores "; + printCompressedIntegerEntries(os, ArrayRef(reportEntries[batchIndex].coreIds)); + os << ")"; + } + } + else { + llvm::SmallVector coreIds; + for (size_t coreIndex = index; coreIndex < runEnd; ++coreIndex) + coreIds.push_back(reportEntries[coreIndex].coreIds.front()); + os << "Core "; + printCompressedIntegerEntries(os, ArrayRef(coreIds)); + } os << ":\n"; - printMemoryReportRow(os, coreReportRows[index].second); - if (runEnd < coreReportRows.size()) + printMemoryReportRow(os, reportEntries[index].row); + if (runEnd < reportEntries.size()) os << "\n"; index = runEnd; @@ -678,7 +711,6 @@ static SmallVector getUsedWeightIndices(Block& block) { indices.push_back(weightIndex); }; - block.walk([&](pim::PimMVMOp mvmOp) { addIndex(mvmOp.getWeightIndex()); }); block.walk([&](pim::PimVMMOp vmmOp) { addIndex(vmmOp.getWeightIndex()); }); llvm::sort(indices); return indices; @@ -753,8 +785,6 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) { coreCodeGen.codeGenConcatOp(concatOp, knowledge); else if (auto vmmOp = dyn_cast(op)) coreCodeGen.codeGenMVMLikeOp(vmmOp.getWeightIndex(), vmmOp, true, knowledge); - else if (auto mvmOp = dyn_cast(op)) - coreCodeGen.codeGenMVMLikeOp(mvmOp.getWeightIndex(), mvmOp, false, knowledge); else if (auto transposeOp = dyn_cast(op)) coreCodeGen.codeGenTransposeOp(transposeOp, knowledge); else if (auto vvaddOp = dyn_cast(op)) @@ -816,6 +846,7 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std:: // This implementation always assigns one crossbar per group. json::Object xbarsPerArrayGroup; size_t maxCoreId = 0; + uint64_t nextBatchReportId = 0; // Create Weight Folder auto mapCoreWeightToFileName = createAndPopulateWeightFolder(funcOp, outputDirPath); @@ -859,7 +890,6 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std:: PimCodeGen coreCodeGen(memory, coreFileStream, emittedCoreIds); aliasMaterializedHostGlobals(moduleOp, funcOp, coreOp, memory); memory.getOrCreateDeviceMem(coreId).allocateCore(coreOp); - memory.reportCore(coreId); int64_t processedOperations = codeGenCoreOps(coreOp.getBody().front(), coreCodeGen); if (processedOperations < 0) @@ -905,18 +935,31 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std:: if (auto coreOp = dyn_cast(op)) { if (auto err = emitCore(coreOp, false)) return err; + memory.recordCoreReport(emittedCoreIds.lookup(static_cast(coreOp.getCoreId())), + memory.getOrCreateDeviceMem(emittedCoreIds.lookup(static_cast(coreOp.getCoreId()))) + .getReportRow()); continue; } auto coreBatchOp = cast(op); + auto batchCoreIds = getBatchCoreIds(coreBatchOp); + SmallVector reportedCoreIds; + reportedCoreIds.reserve(batchCoreIds.size()); + MemoryReportRow batchRow; for (unsigned lane = 0; lane < static_cast(coreBatchOp.getLaneCount()); ++lane) { OnnxMlirCompilerErrorCodes laneResult = CompilerSuccess; if (failed(withScalarCoreFromBatchLane(coreBatchOp, lane, [&](pim::PimCoreOp coreOp) { + size_t originalCoreId = static_cast(batchCoreIds[lane]); + size_t coreId = emittedCoreIds.lookup(originalCoreId); + reportedCoreIds.push_back(static_cast(coreId)); laneResult = emitCore(coreOp, true); + if (laneResult == CompilerSuccess) + batchRow = addMemoryReportRows(batchRow, memory.getOrCreateDeviceMem(coreId).getReportRow()); return laneResult == CompilerSuccess ? success() : failure(); }))) return laneResult == CompilerSuccess ? CompilerFailure : laneResult; } + memory.recordBatchReport(nextBatchReportId++, reportedCoreIds, batchRow); } memory.flushReport(); diff --git a/src/PIM/Compiler/PimCodeGen.hpp b/src/PIM/Compiler/PimCodeGen.hpp index 4a6e0f5..8536b4a 100644 --- a/src/PIM/Compiler/PimCodeGen.hpp +++ b/src/PIM/Compiler/PimCodeGen.hpp @@ -33,6 +33,18 @@ struct MemoryReportRow { } }; +struct MemoryReportEntry { + enum class Kind { + Core, + Batch + }; + + Kind kind = Kind::Core; + uint64_t id = 0; + llvm::SmallVector coreIds; + MemoryReportRow row; +}; + class PimMemory { llvm::SmallVector, 32> memEntries; llvm::SmallDenseMap& globalMemEntriesMap; @@ -66,7 +78,7 @@ private: llvm::SmallDenseMap deviceMem; std::fstream fileReport; std::optional hostReportRow; - llvm::SmallVector, 32> coreReportRows; + llvm::SmallVector reportEntries; public: PimAcceleratorMemory() @@ -86,7 +98,8 @@ public: size_t getValueAddress(mlir::Value value, const StaticValueKnowledge& knowledge = {}) const; void reportHost(); - void reportCore(size_t coreId); + void recordCoreReport(size_t coreId, const MemoryReportRow& row); + void recordBatchReport(uint64_t batchId, llvm::ArrayRef coreIds, const MemoryReportRow& row); void flushReport(); void clean(mlir::Operation* op); }; diff --git a/src/PIM/Compiler/PimWeightEmitter.cpp b/src/PIM/Compiler/PimWeightEmitter.cpp index 7545d5d..2e318ee 100644 --- a/src/PIM/Compiler/PimWeightEmitter.cpp +++ b/src/PIM/Compiler/PimWeightEmitter.cpp @@ -103,7 +103,6 @@ SmallVector getUsedWeightIndices(Block& block) { indices.push_back(weightIndex); }; - block.walk([&](pim::PimMVMOp mvmOp) { addIndex(mvmOp.getWeightIndex()); }); block.walk([&](pim::PimVMMOp vmmOp) { addIndex(vmmOp.getWeightIndex()); }); llvm::sort(indices); return indices; diff --git a/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp b/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp index f66fc8b..58a0c66 100644 --- a/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp +++ b/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp @@ -7,7 +7,6 @@ #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp" -#include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" using namespace mlir; @@ -37,15 +36,10 @@ static void lowerChannelSendTensorBatch(spatial::SpatChannelSendTensorBatchOp se for (int32_t targetCoreId : sendTensorBatchOp.getTargetCoreIds()) targetCoreIds.push_back(translateSpatialCoreIdToPimCoreId(targetCoreId)); - Value input = mapper.lookup(sendTensorBatchOp.getInput()); - if (auto concatOp = input.getDefiningOp()) - if (concatOp.getDim() == 0) - if (Value packedInput = - createPackedExtractSliceTensor(concatOp.getInputs(), rewriter, sendTensorBatchOp.getLoc())) - input = packedInput; - - pim::PimSendTensorBatchOp::create( - rewriter, sendTensorBatchOp.getLoc(), input, rewriter.getDenseI32ArrayAttr(targetCoreIds)); + pim::PimSendTensorBatchOp::create(rewriter, + sendTensorBatchOp.getLoc(), + mapper.lookup(sendTensorBatchOp.getInput()), + rewriter.getDenseI32ArrayAttr(targetCoreIds)); } static void lowerChannelReceiveTensorBatch(spatial::SpatChannelReceiveTensorBatchOp receiveTensorBatchOp, diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPim.td b/src/PIM/Conversion/SpatialToPim/SpatialToPim.td index 4b73be5..d79dd66 100644 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPim.td +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPim.td @@ -21,12 +21,6 @@ def spatToPimVMM : Pat< (NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes)) >; -def spatToPimMVM : Pat< - (SpatMVMOp:$srcOpRes $weightIndex, $vector), - (PimMVMOp $weightIndex, $vector, - (NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes)) ->; - def spatToPimVVAdd : Pat< (SpatVAddOp:$srcOpRes $a, $b), (PimVVAddOp $a, $b, diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp index 1475c2f..7a48fca 100644 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp @@ -11,7 +11,6 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/WalkPatternRewriteDriver.h" #include "llvm/ADT/StringRef.h" @@ -105,12 +104,8 @@ static void lowerChannelSendTensor(spatial::SpatChannelSendTensorOp sendTensorOp targetCoreIds.push_back(translateSpatialCoreIdToPimCoreId(targetCoreId)); rewriter.setInsertionPoint(sendTensorOp); - Value input = sendTensorOp.getInput(); - if (auto concatOp = input.getDefiningOp()) - if (concatOp.getDim() == 0) - if (Value packedInput = createPackedExtractSliceTensor(concatOp.getInputs(), rewriter, sendTensorOp.getLoc())) - input = packedInput; - PimSendTensorOp::create(rewriter, sendTensorOp.getLoc(), input, rewriter.getDenseI32ArrayAttr(targetCoreIds)); + PimSendTensorOp::create( + rewriter, sendTensorOp.getLoc(), sendTensorOp.getInput(), rewriter.getDenseI32ArrayAttr(targetCoreIds)); rewriter.eraseOp(sendTensorOp); } @@ -152,38 +147,6 @@ static void lowerExtractRows(spatial::SpatExtractRowsOp extractRowsOp, IRRewrite rewriter.replaceOp(extractRowsOp, replacements); } -static Value createPackedExtractRowsSlice( - spatial::SpatExtractRowsOp extractRowsOp, unsigned startIndex, unsigned count, IRRewriter& rewriter, Location loc) { - auto rowType = dyn_cast(extractRowsOp.getOutputs()[startIndex].getType()); - auto inputType = dyn_cast(extractRowsOp.getInput().getType()); - if (!rowType || !inputType || !rowType.hasStaticShape() || !inputType.hasStaticShape() || rowType.getRank() == 0) - return {}; - - int64_t rowsPerValue = rowType.getDimSize(0); - if (ShapedType::isDynamic(rowsPerValue)) - return {}; - - auto packedType = getPackedTensorType(rowType, static_cast(count)); - SmallVector offsets; - SmallVector sizes; - SmallVector strides; - offsets.reserve(inputType.getRank()); - sizes.reserve(inputType.getRank()); - strides.reserve(inputType.getRank()); - - offsets.push_back(rewriter.getIndexAttr(static_cast(startIndex) * rowsPerValue)); - sizes.push_back(rewriter.getIndexAttr(static_cast(count) * rowsPerValue)); - strides.push_back(rewriter.getIndexAttr(1)); - for (int64_t dim = 1; dim < inputType.getRank(); ++dim) { - offsets.push_back(rewriter.getIndexAttr(0)); - sizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(dim))); - strides.push_back(rewriter.getIndexAttr(1)); - } - - return tensor::ExtractSliceOp::create(rewriter, loc, packedType, extractRowsOp.getInput(), offsets, sizes, strides) - .getResult(); -} - static void compactSpatialTensorGroups(func::FuncOp funcOp, IRRewriter& rewriter) { SmallVector concatOps; funcOp.walk([&](spatial::SpatConcatOp concatOp) { concatOps.push_back(concatOp); }); @@ -262,11 +225,6 @@ static void compactSpatialTensorGroups(func::FuncOp funcOp, IRRewriter& rewriter .getResult()); rewriter.replaceOp(concatOp, newConcat.getOutput()); } - - RewritePatternSet tensorPackingPatterns(funcOp.getContext()); - populateTensorPackingPatterns(tensorPackingPatterns); - (void) applyPatternsGreedily(funcOp, std::move(tensorPackingPatterns)); - auto eraseUnusedOps = [&](auto tag) { using OpTy = decltype(tag); SmallVector ops; diff --git a/src/PIM/Conversion/SpatialToPim/TensorPackingPatterns.cpp b/src/PIM/Conversion/SpatialToPim/TensorPackingPatterns.cpp index 4307cdf..2dd3043 100644 --- a/src/PIM/Conversion/SpatialToPim/TensorPackingPatterns.cpp +++ b/src/PIM/Conversion/SpatialToPim/TensorPackingPatterns.cpp @@ -3,26 +3,6 @@ using namespace mlir; namespace onnx_mlir { -namespace { - -// Replaces concat-of-adjacent-slices with one packed slice to keep batch sends compact. -struct FoldConcatOfContiguousSlices : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(tensor::ConcatOp op, PatternRewriter& rewriter) const override { - if (op.getDim() != 0) - return failure(); - - Value packed = createPackedExtractSliceTensor(op.getInputs(), rewriter, op.getLoc()); - if (!packed) - return failure(); - - rewriter.replaceOp(op, packed); - return success(); - } -}; - -} // namespace RankedTensorType getPackedTensorType(RankedTensorType elementType, int64_t count) { SmallVector packedShape(elementType.getShape().begin(), elementType.getShape().end()); @@ -30,6 +10,67 @@ RankedTensorType getPackedTensorType(RankedTensorType elementType, int64_t count return RankedTensorType::get(packedShape, elementType.getElementType()); } +Value extractPackedChunk( + Value packedValue, RankedTensorType chunkType, unsigned index, OpBuilder& builder, Location loc) { + auto packedType = dyn_cast(packedValue.getType()); + if (packedType && packedType == chunkType && index == 0) + return packedValue; + + SmallVector offsets; + SmallVector sizes; + SmallVector strides; + offsets.reserve(chunkType.getRank()); + sizes.reserve(chunkType.getRank()); + strides.reserve(chunkType.getRank()); + + offsets.push_back(builder.getIndexAttr(static_cast(index) * chunkType.getDimSize(0))); + sizes.push_back(builder.getIndexAttr(chunkType.getDimSize(0))); + strides.push_back(builder.getIndexAttr(1)); + for (int64_t dim = 1; dim < chunkType.getRank(); ++dim) { + offsets.push_back(builder.getIndexAttr(0)); + sizes.push_back(builder.getIndexAttr(chunkType.getDimSize(dim))); + strides.push_back(builder.getIndexAttr(1)); + } + + return tensor::ExtractSliceOp::create(builder, loc, chunkType, packedValue, offsets, sizes, strides).getResult(); +} + +Value createPackedExtractRowsSlice( + spatial::SpatExtractRowsOp extractRowsOp, unsigned startIndex, unsigned count, OpBuilder& builder, Location loc) { + auto rowType = dyn_cast(extractRowsOp.getOutputs()[startIndex].getType()); + auto inputType = dyn_cast(extractRowsOp.getInput().getType()); + if (!rowType || !inputType || !rowType.hasStaticShape() || !inputType.hasStaticShape() || rowType.getRank() == 0) + return {}; + + int64_t rowsPerValue = rowType.getDimSize(0); + if (ShapedType::isDynamic(rowsPerValue)) + return {}; + + auto packedType = getPackedTensorType(rowType, static_cast(count)); + SmallVector offsets; + SmallVector sizes; + SmallVector strides; + offsets.reserve(inputType.getRank()); + sizes.reserve(inputType.getRank()); + strides.reserve(inputType.getRank()); + + offsets.push_back(builder.getIndexAttr(static_cast(startIndex) * rowsPerValue)); + sizes.push_back(builder.getIndexAttr(static_cast(count) * rowsPerValue)); + strides.push_back(builder.getIndexAttr(1)); + for (int64_t dim = 1; dim < inputType.getRank(); ++dim) { + offsets.push_back(builder.getIndexAttr(0)); + sizes.push_back(builder.getIndexAttr(inputType.getDimSize(dim))); + strides.push_back(builder.getIndexAttr(1)); + } + + bool coversWholeSource = packedType == inputType && startIndex == 0; + if (coversWholeSource) + return extractRowsOp.getInput(); + + return tensor::ExtractSliceOp::create(builder, loc, packedType, extractRowsOp.getInput(), offsets, sizes, strides) + .getResult(); +} + Value createPackedExtractSliceTensor(ValueRange values, OpBuilder& builder, Location loc) { if (values.empty()) return {}; @@ -105,9 +146,4 @@ Value createPackedExtractSliceTensor(ValueRange values, OpBuilder& builder, Loca return tensor::ExtractSliceOp::create(builder, loc, packedType, firstSliceOp.getSource(), offsets, sizes, strides) .getResult(); } - -void populateTensorPackingPatterns(RewritePatternSet& patterns) { - patterns.add(patterns.getContext()); -} - } // namespace onnx_mlir diff --git a/src/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp b/src/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp index 7b34544..f338514 100644 --- a/src/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp +++ b/src/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp @@ -3,11 +3,21 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/PatternMatch.h" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" + namespace onnx_mlir { mlir::RankedTensorType getPackedTensorType(mlir::RankedTensorType elementType, int64_t count); +mlir::Value extractPackedChunk(mlir::Value packedValue, + mlir::RankedTensorType chunkType, + unsigned index, + mlir::OpBuilder& builder, + mlir::Location loc); +mlir::Value createPackedExtractRowsSlice(spatial::SpatExtractRowsOp extractRowsOp, + unsigned startIndex, + unsigned count, + mlir::OpBuilder& builder, + mlir::Location loc); mlir::Value createPackedExtractSliceTensor(mlir::ValueRange values, mlir::OpBuilder& builder, mlir::Location loc); -void populateTensorPackingPatterns(mlir::RewritePatternSet& patterns); - } // namespace onnx_mlir diff --git a/src/PIM/Dialect/Pim/Pim.td b/src/PIM/Dialect/Pim/Pim.td index 429b5a1..c97f174 100644 --- a/src/PIM/Dialect/Pim/Pim.td +++ b/src/PIM/Dialect/Pim/Pim.td @@ -394,30 +394,6 @@ def PimVMMOp : PimOp<"vmm", [DestinationStyleOpInterface]> { }]; } -def PimMVMOp : PimOp<"mvm", [DestinationStyleOpInterface]> { - let summary = "Matrix-vector multiplication: c = a * b"; - - let arguments = (ins - I32Attr:$weightIndex, - PimTensor:$input, - PimTensor:$outputBuffer - ); - - let results = (outs - PimTensor:$output - ); - - let extraClassDeclaration = [{ - mlir::MutableOperandRange getDpsInitsMutable() { - return getOutputBufferMutable(); - } - }]; - - let assemblyFormat = [{ - `(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output) - }]; -} - def PimVVAddOp : PimOp<"vvadd", [DestinationStyleOpInterface]> { let summary = "Element-wise addition: c = a + b"; diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp index 35d3d94..1fad5b5 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp @@ -538,33 +538,6 @@ struct VMMOpInterface : DstBufferizableOpInterfaceExternalModel { - bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { - return !cast(op).isDpsInit(&opOperand); - } - - LogicalResult bufferize(Operation* op, - RewriterBase& rewriter, - const BufferizationOptions& options, - BufferizationState& state) const { - auto mvmOp = cast(op); - - auto inputOpt = getBufferOrValue(rewriter, mvmOp.getInput(), options, state); - if (failed(inputOpt)) - return failure(); - - auto outputBufferOpt = getBufferOrValue(rewriter, mvmOp.getOutputBuffer(), options, state); - if (failed(outputBufferOpt)) - return failure(); - - Value contiguousInput = materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter); - - replaceOpWithNewBufferizedOp( - rewriter, op, outputBufferOpt->getType(), mvmOp.getWeightIndexAttr(), contiguousInput, *outputBufferOpt); - return success(); - } -}; - template struct BinaryDstOpInterface : DstBufferizableOpInterfaceExternalModel, OpTy> { bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const { @@ -655,7 +628,6 @@ void registerOpBufferizationInterfaces(DialectRegistry& registry) { PimMemCopyDevToHostOp::attachInterface(*ctx); PimTransposeOp::attachInterface(*ctx); PimVMMOp::attachInterface(*ctx); - PimMVMOp::attachInterface(*ctx); PimVVAddOp::attachInterface>(*ctx); PimVVSubOp::attachInterface>(*ctx); diff --git a/src/PIM/Dialect/Spatial/Spatial.td b/src/PIM/Dialect/Spatial/Spatial.td index 9414609..80e9a4d 100644 --- a/src/PIM/Dialect/Spatial/Spatial.td +++ b/src/PIM/Dialect/Spatial/Spatial.td @@ -150,10 +150,7 @@ def SpatChannelSendTensorOp : SpatOp<"channel_send_tensor", []> { ); let hasVerifier = 1; - - let assemblyFormat = [{ - $input attr-dict `:` type($input) - }]; + let hasCustomAssemblyFormat = 1; } def SpatChannelReceiveTensorOp : SpatOp<"channel_receive_tensor", []> { @@ -170,10 +167,7 @@ def SpatChannelReceiveTensorOp : SpatOp<"channel_receive_tensor", []> { ); let hasVerifier = 1; - - let assemblyFormat = [{ - attr-dict `:` type($output) - }]; + let hasCustomAssemblyFormat = 1; } def SpatChannelSendBatchOp : SpatOp<"channel_send_batch", []> { @@ -201,10 +195,7 @@ def SpatChannelSendTensorBatchOp : SpatOp<"channel_send_tensor_batch", []> { ); let hasVerifier = 1; - - let assemblyFormat = [{ - $input attr-dict `:` type($input) - }]; + let hasCustomAssemblyFormat = 1; } def SpatChannelReceiveBatchOp : SpatOp<"channel_receive_batch", []> { @@ -238,10 +229,7 @@ def SpatChannelReceiveTensorBatchOp : SpatOp<"channel_receive_tensor_batch", []> ); let hasVerifier = 1; - - let assemblyFormat = [{ - attr-dict `:` type($output) - }]; + let hasCustomAssemblyFormat = 1; } //===----------------------------------------------------------------------===// diff --git a/src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp b/src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp index f86ffa9..00e4fc4 100644 --- a/src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp +++ b/src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp @@ -47,6 +47,95 @@ static IntegerAttr getI32Attr(OpAsmParser& parser, int32_t value) { return parser.getBuilder().getI32IntegerAttr(value); } +template +static void printTensorSendOp(OpAsmPrinter& printer, TensorSendOpTy op) { + printer << " "; + printer.printOperand(op.getInput()); + printChannelMetadata(printer, op.getChannelIds(), op.getSourceCoreIds(), op.getTargetCoreIds()); + printer.printOptionalAttrDict( + op->getAttrs(), + {op.getChannelIdsAttrName().getValue(), op.getSourceCoreIdsAttrName().getValue(), op.getTargetCoreIdsAttrName().getValue()}); + printer << " : "; + printer.printType(op.getInput().getType()); +} + +template +static void printTensorReceiveOp(OpAsmPrinter& printer, TensorReceiveOpTy op) { + printChannelMetadata(printer, op.getChannelIds(), op.getSourceCoreIds(), op.getTargetCoreIds()); + printer.printOptionalAttrDict( + op->getAttrs(), + {op.getChannelIdsAttrName().getValue(), op.getSourceCoreIdsAttrName().getValue(), op.getTargetCoreIdsAttrName().getValue()}); + printer << " : "; + printer.printType(op.getOutput().getType()); +} + +static ParseResult parseTensorSendOp(OpAsmParser& parser, OperationState& result) { + OpAsmParser::UnresolvedOperand input; + Type inputType; + SmallVector channelIds; + SmallVector sourceCoreIds; + SmallVector targetCoreIds; + + if (parser.parseOperand(input)) + return failure(); + + bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels")); + if (hasMetadata) { + if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from") + || parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to") + || parseCompressedIntegerList(parser, targetCoreIds)) + return failure(); + } + + if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType)) + return failure(); + + if (hasMetadata + && (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds") + || result.attributes.get("targetCoreIds"))) + return parser.emitError(parser.getCurrentLocation(), + "channel metadata cannot be specified both positionally and in attr-dict"); + if (hasMetadata) { + result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds)); + result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds)); + result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds)); + } + + return parser.resolveOperand(input, inputType, result.operands); +} + +static ParseResult parseTensorReceiveOp(OpAsmParser& parser, OperationState& result) { + Type outputType; + SmallVector channelIds; + SmallVector sourceCoreIds; + SmallVector targetCoreIds; + + bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels")); + if (hasMetadata) { + if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from") + || parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to") + || parseCompressedIntegerList(parser, targetCoreIds)) + return failure(); + } + + if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(outputType)) + return failure(); + + if (hasMetadata + && (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds") + || result.attributes.get("targetCoreIds"))) + return parser.emitError(parser.getCurrentLocation(), + "channel metadata cannot be specified both positionally and in attr-dict"); + if (hasMetadata) { + result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds)); + result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds)); + result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds)); + } + + result.addTypes(outputType); + return success(); +} + } // namespace void SpatYieldOp::print(OpAsmPrinter& printer) { @@ -316,6 +405,12 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result) return parser.parseRegion(*body, regionArgs); } +void SpatChannelSendTensorOp::print(OpAsmPrinter& printer) { printTensorSendOp(printer, *this); } + +ParseResult SpatChannelSendTensorOp::parse(OpAsmParser& parser, OperationState& result) { + return parseTensorSendOp(parser, result); +} + void SpatChannelSendBatchOp::print(OpAsmPrinter& printer) { printer << " "; printer.printOperand(getInput()); @@ -362,6 +457,18 @@ ParseResult SpatChannelSendBatchOp::parse(OpAsmParser& parser, OperationState& r return parser.resolveOperand(input, inputType, result.operands); } +void SpatChannelSendTensorBatchOp::print(OpAsmPrinter& printer) { printTensorSendOp(printer, *this); } + +ParseResult SpatChannelSendTensorBatchOp::parse(OpAsmParser& parser, OperationState& result) { + return parseTensorSendOp(parser, result); +} + +void SpatChannelReceiveTensorOp::print(OpAsmPrinter& printer) { printTensorReceiveOp(printer, *this); } + +ParseResult SpatChannelReceiveTensorOp::parse(OpAsmParser& parser, OperationState& result) { + return parseTensorReceiveOp(parser, result); +} + void SpatChannelReceiveBatchOp::print(OpAsmPrinter& printer) { printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds()); printer.printOptionalAttrDict( @@ -403,5 +510,11 @@ ParseResult SpatChannelReceiveBatchOp::parse(OpAsmParser& parser, OperationState return success(); } +void SpatChannelReceiveTensorBatchOp::print(OpAsmPrinter& printer) { printTensorReceiveOp(printer, *this); } + +ParseResult SpatChannelReceiveTensorBatchOp::parse(OpAsmParser& parser, OperationState& result) { + return parseTensorReceiveOp(parser, result); +} + } // namespace spatial } // namespace onnx_mlir diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp index c9a4865..eb6d712 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp @@ -1436,6 +1436,21 @@ public: compactBatchChannelRuns(func); compactRegularOpRuns(func); compactRowWiseWvmmRuns(func); + compactScalarChannelRuns(func, nextChannelId); + compactBatchChannelRuns(func); + + auto eraseUnusedOps = [&](auto tag) { + using OpTy = decltype(tag); + SmallVector ops; + func.walk([&](OpTy op) { ops.push_back(op); }); + for (auto op : llvm::reverse(ops)) + if (op->use_empty()) + op.erase(); + }; + eraseUnusedOps(tensor::ExtractSliceOp {}); + eraseUnusedOps(spatial::SpatConcatOp {}); + eraseUnusedOps(spatial::SpatExtractRowsOp {}); + if (!sortTopologically(&func.getBody().front())) { func.emitOpError("failed to topologically order merged Spatial IR"); signalPassFailure(); diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/RegularOpCompaction.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/RegularOpCompaction.cpp index 74082fd..8a92beb 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/RegularOpCompaction.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/RegularOpCompaction.cpp @@ -13,6 +13,7 @@ #include #include "RegularOpCompaction.hpp" +#include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" using namespace mlir; @@ -41,133 +42,75 @@ struct RegularChunk { Value output; }; -static RankedTensorType getPackedTensorType(RankedTensorType elementType, int64_t count) { - SmallVector packedShape(elementType.getShape().begin(), elementType.getShape().end()); - packedShape[0] *= count; - return RankedTensorType::get(packedShape, elementType.getElementType()); -} +static spatial::SpatConcatOp getContiguousConcatUse(ValueRange values, unsigned& startOperandIndex) { + if (values.empty() || !values.front().hasOneUse()) + return {}; -static Value -extractPackedChunk(Value packedValue, RankedTensorType chunkType, unsigned index, IRRewriter& rewriter, Location loc) { - SmallVector offsets; - SmallVector sizes; - SmallVector strides; - offsets.reserve(chunkType.getRank()); - sizes.reserve(chunkType.getRank()); - strides.reserve(chunkType.getRank()); + OpOperand& firstUse = *values.front().getUses().begin(); + auto concatOp = dyn_cast(firstUse.getOwner()); + if (!concatOp) + return {}; - offsets.push_back(rewriter.getIndexAttr(static_cast(index) * chunkType.getDimSize(0))); - sizes.push_back(rewriter.getIndexAttr(chunkType.getDimSize(0))); - strides.push_back(rewriter.getIndexAttr(1)); - for (int64_t dim = 1; dim < chunkType.getRank(); ++dim) { - offsets.push_back(rewriter.getIndexAttr(0)); - sizes.push_back(rewriter.getIndexAttr(chunkType.getDimSize(dim))); - strides.push_back(rewriter.getIndexAttr(1)); + startOperandIndex = firstUse.getOperandNumber(); + for (auto [index, value] : llvm::enumerate(values)) { + if (!value.hasOneUse()) + return {}; + OpOperand& use = *value.getUses().begin(); + if (use.getOwner() != concatOp || use.getOperandNumber() != startOperandIndex + index) + return {}; } - return tensor::ExtractSliceOp::create(rewriter, loc, chunkType, packedValue, offsets, sizes, strides).getResult(); + return concatOp; } -static Value createPackedExtractRowsSlice( - spatial::SpatExtractRowsOp extractRowsOp, unsigned startIndex, unsigned count, IRRewriter& rewriter, Location loc) { - auto rowType = dyn_cast(extractRowsOp.getOutputs()[startIndex].getType()); - auto inputType = dyn_cast(extractRowsOp.getInput().getType()); - if (!rowType || !inputType || !rowType.hasStaticShape() || !inputType.hasStaticShape() || rowType.getRank() == 0) - return {}; - - int64_t rowsPerValue = rowType.getDimSize(0); - if (ShapedType::isDynamic(rowsPerValue)) - return {}; - - auto packedType = getPackedTensorType(rowType, static_cast(count)); - SmallVector offsets; - SmallVector sizes; - SmallVector strides; - offsets.reserve(inputType.getRank()); - sizes.reserve(inputType.getRank()); - strides.reserve(inputType.getRank()); - - offsets.push_back(rewriter.getIndexAttr(static_cast(startIndex) * rowsPerValue)); - sizes.push_back(rewriter.getIndexAttr(static_cast(count) * rowsPerValue)); - strides.push_back(rewriter.getIndexAttr(1)); - for (int64_t dim = 1; dim < inputType.getRank(); ++dim) { - offsets.push_back(rewriter.getIndexAttr(0)); - sizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(dim))); - strides.push_back(rewriter.getIndexAttr(1)); +static void replaceConcatRunWithPackedValue(spatial::SpatConcatOp concatOp, + unsigned startOperandIndex, + unsigned operandCount, + Value packedValue, + IRRewriter& rewriter) { + SmallVector newInputs; + newInputs.reserve(concatOp.getInputs().size() - operandCount + 1); + for (auto [operandIndex, operand] : llvm::enumerate(concatOp.getInputs())) { + if (operandIndex == startOperandIndex) + newInputs.push_back(packedValue); + if (operandIndex < startOperandIndex || operandIndex >= startOperandIndex + operandCount) + newInputs.push_back(operand); } - - return tensor::ExtractSliceOp::create(rewriter, loc, packedType, extractRowsOp.getInput(), offsets, sizes, strides) - .getResult(); + if (newInputs.size() == 1 && newInputs.front().getType() == concatOp.getOutput().getType()) { + rewriter.replaceOp(concatOp, newInputs.front()); + return; + } + rewriter.modifyOpInPlace(concatOp, [&] { concatOp->setOperands(newInputs); }); } -static Value createPackedExtractSliceTensor(ValueRange values, IRRewriter& rewriter, Location loc) { - if (values.empty()) - return {}; - if (values.size() == 1) - return values.front(); - - auto firstSliceOp = values.front().getDefiningOp(); - if (!firstSliceOp) +static RankedTensorType +getPackedConcatSliceType(spatial::SpatConcatOp concatOp, unsigned startOperandIndex, unsigned operandCount) { + auto firstType = dyn_cast(concatOp.getInputs()[startOperandIndex].getType()); + if (!firstType || !firstType.hasStaticShape()) return {}; - auto firstType = dyn_cast(firstSliceOp.getResult().getType()); - auto sourceType = dyn_cast(firstSliceOp.getSource().getType()); - if (!firstType || !sourceType || !firstType.hasStaticShape() || !sourceType.hasStaticShape() - || firstType.getRank() == 0) + int64_t axis = concatOp.getAxis(); + if (axis < 0 || axis >= firstType.getRank()) return {}; - auto hasStaticValues = [](ArrayRef values) { - return llvm::all_of(values, [](int64_t value) { return !ShapedType::isDynamic(value); }); - }; - if (!hasStaticValues(firstSliceOp.getStaticOffsets()) || !hasStaticValues(firstSliceOp.getStaticSizes()) - || !hasStaticValues(firstSliceOp.getStaticStrides())) - return {}; - - ArrayRef firstOffsets = firstSliceOp.getStaticOffsets(); - ArrayRef firstSizes = firstSliceOp.getStaticSizes(); - ArrayRef firstStrides = firstSliceOp.getStaticStrides(); - int64_t rowsPerValue = firstSizes[0]; - if (ShapedType::isDynamic(rowsPerValue)) - return {}; - - for (size_t index = 1; index < values.size(); ++index) { - auto sliceOp = values[index].getDefiningOp(); - if (!sliceOp || sliceOp.getSource() != firstSliceOp.getSource() - || sliceOp.getResult().getType() != firstSliceOp.getResult().getType() - || !hasStaticValues(sliceOp.getStaticOffsets()) || !hasStaticValues(sliceOp.getStaticSizes()) - || !hasStaticValues(sliceOp.getStaticStrides())) + SmallVector shape(firstType.getShape().begin(), firstType.getShape().end()); + shape[axis] = 0; + for (unsigned index = 0; index < operandCount; ++index) { + auto operandType = dyn_cast(concatOp.getInputs()[startOperandIndex + index].getType()); + if (!operandType || !operandType.hasStaticShape() || operandType.getRank() != firstType.getRank()) return {}; - if (sliceOp.getStaticSizes() != firstSizes || sliceOp.getStaticStrides() != firstStrides) - return {}; - - if (sliceOp.getStaticOffsets()[0] != firstOffsets[0] + static_cast(index) * rowsPerValue) - return {}; - - for (int64_t dim = 1; dim < firstType.getRank(); ++dim) - if (sliceOp.getStaticOffsets()[dim] != firstOffsets[dim]) + for (int64_t dim = 0; dim < firstType.getRank(); ++dim) { + if (dim == axis) + continue; + if (operandType.getShape()[dim] != shape[dim]) return {}; + } + + shape[axis] += operandType.getShape()[axis]; } - auto packedType = getPackedTensorType(firstType, static_cast(values.size())); - SmallVector offsets; - SmallVector sizes; - SmallVector strides; - offsets.reserve(firstType.getRank()); - sizes.reserve(firstType.getRank()); - strides.reserve(firstType.getRank()); - - offsets.push_back(rewriter.getIndexAttr(firstOffsets[0])); - sizes.push_back(rewriter.getIndexAttr(rowsPerValue * static_cast(values.size()))); - strides.push_back(rewriter.getIndexAttr(firstStrides[0])); - for (int64_t dim = 1; dim < firstType.getRank(); ++dim) { - offsets.push_back(rewriter.getIndexAttr(firstOffsets[dim])); - sizes.push_back(rewriter.getIndexAttr(firstSizes[dim])); - strides.push_back(rewriter.getIndexAttr(firstStrides[dim])); - } - - return tensor::ExtractSliceOp::create(rewriter, loc, packedType, firstSliceOp.getSource(), offsets, sizes, strides) - .getResult(); + return RankedTensorType::get(shape, firstType.getElementType()); } static bool getContiguousOpResults(ValueRange values, Operation*& owner, unsigned& startIndex) { @@ -207,8 +150,7 @@ static Value createPackedTensorForValues(ValueRange values, IRRewriter& rewriter return {}; if (!llvm::all_of(values.drop_front(), [&](Value value) { return value.getType() == firstType; })) return {}; - - return tensor::ConcatOp::create(rewriter, loc, /*dim=*/0, values).getResult(); + return {}; } static bool areEquivalentRegularSteps(const RegularStep& lhs, const RegularStep& rhs) { @@ -346,11 +288,28 @@ static void compactRegularChunkRun(IRRewriter& rewriter, ArrayRef scf::YieldOp::create(rewriter, anchorChunk.startOp->getLoc(), inserted.getResult()); } - for (auto [index, chunk] : llvm::enumerate(run)) { - Value replacement = extractPackedChunk( - loop.getResult(0), outputType, static_cast(index), rewriter, chunk.startOp->getLoc()); - Value output = chunk.output; - output.replaceAllUsesWith(replacement); + SmallVector outputs; + outputs.reserve(run.size()); + for (const RegularChunk& chunk : run) + outputs.push_back(chunk.output); + + unsigned concatStartIndex = 0; + auto concatOp = getContiguousConcatUse(ValueRange(outputs), concatStartIndex); + auto concatPackedType = concatOp + ? getPackedConcatSliceType(concatOp, concatStartIndex, static_cast(outputs.size())) + : RankedTensorType {}; + + if (concatOp && concatPackedType == packedOutputType) { + replaceConcatRunWithPackedValue( + concatOp, concatStartIndex, static_cast(outputs.size()), loop.getResult(0), rewriter); + } + else { + for (auto [index, chunk] : llvm::enumerate(run)) { + Value replacement = extractPackedChunk( + loop.getResult(0), outputType, static_cast(index), rewriter, chunk.startOp->getLoc()); + Value output = chunk.output; + output.replaceAllUsesWith(replacement); + } } SmallVector opsToErase; @@ -412,7 +371,18 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) { } auto rowType = cast(run.front().getOutput().getType()); - auto packedType = getPackedTensorType(rowType, static_cast(sortedEntries.size())); + auto fallbackPackedType = getPackedTensorType(rowType, static_cast(sortedEntries.size())); + SmallVector sortedOutputs; + sortedOutputs.reserve(sortedEntries.size()); + for (ReceiveEntry& entry : sortedEntries) + sortedOutputs.push_back(entry.op.getOutput()); + + unsigned concatStartIndex = 0; + auto concatOp = getContiguousConcatUse(ValueRange(sortedOutputs), concatStartIndex); + auto concatPackedType = + concatOp ? getPackedConcatSliceType(concatOp, concatStartIndex, static_cast(sortedOutputs.size())) + : RankedTensorType {}; + auto packedType = concatPackedType ? concatPackedType : fallbackPackedType; rewriter.setInsertionPoint(run.front()); auto compactReceive = spatial::SpatChannelReceiveTensorOp::create(rewriter, @@ -421,9 +391,18 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) { rewriter.getDenseI64ArrayAttr(channelIds), rewriter.getDenseI32ArrayAttr(sourceCoreIds), rewriter.getDenseI32ArrayAttr(targetCoreIds)); - for (auto [sortedIndex, entry] : llvm::enumerate(sortedEntries)) - entry.op.getOutput().replaceAllUsesWith(extractPackedChunk( - compactReceive.getOutput(), rowType, static_cast(sortedIndex), rewriter, entry.op.getLoc())); + if (concatOp && concatPackedType) { + replaceConcatRunWithPackedValue(concatOp, + concatStartIndex, + static_cast(sortedOutputs.size()), + compactReceive.getOutput(), + rewriter); + } + else { + for (auto [sortedIndex, entry] : llvm::enumerate(sortedEntries)) + entry.op.getOutput().replaceAllUsesWith(extractPackedChunk( + compactReceive.getOutput(), rowType, static_cast(sortedIndex), rewriter, entry.op.getLoc())); + } for (auto op : run) rewriter.eraseOp(op); @@ -531,7 +510,18 @@ void compactBatchChannelRuns(func::FuncOp funcOp) { } auto rowType = cast(run.front().getOutput().getType()); - auto packedType = getPackedTensorType(rowType, static_cast(run.size())); + auto fallbackPackedType = getPackedTensorType(rowType, static_cast(run.size())); + SmallVector outputs; + outputs.reserve(run.size()); + for (auto op : run) + outputs.push_back(op.getOutput()); + + unsigned concatStartIndex = 0; + auto concatOp = getContiguousConcatUse(ValueRange(outputs), concatStartIndex); + auto concatPackedType = + concatOp ? getPackedConcatSliceType(concatOp, concatStartIndex, static_cast(outputs.size())) + : RankedTensorType {}; + auto packedType = concatPackedType ? concatPackedType : fallbackPackedType; rewriter.setInsertionPoint(run.front()); auto compactReceive = spatial::SpatChannelReceiveTensorBatchOp::create(rewriter, @@ -540,9 +530,15 @@ void compactBatchChannelRuns(func::FuncOp funcOp) { rewriter.getDenseI64ArrayAttr(channelIds), rewriter.getDenseI32ArrayAttr(sourceCoreIds), rewriter.getDenseI32ArrayAttr(targetCoreIds)); - for (auto [index, op] : llvm::enumerate(run)) - op.getOutput().replaceAllUsesWith(extractPackedChunk( - compactReceive.getOutput(), rowType, static_cast(index), rewriter, op.getLoc())); + if (concatOp && concatPackedType) { + replaceConcatRunWithPackedValue( + concatOp, concatStartIndex, static_cast(outputs.size()), compactReceive.getOutput(), rewriter); + } + else { + for (auto [index, op] : llvm::enumerate(run)) + op.getOutput().replaceAllUsesWith(extractPackedChunk( + compactReceive.getOutput(), rowType, static_cast(index), rewriter, op.getLoc())); + } for (auto op : run) rewriter.eraseOp(op);