add relu lowering
Validate Operations / validate-operations (push) Failing after 2h50m56s

add relu validation
add spatial compute helper
minor refactors
This commit is contained in:
NiccoloN
2026-03-25 11:03:03 +01:00
parent 4e19650b80
commit 742df111e3
29 changed files with 258 additions and 116 deletions
@@ -8,7 +8,7 @@ include "src/Dialect/ONNX/ONNX.td"
include "src/Accelerators/PIM/Dialect/Spatial/Spatial.td"
#endif // OP_BASE
def onnxToArithConstantOp : Pat<
def onnxToArithConstant : Pat<
(ONNXConstantOp $sparse_value, $value, $value_float, $value_floats, $value_int, $value_ints, $value_string, $value_strings),
(Arith_ConstantOp $value)
>;
@@ -19,7 +19,7 @@ def IsRank2Result: Constraint<
CPred<"cast<ShapedType>($0.getType()).getRank() == 2">,
"Result is rank 2">;
def matMulAddToGemmPattern : Pat<
def matMulAddToGemm : Pat<
(ONNXAddOp (ONNXMatMulOp:$matmulres $A, $B), $C),
(ONNXGemmOp $A, $B, $C,
/* alpha = */ (NativeCodeCall<"$_builder.getF32FloatAttr(1)">),
@@ -30,7 +30,7 @@ def matMulAddToGemmPattern : Pat<
[(IsRank2Result $matmulres)]
>;
def matMulToGemmPattern : Pat<
def matMulToGemm : Pat<
(ONNXMatMulOp:$matmulres $A, $B),
(
ONNXGemmOp $A, $B,
@@ -45,14 +45,13 @@ def matMulToGemmPattern : Pat<
// ONNXConvOp + ONNXAddOp to ONNXConvOp pattern
// This pattern is used to fuse an ONNXConvOp and an ONNXAddOp into a single
// ONNXConvOp with a bias.
def convAddToConvWithBiasPatternLeft : Pat<
// This pattern is used to fuse an ONNXConvOp and an ONNXAddOp into a single ONNXConvOp with a bias.
def convAddToConvWithBiasLeft : Pat<
(ONNXAddOp $add_operand, (ONNXConvOp:$convres $x, $w, $bias, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides)),
(ONNXConvOp $x, $w, $add_operand, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides)
>;
def convAddToConvWithBiasPatternRight : Pat<
def convAddToConvWithBiasRight : Pat<
(ONNXAddOp (ONNXConvOp:$convres $x, $w, $bias, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides), $add_operand),
(ONNXConvOp $x, $w, $add_operand, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides)
>;
@@ -61,7 +60,7 @@ def convAddToConvWithBiasPatternRight : Pat<
def replaceWithOperationOfValue : NativeCodeCall<"$0">;
def removeLRNPattern : Pat<
def removeLRN : Pat<
(ONNXLRNOp $A, $_, $_, $_, $_),
(replaceWithOperationOfValue $A)
>;
@@ -70,10 +69,10 @@ def HaveSameStaticShape: Constraint<
CPred<"onnx_mlir::haveSameStaticShape($0, $1)">,
"Two tensors have the same static shape">;
def removeFlattenSameShapePattern : Pat<
def removeFlattenSameShape : Pat<
(ONNXFlattenOp:$flattenOp $A, $axis),
(replaceWithOperationOfValue $A),
[(HaveSameStaticShape $flattenOp, $A)]
>; // Add closing parenthesis here
>;
#endif // ONNX_TO_SPATIAL