Single Concat Fix
This commit is contained in:
@@ -49,6 +49,7 @@ private:
|
||||
void annotateWeightsConstants(func::FuncOp funcOp) const;
|
||||
LogicalResult encapsulateGlobalInstruction(func::FuncOp funcOp);
|
||||
LogicalResult promoteConstantInputsToWeights(func::FuncOp funcOp);
|
||||
void populateEmptyFunction(func::FuncOp funcOp);
|
||||
};
|
||||
|
||||
} // namespace
|
||||
@@ -64,7 +65,8 @@ static void foldSingleLaneComputeBatches(func::FuncOp funcOp) {
|
||||
|
||||
auto loc = batchOp.getLoc();
|
||||
rewriter.setInsertionPoint(batchOp);
|
||||
auto computeOp = spatial::SpatCompute::create(rewriter, loc, batchOp.getResultTypes(), batchOp.getWeights(), batchOp.getInputs());
|
||||
auto computeOp =
|
||||
spatial::SpatCompute::create(rewriter, loc, batchOp.getResultTypes(), batchOp.getWeights(), batchOp.getInputs());
|
||||
computeOp.getProperties().setOperandSegmentSizes(
|
||||
{static_cast<int>(batchOp.getWeights().size()), static_cast<int>(batchOp.getInputs().size())});
|
||||
|
||||
@@ -75,8 +77,8 @@ static void foldSingleLaneComputeBatches(func::FuncOp funcOp) {
|
||||
blockArgTypes.push_back(arg.getType());
|
||||
blockArgLocs.push_back(loc);
|
||||
}
|
||||
auto* newBlock = rewriter.createBlock(
|
||||
&computeOp.getBody(), computeOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
||||
auto* newBlock =
|
||||
rewriter.createBlock(&computeOp.getBody(), computeOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
||||
|
||||
IRMapping mapper;
|
||||
for (auto [oldArg, newArg] : llvm::zip(templateBlock.getArguments(), newBlock->getArguments()))
|
||||
@@ -183,6 +185,8 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
|
||||
annotateWeightsConstants(*entryFunc);
|
||||
|
||||
populateEmptyFunction(*entryFunc);
|
||||
|
||||
if (failed(encapsulateGlobalInstruction(*entryFunc))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
@@ -376,8 +380,7 @@ LogicalResult ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcO
|
||||
while (keep) {
|
||||
keep = false;
|
||||
for (auto& instruction : llvm::make_early_inc_range(funcOp.getOps())) {
|
||||
if (isa<spatial::SpatCompute, spatial::SpatComputeBatch, spatial::SpatConcatOp, spatial::SpatExtractRowsOp>(
|
||||
instruction)
|
||||
if (isa<spatial::SpatCompute, spatial::SpatComputeBatch, spatial::SpatExtractRowsOp>(instruction)
|
||||
|| isa<func::ReturnOp>(instruction))
|
||||
continue;
|
||||
|
||||
@@ -490,6 +493,47 @@ LogicalResult ONNXToSpatialPass::promoteConstantInputsToWeights(func::FuncOp fun
|
||||
return success();
|
||||
}
|
||||
|
||||
void ONNXToSpatialPass::populateEmptyFunction(func::FuncOp funcOp) {
|
||||
IRRewriter rewriter(&getContext());
|
||||
IRMapping mapper;
|
||||
SmallVector<spatial::SpatCompute> computes(funcOp.getOps<spatial::SpatCompute>());
|
||||
if (!computes.empty())
|
||||
return;
|
||||
auto returnOp = llvm::cast<func::ReturnOp>(funcOp.getRegion().front().getTerminator());
|
||||
rewriter.setInsertionPoint(returnOp);
|
||||
|
||||
SmallVector<Type> sourceTypes;
|
||||
SmallVector<Location> sourceLoc;
|
||||
for (auto source : funcOp.getArguments()) {
|
||||
sourceTypes.push_back(source.getType());
|
||||
sourceLoc.push_back(source.getLoc());
|
||||
}
|
||||
|
||||
auto newCompute = spatial::SpatCompute::create(
|
||||
rewriter, returnOp.getLoc(), returnOp.getOperandTypes(), funcOp.getArguments(), {}, {});
|
||||
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLoc);
|
||||
for (auto [bbArg, computeArg] : llvm::zip(BB->getArguments(), newCompute.getOperands()))
|
||||
mapper.map(computeArg, bbArg);
|
||||
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) sourceTypes.size()});
|
||||
rewriter.setInsertionPointToEnd(BB);
|
||||
for (Operation& inst : funcOp.getOps())
|
||||
if (!isa<spatial::SpatCompute, func::ReturnOp>(&inst))
|
||||
rewriter.clone(inst, mapper);
|
||||
|
||||
auto yield = spatial::SpatYieldOp::create(rewriter, funcOp.getLoc(), returnOp.getOperands());
|
||||
for (size_t i = 0; i < yield.getNumOperands(); ++i)
|
||||
yield.setOperand(i, mapper.lookupOrDefault(yield.getOperand(i)));
|
||||
|
||||
for (Operation& inst : llvm::make_early_inc_range(funcOp.getOps()))
|
||||
if (!isa<spatial::SpatCompute, func::ReturnOp>(&inst)){
|
||||
inst.dropAllUses();
|
||||
rewriter.eraseOp(&inst);
|
||||
}
|
||||
|
||||
for (auto [index, computeResult] : llvm::enumerate(newCompute.getResults()))
|
||||
returnOp.setOperand(index, computeResult);
|
||||
}
|
||||
|
||||
std::unique_ptr<Pass> createONNXToSpatialPass() { return std::make_unique<ONNXToSpatialPass>(); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -23,6 +23,16 @@ using namespace mlir;
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
struct EraseSpatNopPattern : public mlir::OpRewritePattern<spatial::SpatNopOp> {
|
||||
using OpRewritePattern<spatial::SpatNopOp>::OpRewritePattern;
|
||||
mlir::LogicalResult matchAndRewrite(spatial::SpatNopOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
rewriter.eraseOp(op);
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
static std::optional<unsigned> getDirectComputeInputIndex(Operation* owner, unsigned operandNumber) {
|
||||
if (auto compute = dyn_cast<spatial::SpatCompute>(owner)) {
|
||||
unsigned inputCount = compute.getInputs().size();
|
||||
@@ -416,7 +426,7 @@ struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncO
|
||||
|
||||
} // namespace
|
||||
void populateGlobalTensorToMemrefPatterns(RewritePatternSet& patterns) {
|
||||
patterns.add<MoveExtractSliceIntoCompute, FuncOpArgToGlobalMemoryPattern, ArithConstToGlobalMemoryPattern>(
|
||||
patterns.add<EraseSpatNopPattern, MoveExtractSliceIntoCompute, FuncOpArgToGlobalMemoryPattern, ArithConstToGlobalMemoryPattern>(
|
||||
patterns.getContext());
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user