|
|
|
@@ -49,6 +49,7 @@ private:
|
|
|
|
|
void annotateWeightsConstants(func::FuncOp funcOp) const;
|
|
|
|
|
LogicalResult encapsulateGlobalInstruction(func::FuncOp funcOp);
|
|
|
|
|
LogicalResult promoteConstantInputsToWeights(func::FuncOp funcOp);
|
|
|
|
|
void populateEmptyFunction(func::FuncOp funcOp);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace
|
|
|
|
@@ -64,7 +65,8 @@ static void foldSingleLaneComputeBatches(func::FuncOp funcOp) {
|
|
|
|
|
|
|
|
|
|
auto loc = batchOp.getLoc();
|
|
|
|
|
rewriter.setInsertionPoint(batchOp);
|
|
|
|
|
auto computeOp = spatial::SpatCompute::create(rewriter, loc, batchOp.getResultTypes(), batchOp.getWeights(), batchOp.getInputs());
|
|
|
|
|
auto computeOp =
|
|
|
|
|
spatial::SpatCompute::create(rewriter, loc, batchOp.getResultTypes(), batchOp.getWeights(), batchOp.getInputs());
|
|
|
|
|
computeOp.getProperties().setOperandSegmentSizes(
|
|
|
|
|
{static_cast<int>(batchOp.getWeights().size()), static_cast<int>(batchOp.getInputs().size())});
|
|
|
|
|
|
|
|
|
@@ -75,8 +77,8 @@ static void foldSingleLaneComputeBatches(func::FuncOp funcOp) {
|
|
|
|
|
blockArgTypes.push_back(arg.getType());
|
|
|
|
|
blockArgLocs.push_back(loc);
|
|
|
|
|
}
|
|
|
|
|
auto* newBlock = rewriter.createBlock(
|
|
|
|
|
&computeOp.getBody(), computeOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
|
|
|
|
auto* newBlock =
|
|
|
|
|
rewriter.createBlock(&computeOp.getBody(), computeOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
|
|
|
|
|
|
|
|
|
IRMapping mapper;
|
|
|
|
|
for (auto [oldArg, newArg] : llvm::zip(templateBlock.getArguments(), newBlock->getArguments()))
|
|
|
|
@@ -183,6 +185,8 @@ void ONNXToSpatialPass::runOnOperation() {
|
|
|
|
|
|
|
|
|
|
annotateWeightsConstants(*entryFunc);
|
|
|
|
|
|
|
|
|
|
populateEmptyFunction(*entryFunc);
|
|
|
|
|
|
|
|
|
|
if (failed(encapsulateGlobalInstruction(*entryFunc))) {
|
|
|
|
|
signalPassFailure();
|
|
|
|
|
return;
|
|
|
|
@@ -376,8 +380,7 @@ LogicalResult ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcO
|
|
|
|
|
while (keep) {
|
|
|
|
|
keep = false;
|
|
|
|
|
for (auto& instruction : llvm::make_early_inc_range(funcOp.getOps())) {
|
|
|
|
|
if (isa<spatial::SpatCompute, spatial::SpatComputeBatch, spatial::SpatConcatOp, spatial::SpatExtractRowsOp>(
|
|
|
|
|
instruction)
|
|
|
|
|
if (isa<spatial::SpatCompute, spatial::SpatComputeBatch, spatial::SpatExtractRowsOp>(instruction)
|
|
|
|
|
|| isa<func::ReturnOp>(instruction))
|
|
|
|
|
continue;
|
|
|
|
|
|
|
|
|
@@ -490,6 +493,47 @@ LogicalResult ONNXToSpatialPass::promoteConstantInputsToWeights(func::FuncOp fun
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ONNXToSpatialPass::populateEmptyFunction(func::FuncOp funcOp) {
|
|
|
|
|
IRRewriter rewriter(&getContext());
|
|
|
|
|
IRMapping mapper;
|
|
|
|
|
SmallVector<spatial::SpatCompute> computes(funcOp.getOps<spatial::SpatCompute>());
|
|
|
|
|
if (!computes.empty())
|
|
|
|
|
return;
|
|
|
|
|
auto returnOp = llvm::cast<func::ReturnOp>(funcOp.getRegion().front().getTerminator());
|
|
|
|
|
rewriter.setInsertionPoint(returnOp);
|
|
|
|
|
|
|
|
|
|
SmallVector<Type> sourceTypes;
|
|
|
|
|
SmallVector<Location> sourceLoc;
|
|
|
|
|
for (auto source : funcOp.getArguments()) {
|
|
|
|
|
sourceTypes.push_back(source.getType());
|
|
|
|
|
sourceLoc.push_back(source.getLoc());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto newCompute = spatial::SpatCompute::create(
|
|
|
|
|
rewriter, returnOp.getLoc(), returnOp.getOperandTypes(), funcOp.getArguments(), {}, {});
|
|
|
|
|
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLoc);
|
|
|
|
|
for (auto [bbArg, computeArg] : llvm::zip(BB->getArguments(), newCompute.getOperands()))
|
|
|
|
|
mapper.map(computeArg, bbArg);
|
|
|
|
|
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) sourceTypes.size()});
|
|
|
|
|
rewriter.setInsertionPointToEnd(BB);
|
|
|
|
|
for (Operation& inst : funcOp.getOps())
|
|
|
|
|
if (!isa<spatial::SpatCompute, func::ReturnOp>(&inst))
|
|
|
|
|
rewriter.clone(inst, mapper);
|
|
|
|
|
|
|
|
|
|
auto yield = spatial::SpatYieldOp::create(rewriter, funcOp.getLoc(), returnOp.getOperands());
|
|
|
|
|
for (size_t i = 0; i < yield.getNumOperands(); ++i)
|
|
|
|
|
yield.setOperand(i, mapper.lookupOrDefault(yield.getOperand(i)));
|
|
|
|
|
|
|
|
|
|
for (Operation& inst : llvm::make_early_inc_range(funcOp.getOps()))
|
|
|
|
|
if (!isa<spatial::SpatCompute, func::ReturnOp>(&inst)){
|
|
|
|
|
inst.dropAllUses();
|
|
|
|
|
rewriter.eraseOp(&inst);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (auto [index, computeResult] : llvm::enumerate(newCompute.getResults()))
|
|
|
|
|
returnOp.setOperand(index, computeResult);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<Pass> createONNXToSpatialPass() { return std::make_unique<ONNXToSpatialPass>(); }
|
|
|
|
|
|
|
|
|
|
} // namespace onnx_mlir
|
|
|
|
|