Remap weight within cloned operation
All checks were successful
Validate Operations / validate-operations (push) Successful in 16m20s

This commit is contained in:
ilgeco
2026-04-09 10:24:03 +02:00
parent 3f870fb74b
commit ece24867e4

View File

@@ -130,7 +130,8 @@ public:
}
for (auto computeNodetoRemove : llvm::make_early_inc_range(llvm::reverse(analysisResult.dominanceOrderCompute))) {
for (auto users : computeNodetoRemove->getUsers()) users->dump();
for (auto users : computeNodetoRemove->getUsers())
users->dump();
computeNodetoRemove.erase();
}
func::FuncOp func = getOperation();
@@ -232,7 +233,8 @@ private:
assert(sizeW + 1 == toCompute.getWeights().size());
assert(sizeI == toCompute.getInputs().size());
assert(sizeW + sizeI + 1 == toCompute.getOperands().size());
}else {
}
else {
mapper.map(weight, *founded);
}
}
@@ -270,7 +272,8 @@ private:
toBB.addArgument(fromBB.getArgument(bbIndex).getType(), loc);
mapper.map(fromBB.getArgument(bbIndex), toBB.getArguments().back());
}else {
}
else {
auto distance = std::distance(toCompute.getInputs().begin(), founded);
mapper.map(fromBB.getArgument(bbIndex), toBB.getArgument(distance));
}
@@ -288,7 +291,20 @@ private:
rewriter.clone(op, mapper);
}
else {
rewriter.clone(op, mapper);
auto newInst = rewriter.clone(op, mapper);
//TODO Refactor in a lambda? same code just different cast, but templated lambda are C++20 and a free function is a bit too much
if (auto vmOp = llvm::dyn_cast<spatial::SpatWeightedMVMOp>(newInst)) {
auto oldIndex = vmOp.getWeightIndex();
auto newWeight = mapper.lookup(*std::next(fromCompute.getWeights().begin(), oldIndex));
auto newIndex = std::distance(toCompute.getWeights().begin(), llvm::find(toCompute.getWeights(), newWeight));
vmOp.setWeightIndex(newIndex);
}
if (auto vmOp = llvm::dyn_cast<spatial::SpatWeightedVMMOp>(newInst)) {
auto oldIndex = vmOp.getWeightIndex();
auto newWeight = mapper.lookup(*std::next(fromCompute.getWeights().begin(), oldIndex));
auto newIndex = std::distance(toCompute.getWeights().begin(), llvm::find(toCompute.getWeights(), newWeight));
vmOp.setWeightIndex(newIndex);
}
}
}