#ifndef ONNX_TO_SPATIAL #define ONNX_TO_SPATIAL #ifndef OP_BASE include "mlir/Dialect/Tensor/IR/TensorOps.td" include "mlir/Dialect/Arith/IR/ArithOps.td" include "src/Dialect/ONNX/ONNX.td" include "src/Accelerators/PIM/Dialect/Spatial/Spatial.td" #endif // OP_BASE def onnxToArithConstantOp : Pat< (ONNXConstantOp $sparse_value, $value, $value_float, $value_floats, $value_int, $value_ints, $value_string, $value_strings), (Arith_ConstantOp $value) >; // ONNXMatMulOp to ONNXGemmOp patterns def IsRank2Result: Constraint< CPred<"cast($0.getType()).getRank() == 2">, "Result is rank 2">; def matMulAddToGemmPattern : Pat< (ONNXAddOp (ONNXMatMulOp:$matmulres $A, $B), $C), (ONNXGemmOp $A, $B, $C, /* alpha = */ (NativeCodeCall<"$_builder.getF32FloatAttr(1)">), /* beta = */ (NativeCodeCall<"$_builder.getF32FloatAttr(1)">), /* transA = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">), /* transB = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">) ), [(IsRank2Result $matmulres)] >; def matMulToGemmPattern : Pat< (ONNXMatMulOp:$matmulres $A, $B), ( ONNXGemmOp $A, $B, /* C = */ (NativeCodeCall<"tensor::EmptyOp::create($_builder, $_loc, cast(matmulres.getY().getType()).getShape(), cast(matmulres.getY().getType()).getElementType());">), /* alpha = */ (NativeCodeCall<"$_builder.getF32FloatAttr(1)">), /* beta = */ (NativeCodeCall<"$_builder.getF32FloatAttr(0)">), /* transA = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">), /* transB = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">) ), [(IsRank2Result $matmulres)] >; // ONNXConvOp + ONNXAddOp to ONNXConvOp pattern // This pattern is used to fuse an ONNXConvOp and an ONNXAddOp into a single // ONNXConvOp with a bias. def convAddToConvWithBiasPatternLeft : Pat< (ONNXAddOp $add_operand, (ONNXConvOp:$convres $x, $w, $bias, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides)), (ONNXConvOp $x, $w, $add_operand, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides) >; def convAddToConvWithBiasPatternRight : Pat< (ONNXAddOp (ONNXConvOp:$convres $x, $w, $bias, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides), $add_operand), (ONNXConvOp $x, $w, $add_operand, $auto_pad, $dilations, $group, $kernel_shape, $pad, $strides) >; // Operation to ignore (i.e. remove) def replaceWithOperationOfValue : NativeCodeCall<"$0">; def removeLRNPattern : Pat< (ONNXLRNOp $A, $_, $_, $_, $_), (replaceWithOperationOfValue $A) >; def HaveSameStaticShape: Constraint< CPred<"onnx_mlir::haveSameStaticShape($0, $1)">, "Two tensors have the same static shape">; def removeFlattenSameShapePattern : Pat< (ONNXFlattenOp:$flattenOp $A, $axis), (replaceWithOperationOfValue $A), [(HaveSameStaticShape $flattenOp, $A)] >; // Add closing parenthesis here #endif // ONNX_TO_SPATIAL