add relu lowering
Validate Operations / validate-operations (push) Failing after 2h50m56s

add relu validation
add spatial compute helper
minor refactors
This commit is contained in:
NiccoloN
2026-03-25 11:03:03 +01:00
parent 4e19650b80
commit 742df111e3
29 changed files with 258 additions and 116 deletions
@@ -260,6 +260,6 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
return success();
}
void populateConvOpPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert<ConvToGemm>(ctx); }
void populateConvPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert<ConvToGemm>(ctx); }
} // namespace onnx_mlir
@@ -58,8 +58,7 @@ struct GemmToManyGemv : OpConversionPattern<ONNXGemmOp> {
};
struct GemvToSpatialCompute : OpConversionPattern<ONNXGemmOp> {
GemvToSpatialCompute(MLIRContext* ctx)
: OpConversionPattern(ctx, 1) {}
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(ONNXGemmOp gemmOp,
ONNXGemmOpAdaptor gemmOpAdaptor,
@@ -352,7 +351,7 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
return success();
}
void populateOnnxGemmOpPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
void populateGemmPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<GemmToManyGemv>(ctx);
patterns.insert<GemvToSpatialCompute>(ctx);
}
@@ -257,7 +257,7 @@ struct PoolToSpatialCompute<ONNXAveragePoolOp>
} // namespace
void populatePoolTilingPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
void populatePoolPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<PoolToSpatialCompute<ONNXMaxPoolSingleOutOp>>(ctx);
patterns.insert<PoolToSpatialCompute<ONNXAveragePoolOp>>(ctx);
}
@@ -0,0 +1,33 @@
#include "mlir/Transforms/DialectConversion.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
struct ReluToSpatialCompute : OpConversionPattern<ONNXReluOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(ONNXReluOp reluOp, ONNXReluOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override {
Location loc = reluOp.getLoc();
Type resultType = reluOp.getResult().getType();
constexpr size_t numInputs = 1;
auto computeOp = createSpatCompute<numInputs>(rewriter, loc, resultType, {}, adaptor.getX(), [&](Value x) {
auto spatReluOp = spatial::SpatReluOp::create(rewriter, loc, resultType, x);
spatial::SpatYieldOp::create(rewriter, loc, spatReluOp.getResult());
});
rewriter.replaceOp(reluOp, computeOp);
return success();
}
};
} // namespace
void populateReluPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.add<ReluToSpatialCompute>(ctx); }
} // namespace onnx_mlir
@@ -9,8 +9,7 @@ using namespace mlir;
namespace onnx_mlir {
struct Concat : public OpConversionPattern<ONNXConcatOp> {
Concat(MLIRContext* ctx)
: OpConversionPattern(ctx) {}
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(ONNXConcatOp maxpoolOp,
ONNXConcatOpAdaptor adaptor,
@@ -24,7 +23,7 @@ struct Concat : public OpConversionPattern<ONNXConcatOp> {
}
};
void populateONNXConcatToTensorConcatPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
void populateConcatPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<Concat>(ctx);
}
@@ -114,6 +114,6 @@ struct Reshape : OpConversionPattern<ONNXReshapeOp> {
} // namespace
void populateReshapeConversionPattern(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert<Reshape>(ctx); }
void populateReshapePatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert<Reshape>(ctx); }
} // namespace onnx_mlir