37 lines
1.2 KiB
C++
37 lines
1.2 KiB
C++
#include "mlir/Transforms/DialectConversion.h"
|
|
|
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace onnx_mlir {
|
|
namespace {
|
|
|
|
struct SigmoidToSpatialCompute : OpConversionPattern<ONNXSigmoidOp> {
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
LogicalResult matchAndRewrite(ONNXSigmoidOp sigmoidOp,
|
|
ONNXSigmoidOpAdaptor adaptor,
|
|
ConversionPatternRewriter& rewriter) const override {
|
|
Location loc = sigmoidOp.getLoc();
|
|
Type resultType = sigmoidOp.getResult().getType();
|
|
constexpr size_t numInputs = 1;
|
|
auto computeOp = createSpatCompute<numInputs>(rewriter, loc, resultType, {}, adaptor.getX(), [&](Value x) {
|
|
auto spatSigmoidOp = spatial::SpatSigmoidOp::create(rewriter, loc, resultType, x);
|
|
spatial::SpatYieldOp::create(rewriter, loc, spatSigmoidOp.getResult());
|
|
});
|
|
rewriter.replaceOp(sigmoidOp, computeOp);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void populateSigmoidPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
|
patterns.add<SigmoidToSpatialCompute>(ctx);
|
|
}
|
|
|
|
} // namespace onnx_mlir
|