Fix trivialmerge considering also weight attributes
All checks were successful
Validate Operations / validate-operations (push) Successful in 2h8m9s

This commit is contained in:
ilgeco
2026-04-14 16:45:14 +02:00
parent a7dee5b840
commit 2151e322ca
2 changed files with 36 additions and 7 deletions

View File

@@ -239,8 +239,9 @@ void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
for (auto compute : funcOp.getOps<spatial::SpatWeightedCompute>()) for (auto compute : funcOp.getOps<spatial::SpatWeightedCompute>())
if (compute->hasOneUse()) { if (compute->hasOneUse()) {
auto user = *compute->getUsers().begin(); auto user = dyn_cast<spatial::SpatWeightedCompute>(*compute->getUsers().begin());
if (llvm::isa<spatial::SpatWeightedCompute>(user) && user->getNumOperands() == 1)
if (user && user.getInputs().size() == 1)
trivialComputes.push_back(compute); trivialComputes.push_back(compute);
} }
@@ -255,19 +256,48 @@ void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
auto child = cast<spatial::SpatWeightedCompute>(*compute->getUsers().begin()); auto child = cast<spatial::SpatWeightedCompute>(*compute->getUsers().begin());
rewriter.setInsertionPointAfter(compute.getOperation()); rewriter.setInsertionPointAfter(compute.getOperation());
auto newCompute = auto newCompute =
spatial::SpatWeightedCompute::create(rewriter, loc, child.getResultTypes(), compute.getOperands()); spatial::SpatWeightedCompute::create(rewriter, loc, child.getResultTypes(), compute.getOperands());
newCompute.getProperties().setOperandSegmentSizes( newCompute.getProperties().setOperandSegmentSizes(
{static_cast<int>(compute.getWeights().size()), static_cast<int>(compute.getInputs().size())}); {static_cast<int>(compute.getWeights().size()), static_cast<int>(compute.getInputs().size())});
IRMapping mapper; IRMapping mapper;
auto weightMutableIter = newCompute.getWeightsMutable();
for (auto weight : child.getWeights()) {
auto founded = llvm::find(newCompute.getWeights(), weight);
if (founded == newCompute.getWeights().end()) {
weightMutableIter.append(weight);
auto last = weightMutableIter.end();
last = std::prev(last, 1);
mapper.map(weight, last->get());
}
else {
mapper.map(weight, *founded);
}
}
compute.getBodyRegion().cloneInto(&newCompute.getBodyRegion(), mapper); compute.getBodyRegion().cloneInto(&newCompute.getBodyRegion(), mapper);
auto newTerminator = newCompute.getBody().front().getTerminator(); auto newTerminator = newCompute.getBody().front().getTerminator();
mapper.map(*child.getBody().front().getArguments().begin(), newTerminator->getOperand(0)); mapper.map(*child.getBody().front().getArguments().begin(), newTerminator->getOperand(0));
newTerminator->erase(); newTerminator->erase();
rewriter.setInsertionPoint(&newCompute.getBody().front(), newCompute.getBody().front().end()); rewriter.setInsertionPoint(&newCompute.getBody().front(), newCompute.getBody().front().end());
for (auto& op : child.getBody().front()) for (auto& op : child.getBody().front()) {
rewriter.clone(op, mapper); auto newInst = rewriter.clone(op, mapper);
if (auto vmOp = llvm::dyn_cast<spatial::SpatWeightedMVMOp>(newInst)) {
auto oldIndex = vmOp.getWeightIndex();
auto newWeight = mapper.lookup(*std::next(child.getWeights().begin(), oldIndex));
auto newIndex = std::distance(newCompute.getWeights().begin(), llvm::find(newCompute.getWeights(), newWeight));
vmOp.setWeightIndex(newIndex);
}
if (auto vmOp = llvm::dyn_cast<spatial::SpatWeightedVMMOp>(newInst)) {
auto oldIndex = vmOp.getWeightIndex();
auto newWeight = mapper.lookup(*std::next(child.getWeights().begin(), oldIndex));
auto newIndex = std::distance(newCompute.getWeights().begin(), llvm::find(newCompute.getWeights(), newWeight));
vmOp.setWeightIndex(newIndex);
}
}
child.replaceAllUsesWith(newCompute); child.replaceAllUsesWith(newCompute);
toErase.insert(child); toErase.insert(child);
@@ -277,8 +307,8 @@ void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
toErase.insert(compute); toErase.insert(compute);
if (newCompute->hasOneUse()) { if (newCompute->hasOneUse()) {
auto user = *newCompute->getUsers().begin(); auto user = dyn_cast<spatial::SpatWeightedCompute>(*newCompute->getUsers().begin());
if (llvm::isa<spatial::SpatWeightedCompute>(user) && user->getNumOperands() == 1) if (user && user.getInputs().size() == 1)
trivialComputes.push_back(newCompute); trivialComputes.push_back(newCompute);
} }
} }

View File

@@ -219,7 +219,6 @@ private:
auto weightMutableIter = toCompute.getWeightsMutable(); auto weightMutableIter = toCompute.getWeightsMutable();
for (auto weight : fromCompute.getWeights()) { for (auto weight : fromCompute.getWeights()) {
// TODO non clonare weight gia' presenti e poi vanno rimappate le nuove OP con i nuovi weight
auto founded = llvm::find(toCompute.getWeights(), weight); auto founded = llvm::find(toCompute.getWeights(), weight);
if (founded == toCompute.getWeights().end()) { if (founded == toCompute.getWeights().end()) {
size_t sizeW = toCompute.getWeights().size(); size_t sizeW = toCompute.getWeights().size();