multiple-output spat computes
All checks were successful
Validate Operations / validate-operations (push) Successful in 22m38s
All checks were successful
Validate Operations / validate-operations (push) Successful in 22m38s
This commit is contained in:
@@ -133,7 +133,7 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
if (coresCount != -1) {
|
||||
int computeOpsCount = 0;
|
||||
for (auto& op : entryFunc->getFunctionBody().front().getOperations())
|
||||
if (isa<spatial::SpatWeightedCompute>(op))
|
||||
if (isa<spatial::SpatCompute>(op))
|
||||
computeOpsCount++;
|
||||
|
||||
if (computeOpsCount > coresCount) {
|
||||
@@ -167,16 +167,16 @@ bool encapsulator(IRRewriter& rewriter, Location loc, Operation* inst, std::func
|
||||
if (T toRemoveOp = llvm::dyn_cast_if_present<T>(inst)) {
|
||||
Value source = funcSource(toRemoveOp);
|
||||
rewriter.setInsertionPointAfter(toRemoveOp);
|
||||
if (isa_and_present<spatial::SpatWeightedCompute>(source.getDefiningOp())) {
|
||||
auto newCompute = spatial::SpatWeightedCompute::create(rewriter, loc, inst->getResultTypes().front(), source);
|
||||
if (isa_and_present<spatial::SpatCompute>(source.getDefiningOp())) {
|
||||
auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), source);
|
||||
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), {source.getType()}, {loc});
|
||||
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) 1});
|
||||
rewriter.setInsertionPointToEnd(BB);
|
||||
IRMapping mapper;
|
||||
mapper.map(source, BB->getArgument(0));
|
||||
auto newInst = rewriter.clone(*inst, mapper);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, newInst->getResult(0));
|
||||
inst->replaceAllUsesWith(newCompute);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, newInst->getResults());
|
||||
inst->replaceAllUsesWith(newCompute->getResults());
|
||||
inst->erase();
|
||||
return true;
|
||||
}
|
||||
@@ -189,8 +189,8 @@ bool encapsulateConcat(IRRewriter& rewriter, Location loc, Operation* inst) {
|
||||
auto sources = toRemoveOp.getInputs();
|
||||
rewriter.setInsertionPointAfter(toRemoveOp);
|
||||
if (llvm::any_of(
|
||||
sources, [](auto source) { return isa_and_present<spatial::SpatWeightedCompute>(source.getDefiningOp()); })) {
|
||||
auto newCompute = spatial::SpatWeightedCompute::create(rewriter, loc, inst->getResultTypes().front(), sources);
|
||||
sources, [](auto source) { return isa_and_present<spatial::SpatCompute>(source.getDefiningOp()); })) {
|
||||
auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), sources);
|
||||
SmallVector<Type> sourceTypes;
|
||||
SmallVector<Location> sourceLoc;
|
||||
for (auto source : sources) {
|
||||
@@ -204,8 +204,8 @@ bool encapsulateConcat(IRRewriter& rewriter, Location loc, Operation* inst) {
|
||||
for (auto [source, bbArg] : llvm::zip(sources, BB->getArguments()))
|
||||
mapper.map(source, bbArg);
|
||||
auto newConcat = rewriter.clone(*inst, mapper);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, newConcat->getResult(0));
|
||||
inst->replaceAllUsesWith(newCompute);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, newConcat->getResults());
|
||||
inst->replaceAllUsesWith(newCompute->getResults());
|
||||
inst->erase();
|
||||
return true;
|
||||
}
|
||||
@@ -298,14 +298,15 @@ void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) {
|
||||
void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
|
||||
Location loc = funcOp.getLoc();
|
||||
IRRewriter rewriter(&getContext());
|
||||
SmallVector<spatial::SpatWeightedCompute> trivialComputes;
|
||||
llvm::SmallSet<spatial::SpatWeightedCompute, 8> toErase;
|
||||
SmallVector<spatial::SpatCompute> trivialComputes;
|
||||
llvm::SmallSet<spatial::SpatCompute, 8> toErase;
|
||||
|
||||
for (auto compute : funcOp.getOps<spatial::SpatWeightedCompute>())
|
||||
for (auto compute : funcOp.getOps<spatial::SpatCompute>())
|
||||
if (compute->hasOneUse()) {
|
||||
auto user = dyn_cast<spatial::SpatWeightedCompute>(*compute->getUsers().begin());
|
||||
auto& use = *compute->getUses().begin();
|
||||
auto user = dyn_cast<spatial::SpatCompute>(use.getOwner());
|
||||
|
||||
if (user && user.getInputs().size() == 1)
|
||||
if (user && user.getInputs().size() == 1 && use.getOperandNumber() >= user.getWeights().size())
|
||||
trivialComputes.push_back(compute);
|
||||
}
|
||||
|
||||
@@ -317,12 +318,15 @@ void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
|
||||
trivialComputes.pop_back();
|
||||
continue;
|
||||
}
|
||||
auto child = cast<spatial::SpatWeightedCompute>(*compute->getUsers().begin());
|
||||
auto& computeUse = *compute->getUses().begin();
|
||||
auto child = cast<spatial::SpatCompute>(computeUse.getOwner());
|
||||
auto usedResult = cast<OpResult>(computeUse.get()).getResultNumber();
|
||||
auto childArgIndex = computeUse.getOperandNumber() - child.getWeights().size();
|
||||
|
||||
rewriter.setInsertionPointAfter(compute.getOperation());
|
||||
|
||||
auto newCompute =
|
||||
spatial::SpatWeightedCompute::create(rewriter, loc, child.getResultTypes(), compute.getOperands());
|
||||
spatial::SpatCompute::create(rewriter, loc, child.getResultTypes(), compute.getOperands());
|
||||
newCompute.getProperties().setOperandSegmentSizes(
|
||||
{static_cast<int>(compute.getWeights().size()), static_cast<int>(compute.getInputs().size())});
|
||||
|
||||
@@ -343,7 +347,7 @@ void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
|
||||
|
||||
compute.getBodyRegion().cloneInto(&newCompute.getBodyRegion(), mapper);
|
||||
auto newTerminator = newCompute.getBody().front().getTerminator();
|
||||
mapper.map(*child.getBody().front().getArguments().begin(), newTerminator->getOperand(0));
|
||||
mapper.map(child.getBody().front().getArgument(childArgIndex), newTerminator->getOperand(usedResult));
|
||||
newTerminator->erase();
|
||||
rewriter.setInsertionPoint(&newCompute.getBody().front(), newCompute.getBody().front().end());
|
||||
for (auto& op : child.getBody().front()) {
|
||||
@@ -371,14 +375,16 @@ void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
|
||||
toErase.insert(compute);
|
||||
|
||||
if (newCompute->hasOneUse()) {
|
||||
auto user = dyn_cast<spatial::SpatWeightedCompute>(*newCompute->getUsers().begin());
|
||||
if (user && user.getInputs().size() == 1)
|
||||
auto& use = *newCompute->getUses().begin();
|
||||
auto user = dyn_cast<spatial::SpatCompute>(use.getOwner());
|
||||
if (user && user.getInputs().size() == 1 && use.getOperandNumber() >= user.getWeights().size())
|
||||
trivialComputes.push_back(newCompute);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto compute : toErase) {
|
||||
compute.getResult(0).dropAllUses();
|
||||
for (Value result : compute->getResults())
|
||||
result.dropAllUses();
|
||||
compute.erase();
|
||||
}
|
||||
}
|
||||
@@ -386,7 +392,7 @@ void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
|
||||
void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const {
|
||||
funcOp.walk([&](arith::ConstantOp constantOp) {
|
||||
bool isAlwaysWeight =
|
||||
llvm::all_of(constantOp->getUsers(), [](auto user) -> bool { return isa<spatial::SpatWeightedCompute>(user); });
|
||||
llvm::all_of(constantOp->getUsers(), [](auto user) -> bool { return isa<spatial::SpatCompute>(user); });
|
||||
if (isAlwaysWeight)
|
||||
markWeightAlways(constantOp);
|
||||
});
|
||||
@@ -394,7 +400,7 @@ void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const {
|
||||
|
||||
LogicalResult ONNXToSpatialPass::promoteConstantInputsToWeights(func::FuncOp funcOp) {
|
||||
IRRewriter rewriter(&getContext());
|
||||
SmallVector<spatial::SpatWeightedCompute> computes(funcOp.getOps<spatial::SpatWeightedCompute>());
|
||||
SmallVector<spatial::SpatCompute> computes(funcOp.getOps<spatial::SpatCompute>());
|
||||
|
||||
for (auto compute : computes) {
|
||||
SmallVector<bool> promoteInput(compute.getInputs().size(), false);
|
||||
@@ -430,7 +436,7 @@ LogicalResult ONNXToSpatialPass::promoteConstantInputsToWeights(func::FuncOp fun
|
||||
}
|
||||
|
||||
auto newCompute =
|
||||
spatial::SpatWeightedCompute::create(rewriter, compute.getLoc(), compute.getResultTypes(), newWeights, newInputs);
|
||||
spatial::SpatCompute::create(rewriter, compute.getLoc(), compute.getResultTypes(), newWeights, newInputs);
|
||||
auto* newBlock =
|
||||
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), newInputTypes, newInputLocs);
|
||||
newCompute.getProperties().setOperandSegmentSizes(
|
||||
|
||||
Reference in New Issue
Block a user