This commit is contained in:
@@ -124,6 +124,7 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
target.addIllegalOp<ONNXMatMulOp>();
|
||||
target.addIllegalOp<ONNXTransposeOp>();
|
||||
target.addIllegalOp<ONNXAddOp>();
|
||||
target.addIllegalOp<ONNXSubOp>();
|
||||
target.addIllegalOp<ONNXDivOp>();
|
||||
target.addIllegalOp<ONNXMulOp>();
|
||||
target.addIllegalOp<ONNXGemmOp>();
|
||||
|
||||
@@ -189,6 +189,7 @@ struct DivToSpatialCompute : OpConversionPattern<ONNXDivOp> {
|
||||
|
||||
void populateElementwisePatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.add<BinaryElementwiseToSpatialCompute<ONNXAddOp, spatial::SpatVAddOp>>(ctx);
|
||||
patterns.add<BinaryElementwiseToSpatialCompute<ONNXSubOp, spatial::SpatVSubOp>>(ctx);
|
||||
patterns.add<BinaryElementwiseToSpatialCompute<ONNXMulOp, spatial::SpatVMulOp>>(ctx);
|
||||
patterns.add<DivToSpatialCompute>(ctx);
|
||||
}
|
||||
|
||||
@@ -27,6 +27,12 @@ def spatToPimVVAdd : Pat<
|
||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||
>;
|
||||
|
||||
def spatToPimVVSub : Pat<
|
||||
(SpatVSubOp:$srcOpRes $a, $b),
|
||||
(PimVVSubOp $a, $b,
|
||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||
>;
|
||||
|
||||
def spatToPimVVMul : Pat<
|
||||
(SpatVMulOp:$srcOpRes $a, $b),
|
||||
(PimVVMulOp $a, $b,
|
||||
|
||||
@@ -257,6 +257,25 @@ def SpatVAddOp : SpatOp<"vadd", []> {
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatVSubOp : SpatOp<"vsub", []> {
|
||||
let summary = "Element-wise subtraction between two tensors; rhs must match lhs or be 1x1";
|
||||
|
||||
let arguments = (ins
|
||||
SpatTensor:$lhs,
|
||||
SpatTensor:$rhs
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SpatTensor:$output
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
|
||||
let assemblyFormat = [{
|
||||
$lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatVMulOp : SpatOp<"vmul", []> {
|
||||
let summary = "Element-wise multiplication between two tensors; rhs must match lhs or be 1x1";
|
||||
|
||||
|
||||
@@ -254,6 +254,12 @@ LogicalResult SpatVAddOp::verify() {
|
||||
return OpTrait::impl::verifySameOperandsAndResultType(*this);
|
||||
}
|
||||
|
||||
LogicalResult SpatVSubOp::verify() {
|
||||
if (failed(OpTrait::impl::verifyAtLeastNOperands(*this, 2)))
|
||||
return failure();
|
||||
return OpTrait::impl::verifySameOperandsAndResultType(*this);
|
||||
}
|
||||
|
||||
LogicalResult SpatVMaxOp::verify() {
|
||||
if (failed(OpTrait::impl::verifyAtLeastNOperands(*this, 2)))
|
||||
return failure();
|
||||
|
||||
@@ -1549,6 +1549,82 @@ def add_leading_dimension_broadcast():
|
||||
save_model(model, "add/leading_dimension_broadcast", "add_leading_dimension_broadcast.onnx")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sub tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def sub_basic():
|
||||
"""Elementwise Sub on two runtime inputs with identical shapes."""
|
||||
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [4, 8])
|
||||
B = helper.make_tensor_value_info("B", TensorProto.FLOAT, [4, 8])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [4, 8])
|
||||
node = helper.make_node("Sub", ["A", "B"], ["Y"])
|
||||
graph = helper.make_graph([node], "sub_basic", [A, B], [Y])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "sub/basic", "sub_basic.onnx")
|
||||
|
||||
|
||||
def sub_broadcast_row():
|
||||
"""Elementwise Sub with a broadcast row-vector RHS constant."""
|
||||
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [4, 8])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [4, 8])
|
||||
B = numpy_helper.from_array(np.random.default_rng(103).uniform(-1, 1, (8,)).astype(np.float32), name="B")
|
||||
node = helper.make_node("Sub", ["A", "B"], ["Y"])
|
||||
graph = helper.make_graph([node], "sub_broadcast_row", [A], [Y], initializer=[B])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "sub/broadcast_row", "sub_broadcast_row.onnx")
|
||||
|
||||
|
||||
def sub_constant_lhs_broadcast():
|
||||
"""Elementwise Sub with a broadcast constant LHS to preserve operand order."""
|
||||
B = helper.make_tensor_value_info("B", TensorProto.FLOAT, [4, 8])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [4, 8])
|
||||
A = numpy_helper.from_array(np.random.default_rng(104).uniform(-1, 1, (8,)).astype(np.float32), name="A")
|
||||
node = helper.make_node("Sub", ["A", "B"], ["Y"])
|
||||
graph = helper.make_graph([node], "sub_constant_lhs_broadcast", [B], [Y], initializer=[A])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "sub/constant_lhs_broadcast", "sub_constant_lhs_broadcast.onnx")
|
||||
|
||||
|
||||
def sub_after_gemm():
|
||||
"""Gemm followed by Sub with a broadcast constant vector."""
|
||||
B, K, N = 4, 64, 32
|
||||
rng = np.random.default_rng(105)
|
||||
W = numpy_helper.from_array(rng.uniform(-1, 1, (K, N)).astype(np.float32), name="W")
|
||||
C = numpy_helper.from_array(rng.uniform(-1, 1, (N,)).astype(np.float32), name="C")
|
||||
S = numpy_helper.from_array(rng.uniform(-1, 1, (N,)).astype(np.float32), name="S")
|
||||
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [B, K])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [B, N])
|
||||
gemm = helper.make_node("Gemm", ["A", "W", "C"], ["G"])
|
||||
sub = helper.make_node("Sub", ["G", "S"], ["Y"])
|
||||
graph = helper.make_graph([gemm, sub], "sub_after_gemm", [A], [Y], initializer=[W, C, S])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "sub/after_gemm", "sub_after_gemm.onnx")
|
||||
|
||||
|
||||
def sub_channel_broadcast_1024():
|
||||
"""Elementwise Sub with 1024-channel constant broadcasting."""
|
||||
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [1, 1024, 1, 1])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1024, 1, 1])
|
||||
B = numpy_helper.from_array(
|
||||
np.random.default_rng(106).uniform(-1, 1, (1, 1024, 1, 1)).astype(np.float32), name="B")
|
||||
node = helper.make_node("Sub", ["A", "B"], ["Y"])
|
||||
graph = helper.make_graph([node], "sub_channel_broadcast_1024", [A], [Y], initializer=[B])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "sub/channel_broadcast_1024", "sub_channel_broadcast_1024.onnx")
|
||||
|
||||
|
||||
def sub_leading_dimension_broadcast():
|
||||
"""Elementwise Sub with trailing-dimension constant broadcasting."""
|
||||
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, 3, 4])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 3, 4])
|
||||
B = numpy_helper.from_array(np.random.default_rng(107).uniform(-1, 1, (4,)).astype(np.float32), name="B")
|
||||
node = helper.make_node("Sub", ["A", "B"], ["Y"])
|
||||
graph = helper.make_graph([node], "sub_leading_dimension_broadcast", [A], [Y], initializer=[B])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "sub/leading_dimension_broadcast", "sub_leading_dimension_broadcast.onnx")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mul tests
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -1844,6 +1920,14 @@ if __name__ == "__main__":
|
||||
add_channel_broadcast_1024()
|
||||
add_leading_dimension_broadcast()
|
||||
|
||||
print("\nGenerating Sub tests:")
|
||||
sub_basic()
|
||||
sub_broadcast_row()
|
||||
sub_constant_lhs_broadcast()
|
||||
sub_after_gemm()
|
||||
sub_channel_broadcast_1024()
|
||||
sub_leading_dimension_broadcast()
|
||||
|
||||
print("\nGenerating Mul tests:")
|
||||
mul_basic()
|
||||
mul_scalar_constant()
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
Binary file not shown.
Reference in New Issue
Block a user