Files
Raptor/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Relu.cpp
T
2026-06-24 15:52:07 +02:00

31 lines
1001 B
C++

#include "mlir/Transforms/DialectConversion.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
struct ReluToSpatialCompute : OpConversionPattern<ONNXReluOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(ONNXReluOp reluOp, ONNXReluOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override {
Location loc = reluOp.getLoc();
Type resultType = reluOp.getResult().getType();
auto reluPlan = spatial::SpatReluPlanOp::create(
rewriter, loc, resultType, adaptor.getX(), rewriter.getStringAttr("nchw"));
rewriter.replaceOp(reluOp, reluPlan.getResult());
return success();
}
};
} // namespace
void populateReluPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.add<ReluToSpatialCompute>(ctx); }
} // namespace onnx_mlir