Remap weight within cloned operation
All checks were successful
Validate Operations / validate-operations (push) Successful in 16m20s
All checks were successful
Validate Operations / validate-operations (push) Successful in 16m20s
This commit is contained in:
@@ -130,7 +130,8 @@ public:
|
||||
}
|
||||
|
||||
for (auto computeNodetoRemove : llvm::make_early_inc_range(llvm::reverse(analysisResult.dominanceOrderCompute))) {
|
||||
for (auto users : computeNodetoRemove->getUsers()) users->dump();
|
||||
for (auto users : computeNodetoRemove->getUsers())
|
||||
users->dump();
|
||||
computeNodetoRemove.erase();
|
||||
}
|
||||
func::FuncOp func = getOperation();
|
||||
@@ -220,19 +221,20 @@ private:
|
||||
|
||||
auto weightMutableIter = toCompute.getWeightsMutable();
|
||||
for (auto weight : fromCompute.getWeights()) {
|
||||
//TODO non clonare weight gia' presenti e poi vanno rimappate le nuove OP con i nuovi weight
|
||||
// TODO non clonare weight gia' presenti e poi vanno rimappate le nuove OP con i nuovi weight
|
||||
auto founded = llvm::find(toCompute.getWeights(), weight);
|
||||
if(founded == toCompute.getWeights().end()){
|
||||
if (founded == toCompute.getWeights().end()) {
|
||||
size_t sizeW = toCompute.getWeights().size();
|
||||
size_t sizeI = toCompute.getInputs().size();
|
||||
weightMutableIter.append(weight);
|
||||
auto last = weightMutableIter.end();
|
||||
last = std::prev(last,1);
|
||||
auto last = weightMutableIter.end();
|
||||
last = std::prev(last, 1);
|
||||
mapper.map(weight, last->get());
|
||||
assert(sizeW + 1 == toCompute.getWeights().size());
|
||||
assert(sizeI == toCompute.getInputs().size());
|
||||
assert(sizeW + sizeI + 1 == toCompute.getOperands().size());
|
||||
}else {
|
||||
}
|
||||
else {
|
||||
mapper.map(weight, *founded);
|
||||
}
|
||||
}
|
||||
@@ -259,18 +261,19 @@ private:
|
||||
}
|
||||
else {
|
||||
|
||||
auto founded = llvm::find(toCompute.getInputs(), arg);
|
||||
if(founded == toCompute.getInputs().end()){
|
||||
size_t sizeW = toCompute.getWeights().size();
|
||||
size_t sizeI = toCompute.getInputs().size();
|
||||
inputeArgMutable.append(arg);
|
||||
assert(sizeW == toCompute.getWeights().size());
|
||||
assert(sizeI + 1 == toCompute.getInputs().size());
|
||||
assert(sizeW + sizeI + 1 == toCompute.getOperands().size());
|
||||
auto founded = llvm::find(toCompute.getInputs(), arg);
|
||||
if (founded == toCompute.getInputs().end()) {
|
||||
size_t sizeW = toCompute.getWeights().size();
|
||||
size_t sizeI = toCompute.getInputs().size();
|
||||
inputeArgMutable.append(arg);
|
||||
assert(sizeW == toCompute.getWeights().size());
|
||||
assert(sizeI + 1 == toCompute.getInputs().size());
|
||||
assert(sizeW + sizeI + 1 == toCompute.getOperands().size());
|
||||
|
||||
toBB.addArgument(fromBB.getArgument(bbIndex).getType(), loc);
|
||||
mapper.map(fromBB.getArgument(bbIndex), toBB.getArguments().back());
|
||||
}else {
|
||||
toBB.addArgument(fromBB.getArgument(bbIndex).getType(), loc);
|
||||
mapper.map(fromBB.getArgument(bbIndex), toBB.getArguments().back());
|
||||
}
|
||||
else {
|
||||
auto distance = std::distance(toCompute.getInputs().begin(), founded);
|
||||
mapper.map(fromBB.getArgument(bbIndex), toBB.getArgument(distance));
|
||||
}
|
||||
@@ -288,7 +291,20 @@ private:
|
||||
rewriter.clone(op, mapper);
|
||||
}
|
||||
else {
|
||||
rewriter.clone(op, mapper);
|
||||
auto newInst = rewriter.clone(op, mapper);
|
||||
//TODO Refactor in a lambda? same code just different cast, but templated lambda are C++20 and a free function is a bit too much
|
||||
if (auto vmOp = llvm::dyn_cast<spatial::SpatWeightedMVMOp>(newInst)) {
|
||||
auto oldIndex = vmOp.getWeightIndex();
|
||||
auto newWeight = mapper.lookup(*std::next(fromCompute.getWeights().begin(), oldIndex));
|
||||
auto newIndex = std::distance(toCompute.getWeights().begin(), llvm::find(toCompute.getWeights(), newWeight));
|
||||
vmOp.setWeightIndex(newIndex);
|
||||
}
|
||||
if (auto vmOp = llvm::dyn_cast<spatial::SpatWeightedVMMOp>(newInst)) {
|
||||
auto oldIndex = vmOp.getWeightIndex();
|
||||
auto newWeight = mapper.lookup(*std::next(fromCompute.getWeights().begin(), oldIndex));
|
||||
auto newIndex = std::distance(toCompute.getWeights().begin(), llvm::find(toCompute.getWeights(), newWeight));
|
||||
vmOp.setWeightIndex(newIndex);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user