diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp index cbcc67a..edf311e 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp @@ -124,6 +124,7 @@ void ONNXToSpatialPass::runOnOperation() { target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Elementwise.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Elementwise.cpp index 2912e49..506456f 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Elementwise.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Elementwise.cpp @@ -189,6 +189,7 @@ struct DivToSpatialCompute : OpConversionPattern { void populateElementwisePatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.add>(ctx); + patterns.add>(ctx); patterns.add>(ctx); patterns.add(ctx); } diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPim.td b/src/PIM/Conversion/SpatialToPim/SpatialToPim.td index 31a4e13..dfc6c47 100644 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPim.td +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPim.td @@ -27,6 +27,12 @@ def spatToPimVVAdd : Pat< (NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes)) >; +def spatToPimVVSub : Pat< + (SpatVSubOp:$srcOpRes $a, $b), + (PimVVSubOp $a, $b, + (NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes)) +>; + def spatToPimVVMul : Pat< (SpatVMulOp:$srcOpRes $a, $b), (PimVVMulOp $a, $b, diff --git a/src/PIM/Dialect/Spatial/Spatial.td b/src/PIM/Dialect/Spatial/Spatial.td index cdd3464..56347d8 100644 --- a/src/PIM/Dialect/Spatial/Spatial.td +++ b/src/PIM/Dialect/Spatial/Spatial.td @@ -257,6 +257,25 @@ def SpatVAddOp : SpatOp<"vadd", []> { }]; } +def SpatVSubOp : SpatOp<"vsub", []> { + let summary = "Element-wise subtraction between two tensors; rhs must match lhs or be 1x1"; + + let arguments = (ins + SpatTensor:$lhs, + SpatTensor:$rhs + ); + + let results = (outs + SpatTensor:$output + ); + + let hasVerifier = 1; + + let assemblyFormat = [{ + $lhs `,` $rhs attr-dict `:` `(` type($lhs) `,` type($rhs) `)` `->` type($output) + }]; +} + def SpatVMulOp : SpatOp<"vmul", []> { let summary = "Element-wise multiplication between two tensors; rhs must match lhs or be 1x1"; diff --git a/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp b/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp index 2639423..02bc62d 100644 --- a/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp +++ b/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp @@ -254,6 +254,12 @@ LogicalResult SpatVAddOp::verify() { return OpTrait::impl::verifySameOperandsAndResultType(*this); } +LogicalResult SpatVSubOp::verify() { + if (failed(OpTrait::impl::verifyAtLeastNOperands(*this, 2))) + return failure(); + return OpTrait::impl::verifySameOperandsAndResultType(*this); +} + LogicalResult SpatVMaxOp::verify() { if (failed(OpTrait::impl::verifyAtLeastNOperands(*this, 2))) return failure(); diff --git a/validation/operations/gen_tests.py b/validation/operations/gen_tests.py index ba09a7a..1d5adce 100644 --- a/validation/operations/gen_tests.py +++ b/validation/operations/gen_tests.py @@ -1549,6 +1549,82 @@ def add_leading_dimension_broadcast(): save_model(model, "add/leading_dimension_broadcast", "add_leading_dimension_broadcast.onnx") +# --------------------------------------------------------------------------- +# Sub tests +# --------------------------------------------------------------------------- + +def sub_basic(): + """Elementwise Sub on two runtime inputs with identical shapes.""" + A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [4, 8]) + B = helper.make_tensor_value_info("B", TensorProto.FLOAT, [4, 8]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [4, 8]) + node = helper.make_node("Sub", ["A", "B"], ["Y"]) + graph = helper.make_graph([node], "sub_basic", [A, B], [Y]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "sub/basic", "sub_basic.onnx") + + +def sub_broadcast_row(): + """Elementwise Sub with a broadcast row-vector RHS constant.""" + A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [4, 8]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [4, 8]) + B = numpy_helper.from_array(np.random.default_rng(103).uniform(-1, 1, (8,)).astype(np.float32), name="B") + node = helper.make_node("Sub", ["A", "B"], ["Y"]) + graph = helper.make_graph([node], "sub_broadcast_row", [A], [Y], initializer=[B]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "sub/broadcast_row", "sub_broadcast_row.onnx") + + +def sub_constant_lhs_broadcast(): + """Elementwise Sub with a broadcast constant LHS to preserve operand order.""" + B = helper.make_tensor_value_info("B", TensorProto.FLOAT, [4, 8]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [4, 8]) + A = numpy_helper.from_array(np.random.default_rng(104).uniform(-1, 1, (8,)).astype(np.float32), name="A") + node = helper.make_node("Sub", ["A", "B"], ["Y"]) + graph = helper.make_graph([node], "sub_constant_lhs_broadcast", [B], [Y], initializer=[A]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "sub/constant_lhs_broadcast", "sub_constant_lhs_broadcast.onnx") + + +def sub_after_gemm(): + """Gemm followed by Sub with a broadcast constant vector.""" + B, K, N = 4, 64, 32 + rng = np.random.default_rng(105) + W = numpy_helper.from_array(rng.uniform(-1, 1, (K, N)).astype(np.float32), name="W") + C = numpy_helper.from_array(rng.uniform(-1, 1, (N,)).astype(np.float32), name="C") + S = numpy_helper.from_array(rng.uniform(-1, 1, (N,)).astype(np.float32), name="S") + A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [B, K]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [B, N]) + gemm = helper.make_node("Gemm", ["A", "W", "C"], ["G"]) + sub = helper.make_node("Sub", ["G", "S"], ["Y"]) + graph = helper.make_graph([gemm, sub], "sub_after_gemm", [A], [Y], initializer=[W, C, S]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "sub/after_gemm", "sub_after_gemm.onnx") + + +def sub_channel_broadcast_1024(): + """Elementwise Sub with 1024-channel constant broadcasting.""" + A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [1, 1024, 1, 1]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1024, 1, 1]) + B = numpy_helper.from_array( + np.random.default_rng(106).uniform(-1, 1, (1, 1024, 1, 1)).astype(np.float32), name="B") + node = helper.make_node("Sub", ["A", "B"], ["Y"]) + graph = helper.make_graph([node], "sub_channel_broadcast_1024", [A], [Y], initializer=[B]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "sub/channel_broadcast_1024", "sub_channel_broadcast_1024.onnx") + + +def sub_leading_dimension_broadcast(): + """Elementwise Sub with trailing-dimension constant broadcasting.""" + A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, 3, 4]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 3, 4]) + B = numpy_helper.from_array(np.random.default_rng(107).uniform(-1, 1, (4,)).astype(np.float32), name="B") + node = helper.make_node("Sub", ["A", "B"], ["Y"]) + graph = helper.make_graph([node], "sub_leading_dimension_broadcast", [A], [Y], initializer=[B]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "sub/leading_dimension_broadcast", "sub_leading_dimension_broadcast.onnx") + + # --------------------------------------------------------------------------- # Mul tests # --------------------------------------------------------------------------- @@ -1844,6 +1920,14 @@ if __name__ == "__main__": add_channel_broadcast_1024() add_leading_dimension_broadcast() + print("\nGenerating Sub tests:") + sub_basic() + sub_broadcast_row() + sub_constant_lhs_broadcast() + sub_after_gemm() + sub_channel_broadcast_1024() + sub_leading_dimension_broadcast() + print("\nGenerating Mul tests:") mul_basic() mul_scalar_constant() diff --git a/validation/operations/sub/after_gemm/sub_after_gemm.onnx b/validation/operations/sub/after_gemm/sub_after_gemm.onnx new file mode 100644 index 0000000..0f79fa9 Binary files /dev/null and b/validation/operations/sub/after_gemm/sub_after_gemm.onnx differ diff --git a/validation/operations/sub/basic/sub_basic.onnx b/validation/operations/sub/basic/sub_basic.onnx new file mode 100644 index 0000000..eda1d67 Binary files /dev/null and b/validation/operations/sub/basic/sub_basic.onnx differ diff --git a/validation/operations/sub/broadcast_row/sub_broadcast_row.onnx b/validation/operations/sub/broadcast_row/sub_broadcast_row.onnx new file mode 100644 index 0000000..39e80a6 Binary files /dev/null and b/validation/operations/sub/broadcast_row/sub_broadcast_row.onnx differ diff --git a/validation/operations/sub/channel_broadcast_1024/sub_channel_broadcast_1024.onnx b/validation/operations/sub/channel_broadcast_1024/sub_channel_broadcast_1024.onnx new file mode 100644 index 0000000..f37e916 Binary files /dev/null and b/validation/operations/sub/channel_broadcast_1024/sub_channel_broadcast_1024.onnx differ diff --git a/validation/operations/sub/constant_lhs_broadcast/sub_constant_lhs_broadcast.onnx b/validation/operations/sub/constant_lhs_broadcast/sub_constant_lhs_broadcast.onnx new file mode 100644 index 0000000..ec9995f Binary files /dev/null and b/validation/operations/sub/constant_lhs_broadcast/sub_constant_lhs_broadcast.onnx differ diff --git a/validation/operations/sub/leading_dimension_broadcast/sub_leading_dimension_broadcast.onnx b/validation/operations/sub/leading_dimension_broadcast/sub_leading_dimension_broadcast.onnx new file mode 100644 index 0000000..b757a01 Binary files /dev/null and b/validation/operations/sub/leading_dimension_broadcast/sub_leading_dimension_broadcast.onnx differ