#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_os_ostream.h" #include #include "Common/PimCommon.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Pass/PimPasses.hpp" #include "src/Compiler/CompilerOptions.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" using namespace mlir; namespace onnx_mlir { bool haveSameStaticShape(Value lhs, Value rhs); namespace { #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc" struct ONNXToSpatialPass : PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ONNXToSpatialPass) StringRef getArgument() const override { return "convert-onnx-to-spatial"; } StringRef getDescription() const override { return "Lower ONNX ops to Spatial ops."; } ONNXToSpatialPass() = default; ONNXToSpatialPass(const ONNXToSpatialPass& pass) {} void runOnOperation() override; private: void annotateWeightsConstants(func::FuncOp funcOp) const; }; } // namespace void ONNXToSpatialPass::runOnOperation() { ModuleOp moduleOp = getOperation(); MLIRContext* ctx = &getContext(); RewritePatternSet mergeActivationPatterns(ctx); mergeActivationPatterns.add(ctx); mergeActivationPatterns.add(ctx); mergeActivationPatterns.add(ctx); mergeActivationPatterns.add(ctx); mergeActivationPatterns.add(ctx); mergeActivationPatterns.add(ctx); populateMatMulRewritePatterns(mergeActivationPatterns, ctx); if (failed(applyPatternsGreedily(moduleOp, std::move(mergeActivationPatterns)))) llvm::dbgs() << "Failed to merge activation patterns, continuing...\n"; IRRewriter rewriter(moduleOp); auto entryFunc = getPimEntryFunc(moduleOp); if (failed(entryFunc)) { signalPassFailure(); return; } ConversionTarget target(*ctx); target.addLegalDialect(); target.addDynamicallyLegalOp([](ONNXMatMulOp op) { return cast(op.getY().getType()).getRank() != 2; }); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); RewritePatternSet patterns(ctx); patterns.add(ctx); populateConvOpPatterns(patterns, ctx); populatePoolTilingPattern(patterns, ctx); populateOnnxGemmOpPatterns(patterns, ctx); populateReshapeConversionPattern(patterns, ctx); populateONNXConcatToTensorConcatPattern(patterns, ctx); if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) { signalPassFailure(); return; } // Count the number of compute ops and check they do not exceed the core count if (coresCount != -1) { int computeOpsCount = 0; for (auto& op : entryFunc->getFunctionBody().front().getOperations()) if (isa(op)) computeOpsCount++; if (computeOpsCount > coresCount) { llvm::dbgs() << "Number of compute ops exceeds the core count\n"; signalPassFailure(); return; } } PassManager cleanupPM(ctx); cleanupPM.addPass(createCanonicalizerPass()); if (failed(cleanupPM.run(moduleOp))) llvm::dbgs() << "Failed to run canonicalization cleanup, continuing...\n"; annotateWeightsConstants(*entryFunc); // Dump to file for debug dumpModule(moduleOp, "spatial"); } void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const { funcOp.walk([&](arith::ConstantOp constantOp) { bool isAlwaysWeight = llvm::all_of(constantOp->getUsers(), [](auto user) -> bool { return isa(user); }); if (isAlwaysWeight) markWeightAlways(constantOp); }); } std::unique_ptr createONNXToSpatialPass() { return std::make_unique(); } } // namespace onnx_mlir