Remap weight within cloned operation
All checks were successful
Validate Operations / validate-operations (push) Successful in 16m20s
All checks were successful
Validate Operations / validate-operations (push) Successful in 16m20s
This commit is contained in:
@@ -130,7 +130,8 @@ public:
|
||||
}
|
||||
|
||||
for (auto computeNodetoRemove : llvm::make_early_inc_range(llvm::reverse(analysisResult.dominanceOrderCompute))) {
|
||||
for (auto users : computeNodetoRemove->getUsers()) users->dump();
|
||||
for (auto users : computeNodetoRemove->getUsers())
|
||||
users->dump();
|
||||
computeNodetoRemove.erase();
|
||||
}
|
||||
func::FuncOp func = getOperation();
|
||||
@@ -232,7 +233,8 @@ private:
|
||||
assert(sizeW + 1 == toCompute.getWeights().size());
|
||||
assert(sizeI == toCompute.getInputs().size());
|
||||
assert(sizeW + sizeI + 1 == toCompute.getOperands().size());
|
||||
}else {
|
||||
}
|
||||
else {
|
||||
mapper.map(weight, *founded);
|
||||
}
|
||||
}
|
||||
@@ -270,7 +272,8 @@ private:
|
||||
|
||||
toBB.addArgument(fromBB.getArgument(bbIndex).getType(), loc);
|
||||
mapper.map(fromBB.getArgument(bbIndex), toBB.getArguments().back());
|
||||
}else {
|
||||
}
|
||||
else {
|
||||
auto distance = std::distance(toCompute.getInputs().begin(), founded);
|
||||
mapper.map(fromBB.getArgument(bbIndex), toBB.getArgument(distance));
|
||||
}
|
||||
@@ -288,7 +291,20 @@ private:
|
||||
rewriter.clone(op, mapper);
|
||||
}
|
||||
else {
|
||||
rewriter.clone(op, mapper);
|
||||
auto newInst = rewriter.clone(op, mapper);
|
||||
//TODO Refactor in a lambda? same code just different cast, but templated lambda are C++20 and a free function is a bit too much
|
||||
if (auto vmOp = llvm::dyn_cast<spatial::SpatWeightedMVMOp>(newInst)) {
|
||||
auto oldIndex = vmOp.getWeightIndex();
|
||||
auto newWeight = mapper.lookup(*std::next(fromCompute.getWeights().begin(), oldIndex));
|
||||
auto newIndex = std::distance(toCompute.getWeights().begin(), llvm::find(toCompute.getWeights(), newWeight));
|
||||
vmOp.setWeightIndex(newIndex);
|
||||
}
|
||||
if (auto vmOp = llvm::dyn_cast<spatial::SpatWeightedVMMOp>(newInst)) {
|
||||
auto oldIndex = vmOp.getWeightIndex();
|
||||
auto newWeight = mapper.lookup(*std::next(fromCompute.getWeights().begin(), oldIndex));
|
||||
auto newIndex = std::distance(toCompute.getWeights().begin(), llvm::find(toCompute.getWeights(), newWeight));
|
||||
vmOp.setWeightIndex(newIndex);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user