extend operation support for conv and gemm
add more tests in validation
This commit is contained in:
@@ -15,6 +15,10 @@ def onnxToArithConstantOp : Pat<
|
||||
|
||||
// ONNXMatMulOp to ONNXGemmOp patterns
|
||||
|
||||
def IsRank2Result: Constraint<
|
||||
CPred<"cast<ShapedType>($0.getType()).getRank() == 2">,
|
||||
"Result is rank 2">;
|
||||
|
||||
def matMulAddToGemmPattern : Pat<
|
||||
(ONNXAddOp (ONNXMatMulOp:$matmulres $A, $B), $C),
|
||||
(ONNXGemmOp $A, $B, $C,
|
||||
@@ -22,19 +26,21 @@ def matMulAddToGemmPattern : Pat<
|
||||
/* beta = */ (NativeCodeCall<"$_builder.getF32FloatAttr(1)">),
|
||||
/* transA = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">),
|
||||
/* transB = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">)
|
||||
)
|
||||
),
|
||||
[(IsRank2Result $matmulres)]
|
||||
>;
|
||||
|
||||
def matMulToGemmPattern : Pat<
|
||||
(ONNXMatMulOp:$matmulres $A, $B),
|
||||
(
|
||||
ONNXGemmOp $A, $B,
|
||||
ONNXGemmOp $A, $B,
|
||||
/* C = */ (NativeCodeCall<"tensor::EmptyOp::create($_builder, $_loc, cast<ShapedType>(matmulres.getY().getType()).getShape(), cast<ShapedType>(matmulres.getY().getType()).getElementType());">),
|
||||
/* alpha = */ (NativeCodeCall<"$_builder.getF32FloatAttr(1)">),
|
||||
/* beta = */ (NativeCodeCall<"$_builder.getF32FloatAttr(0)">),
|
||||
/* transA = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">),
|
||||
/* transB = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">)
|
||||
)
|
||||
),
|
||||
[(IsRank2Result $matmulres)]
|
||||
>;
|
||||
|
||||
// ONNXConvOp + ONNXAddOp to ONNXConvOp pattern
|
||||
|
||||
Reference in New Issue
Block a user