#include "mlir/Transforms/DialectConversion.h" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" using namespace mlir; namespace onnx_mlir { namespace { struct SigmoidToSpatialCompute : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(ONNXSigmoidOp sigmoidOp, ONNXSigmoidOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { Location loc = sigmoidOp.getLoc(); Type resultType = sigmoidOp.getResult().getType(); constexpr size_t numInputs = 1; auto computeOp = createSpatCompute(rewriter, loc, resultType, {}, adaptor.getX(), [&](Value x) { auto spatSigmoidOp = spatial::SpatSigmoidOp::create(rewriter, loc, resultType, x); spatial::SpatYieldOp::create(rewriter, loc, spatSigmoidOp.getResult()); }); rewriter.replaceOp(sigmoidOp, computeOp); return success(); } }; } // namespace void populateSigmoidPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.add(ctx); } } // namespace onnx_mlir