Files
Raptor/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp
T
ilgeco 78e97f9fd8
Validate Operations / validate-operations (push) Waiting to run
Bose
2026-06-26 17:45:27 +02:00

232 lines
9.2 KiB
C++

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/SmallVector.h"
#include "Common/Common.hpp"
#include "Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "ONNXToSpatialVerifier.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
struct ONNXToSpatialPass : PassWrapper<ONNXToSpatialPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ONNXToSpatialPass)
StringRef getArgument() const override { return "convert-onnx-to-spatial"; }
StringRef getDescription() const override { return "Lower ONNX ops to Spatial ops."; }
ONNXToSpatialPass() = default;
ONNXToSpatialPass(const ONNXToSpatialPass& pass) {}
void runOnOperation() override;
};
} // namespace
static void populateEmptyFunction(func::FuncOp funcOp) {
IRRewriter rewriter(funcOp.getContext());
IRMapping mapper;
SmallVector<spatial::SpatGraphCompute> computes(funcOp.getOps<spatial::SpatGraphCompute>());
SmallVector<spatial::SpatGraphComputeBatch> computeBatches(funcOp.getOps<spatial::SpatGraphComputeBatch>());
SmallVector<spatial::SpatConv2DPlanOp> convPlans(funcOp.getOps<spatial::SpatConv2DPlanOp>());
SmallVector<spatial::SpatReluPlanOp> reluPlans(funcOp.getOps<spatial::SpatReluPlanOp>());
SmallVector<spatial::SpatBlueprintOp> blueprints(funcOp.getOps<spatial::SpatBlueprintOp>());
SmallVector<spatial::SpatMaterializeLayoutOp> materializers(funcOp.getOps<spatial::SpatMaterializeLayoutOp>());
if (!computes.empty() || !computeBatches.empty() || !convPlans.empty() || !reluPlans.empty() || !blueprints.empty()
|| !materializers.empty()) {
return;
}
auto returnOp = cast<func::ReturnOp>(funcOp.getFunctionBody().front().getTerminator());
rewriter.setInsertionPoint(returnOp);
SmallVector<Type> sourceTypes;
SmallVector<Location> sourceLocs;
sourceTypes.reserve(funcOp.getNumArguments());
sourceLocs.reserve(funcOp.getNumArguments());
for (Value source : funcOp.getArguments()) {
sourceTypes.push_back(source.getType());
sourceLocs.push_back(source.getLoc());
}
auto newCompute = spatial::SpatGraphCompute::create(
rewriter, returnOp.getLoc(), returnOp.getOperandTypes(), funcOp.getArguments(), {}, {});
auto* newBlock = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLocs);
for (auto [blockArg, computeArg] : llvm::zip(newBlock->getArguments(), newCompute.getOperands()))
mapper.map(computeArg, blockArg);
newCompute.getProperties().setOperandSegmentSizes({0, static_cast<int>(sourceTypes.size())});
rewriter.setInsertionPointToEnd(newBlock);
for (Operation& op : funcOp.getOps())
if (!isa<spatial::SpatGraphCompute, func::ReturnOp>(&op))
rewriter.clone(op, mapper);
auto yield = spatial::SpatYieldOp::create(rewriter, funcOp.getLoc(), returnOp.getOperands());
for (size_t i = 0; i < yield.getNumOperands(); ++i)
yield.setOperand(i, mapper.lookupOrDefault(yield.getOperand(i)));
for (Operation& op : llvm::make_early_inc_range(funcOp.getOps()))
if (!isa<spatial::SpatGraphCompute, func::ReturnOp>(&op)) {
op.dropAllUses();
rewriter.eraseOp(&op);
}
for (auto [index, computeResult] : llvm::enumerate(newCompute.getResults()))
returnOp.setOperand(index, computeResult);
}
void ONNXToSpatialPass::runOnOperation() {
ModuleOp moduleOp = getOperation();
MLIRContext* ctx = &getContext();
ConversionTarget preTarget(*ctx);
preTarget.addLegalDialect<spatial::SpatialDialect,
ONNXDialect,
linalg::LinalgDialect,
tensor::TensorDialect,
affine::AffineDialect,
arith::ArithDialect,
scf::SCFDialect>();
preTarget.addIllegalOp<ONNXConstantOp, ONNXFlattenOp>();
RewritePatternSet prePatterns(ctx);
populatePrePatterns(prePatterns, ctx);
if (failed(applyPartialConversion(moduleOp, preTarget, std::move(prePatterns)))) {
moduleOp.emitError("failed to apply ONNX-to-Spatial pre-rewrites");
signalPassFailure();
return;
}
auto entryFunc = getPimEntryFunc(moduleOp);
if (failed(entryFunc)) {
moduleOp.emitError("failed to locate the PIM entry function during ONNX-to-Spatial lowering");
signalPassFailure();
return;
}
ConversionTarget target(*ctx);
target.addLegalDialect<spatial::SpatialDialect,
ONNXDialect,
linalg::LinalgDialect,
tensor::TensorDialect,
affine::AffineDialect,
arith::ArithDialect,
scf::SCFDialect>();
target.addIllegalOp<ONNXMatMulOp>();
target.addIllegalOp<ONNXTransposeOp>();
target.addIllegalOp<ONNXAddOp>();
target.addIllegalOp<ONNXSubOp>();
target.addIllegalOp<ONNXDivOp>();
target.addIllegalOp<ONNXMulOp>();
target.addIllegalOp<ONNXGemmOp>();
target.addIllegalOp<ONNXConvOp>();
target.addIllegalOp<ONNXMaxPoolSingleOutOp>();
target.addIllegalOp<ONNXAveragePoolOp>();
target.addIllegalOp<ONNXReluOp>();
target.addIllegalOp<ONNXSigmoidOp>();
target.addIllegalOp<ONNXSoftmaxOp>();
target.addIllegalOp<ONNXConcatOp>();
target.addIllegalOp<ONNXGatherOp>();
target.addIllegalOp<ONNXReshapeOp>();
target.addIllegalOp<ONNXResizeOp>();
target.addIllegalOp<ONNXSliceOp>();
target.addIllegalOp<ONNXLRNOp>();
target.addIllegalOp<ONNXReduceMeanOp>();
target.addIllegalOp<ONNXReduceMeanV13Op>();
target.addIllegalOp<ONNXSplitOp>();
RewritePatternSet conversionPatterns(ctx);
populateConversionPatterns(conversionPatterns, ctx);
if (failed(applyPartialConversion(moduleOp, target, std::move(conversionPatterns)))) {
moduleOp.emitError("failed to convert required ONNX ops to Spatial ops");
signalPassFailure();
return;
}
if (failed(verifyLogicalSpatialGraphInvariants(*entryFunc))) {
moduleOp.emitError("logical Spatial graph verification failed after ONNX conversion");
signalPassFailure();
return;
}
ConversionTarget earlyPostTarget(*ctx);
earlyPostTarget.addLegalDialect<spatial::SpatialDialect,
ONNXDialect,
linalg::LinalgDialect,
tensor::TensorDialect,
affine::AffineDialect,
arith::ArithDialect,
scf::SCFDialect>();
PassManager cleanupPM(ctx);
cleanupPM.addPass(createCanonicalizerPass());
if (failed(cleanupPM.run(moduleOp)))
moduleOp.emitWarning("failed to run ONNX-to-Spatial canonicalization cleanup; continuing");
annotateWeightsConstants(*entryFunc);
if (failed(verifyLogicalSpatialGraphInvariants(*entryFunc))) {
moduleOp.emitError("logical Spatial graph verification failed after weight annotation");
signalPassFailure();
return;
}
ConversionTarget postTarget(*ctx);
postTarget.addLegalDialect<spatial::SpatialDialect,
ONNXDialect,
linalg::LinalgDialect,
tensor::TensorDialect,
affine::AffineDialect,
arith::ArithDialect,
scf::SCFDialect>();
postTarget.addDynamicallyLegalOp<spatial::SpatGraphCompute>(
[](spatial::SpatGraphCompute computeOp) { return !requiresPostRewrite(computeOp); });
postTarget.addDynamicallyLegalOp<spatial::SpatGraphComputeBatch>(
[](spatial::SpatGraphComputeBatch computeOp) { return !requiresPostRewrite(computeOp); });
if (failed(verifyLogicalSpatialGraphInvariants(*entryFunc))) {
moduleOp.emitError("logical Spatial graph verification failed before post rewrites");
signalPassFailure();
return;
}
RewritePatternSet postPatterns(ctx);
populatePostPatterns(postPatterns, ctx);
if (failed(applyPartialConversion(*entryFunc, postTarget, std::move(postPatterns)))) {
moduleOp.emitError("failed to normalize weight-like Spatial compute operands before Spatial-to-PIM lowering");
signalPassFailure();
return;
}
populateEmptyFunction(*entryFunc);
if (failed(verifyLogicalSpatialGraphInvariants(*entryFunc))) {
moduleOp.emitError("logical Spatial graph verification failed after ONNX-to-Spatial");
signalPassFailure();
return;
}
dumpModule(moduleOp, "spatial0");
if (failed(verifyONNXToSpatial(*entryFunc))) {
moduleOp.emitError("ONNX-to-Spatial host legality verification failed");
signalPassFailure();
}
}
std::unique_ptr<Pass> createONNXToSpatialPass() { return std::make_unique<ONNXToSpatialPass>(); }
} // namespace onnx_mlir