add relu lowering
Some checks failed
Validate Operations / validate-operations (push) Failing after 2h50m56s

add relu validation
add spatial compute helper
minor refactors
This commit is contained in:
NiccoloN
2026-03-25 11:03:03 +01:00
parent 4e19650b80
commit 742df111e3
29 changed files with 258 additions and 116 deletions

View File

@@ -50,12 +50,12 @@ void ONNXToSpatialPass::runOnOperation() {
MLIRContext* ctx = &getContext();
RewritePatternSet mergeActivationPatterns(ctx);
mergeActivationPatterns.add<onnxToArithConstantOp>(ctx);
mergeActivationPatterns.add<convAddToConvWithBiasPatternLeft>(ctx);
mergeActivationPatterns.add<convAddToConvWithBiasPatternRight>(ctx);
mergeActivationPatterns.add<matMulAddToGemmPattern>(ctx);
mergeActivationPatterns.add<matMulToGemmPattern>(ctx);
mergeActivationPatterns.add<removeFlattenSameShapePattern>(ctx);
mergeActivationPatterns.add<onnxToArithConstant>(ctx);
mergeActivationPatterns.add<convAddToConvWithBiasLeft>(ctx);
mergeActivationPatterns.add<convAddToConvWithBiasRight>(ctx);
mergeActivationPatterns.add<matMulAddToGemm>(ctx);
mergeActivationPatterns.add<matMulToGemm>(ctx);
mergeActivationPatterns.add<removeFlattenSameShape>(ctx);
populateMatMulRewritePatterns(mergeActivationPatterns, ctx);
if (failed(applyPatternsGreedily(moduleOp, std::move(mergeActivationPatterns))))
@@ -74,23 +74,24 @@ void ONNXToSpatialPass::runOnOperation() {
[](ONNXMatMulOp op) { return cast<ShapedType>(op.getY().getType()).getRank() != 2; });
target.addIllegalOp<ONNXGemmOp>();
target.addIllegalOp<ONNXConvOp>();
target.addIllegalOp<ONNXLRNOp>();
target.addIllegalOp<ONNXMaxPoolSingleOutOp>();
target.addIllegalOp<ONNXAveragePoolOp>();
target.addIllegalOp<ONNXConcatOp>();
target.addIllegalOp<ONNXReluOp>();
target.addIllegalOp<ONNXSoftmaxOp>();
target.addIllegalOp<ONNXReduceMeanV13Op>();
target.addIllegalOp<ONNXConcatOp>();
target.addIllegalOp<ONNXReshapeOp>();
target.addIllegalOp<ONNXLRNOp>();
target.addIllegalOp<ONNXReduceMeanV13Op>();
RewritePatternSet patterns(ctx);
patterns.add<removeLRNPattern>(ctx);
patterns.add<removeLRN>(ctx);
populateConvOpPatterns(patterns, ctx);
populatePoolTilingPattern(patterns, ctx);
populateOnnxGemmOpPatterns(patterns, ctx);
populateReshapeConversionPattern(patterns, ctx);
populateONNXConcatToTensorConcatPattern(patterns, ctx);
populateGemmPatterns(patterns, ctx);
populateConvPatterns(patterns, ctx);
populatePoolPatterns(patterns, ctx);
populateReluPatterns(patterns, ctx);
populateConcatPatterns(patterns, ctx);
populateReshapePatterns(patterns, ctx);
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
signalPassFailure();