From 74931ad75b494fb439132e54d90057cb01fcc762 Mon Sep 17 00:00:00 2001 From: ilgeco Date: Thu, 7 May 2026 16:47:01 +0200 Subject: [PATCH] Single Concat Fix --- .gitignore | 2 + src/PIM/Compiler/PimCodeGen.cpp | 26 --------- .../ONNXToSpatial/ONNXToSpatialPass.cpp | 54 +++++++++++++++++-- src/PIM/Conversion/SpatialToPim/Patterns.cpp | 12 ++++- 4 files changed, 62 insertions(+), 32 deletions(-) diff --git a/.gitignore b/.gitignore index bf5221d..a273593 100644 --- a/.gitignore +++ b/.gitignore @@ -9,7 +9,9 @@ AGENTS.md CMakeUserPresets.json build +build_release cmake-build-debug cmake-build-release +compile.sh **/__* diff --git a/src/PIM/Compiler/PimCodeGen.cpp b/src/PIM/Compiler/PimCodeGen.cpp index fed460b..b12cf87 100644 --- a/src/PIM/Compiler/PimCodeGen.cpp +++ b/src/PIM/Compiler/PimCodeGen.cpp @@ -169,32 +169,6 @@ void PimMemory::report(llvm::raw_ostream& file) { } } -// void PimMemory::report(llvm::raw_ostream& file) { -// std::vector orderedList(globalMemEntriesMap.begin(), globalMemEntriesMap.end()); -// std::sort( -// orderedList.begin(), orderedList.end(), [](auto lft, auto rgt) { return lft.second.address < rgt.second.address; -// }); -// auto newEnd = std::unique(orderedList.begin(), orderedList.end(), [](auto lft, auto rgt) { -// return (lft.first.getDefiningOp() == rgt.first.getDefiningOp()) && (lft.second.address == rgt.second.address); -// }); -// orderedList.erase(newEnd, orderedList.end()); -// mlir::OpPrintingFlags flags; -// flags.assumeVerified(true); -// for (auto& [value, memEntry] : orderedList) { -// if (auto op = value.getDefiningOp()) { -// file.indent(4) << op << ": "; -// op->print(file, flags); -// file << "\n"; -// file.indent(6) << "Address: " << llvm::format_hex(memEntry.address, 10) << "\n"; -// file.indent(6) << "Memory: " << formatMemory(memEntry.size) << "\n"; -// } -// else { -// file.indent(4) << value << "\n"; -// file.indent(6) << "Address: " << llvm::format_hex(memEntry.address, 10) << "\n"; -// file.indent(6) << "Memory: " << formatMemory(memEntry.size) << "\n"; -// } -// } -// } void PimMemory::remove(mlir::Value val) { if (auto removeIter = globalMemEntriesMap.find(val); removeIter != globalMemEntriesMap.end()) diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp index 07ff217..c08fa4e 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp @@ -49,6 +49,7 @@ private: void annotateWeightsConstants(func::FuncOp funcOp) const; LogicalResult encapsulateGlobalInstruction(func::FuncOp funcOp); LogicalResult promoteConstantInputsToWeights(func::FuncOp funcOp); + void populateEmptyFunction(func::FuncOp funcOp); }; } // namespace @@ -64,7 +65,8 @@ static void foldSingleLaneComputeBatches(func::FuncOp funcOp) { auto loc = batchOp.getLoc(); rewriter.setInsertionPoint(batchOp); - auto computeOp = spatial::SpatCompute::create(rewriter, loc, batchOp.getResultTypes(), batchOp.getWeights(), batchOp.getInputs()); + auto computeOp = + spatial::SpatCompute::create(rewriter, loc, batchOp.getResultTypes(), batchOp.getWeights(), batchOp.getInputs()); computeOp.getProperties().setOperandSegmentSizes( {static_cast(batchOp.getWeights().size()), static_cast(batchOp.getInputs().size())}); @@ -75,8 +77,8 @@ static void foldSingleLaneComputeBatches(func::FuncOp funcOp) { blockArgTypes.push_back(arg.getType()); blockArgLocs.push_back(loc); } - auto* newBlock = rewriter.createBlock( - &computeOp.getBody(), computeOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs); + auto* newBlock = + rewriter.createBlock(&computeOp.getBody(), computeOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs); IRMapping mapper; for (auto [oldArg, newArg] : llvm::zip(templateBlock.getArguments(), newBlock->getArguments())) @@ -183,6 +185,8 @@ void ONNXToSpatialPass::runOnOperation() { annotateWeightsConstants(*entryFunc); + populateEmptyFunction(*entryFunc); + if (failed(encapsulateGlobalInstruction(*entryFunc))) { signalPassFailure(); return; @@ -376,8 +380,7 @@ LogicalResult ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcO while (keep) { keep = false; for (auto& instruction : llvm::make_early_inc_range(funcOp.getOps())) { - if (isa( - instruction) + if (isa(instruction) || isa(instruction)) continue; @@ -490,6 +493,47 @@ LogicalResult ONNXToSpatialPass::promoteConstantInputsToWeights(func::FuncOp fun return success(); } +void ONNXToSpatialPass::populateEmptyFunction(func::FuncOp funcOp) { + IRRewriter rewriter(&getContext()); + IRMapping mapper; + SmallVector computes(funcOp.getOps()); + if (!computes.empty()) + return; + auto returnOp = llvm::cast(funcOp.getRegion().front().getTerminator()); + rewriter.setInsertionPoint(returnOp); + + SmallVector sourceTypes; + SmallVector sourceLoc; + for (auto source : funcOp.getArguments()) { + sourceTypes.push_back(source.getType()); + sourceLoc.push_back(source.getLoc()); + } + + auto newCompute = spatial::SpatCompute::create( + rewriter, returnOp.getLoc(), returnOp.getOperandTypes(), funcOp.getArguments(), {}, {}); + auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLoc); + for (auto [bbArg, computeArg] : llvm::zip(BB->getArguments(), newCompute.getOperands())) + mapper.map(computeArg, bbArg); + newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) sourceTypes.size()}); + rewriter.setInsertionPointToEnd(BB); + for (Operation& inst : funcOp.getOps()) + if (!isa(&inst)) + rewriter.clone(inst, mapper); + + auto yield = spatial::SpatYieldOp::create(rewriter, funcOp.getLoc(), returnOp.getOperands()); + for (size_t i = 0; i < yield.getNumOperands(); ++i) + yield.setOperand(i, mapper.lookupOrDefault(yield.getOperand(i))); + + for (Operation& inst : llvm::make_early_inc_range(funcOp.getOps())) + if (!isa(&inst)){ + inst.dropAllUses(); + rewriter.eraseOp(&inst); + } + + for (auto [index, computeResult] : llvm::enumerate(newCompute.getResults())) + returnOp.setOperand(index, computeResult); +} + std::unique_ptr createONNXToSpatialPass() { return std::make_unique(); } } // namespace onnx_mlir diff --git a/src/PIM/Conversion/SpatialToPim/Patterns.cpp b/src/PIM/Conversion/SpatialToPim/Patterns.cpp index 52228c9..8c43f6a 100644 --- a/src/PIM/Conversion/SpatialToPim/Patterns.cpp +++ b/src/PIM/Conversion/SpatialToPim/Patterns.cpp @@ -23,6 +23,16 @@ using namespace mlir; namespace onnx_mlir { namespace { +struct EraseSpatNopPattern : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + mlir::LogicalResult matchAndRewrite(spatial::SpatNopOp op, + mlir::PatternRewriter &rewriter) const override { + rewriter.eraseOp(op); + return mlir::success(); + } +}; + + static std::optional getDirectComputeInputIndex(Operation* owner, unsigned operandNumber) { if (auto compute = dyn_cast(owner)) { unsigned inputCount = compute.getInputs().size(); @@ -416,7 +426,7 @@ struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern( + patterns.add( patterns.getContext()); }