#include "Patterns.hpp" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include #include "src/Accelerators/PIM/Common/PimCommon.hpp" using namespace mlir; namespace onnx_mlir { namespace { struct PimConstantFoldingPass : PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PimConstantFoldingPass) StringRef getArgument() const override { return "pim-constant-folding-pass"; } StringRef getDescription() const override { return "Fold host-side constant expressions before PIM verification"; } LogicalResult initialize(MLIRContext* context) override { RewritePatternSet owningPatterns(context); for (auto* dialect : context->getLoadedDialects()) dialect->getCanonicalizationPatterns(owningPatterns); for (RegisteredOperationName op : context->getRegisteredOperations()) op.getCanonicalizationPatterns(owningPatterns, context); populateConstantFoldingConstantPatterns(owningPatterns); populateConstantFoldingSubviewPatterns(owningPatterns); patterns = std::make_shared(std::move(owningPatterns)); return success(); } void runOnOperation() override { GreedyRewriteConfig config; config.enableFolding(); if (failed(applyPatternsGreedily(getOperation(), *patterns, config))) { signalPassFailure(); return; } dumpModule(getOperation(), "pim2_folded"); } std::shared_ptr patterns; }; } // namespace std::unique_ptr createPimConstantFoldingPass() { return std::make_unique(); } } // namespace onnx_mlir