Remap weight within cloned operation
All checks were successful
Validate Operations / validate-operations (push) Successful in 16m20s

This commit is contained in:
ilgeco
2026-04-09 10:24:03 +02:00
parent 3f870fb74b
commit ece24867e4

View File

@@ -130,7 +130,8 @@ public:
} }
for (auto computeNodetoRemove : llvm::make_early_inc_range(llvm::reverse(analysisResult.dominanceOrderCompute))) { for (auto computeNodetoRemove : llvm::make_early_inc_range(llvm::reverse(analysisResult.dominanceOrderCompute))) {
for (auto users : computeNodetoRemove->getUsers()) users->dump(); for (auto users : computeNodetoRemove->getUsers())
users->dump();
computeNodetoRemove.erase(); computeNodetoRemove.erase();
} }
func::FuncOp func = getOperation(); func::FuncOp func = getOperation();
@@ -220,19 +221,20 @@ private:
auto weightMutableIter = toCompute.getWeightsMutable(); auto weightMutableIter = toCompute.getWeightsMutable();
for (auto weight : fromCompute.getWeights()) { for (auto weight : fromCompute.getWeights()) {
//TODO non clonare weight gia' presenti e poi vanno rimappate le nuove OP con i nuovi weight // TODO non clonare weight gia' presenti e poi vanno rimappate le nuove OP con i nuovi weight
auto founded = llvm::find(toCompute.getWeights(), weight); auto founded = llvm::find(toCompute.getWeights(), weight);
if(founded == toCompute.getWeights().end()){ if (founded == toCompute.getWeights().end()) {
size_t sizeW = toCompute.getWeights().size(); size_t sizeW = toCompute.getWeights().size();
size_t sizeI = toCompute.getInputs().size(); size_t sizeI = toCompute.getInputs().size();
weightMutableIter.append(weight); weightMutableIter.append(weight);
auto last = weightMutableIter.end(); auto last = weightMutableIter.end();
last = std::prev(last,1); last = std::prev(last, 1);
mapper.map(weight, last->get()); mapper.map(weight, last->get());
assert(sizeW + 1 == toCompute.getWeights().size()); assert(sizeW + 1 == toCompute.getWeights().size());
assert(sizeI == toCompute.getInputs().size()); assert(sizeI == toCompute.getInputs().size());
assert(sizeW + sizeI + 1 == toCompute.getOperands().size()); assert(sizeW + sizeI + 1 == toCompute.getOperands().size());
}else { }
else {
mapper.map(weight, *founded); mapper.map(weight, *founded);
} }
} }
@@ -259,18 +261,19 @@ private:
} }
else { else {
auto founded = llvm::find(toCompute.getInputs(), arg); auto founded = llvm::find(toCompute.getInputs(), arg);
if(founded == toCompute.getInputs().end()){ if (founded == toCompute.getInputs().end()) {
size_t sizeW = toCompute.getWeights().size(); size_t sizeW = toCompute.getWeights().size();
size_t sizeI = toCompute.getInputs().size(); size_t sizeI = toCompute.getInputs().size();
inputeArgMutable.append(arg); inputeArgMutable.append(arg);
assert(sizeW == toCompute.getWeights().size()); assert(sizeW == toCompute.getWeights().size());
assert(sizeI + 1 == toCompute.getInputs().size()); assert(sizeI + 1 == toCompute.getInputs().size());
assert(sizeW + sizeI + 1 == toCompute.getOperands().size()); assert(sizeW + sizeI + 1 == toCompute.getOperands().size());
toBB.addArgument(fromBB.getArgument(bbIndex).getType(), loc); toBB.addArgument(fromBB.getArgument(bbIndex).getType(), loc);
mapper.map(fromBB.getArgument(bbIndex), toBB.getArguments().back()); mapper.map(fromBB.getArgument(bbIndex), toBB.getArguments().back());
}else { }
else {
auto distance = std::distance(toCompute.getInputs().begin(), founded); auto distance = std::distance(toCompute.getInputs().begin(), founded);
mapper.map(fromBB.getArgument(bbIndex), toBB.getArgument(distance)); mapper.map(fromBB.getArgument(bbIndex), toBB.getArgument(distance));
} }
@@ -288,7 +291,20 @@ private:
rewriter.clone(op, mapper); rewriter.clone(op, mapper);
} }
else { else {
rewriter.clone(op, mapper); auto newInst = rewriter.clone(op, mapper);
//TODO Refactor in a lambda? same code just different cast, but templated lambda are C++20 and a free function is a bit too much
if (auto vmOp = llvm::dyn_cast<spatial::SpatWeightedMVMOp>(newInst)) {
auto oldIndex = vmOp.getWeightIndex();
auto newWeight = mapper.lookup(*std::next(fromCompute.getWeights().begin(), oldIndex));
auto newIndex = std::distance(toCompute.getWeights().begin(), llvm::find(toCompute.getWeights(), newWeight));
vmOp.setWeightIndex(newIndex);
}
if (auto vmOp = llvm::dyn_cast<spatial::SpatWeightedVMMOp>(newInst)) {
auto oldIndex = vmOp.getWeightIndex();
auto newWeight = mapper.lookup(*std::next(fromCompute.getWeights().begin(), oldIndex));
auto newIndex = std::distance(toCompute.getWeights().begin(), llvm::find(toCompute.getWeights(), newWeight));
vmOp.setWeightIndex(newIndex);
}
} }
} }