This commit is contained in:
ilgeco
2026-06-24 15:52:07 +02:00
parent 2b4115699a
commit 62dd40ee89
47 changed files with 7993 additions and 1100 deletions
@@ -18,6 +18,7 @@
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "ONNXToSpatialVerifier.hpp"
using namespace mlir;
@@ -41,10 +42,16 @@ struct ONNXToSpatialPass : PassWrapper<ONNXToSpatialPass, OperationPass<ModuleOp
static void populateEmptyFunction(func::FuncOp funcOp) {
IRRewriter rewriter(funcOp.getContext());
IRMapping mapper;
SmallVector<spatial::SpatCompute> computes(funcOp.getOps<spatial::SpatCompute>());
SmallVector<spatial::SpatComputeBatch> computeBatches(funcOp.getOps<spatial::SpatComputeBatch>());
if (!computes.empty() || !computeBatches.empty())
SmallVector<spatial::SpatGraphCompute> computes(funcOp.getOps<spatial::SpatGraphCompute>());
SmallVector<spatial::SpatGraphComputeBatch> computeBatches(funcOp.getOps<spatial::SpatGraphComputeBatch>());
SmallVector<spatial::SpatConv2DPlanOp> convPlans(funcOp.getOps<spatial::SpatConv2DPlanOp>());
SmallVector<spatial::SpatReluPlanOp> reluPlans(funcOp.getOps<spatial::SpatReluPlanOp>());
SmallVector<spatial::SpatReconciliatorOp> reconciliators(funcOp.getOps<spatial::SpatReconciliatorOp>());
SmallVector<spatial::SpatMaterializeLayoutOp> materializers(funcOp.getOps<spatial::SpatMaterializeLayoutOp>());
if (!computes.empty() || !computeBatches.empty() || !convPlans.empty() || !reluPlans.empty() || !reconciliators.empty()
|| !materializers.empty()) {
return;
}
auto returnOp = cast<func::ReturnOp>(funcOp.getFunctionBody().front().getTerminator());
rewriter.setInsertionPoint(returnOp);
@@ -58,7 +65,7 @@ static void populateEmptyFunction(func::FuncOp funcOp) {
sourceLocs.push_back(source.getLoc());
}
auto newCompute = spatial::SpatCompute::create(
auto newCompute = spatial::SpatGraphCompute::create(
rewriter, returnOp.getLoc(), returnOp.getOperandTypes(), funcOp.getArguments(), {}, {});
auto* newBlock = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLocs);
for (auto [blockArg, computeArg] : llvm::zip(newBlock->getArguments(), newCompute.getOperands()))
@@ -67,7 +74,7 @@ static void populateEmptyFunction(func::FuncOp funcOp) {
rewriter.setInsertionPointToEnd(newBlock);
for (Operation& op : funcOp.getOps())
if (!isa<spatial::SpatCompute, func::ReturnOp>(&op))
if (!isa<spatial::SpatGraphCompute, func::ReturnOp>(&op))
rewriter.clone(op, mapper);
auto yield = spatial::SpatYieldOp::create(rewriter, funcOp.getLoc(), returnOp.getOperands());
@@ -75,7 +82,7 @@ static void populateEmptyFunction(func::FuncOp funcOp) {
yield.setOperand(i, mapper.lookupOrDefault(yield.getOperand(i)));
for (Operation& op : llvm::make_early_inc_range(funcOp.getOps()))
if (!isa<spatial::SpatCompute, func::ReturnOp>(&op)) {
if (!isa<spatial::SpatGraphCompute, func::ReturnOp>(&op)) {
op.dropAllUses();
rewriter.eraseOp(&op);
}
@@ -152,6 +159,11 @@ void ONNXToSpatialPass::runOnOperation() {
return;
}
if (failed(verifyLogicalSpatialGraphInvariants(*entryFunc))) {
moduleOp.emitError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed after ONNX conversion");
signalPassFailure();
return;
}
ConversionTarget earlyPostTarget(*ctx);
earlyPostTarget.addLegalDialect<spatial::SpatialDialect,
ONNXDialect,
@@ -168,6 +180,11 @@ void ONNXToSpatialPass::runOnOperation() {
annotateWeightsConstants(*entryFunc);
if (failed(verifyLogicalSpatialGraphInvariants(*entryFunc))) {
moduleOp.emitError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed after weight annotation");
signalPassFailure();
return;
}
ConversionTarget postTarget(*ctx);
postTarget.addLegalDialect<spatial::SpatialDialect,
ONNXDialect,
@@ -176,11 +193,16 @@ void ONNXToSpatialPass::runOnOperation() {
affine::AffineDialect,
arith::ArithDialect,
scf::SCFDialect>();
postTarget.addDynamicallyLegalOp<spatial::SpatCompute>(
[](spatial::SpatCompute computeOp) { return !requiresPostRewrite(computeOp); });
postTarget.addDynamicallyLegalOp<spatial::SpatComputeBatch>(
[](spatial::SpatComputeBatch computeOp) { return !requiresPostRewrite(computeOp); });
postTarget.addDynamicallyLegalOp<spatial::SpatGraphCompute>(
[](spatial::SpatGraphCompute computeOp) { return !requiresPostRewrite(computeOp); });
postTarget.addDynamicallyLegalOp<spatial::SpatGraphComputeBatch>(
[](spatial::SpatGraphComputeBatch computeOp) { return !requiresPostRewrite(computeOp); });
if (failed(verifyLogicalSpatialGraphInvariants(*entryFunc))) {
moduleOp.emitError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed before post rewrites");
signalPassFailure();
return;
}
RewritePatternSet postPatterns(ctx);
populatePostPatterns(postPatterns, ctx);
if (failed(applyPartialConversion(*entryFunc, postTarget, std::move(postPatterns)))) {
@@ -191,6 +213,11 @@ void ONNXToSpatialPass::runOnOperation() {
populateEmptyFunction(*entryFunc);
if (failed(verifyLogicalSpatialGraphInvariants(*entryFunc))) {
moduleOp.emitError("RAPTOR_PHASE_CHECK logical Spatial graph verification failed after ONNX-to-Spatial");
signalPassFailure();
return;
}
dumpModule(moduleOp, "spatial0");
if (failed(verifyONNXToSpatial(*entryFunc))) {