From a7dee5b8403e07fa025697d0e46e7648e75df5d6 Mon Sep 17 00:00:00 2001 From: NiccoloN Date: Tue, 14 Apr 2026 15:13:45 +0200 Subject: [PATCH] fix mergeTriviallyConnectedComputes --- .../ONNXToSpatial/ONNXToSpatialPass.cpp | 74 +++++++++---------- 1 file changed, 35 insertions(+), 39 deletions(-) diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp index 862fa72..b503837 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp @@ -8,6 +8,7 @@ #include "mlir/Transforms/Passes.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" @@ -22,8 +23,8 @@ #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" -#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.hpp" #include "src/Accelerators/PIM/Pass/PIMPasses.h" #include "src/Compiler/CompilerOptions.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" @@ -51,7 +52,7 @@ struct ONNXToSpatialPass : PassWrapper computeSingleChild; +void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) { Location loc = funcOp.getLoc(); IRRewriter rewriter(&getContext()); + SmallVector trivialComputes; + llvm::SmallSet toErase; + for (auto compute : funcOp.getOps()) - if (std::distance(compute->getUses().begin(), compute->getUses().end()) == 1) { + if (compute->hasOneUse()) { auto user = *compute->getUsers().begin(); - if (user->getNumOperands() == 1) - if (llvm::isa(user)) - computeSingleChild.push_back(compute); + if (llvm::isa(user) && user->getNumOperands() == 1) + trivialComputes.push_back(compute); } - IRMapping mapper; - while (!computeSingleChild.empty()) { - auto compute = computeSingleChild.front(); - auto child = dyn_cast_if_present(*compute->getUsers().begin()); - assert(child && "Child required!"); + while (!trivialComputes.empty()) { + auto compute = trivialComputes.front(); + + if (compute.use_empty()) { + std::swap(trivialComputes.front(), trivialComputes.back()); + trivialComputes.pop_back(); + continue; + } + auto child = cast(*compute->getUsers().begin()); + rewriter.setInsertionPointAfter(compute.getOperation()); auto newCompute = spatial::SpatWeightedCompute::create(rewriter, loc, child.getResultTypes(), compute.getOperands()); newCompute.getProperties().setOperandSegmentSizes( - {(int) compute.getWeights().size(), (int) compute.getInputs().size()}); - llvm::dbgs() << "After Creation\n"; - newCompute.dump(); + {static_cast(compute.getWeights().size()), static_cast(compute.getInputs().size())}); + IRMapping mapper; compute.getBodyRegion().cloneInto(&newCompute.getBodyRegion(), mapper); - llvm::dbgs() << "After Clone\n"; - newCompute.dump(); auto newTerminator = newCompute.getBody().front().getTerminator(); mapper.map(*child.getBody().front().getArguments().begin(), newTerminator->getOperand(0)); newTerminator->erase(); - llvm::dbgs() << "After terminator\n"; - newCompute.dump(); rewriter.setInsertionPoint(&newCompute.getBody().front(), newCompute.getBody().front().end()); - for (auto& op : child.getBody().front()) rewriter.clone(op, mapper); child.replaceAllUsesWith(newCompute); - assert(child->getUses().empty() && "It's not obvius"); - llvm::dbgs() << "Node\n"; - newCompute.dump(); + toErase.insert(child); - llvm::dbgs() << "Parent\n"; - compute.dump(); + std::swap(trivialComputes.front(), trivialComputes.back()); + trivialComputes.pop_back(); + toErase.insert(compute); - llvm::dbgs() << "Child\n"; - child.dump(); - - child.erase(); - compute.erase(); - - if (std::distance(newCompute->getUses().begin(), newCompute->getUses().end()) == 1) { + if (newCompute->hasOneUse()) { auto user = *newCompute->getUsers().begin(); - if (user->getNumOperands() == 1) - if (llvm::isa(user)) - computeSingleChild.push_back(newCompute); + if (llvm::isa(user) && user->getNumOperands() == 1) + trivialComputes.push_back(newCompute); } - std::swap(computeSingleChild.front(), computeSingleChild.back()); - computeSingleChild.pop_back(); + } + + for (auto compute : toErase) { + compute.getResult(0).dropAllUses(); + compute.erase(); } }