batched matmul pattern
Validate Operations / validate-operations (push) Has been cancelled

add conv helpers
new validation tests for matmul
This commit is contained in:
NiccoloN
2026-05-29 19:07:24 +02:00
parent 8bb0babf1b
commit a41f694cf0
18 changed files with 877 additions and 192 deletions
+10 -6
View File
@@ -48,12 +48,16 @@ python3 validation/operations/gen_tests.py
## MatMul
| Test | Directory | A input | B tensor | Output | Notes |
|------------|---------------------|---------|----------|---------|------------------------------------|
| Basic | `matmul/basic` | [2,3] | [3,4] | [2,4] | Direct 2D MatMul rewrite path |
| Left constant | `matmul/left_constant` | [2,3] | [3,4] | [2,4] | Constant LHS transpose rewrite path |
| Dynamic | `matmul/dynamic` | [2,3] | [3,4] | [2,4] | Runtime matrix operands |
| Batched 3D | `matmul/batched_3d` | [2,2,3] | [2,3,4] | [2,2,4] | Matching-batch MatMul rewrite path |
| Test | Directory | A input | B tensor | Output | Notes |
|---------------------|----------------------------------|----------|----------|---------|-------------------------------------------------|
| Basic | `matmul/basic` | [2,3] | [3,4] | [2,4] | Direct 2D MatMul rewrite path |
| Left constant | `matmul/left_constant` | [2,3] | [3,4] | [2,4] | Constant LHS transpose rewrite path |
| Dynamic | `matmul/dynamic` | [2,3] | [3,4] | [2,4] | Runtime matrix operands |
| Batched 3D | `matmul/batched_3d` | [2,2,3] | [2,3,4] | [2,2,4] | Matching-batch direct batched lowering |
| Batched 3D dynamic | `matmul/batched_3d_dynamic` | [2,2,3] | [2,3,4] | [2,2,4] | Batched runtime operands |
| Batched left const | `matmul/batched_left_constant` | [2,2,3] | [2,3,4] | [2,2,4] | Batched constant-LHS transpose path |
| Batched RHS broadcast | `matmul/batched_rhs_broadcast` | [2,2,3] | [3,4] | [2,2,4] | Rank-2 RHS broadcast across batch |
| Batched LHS broadcast | `matmul/batched_lhs_broadcast` | [2,3] | [2,3,4] | [2,2,4] | Rank-2 LHS broadcast across batched RHS |
## Gemv
Binary file not shown.
Binary file not shown.
Binary file not shown.
+51
View File
@@ -421,6 +421,53 @@ def matmul_batched_3d():
save_model(model, "matmul/batched_3d", "matmul_batched_3d.onnx")
def matmul_batched_3d_dynamic():
"""Batched 3D MatMul with both operands provided at runtime."""
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, 2, 3])
B = helper.make_tensor_value_info("B", TensorProto.FLOAT, [2, 3, 4])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 2, 4])
node = helper.make_node("MatMul", ["A", "B"], ["Y"])
graph = helper.make_graph([node], "matmul_batched_3d_dynamic", [A, B], [Y])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "matmul/batched_3d_dynamic", "matmul_batched_3d_dynamic.onnx")
def matmul_batched_left_constant():
"""Batched 3D MatMul with constant LHS and runtime RHS."""
rng = np.random.default_rng(70)
A = numpy_helper.from_array(rng.uniform(-1, 1, (2, 2, 3)).astype(np.float32), name="A")
B = helper.make_tensor_value_info("B", TensorProto.FLOAT, [2, 3, 4])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 2, 4])
node = helper.make_node("MatMul", ["A", "B"], ["Y"])
graph = helper.make_graph([node], "matmul_batched_left_constant", [B], [Y], initializer=[A])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "matmul/batched_left_constant", "matmul_batched_left_constant.onnx")
def matmul_batched_rhs_broadcast():
"""Batched 3D MatMul with 2D constant RHS broadcast across batch."""
rng = np.random.default_rng(71)
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, 2, 3])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 2, 4])
B = numpy_helper.from_array(rng.uniform(-1, 1, (3, 4)).astype(np.float32), name="B")
node = helper.make_node("MatMul", ["A", "B"], ["Y"])
graph = helper.make_graph([node], "matmul_batched_rhs_broadcast", [A], [Y], initializer=[B])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "matmul/batched_rhs_broadcast", "matmul_batched_rhs_broadcast.onnx")
def matmul_batched_lhs_broadcast():
"""Batched 3D MatMul with 2D runtime LHS broadcast across batched RHS."""
rng = np.random.default_rng(72)
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, 3])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 2, 4])
B = numpy_helper.from_array(rng.uniform(-1, 1, (2, 3, 4)).astype(np.float32), name="B")
node = helper.make_node("MatMul", ["A", "B"], ["Y"])
graph = helper.make_graph([node], "matmul_batched_lhs_broadcast", [A], [Y], initializer=[B])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "matmul/batched_lhs_broadcast", "matmul_batched_lhs_broadcast.onnx")
# ---------------------------------------------------------------------------
# Pooling tests
# ---------------------------------------------------------------------------
@@ -972,6 +1019,10 @@ if __name__ == "__main__":
matmul_left_constant()
matmul_dynamic()
matmul_batched_3d()
matmul_batched_3d_dynamic()
matmul_batched_left_constant()
matmul_batched_rhs_broadcast()
matmul_batched_lhs_broadcast()
print("\nGenerating Pooling tests:")
maxpool_basic()