From ece24867e4f1480a936a64a58e2a7e145d052269 Mon Sep 17 00:00:00 2001 From: ilgeco Date: Thu, 9 Apr 2026 10:24:03 +0200 Subject: [PATCH] Remap weight within cloned operation --- .../MergeComputeNode/MergeComputeNodePass.cpp | 52 ++++++++++++------- 1 file changed, 34 insertions(+), 18 deletions(-) diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNode/MergeComputeNodePass.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNode/MergeComputeNodePass.cpp index 9f39e18..0614414 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNode/MergeComputeNodePass.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNode/MergeComputeNodePass.cpp @@ -130,7 +130,8 @@ public: } for (auto computeNodetoRemove : llvm::make_early_inc_range(llvm::reverse(analysisResult.dominanceOrderCompute))) { - for (auto users : computeNodetoRemove->getUsers()) users->dump(); + for (auto users : computeNodetoRemove->getUsers()) + users->dump(); computeNodetoRemove.erase(); } func::FuncOp func = getOperation(); @@ -220,19 +221,20 @@ private: auto weightMutableIter = toCompute.getWeightsMutable(); for (auto weight : fromCompute.getWeights()) { - //TODO non clonare weight gia' presenti e poi vanno rimappate le nuove OP con i nuovi weight + // TODO non clonare weight gia' presenti e poi vanno rimappate le nuove OP con i nuovi weight auto founded = llvm::find(toCompute.getWeights(), weight); - if(founded == toCompute.getWeights().end()){ + if (founded == toCompute.getWeights().end()) { size_t sizeW = toCompute.getWeights().size(); size_t sizeI = toCompute.getInputs().size(); weightMutableIter.append(weight); - auto last = weightMutableIter.end(); - last = std::prev(last,1); + auto last = weightMutableIter.end(); + last = std::prev(last, 1); mapper.map(weight, last->get()); assert(sizeW + 1 == toCompute.getWeights().size()); assert(sizeI == toCompute.getInputs().size()); assert(sizeW + sizeI + 1 == toCompute.getOperands().size()); - }else { + } + else { mapper.map(weight, *founded); } } @@ -259,18 +261,19 @@ private: } else { - auto founded = llvm::find(toCompute.getInputs(), arg); - if(founded == toCompute.getInputs().end()){ - size_t sizeW = toCompute.getWeights().size(); - size_t sizeI = toCompute.getInputs().size(); - inputeArgMutable.append(arg); - assert(sizeW == toCompute.getWeights().size()); - assert(sizeI + 1 == toCompute.getInputs().size()); - assert(sizeW + sizeI + 1 == toCompute.getOperands().size()); + auto founded = llvm::find(toCompute.getInputs(), arg); + if (founded == toCompute.getInputs().end()) { + size_t sizeW = toCompute.getWeights().size(); + size_t sizeI = toCompute.getInputs().size(); + inputeArgMutable.append(arg); + assert(sizeW == toCompute.getWeights().size()); + assert(sizeI + 1 == toCompute.getInputs().size()); + assert(sizeW + sizeI + 1 == toCompute.getOperands().size()); - toBB.addArgument(fromBB.getArgument(bbIndex).getType(), loc); - mapper.map(fromBB.getArgument(bbIndex), toBB.getArguments().back()); - }else { + toBB.addArgument(fromBB.getArgument(bbIndex).getType(), loc); + mapper.map(fromBB.getArgument(bbIndex), toBB.getArguments().back()); + } + else { auto distance = std::distance(toCompute.getInputs().begin(), founded); mapper.map(fromBB.getArgument(bbIndex), toBB.getArgument(distance)); } @@ -288,7 +291,20 @@ private: rewriter.clone(op, mapper); } else { - rewriter.clone(op, mapper); + auto newInst = rewriter.clone(op, mapper); + //TODO Refactor in a lambda? same code just different cast, but templated lambda are C++20 and a free function is a bit too much + if (auto vmOp = llvm::dyn_cast(newInst)) { + auto oldIndex = vmOp.getWeightIndex(); + auto newWeight = mapper.lookup(*std::next(fromCompute.getWeights().begin(), oldIndex)); + auto newIndex = std::distance(toCompute.getWeights().begin(), llvm::find(toCompute.getWeights(), newWeight)); + vmOp.setWeightIndex(newIndex); + } + if (auto vmOp = llvm::dyn_cast(newInst)) { + auto oldIndex = vmOp.getWeightIndex(); + auto newWeight = mapper.lookup(*std::next(fromCompute.getWeights().begin(), oldIndex)); + auto newIndex = std::distance(toCompute.getWeights().begin(), llvm::find(toCompute.getWeights(), newWeight)); + vmOp.setWeightIndex(newIndex); + } } }