#include "mlir/Transforms/Passes.h" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerUtils.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Pass/PIMPasses.h" #include "src/Compiler/CompilerPasses.hpp" #define DEBUG_TYPE "PimCompilerUtils" using namespace mlir; using namespace onnx_mlir; namespace onnx_mlir { void addPassesPim(OwningOpRef& module, PassManager& pm, EmissionTargetType& emissionTarget, std::string outputNameNoExt) { if (pimOnlyCodegen) { // Skip all the lowering passes and directly generate code for PIM. return; } if (emissionTarget >= EmitONNXIR) addONNXToMLIRPasses(pm, /*target CPU*/ false); if (pimEmissionTarget >= EmitSpatial) { pm.addPass(createONNXToSpatialPass()); // pm.addPass(createCountInstructionPass()); pm.addPass(createMessagePass("Onnx lowered to Spatial")); } if (pimEmissionTarget >= EmitPim) { pm.addPass(createSpatialToPimPass()); // pm.addPass(createCountInstructionPass()); pm.addPass(createMessagePass("Spatial lowered to Pim")); } if (pimEmissionTarget >= EmitPimBufferized) { pm.addPass(createPimBufferizationPass()); // pm.addPass(createCountInstructionPass()); pm.addPass(createMessagePass("Pim bufferized")); } if (pimEmissionTarget >= EmitPimCodegen) { pm.addPass(createPimConstantFoldingPass()); pm.addPass(createMessagePass("Pim constants folded")); pm.addPass(createPimMaterializeConstantsPass()); pm.addPass(createPimVerificationPass()); pm.addPass(createMessagePass("Pim verified")); pm.addPass(createEmitPimJsonPass()); // pm.addPass(createCountInstructionPass()); pm.addPass(createMessagePass("Pim json code emitted")); } } } // namespace onnx_mlir