#include "mlir/Transforms/DialectConversion.h" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" using namespace mlir; namespace onnx_mlir { namespace { struct ReluToSpatialCompute : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(ONNXReluOp reluOp, ONNXReluOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { Location loc = reluOp.getLoc(); Type resultType = reluOp.getResult().getType(); constexpr size_t numInputs = 1; auto computeOp = createSpatCompute(rewriter, loc, resultType, {}, adaptor.getX(), [&](Value x) { auto spatReluOp = spatial::SpatReluOp::create(rewriter, loc, resultType, x); spatial::SpatYieldOp::create(rewriter, loc, spatReluOp.getResult()); }); rewriter.replaceOp(reluOp, computeOp); return success(); } }; } // namespace void populateReluPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.add(ctx); } } // namespace onnx_mlir