Fix trivialmerge considering also weight attributes
All checks were successful
Validate Operations / validate-operations (push) Successful in 2h8m9s
All checks were successful
Validate Operations / validate-operations (push) Successful in 2h8m9s
This commit is contained in:
@@ -239,8 +239,9 @@ void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
|
||||
|
||||
for (auto compute : funcOp.getOps<spatial::SpatWeightedCompute>())
|
||||
if (compute->hasOneUse()) {
|
||||
auto user = *compute->getUsers().begin();
|
||||
if (llvm::isa<spatial::SpatWeightedCompute>(user) && user->getNumOperands() == 1)
|
||||
auto user = dyn_cast<spatial::SpatWeightedCompute>(*compute->getUsers().begin());
|
||||
|
||||
if (user && user.getInputs().size() == 1)
|
||||
trivialComputes.push_back(compute);
|
||||
}
|
||||
|
||||
@@ -255,19 +256,48 @@ void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
|
||||
auto child = cast<spatial::SpatWeightedCompute>(*compute->getUsers().begin());
|
||||
|
||||
rewriter.setInsertionPointAfter(compute.getOperation());
|
||||
|
||||
auto newCompute =
|
||||
spatial::SpatWeightedCompute::create(rewriter, loc, child.getResultTypes(), compute.getOperands());
|
||||
newCompute.getProperties().setOperandSegmentSizes(
|
||||
{static_cast<int>(compute.getWeights().size()), static_cast<int>(compute.getInputs().size())});
|
||||
|
||||
IRMapping mapper;
|
||||
auto weightMutableIter = newCompute.getWeightsMutable();
|
||||
for (auto weight : child.getWeights()) {
|
||||
auto founded = llvm::find(newCompute.getWeights(), weight);
|
||||
if (founded == newCompute.getWeights().end()) {
|
||||
weightMutableIter.append(weight);
|
||||
auto last = weightMutableIter.end();
|
||||
last = std::prev(last, 1);
|
||||
mapper.map(weight, last->get());
|
||||
}
|
||||
else {
|
||||
mapper.map(weight, *founded);
|
||||
}
|
||||
}
|
||||
|
||||
compute.getBodyRegion().cloneInto(&newCompute.getBodyRegion(), mapper);
|
||||
auto newTerminator = newCompute.getBody().front().getTerminator();
|
||||
mapper.map(*child.getBody().front().getArguments().begin(), newTerminator->getOperand(0));
|
||||
newTerminator->erase();
|
||||
rewriter.setInsertionPoint(&newCompute.getBody().front(), newCompute.getBody().front().end());
|
||||
for (auto& op : child.getBody().front())
|
||||
rewriter.clone(op, mapper);
|
||||
for (auto& op : child.getBody().front()) {
|
||||
auto newInst = rewriter.clone(op, mapper);
|
||||
|
||||
if (auto vmOp = llvm::dyn_cast<spatial::SpatWeightedMVMOp>(newInst)) {
|
||||
auto oldIndex = vmOp.getWeightIndex();
|
||||
auto newWeight = mapper.lookup(*std::next(child.getWeights().begin(), oldIndex));
|
||||
auto newIndex = std::distance(newCompute.getWeights().begin(), llvm::find(newCompute.getWeights(), newWeight));
|
||||
vmOp.setWeightIndex(newIndex);
|
||||
}
|
||||
if (auto vmOp = llvm::dyn_cast<spatial::SpatWeightedVMMOp>(newInst)) {
|
||||
auto oldIndex = vmOp.getWeightIndex();
|
||||
auto newWeight = mapper.lookup(*std::next(child.getWeights().begin(), oldIndex));
|
||||
auto newIndex = std::distance(newCompute.getWeights().begin(), llvm::find(newCompute.getWeights(), newWeight));
|
||||
vmOp.setWeightIndex(newIndex);
|
||||
}
|
||||
}
|
||||
|
||||
child.replaceAllUsesWith(newCompute);
|
||||
toErase.insert(child);
|
||||
@@ -277,8 +307,8 @@ void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
|
||||
toErase.insert(compute);
|
||||
|
||||
if (newCompute->hasOneUse()) {
|
||||
auto user = *newCompute->getUsers().begin();
|
||||
if (llvm::isa<spatial::SpatWeightedCompute>(user) && user->getNumOperands() == 1)
|
||||
auto user = dyn_cast<spatial::SpatWeightedCompute>(*newCompute->getUsers().begin());
|
||||
if (user && user.getInputs().size() == 1)
|
||||
trivialComputes.push_back(newCompute);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -219,7 +219,6 @@ private:
|
||||
|
||||
auto weightMutableIter = toCompute.getWeightsMutable();
|
||||
for (auto weight : fromCompute.getWeights()) {
|
||||
// TODO non clonare weight gia' presenti e poi vanno rimappate le nuove OP con i nuovi weight
|
||||
auto founded = llvm::find(toCompute.getWeights(), weight);
|
||||
if (founded == toCompute.getWeights().end()) {
|
||||
size_t sizeW = toCompute.getWeights().size();
|
||||
|
||||
Reference in New Issue
Block a user