From 3f870fb74ba21ca6e87c5a973b4e1fcb725cb0e8 Mon Sep 17 00:00:00 2001 From: ilgeco Date: Wed, 8 Apr 2026 20:39:01 +0200 Subject: [PATCH] MergeDCP pass all test --- .../ONNXToSpatial/ONNXToSpatialPass.cpp | 82 ++++++++++++++ .../Dialect/Spatial/DCPGraph/DCPAnalysis.cpp | 17 ++- .../MergeComputeNode/MergeComputeNodePass.cpp | 102 ++++++++++-------- 3 files changed, 157 insertions(+), 44 deletions(-) diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp index d74ae85..1649c34 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp @@ -5,11 +5,15 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_os_ostream.h" #include +#include "Common.hpp" #include "Common/PimCommon.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" @@ -42,6 +46,7 @@ struct ONNXToSpatialPass : PassWrapper +bool encapsulator(IRRewriter& rewriter, Location loc, Operation* inst, std::function funcSource) { + if (T toRemoveOp = llvm::dyn_cast_if_present(inst)) { + Value source = funcSource(toRemoveOp); + rewriter.setInsertionPointAfter(toRemoveOp); + if (isa_and_present(source.getDefiningOp())) { + auto newCompute = spatial::SpatWeightedCompute::create(rewriter, loc, inst->getResultTypes().front(), source); + auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), {source.getType()}, {loc}); + newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) 1}); + rewriter.setInsertionPointToEnd(BB); + IRMapping mapper; + mapper.map(source, BB->getArgument(0)); + auto newInst = rewriter.clone(*inst, mapper); + spatial::SpatYieldOp::create(rewriter, loc, newInst->getResult(0)); + inst->replaceAllUsesWith(newCompute); + inst->erase(); + return true; + } + } + return false; +} + +bool encapsulateConcat(IRRewriter& rewriter, Location loc, Operation* inst) { + if (auto toRemoveOp = llvm::dyn_cast_if_present(inst)) { + auto sources = toRemoveOp.getInputs(); + rewriter.setInsertionPointAfter(toRemoveOp); + if (llvm::any_of( + sources, [](auto source) { return isa_and_present(source.getDefiningOp()); })) { + auto newCompute = spatial::SpatWeightedCompute::create(rewriter, loc, inst->getResultTypes().front(), sources); + llvm::SmallVector sourceTypes; + llvm::SmallVector sourceLoc; + for (auto source : sources){ + sourceTypes.push_back(source.getType()); + sourceLoc.push_back(loc); + } + auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLoc); + newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) sources.size()}); + rewriter.setInsertionPointToEnd(BB); + IRMapping mapper; + for(auto [source,bbArg] : llvm::zip(sources, BB->getArguments())) + mapper.map(source, bbArg); + auto newConcat = rewriter.clone(*inst, mapper); + spatial::SpatYieldOp::create(rewriter, loc, newConcat->getResult(0)); + inst->replaceAllUsesWith(newCompute); + inst->erase(); + return true; + } + } + return false; +} + +// TODO what we want to keep in global? +void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) { + Location loc = funcOp.getLoc(); + IRRewriter rewriter(&getContext()); + bool keep = true; + while (keep) { + keep = false; + for (auto& instruction : llvm::make_early_inc_range(funcOp.getOps())) { + keep |= encapsulator( + rewriter, loc, &instruction, [](tensor::ExtractSliceOp extract) { return extract.getSource(); }); + + keep |= encapsulator( + rewriter, loc, &instruction, [](tensor::ExpandShapeOp expand) { return expand.getSrc(); }); + + keep |= encapsulator( + rewriter, loc, &instruction, [](ONNXTransposeOp transpose) { return transpose.getData(); }); + + keep |= encapsulator( + rewriter, loc, &instruction, [](tensor::CollapseShapeOp collapse) { return collapse.getSrc(); }); + + keep |= encapsulateConcat(rewriter, loc, &instruction); + } + } +} + void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const { funcOp.walk([&](arith::ConstantOp constantOp) { bool isAlwaysWeight = diff --git a/src/PIM/Dialect/Spatial/DCPGraph/DCPAnalysis.cpp b/src/PIM/Dialect/Spatial/DCPGraph/DCPAnalysis.cpp index fa078a4..7a0ed07 100644 --- a/src/PIM/Dialect/Spatial/DCPGraph/DCPAnalysis.cpp +++ b/src/PIM/Dialect/Spatial/DCPGraph/DCPAnalysis.cpp @@ -1,3 +1,4 @@ +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" @@ -17,6 +18,19 @@ namespace spatial { using namespace mlir; +SpatWeightedCompute getOriginalSpatWeightCompute(Operation* op) { + if (!op) + return {}; + while (auto extract = llvm::dyn_cast(op)) { + op = extract.getSource().getDefiningOp(); + if (!op) + return {}; + } + if (auto res = llvm::dyn_cast(op)) + return res; + return {}; +} + DCPAnalysisResult DCPAnalysis::runAnalysis() { using EdgesIndex = std::tuple; llvm::SmallVector spatWeightedComputes; @@ -27,8 +41,7 @@ DCPAnalysisResult DCPAnalysis::runAnalysis() { for (auto [indexEndEdge, spatWeightedCompute] : llvm::enumerate(spatWeightedComputes)) { for (Value input : spatWeightedCompute.getInputs()) { - if (auto spatWeightedComputeArgOp = llvm::dyn_cast_if_present(input.getDefiningOp()); - spatWeightedComputeArgOp) { + if (auto spatWeightedComputeArgOp = getOriginalSpatWeightCompute(input.getDefiningOp())) { auto elemIter = llvm::find(spatWeightedComputes, spatWeightedComputeArgOp); assert(elemIter != spatWeightedComputes.end()); auto indexStartEdge = std::distance(spatWeightedComputes.begin(), elemIter); diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNode/MergeComputeNodePass.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNode/MergeComputeNodePass.cpp index 3d5da4d..9f39e18 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNode/MergeComputeNodePass.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNode/MergeComputeNodePass.cpp @@ -11,9 +11,12 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Debug.h" +#include #include #include +#include #include #include "src/Accelerators/PIM/Common/PimCommon.hpp" @@ -62,15 +65,12 @@ public: auto [first, second] = channelNewInserter(); channelNewOpVal = first; channelSendInserter = second; - auto op = computeResults.innerValue.getDefiningOp(); - if (op) { - insertPointSend = InsertPoint(op->getBlock(), ++Block::iterator(op)); - } - else { - auto BB = computeResults.innerValue.getParentBlock(); - insertPointSend = InsertPoint(BB, BB->begin()); - } } + auto BB = computeResults.innerValue.getParentBlock(); + if (!BB->empty() && isa(BB->back())) + insertPointSend = InsertPoint(BB, --BB->end()); + else + insertPointSend = InsertPoint(BB, BB->end()); if (spatWeightedCompute) { for (auto& BB : spatWeightedCompute.getBody()) if (&BB == insertPointSend.getBlock()) @@ -129,8 +129,10 @@ public: } } - for (auto computeNodetoRemove : llvm::make_early_inc_range(llvm::reverse(analysisResult.dominanceOrderCompute))) + for (auto computeNodetoRemove : llvm::make_early_inc_range(llvm::reverse(analysisResult.dominanceOrderCompute))) { + for (auto users : computeNodetoRemove->getUsers()) users->dump(); computeNodetoRemove.erase(); + } func::FuncOp func = getOperation(); dumpModule(cast(func->getParentOp()), "SpatialDCPMerged"); } @@ -154,36 +156,40 @@ private: newComputeOperand.push_back(arg); for (auto arg : oldWeightedCompute.getInputs()) - if (!llvm::isa(arg.getDefiningOp())) { + if (!llvm::isa_and_present(arg.getDefiningOp())) { newComputeOperand.push_back(arg); newBBOperandType.push_back(arg.getType()); newBBLocations.push_back(loc); } auto newWeightedCompute = SpatWeightedCompute::create(rewriter, loc, newWeightedComputeType, newComputeOperand); + rewriter.createBlock( &newWeightedCompute.getBody(), newWeightedCompute.getBody().end(), newBBOperandType, newBBLocations); newWeightedCompute.getProperties().setOperandSegmentSizes( {(int) oldWeightedCompute.getWeights().size(), (int) newBBOperandType.size()}); - rewriter.setInsertionPointToEnd(&newWeightedCompute.getBody().front()); + + auto& newBB = newWeightedCompute.getBody().front(); + auto& oldBB = oldWeightedCompute.getBody().front(); + rewriter.setInsertionPointToEnd(&newBB); int indexNew = 0; - int indexOld = oldWeightedCompute.getWeights().size(); - int indexOldStart = oldWeightedCompute.getWeights().size(); + size_t indexOld = oldWeightedCompute.getWeights().size(); + size_t indexOldStart = oldWeightedCompute.getWeights().size(); for (; indexOld < oldWeightedCompute.getNumOperands(); ++indexOld) { - if (!llvm::isa(oldWeightedCompute.getOperand(indexOld).getDefiningOp())) { - mapper.map(oldWeightedCompute.getBody().front().getArgument(indexOld - indexOldStart), - newWeightedCompute.getBody().front().getArgument(indexNew++)); + if (!llvm::isa_and_present(oldWeightedCompute.getOperand(indexOld).getDefiningOp())) { + mapper.map(oldBB.getArgument(indexOld - indexOldStart), newBB.getArgument(indexNew++)); } else { auto argWeightCompute = llvm::dyn_cast_if_present(oldWeightedCompute.getOperand(indexOld).getDefiningOp()); LazyInsertComputeResult& lazyArgWeight = newComputeNodeResults.at(argWeightCompute); - auto [channelVal, _] = lazyArgWeight.getAsChannelValueAndInsertSender(); + auto [channelVal, isChannel] = lazyArgWeight.getAsChannelValueAndInsertSender(); + assert(isChannel == true); spatial::SpatChannelReceiveOp reciveOp = - spatial::SpatChannelReceiveOp::create(rewriter, loc, channelVal.getType(), channelVal); - mapper.map(oldWeightedCompute.getBody().front().getArgument(indexOld - indexOldStart), reciveOp); + spatial::SpatChannelReceiveOp::create(rewriter, loc, argWeightCompute.getType(0), channelVal); + mapper.map(oldBB.getArgument(indexOld - indexOldStart), reciveOp); } } @@ -214,19 +220,28 @@ private: auto weightMutableIter = toCompute.getWeightsMutable(); for (auto weight : fromCompute.getWeights()) { - int sizeW = toCompute.getWeights().size(); - int sizeI = toCompute.getInputs().size(); - weightMutableIter.append(weight); - assert(sizeW + 1 == toCompute.getWeights().size()); - assert(sizeI == toCompute.getInputs().size()); - assert(sizeW + sizeI + 1 == toCompute.getOperands().size()); + //TODO non clonare weight gia' presenti e poi vanno rimappate le nuove OP con i nuovi weight + auto founded = llvm::find(toCompute.getWeights(), weight); + if(founded == toCompute.getWeights().end()){ + size_t sizeW = toCompute.getWeights().size(); + size_t sizeI = toCompute.getInputs().size(); + weightMutableIter.append(weight); + auto last = weightMutableIter.end(); + last = std::prev(last,1); + mapper.map(weight, last->get()); + assert(sizeW + 1 == toCompute.getWeights().size()); + assert(sizeI == toCompute.getInputs().size()); + assert(sizeW + sizeI + 1 == toCompute.getOperands().size()); + }else { + mapper.map(weight, *founded); + } } - + auto& toBB = toCompute.getBody().front(); + auto& fromBB = fromCompute.getBody().front(); auto inputeArgMutable = toCompute.getInputsMutable(); // Insert reciveOp - rewriter.setInsertionPointToEnd(&toCompute.getBody().front()); - int newBBindex = toCompute.getBody().front().getArguments().size(); + rewriter.setInsertionPointToEnd(&toBB); for (auto [bbIndex, arg] : llvm::enumerate(fromCompute.getInputs())) { if (auto argWeightCompute = llvm::dyn_cast_if_present(arg.getDefiningOp())) { LazyInsertComputeResult& lazyArgWeight = newComputeNodeResults.at(argWeightCompute); @@ -236,30 +251,33 @@ private: if (channelOrLocal.isChannel) { spatial::SpatChannelReceiveOp reciveOp = spatial::SpatChannelReceiveOp::create(rewriter, loc, argWeightCompute.getType(0), channelOrLocal.data); - mapper.map(fromCompute.getBody().front().getArgument(bbIndex), reciveOp.getResult()); + mapper.map(fromBB.getArgument(bbIndex), reciveOp.getResult()); } else { - mapper.map(fromCompute.getBody().front().getArgument(bbIndex), channelOrLocal.data); + mapper.map(fromBB.getArgument(bbIndex), channelOrLocal.data); } } else { - int sizeW = toCompute.getWeights().size(); - int sizeI = toCompute.getInputs().size(); - inputeArgMutable.append(arg); - assert(sizeW == toCompute.getWeights().size()); - assert(sizeI + 1 == toCompute.getInputs().size()); - assert(sizeW + sizeI + 1 == toCompute.getOperands().size()); + auto founded = llvm::find(toCompute.getInputs(), arg); + if(founded == toCompute.getInputs().end()){ + size_t sizeW = toCompute.getWeights().size(); + size_t sizeI = toCompute.getInputs().size(); + inputeArgMutable.append(arg); + assert(sizeW == toCompute.getWeights().size()); + assert(sizeI + 1 == toCompute.getInputs().size()); + assert(sizeW + sizeI + 1 == toCompute.getOperands().size()); - toCompute.getBody().front().addArgument( - fromCompute.getBody().front().getArgument(bbIndex).getType(),loc); - - mapper.map(fromCompute.getBody().front().getArgument(bbIndex), - toCompute.getBody().front().getArgument(newBBindex++)); + toBB.addArgument(fromBB.getArgument(bbIndex).getType(), loc); + mapper.map(fromBB.getArgument(bbIndex), toBB.getArguments().back()); + }else { + auto distance = std::distance(toCompute.getInputs().begin(), founded); + mapper.map(fromBB.getArgument(bbIndex), toBB.getArgument(distance)); + } } } - for (auto oldBBarg : fromCompute.getBody().front().getArguments()) + for (auto oldBBarg : fromBB.getArguments()) assert(mapper.contains(oldBBarg)); ComputeValueResults computeValueResults;