Single Concat Fix

This commit is contained in:
ilgeco
2026-05-07 16:47:01 +02:00
parent f2fe147961
commit 74931ad75b
4 changed files with 62 additions and 32 deletions

View File

@@ -49,6 +49,7 @@ private:
void annotateWeightsConstants(func::FuncOp funcOp) const;
LogicalResult encapsulateGlobalInstruction(func::FuncOp funcOp);
LogicalResult promoteConstantInputsToWeights(func::FuncOp funcOp);
void populateEmptyFunction(func::FuncOp funcOp);
};
} // namespace
@@ -64,7 +65,8 @@ static void foldSingleLaneComputeBatches(func::FuncOp funcOp) {
auto loc = batchOp.getLoc();
rewriter.setInsertionPoint(batchOp);
auto computeOp = spatial::SpatCompute::create(rewriter, loc, batchOp.getResultTypes(), batchOp.getWeights(), batchOp.getInputs());
auto computeOp =
spatial::SpatCompute::create(rewriter, loc, batchOp.getResultTypes(), batchOp.getWeights(), batchOp.getInputs());
computeOp.getProperties().setOperandSegmentSizes(
{static_cast<int>(batchOp.getWeights().size()), static_cast<int>(batchOp.getInputs().size())});
@@ -75,8 +77,8 @@ static void foldSingleLaneComputeBatches(func::FuncOp funcOp) {
blockArgTypes.push_back(arg.getType());
blockArgLocs.push_back(loc);
}
auto* newBlock = rewriter.createBlock(
&computeOp.getBody(), computeOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
auto* newBlock =
rewriter.createBlock(&computeOp.getBody(), computeOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
IRMapping mapper;
for (auto [oldArg, newArg] : llvm::zip(templateBlock.getArguments(), newBlock->getArguments()))
@@ -183,6 +185,8 @@ void ONNXToSpatialPass::runOnOperation() {
annotateWeightsConstants(*entryFunc);
populateEmptyFunction(*entryFunc);
if (failed(encapsulateGlobalInstruction(*entryFunc))) {
signalPassFailure();
return;
@@ -376,8 +380,7 @@ LogicalResult ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcO
while (keep) {
keep = false;
for (auto& instruction : llvm::make_early_inc_range(funcOp.getOps())) {
if (isa<spatial::SpatCompute, spatial::SpatComputeBatch, spatial::SpatConcatOp, spatial::SpatExtractRowsOp>(
instruction)
if (isa<spatial::SpatCompute, spatial::SpatComputeBatch, spatial::SpatExtractRowsOp>(instruction)
|| isa<func::ReturnOp>(instruction))
continue;
@@ -490,6 +493,47 @@ LogicalResult ONNXToSpatialPass::promoteConstantInputsToWeights(func::FuncOp fun
return success();
}
void ONNXToSpatialPass::populateEmptyFunction(func::FuncOp funcOp) {
IRRewriter rewriter(&getContext());
IRMapping mapper;
SmallVector<spatial::SpatCompute> computes(funcOp.getOps<spatial::SpatCompute>());
if (!computes.empty())
return;
auto returnOp = llvm::cast<func::ReturnOp>(funcOp.getRegion().front().getTerminator());
rewriter.setInsertionPoint(returnOp);
SmallVector<Type> sourceTypes;
SmallVector<Location> sourceLoc;
for (auto source : funcOp.getArguments()) {
sourceTypes.push_back(source.getType());
sourceLoc.push_back(source.getLoc());
}
auto newCompute = spatial::SpatCompute::create(
rewriter, returnOp.getLoc(), returnOp.getOperandTypes(), funcOp.getArguments(), {}, {});
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLoc);
for (auto [bbArg, computeArg] : llvm::zip(BB->getArguments(), newCompute.getOperands()))
mapper.map(computeArg, bbArg);
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) sourceTypes.size()});
rewriter.setInsertionPointToEnd(BB);
for (Operation& inst : funcOp.getOps())
if (!isa<spatial::SpatCompute, func::ReturnOp>(&inst))
rewriter.clone(inst, mapper);
auto yield = spatial::SpatYieldOp::create(rewriter, funcOp.getLoc(), returnOp.getOperands());
for (size_t i = 0; i < yield.getNumOperands(); ++i)
yield.setOperand(i, mapper.lookupOrDefault(yield.getOperand(i)));
for (Operation& inst : llvm::make_early_inc_range(funcOp.getOps()))
if (!isa<spatial::SpatCompute, func::ReturnOp>(&inst)){
inst.dropAllUses();
rewriter.eraseOp(&inst);
}
for (auto [index, computeResult] : llvm::enumerate(newCompute.getResults()))
returnOp.setOperand(index, computeResult);
}
std::unique_ptr<Pass> createONNXToSpatialPass() { return std::make_unique<ONNXToSpatialPass>(); }
} // namespace onnx_mlir

View File

@@ -23,6 +23,16 @@ using namespace mlir;
namespace onnx_mlir {
namespace {
struct EraseSpatNopPattern : public mlir::OpRewritePattern<spatial::SpatNopOp> {
using OpRewritePattern<spatial::SpatNopOp>::OpRewritePattern;
mlir::LogicalResult matchAndRewrite(spatial::SpatNopOp op,
mlir::PatternRewriter &rewriter) const override {
rewriter.eraseOp(op);
return mlir::success();
}
};
static std::optional<unsigned> getDirectComputeInputIndex(Operation* owner, unsigned operandNumber) {
if (auto compute = dyn_cast<spatial::SpatCompute>(owner)) {
unsigned inputCount = compute.getInputs().size();
@@ -416,7 +426,7 @@ struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncO
} // namespace
void populateGlobalTensorToMemrefPatterns(RewritePatternSet& patterns) {
patterns.add<MoveExtractSliceIntoCompute, FuncOpArgToGlobalMemoryPattern, ArithConstToGlobalMemoryPattern>(
patterns.add<EraseSpatNopPattern, MoveExtractSliceIntoCompute, FuncOpArgToGlobalMemoryPattern, ArithConstToGlobalMemoryPattern>(
patterns.getContext());
}