Fix trivialmerge considering also weight attributes
All checks were successful
Validate Operations / validate-operations (push) Successful in 2h8m9s
All checks were successful
Validate Operations / validate-operations (push) Successful in 2h8m9s
This commit is contained in:
@@ -239,8 +239,9 @@ void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
|
|||||||
|
|
||||||
for (auto compute : funcOp.getOps<spatial::SpatWeightedCompute>())
|
for (auto compute : funcOp.getOps<spatial::SpatWeightedCompute>())
|
||||||
if (compute->hasOneUse()) {
|
if (compute->hasOneUse()) {
|
||||||
auto user = *compute->getUsers().begin();
|
auto user = dyn_cast<spatial::SpatWeightedCompute>(*compute->getUsers().begin());
|
||||||
if (llvm::isa<spatial::SpatWeightedCompute>(user) && user->getNumOperands() == 1)
|
|
||||||
|
if (user && user.getInputs().size() == 1)
|
||||||
trivialComputes.push_back(compute);
|
trivialComputes.push_back(compute);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -255,19 +256,48 @@ void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
|
|||||||
auto child = cast<spatial::SpatWeightedCompute>(*compute->getUsers().begin());
|
auto child = cast<spatial::SpatWeightedCompute>(*compute->getUsers().begin());
|
||||||
|
|
||||||
rewriter.setInsertionPointAfter(compute.getOperation());
|
rewriter.setInsertionPointAfter(compute.getOperation());
|
||||||
|
|
||||||
auto newCompute =
|
auto newCompute =
|
||||||
spatial::SpatWeightedCompute::create(rewriter, loc, child.getResultTypes(), compute.getOperands());
|
spatial::SpatWeightedCompute::create(rewriter, loc, child.getResultTypes(), compute.getOperands());
|
||||||
newCompute.getProperties().setOperandSegmentSizes(
|
newCompute.getProperties().setOperandSegmentSizes(
|
||||||
{static_cast<int>(compute.getWeights().size()), static_cast<int>(compute.getInputs().size())});
|
{static_cast<int>(compute.getWeights().size()), static_cast<int>(compute.getInputs().size())});
|
||||||
|
|
||||||
IRMapping mapper;
|
IRMapping mapper;
|
||||||
|
auto weightMutableIter = newCompute.getWeightsMutable();
|
||||||
|
for (auto weight : child.getWeights()) {
|
||||||
|
auto founded = llvm::find(newCompute.getWeights(), weight);
|
||||||
|
if (founded == newCompute.getWeights().end()) {
|
||||||
|
weightMutableIter.append(weight);
|
||||||
|
auto last = weightMutableIter.end();
|
||||||
|
last = std::prev(last, 1);
|
||||||
|
mapper.map(weight, last->get());
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
mapper.map(weight, *founded);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
compute.getBodyRegion().cloneInto(&newCompute.getBodyRegion(), mapper);
|
compute.getBodyRegion().cloneInto(&newCompute.getBodyRegion(), mapper);
|
||||||
auto newTerminator = newCompute.getBody().front().getTerminator();
|
auto newTerminator = newCompute.getBody().front().getTerminator();
|
||||||
mapper.map(*child.getBody().front().getArguments().begin(), newTerminator->getOperand(0));
|
mapper.map(*child.getBody().front().getArguments().begin(), newTerminator->getOperand(0));
|
||||||
newTerminator->erase();
|
newTerminator->erase();
|
||||||
rewriter.setInsertionPoint(&newCompute.getBody().front(), newCompute.getBody().front().end());
|
rewriter.setInsertionPoint(&newCompute.getBody().front(), newCompute.getBody().front().end());
|
||||||
for (auto& op : child.getBody().front())
|
for (auto& op : child.getBody().front()) {
|
||||||
rewriter.clone(op, mapper);
|
auto newInst = rewriter.clone(op, mapper);
|
||||||
|
|
||||||
|
if (auto vmOp = llvm::dyn_cast<spatial::SpatWeightedMVMOp>(newInst)) {
|
||||||
|
auto oldIndex = vmOp.getWeightIndex();
|
||||||
|
auto newWeight = mapper.lookup(*std::next(child.getWeights().begin(), oldIndex));
|
||||||
|
auto newIndex = std::distance(newCompute.getWeights().begin(), llvm::find(newCompute.getWeights(), newWeight));
|
||||||
|
vmOp.setWeightIndex(newIndex);
|
||||||
|
}
|
||||||
|
if (auto vmOp = llvm::dyn_cast<spatial::SpatWeightedVMMOp>(newInst)) {
|
||||||
|
auto oldIndex = vmOp.getWeightIndex();
|
||||||
|
auto newWeight = mapper.lookup(*std::next(child.getWeights().begin(), oldIndex));
|
||||||
|
auto newIndex = std::distance(newCompute.getWeights().begin(), llvm::find(newCompute.getWeights(), newWeight));
|
||||||
|
vmOp.setWeightIndex(newIndex);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
child.replaceAllUsesWith(newCompute);
|
child.replaceAllUsesWith(newCompute);
|
||||||
toErase.insert(child);
|
toErase.insert(child);
|
||||||
@@ -277,8 +307,8 @@ void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
|
|||||||
toErase.insert(compute);
|
toErase.insert(compute);
|
||||||
|
|
||||||
if (newCompute->hasOneUse()) {
|
if (newCompute->hasOneUse()) {
|
||||||
auto user = *newCompute->getUsers().begin();
|
auto user = dyn_cast<spatial::SpatWeightedCompute>(*newCompute->getUsers().begin());
|
||||||
if (llvm::isa<spatial::SpatWeightedCompute>(user) && user->getNumOperands() == 1)
|
if (user && user.getInputs().size() == 1)
|
||||||
trivialComputes.push_back(newCompute);
|
trivialComputes.push_back(newCompute);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -219,7 +219,6 @@ private:
|
|||||||
|
|
||||||
auto weightMutableIter = toCompute.getWeightsMutable();
|
auto weightMutableIter = toCompute.getWeightsMutable();
|
||||||
for (auto weight : fromCompute.getWeights()) {
|
for (auto weight : fromCompute.getWeights()) {
|
||||||
// TODO non clonare weight gia' presenti e poi vanno rimappate le nuove OP con i nuovi weight
|
|
||||||
auto founded = llvm::find(toCompute.getWeights(), weight);
|
auto founded = llvm::find(toCompute.getWeights(), weight);
|
||||||
if (founded == toCompute.getWeights().end()) {
|
if (founded == toCompute.getWeights().end()) {
|
||||||
size_t sizeW = toCompute.getWeights().size();
|
size_t sizeW = toCompute.getWeights().size();
|
||||||
|
|||||||
Reference in New Issue
Block a user