#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/IRMapping.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" #include "mlir/Transforms/WalkPatternRewriteDriver.h" #include "llvm/ADT/SmallVector.h" #include "Common/Common.hpp" #include "Common/PimCommon.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostLegality.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Compiler/CompilerOptions.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" using namespace mlir; namespace onnx_mlir { namespace { struct ONNXToSpatialPass : PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ONNXToSpatialPass) StringRef getArgument() const override { return "convert-onnx-to-spatial"; } StringRef getDescription() const override { return "Lower ONNX ops to Spatial ops."; } ONNXToSpatialPass() = default; ONNXToSpatialPass(const ONNXToSpatialPass& pass) {} void runOnOperation() override; }; } // namespace static void populateEmptyFunction(func::FuncOp funcOp) { IRRewriter rewriter(funcOp.getContext()); IRMapping mapper; SmallVector computes(funcOp.getOps()); SmallVector computeBatches(funcOp.getOps()); if (!computes.empty() || !computeBatches.empty()) return; auto returnOp = cast(funcOp.getFunctionBody().front().getTerminator()); rewriter.setInsertionPoint(returnOp); SmallVector sourceTypes; SmallVector sourceLocs; sourceTypes.reserve(funcOp.getNumArguments()); sourceLocs.reserve(funcOp.getNumArguments()); for (Value source : funcOp.getArguments()) { sourceTypes.push_back(source.getType()); sourceLocs.push_back(source.getLoc()); } auto newCompute = spatial::SpatCompute::create( rewriter, returnOp.getLoc(), returnOp.getOperandTypes(), funcOp.getArguments(), {}, {}); auto* newBlock = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLocs); for (auto [blockArg, computeArg] : llvm::zip(newBlock->getArguments(), newCompute.getOperands())) mapper.map(computeArg, blockArg); newCompute.getProperties().setOperandSegmentSizes({0, static_cast(sourceTypes.size())}); rewriter.setInsertionPointToEnd(newBlock); for (Operation& op : funcOp.getOps()) if (!isa(&op)) rewriter.clone(op, mapper); auto yield = spatial::SpatYieldOp::create(rewriter, funcOp.getLoc(), returnOp.getOperands()); for (size_t i = 0; i < yield.getNumOperands(); ++i) yield.setOperand(i, mapper.lookupOrDefault(yield.getOperand(i))); for (Operation& op : llvm::make_early_inc_range(funcOp.getOps())) if (!isa(&op)) { op.dropAllUses(); rewriter.eraseOp(&op); } for (auto [index, computeResult] : llvm::enumerate(newCompute.getResults())) returnOp.setOperand(index, computeResult); } static void wrapTopLevelRuntimeTransposes(func::FuncOp funcOp) { IRRewriter rewriter(funcOp.getContext()); Block& entryBlock = funcOp.getFunctionBody().front(); for (Operation& op : llvm::make_early_inc_range(entryBlock)) { auto transposeOp = dyn_cast(&op); if (!transposeOp || isHostFoldableOp(transposeOp)) continue; // Transpose stays globally legal because constant/view-only cases are // allowed on the host. Any residual runtime transpose must be sunk into // spat.compute before the host legality check. auto resultType = transposeOp.getResult().getType(); rewriter.setInsertionPoint(transposeOp); auto computeOp = createSpatCompute<1>( rewriter, transposeOp.getLoc(), TypeRange {resultType}, {}, ValueRange {transposeOp.getData()}, [&](Value input) { Value transposed = ONNXTransposeOp::create(rewriter, transposeOp.getLoc(), resultType, input, transposeOp.getPermAttr()); spatial::SpatYieldOp::create(rewriter, transposeOp.getLoc(), transposed); }); rewriter.replaceOp(transposeOp, computeOp.getResult(0)); } } void ONNXToSpatialPass::runOnOperation() { ModuleOp moduleOp = getOperation(); MLIRContext* ctx = &getContext(); ConversionTarget preTarget(*ctx); preTarget.addLegalDialect(); preTarget.addIllegalOp(); RewritePatternSet prePatterns(ctx); populatePrePatterns(prePatterns, ctx); if (failed(applyPartialConversion(moduleOp, preTarget, std::move(prePatterns)))) { moduleOp.emitError("failed to apply ONNX-to-Spatial pre-rewrites"); signalPassFailure(); return; } auto entryFunc = getPimEntryFunc(moduleOp); if (failed(entryFunc)) { moduleOp.emitError("failed to locate the PIM entry function during ONNX-to-Spatial lowering"); signalPassFailure(); return; } RewritePatternSet matmulPatterns(ctx); populateMatMulRewritePatterns(matmulPatterns, ctx); walkAndApplyPatterns(moduleOp, std::move(matmulPatterns)); bool hasUnloweredMatMul = false; moduleOp.walk([&](ONNXMatMulOp matmulOp) { hasUnloweredMatMul = true; matmulOp.emitOpError("remaining ONNX MatMul before the required ONNX-to-Spatial conversion"); }); if (hasUnloweredMatMul) { moduleOp.emitError("failed to lower all ONNX MatMul ops before ONNX-to-Spatial conversion"); signalPassFailure(); return; } ConversionTarget target(*ctx); target.addLegalDialect(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); RewritePatternSet conversionPatterns(ctx); populateConversionPatterns(conversionPatterns, ctx); if (failed(applyPartialConversion(moduleOp, target, std::move(conversionPatterns)))) { moduleOp.emitError("failed to convert required ONNX ops to Spatial ops"); signalPassFailure(); return; } ConversionTarget earlyPostTarget(*ctx); earlyPostTarget.addLegalDialect(); PassManager cleanupPM(ctx); cleanupPM.addPass(createCanonicalizerPass()); if (failed(cleanupPM.run(moduleOp))) moduleOp.emitWarning("failed to run ONNX-to-Spatial canonicalization cleanup; continuing"); annotateWeightsConstants(*entryFunc); ConversionTarget postTarget(*ctx); postTarget.addLegalDialect(); postTarget.addDynamicallyLegalOp( [](spatial::SpatCompute computeOp) { return !requiresPostRewrite(computeOp); }); postTarget.addDynamicallyLegalOp( [](spatial::SpatComputeBatch computeOp) { return !requiresPostRewrite(computeOp); }); RewritePatternSet postPatterns(ctx); populatePostPatterns(postPatterns, ctx); if (failed(applyPartialConversion(*entryFunc, postTarget, std::move(postPatterns)))) { moduleOp.emitError("failed to normalize weight-like Spatial compute operands before Spatial-to-PIM lowering"); signalPassFailure(); return; } wrapTopLevelRuntimeTransposes(*entryFunc); if (failed(verifyONNXToSpatialHostLegality(*entryFunc))) { moduleOp.emitError("ONNX-to-Spatial host legality verification failed"); signalPassFailure(); return; } populateEmptyFunction(*entryFunc); dumpModule(moduleOp, "spatial0"); } std::unique_ptr createONNXToSpatialPass() { return std::make_unique(); } } // namespace onnx_mlir