31 lines
1001 B
C++
31 lines
1001 B
C++
#include "mlir/Transforms/DialectConversion.h"
|
|
|
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace onnx_mlir {
|
|
namespace {
|
|
|
|
struct ReluToSpatialCompute : OpConversionPattern<ONNXReluOp> {
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(ONNXReluOp reluOp, ONNXReluOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override {
|
|
Location loc = reluOp.getLoc();
|
|
Type resultType = reluOp.getResult().getType();
|
|
auto reluPlan = spatial::SpatReluPlanOp::create(
|
|
rewriter, loc, resultType, adaptor.getX(), rewriter.getStringAttr("nchw"));
|
|
rewriter.replaceOp(reluOp, reluPlan.getResult());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void populateReluPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.add<ReluToSpatialCompute>(ctx); }
|
|
|
|
} // namespace onnx_mlir
|