Remap weight within cloned operation
All checks were successful
Validate Operations / validate-operations (push) Successful in 16m20s
All checks were successful
Validate Operations / validate-operations (push) Successful in 16m20s
This commit is contained in:
@@ -130,7 +130,8 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (auto computeNodetoRemove : llvm::make_early_inc_range(llvm::reverse(analysisResult.dominanceOrderCompute))) {
|
for (auto computeNodetoRemove : llvm::make_early_inc_range(llvm::reverse(analysisResult.dominanceOrderCompute))) {
|
||||||
for (auto users : computeNodetoRemove->getUsers()) users->dump();
|
for (auto users : computeNodetoRemove->getUsers())
|
||||||
|
users->dump();
|
||||||
computeNodetoRemove.erase();
|
computeNodetoRemove.erase();
|
||||||
}
|
}
|
||||||
func::FuncOp func = getOperation();
|
func::FuncOp func = getOperation();
|
||||||
@@ -232,7 +233,8 @@ private:
|
|||||||
assert(sizeW + 1 == toCompute.getWeights().size());
|
assert(sizeW + 1 == toCompute.getWeights().size());
|
||||||
assert(sizeI == toCompute.getInputs().size());
|
assert(sizeI == toCompute.getInputs().size());
|
||||||
assert(sizeW + sizeI + 1 == toCompute.getOperands().size());
|
assert(sizeW + sizeI + 1 == toCompute.getOperands().size());
|
||||||
}else {
|
}
|
||||||
|
else {
|
||||||
mapper.map(weight, *founded);
|
mapper.map(weight, *founded);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -270,7 +272,8 @@ private:
|
|||||||
|
|
||||||
toBB.addArgument(fromBB.getArgument(bbIndex).getType(), loc);
|
toBB.addArgument(fromBB.getArgument(bbIndex).getType(), loc);
|
||||||
mapper.map(fromBB.getArgument(bbIndex), toBB.getArguments().back());
|
mapper.map(fromBB.getArgument(bbIndex), toBB.getArguments().back());
|
||||||
}else {
|
}
|
||||||
|
else {
|
||||||
auto distance = std::distance(toCompute.getInputs().begin(), founded);
|
auto distance = std::distance(toCompute.getInputs().begin(), founded);
|
||||||
mapper.map(fromBB.getArgument(bbIndex), toBB.getArgument(distance));
|
mapper.map(fromBB.getArgument(bbIndex), toBB.getArgument(distance));
|
||||||
}
|
}
|
||||||
@@ -288,7 +291,20 @@ private:
|
|||||||
rewriter.clone(op, mapper);
|
rewriter.clone(op, mapper);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
rewriter.clone(op, mapper);
|
auto newInst = rewriter.clone(op, mapper);
|
||||||
|
//TODO Refactor in a lambda? same code just different cast, but templated lambda are C++20 and a free function is a bit too much
|
||||||
|
if (auto vmOp = llvm::dyn_cast<spatial::SpatWeightedMVMOp>(newInst)) {
|
||||||
|
auto oldIndex = vmOp.getWeightIndex();
|
||||||
|
auto newWeight = mapper.lookup(*std::next(fromCompute.getWeights().begin(), oldIndex));
|
||||||
|
auto newIndex = std::distance(toCompute.getWeights().begin(), llvm::find(toCompute.getWeights(), newWeight));
|
||||||
|
vmOp.setWeightIndex(newIndex);
|
||||||
|
}
|
||||||
|
if (auto vmOp = llvm::dyn_cast<spatial::SpatWeightedVMMOp>(newInst)) {
|
||||||
|
auto oldIndex = vmOp.getWeightIndex();
|
||||||
|
auto newWeight = mapper.lookup(*std::next(fromCompute.getWeights().begin(), oldIndex));
|
||||||
|
auto newIndex = std::distance(toCompute.getWeights().begin(), llvm::find(toCompute.getWeights(), newWeight));
|
||||||
|
vmOp.setWeightIndex(newIndex);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user