#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/IRMapping.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Debug.h" #include "Common/Common.hpp" #include "Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostLegality.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Compiler/CompilerOptions.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" using namespace mlir; namespace onnx_mlir { namespace { struct ONNXToSpatialPass : PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ONNXToSpatialPass) StringRef getArgument() const override { return "convert-onnx-to-spatial"; } StringRef getDescription() const override { return "Lower ONNX ops to Spatial ops."; } ONNXToSpatialPass() = default; ONNXToSpatialPass(const ONNXToSpatialPass& pass) {} void runOnOperation() override; }; } // namespace static void populateEmptyFunction(func::FuncOp funcOp) { IRRewriter rewriter(funcOp.getContext()); IRMapping mapper; SmallVector computes(funcOp.getOps()); if (!computes.empty()) return; auto returnOp = cast(funcOp.getFunctionBody().front().getTerminator()); rewriter.setInsertionPoint(returnOp); SmallVector sourceTypes; SmallVector sourceLocs; sourceTypes.reserve(funcOp.getNumArguments()); sourceLocs.reserve(funcOp.getNumArguments()); for (Value source : funcOp.getArguments()) { sourceTypes.push_back(source.getType()); sourceLocs.push_back(source.getLoc()); } auto newCompute = spatial::SpatCompute::create( rewriter, returnOp.getLoc(), returnOp.getOperandTypes(), funcOp.getArguments(), {}, {}); auto* newBlock = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLocs); for (auto [blockArg, computeArg] : llvm::zip(newBlock->getArguments(), newCompute.getOperands())) mapper.map(computeArg, blockArg); newCompute.getProperties().setOperandSegmentSizes({0, static_cast(sourceTypes.size())}); rewriter.setInsertionPointToEnd(newBlock); for (Operation& op : funcOp.getOps()) if (!isa(&op)) rewriter.clone(op, mapper); auto yield = spatial::SpatYieldOp::create(rewriter, funcOp.getLoc(), returnOp.getOperands()); for (size_t i = 0; i < yield.getNumOperands(); ++i) yield.setOperand(i, mapper.lookupOrDefault(yield.getOperand(i))); for (Operation& op : llvm::make_early_inc_range(funcOp.getOps())) if (!isa(&op)) { op.dropAllUses(); rewriter.eraseOp(&op); } for (auto [index, computeResult] : llvm::enumerate(newCompute.getResults())) returnOp.setOperand(index, computeResult); } void ONNXToSpatialPass::runOnOperation() { ModuleOp moduleOp = getOperation(); MLIRContext* ctx = &getContext(); RewritePatternSet prePatterns(ctx); populatePrePatterns(prePatterns, ctx); if (failed(applyPatternsGreedily(moduleOp, std::move(prePatterns)))) moduleOp.emitWarning("failed to apply ONNX-to-Spatial pre-patterns; continuing"); auto entryFunc = getPimEntryFunc(moduleOp); if (failed(entryFunc)) { signalPassFailure(); return; } ConversionTarget target(*ctx); target.addLegalDialect(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); RewritePatternSet conversionPatterns(ctx); populateConversionPatterns(conversionPatterns, ctx); if (failed(applyPartialConversion(moduleOp, target, std::move(conversionPatterns)))) { signalPassFailure(); return; } RewritePatternSet earlyPostPatterns(ctx); populateEarlyPostPatterns(earlyPostPatterns, ctx); if (failed(applyPatternsGreedily(*entryFunc, std::move(earlyPostPatterns)))) { signalPassFailure(); return; } if (coresCount != -1) { int computeOpsCount = 0; for (Operation& op : entryFunc->getFunctionBody().front().getOperations()) if (isa(op)) computeOpsCount++; if (computeOpsCount > coresCount) { entryFunc->emitError() << "number of compute ops (" << computeOpsCount << ") exceeds the core count (" << coresCount << ")"; signalPassFailure(); return; } } PassManager cleanupPM(ctx); cleanupPM.addPass(createCanonicalizerPass()); if (failed(cleanupPM.run(moduleOp))) moduleOp.emitWarning("failed to run ONNX-to-Spatial canonicalization cleanup; continuing"); annotateWeightsConstants(*entryFunc); RewritePatternSet postPatterns(ctx); populatePostPatterns(postPatterns, ctx); if (failed(applyPatternsGreedily(*entryFunc, std::move(postPatterns)))) { signalPassFailure(); return; } if (failed(verifyONNXToSpatialHostLegality(*entryFunc))) { signalPassFailure(); return; } populateEmptyFunction(*entryFunc); dumpModule(moduleOp, "spatial0"); } std::unique_ptr createONNXToSpatialPass() { return std::make_unique(); } } // namespace onnx_mlir