Dynamic gemm/conv

This commit is contained in:
ilgeco
2026-05-28 18:00:14 +02:00
parent cbf7b235f1
commit 1ab489fe0a
17 changed files with 704 additions and 69 deletions
+112
View File
@@ -181,6 +181,18 @@ def conv_depthwise_grouped():
save_model(model, "conv/depthwise_grouped", "conv_depthwise_grouped.onnx")
def conv_dynamic():
"""Conv with input and weight both provided at runtime."""
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1, 4, 4])
W = helper.make_tensor_value_info("W", TensorProto.FLOAT, [1, 1, 3, 3])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1, 2, 2])
node = helper.make_node("Conv", ["X", "W"], ["Y"],
kernel_shape=[3, 3], strides=[1, 1], pads=[0, 0, 0, 0])
graph = helper.make_graph([node], "conv_dynamic", [X, W], [Y])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "conv/dynamic", "conv_dynamic.onnx")
# ---------------------------------------------------------------------------
# GEMM tests
# ---------------------------------------------------------------------------
@@ -291,6 +303,75 @@ def gemm_transB_with_bias():
save_model(model, "gemm/transB_with_bias", "gemm_transB_with_bias.onnx")
def gemm_dynamic():
"""GEMM with both matrix operands provided at runtime."""
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, 8])
B = helper.make_tensor_value_info("B", TensorProto.FLOAT, [8, 4])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 4])
node = helper.make_node("Gemm", ["A", "B"], ["Y"])
graph = helper.make_graph([node], "gemm_dynamic", [A, B], [Y])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "gemm/dynamic", "gemm_dynamic.onnx")
def gemm_dynamic_transB():
"""GEMM with runtime matrix operands and transposed runtime B."""
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, 8])
B = helper.make_tensor_value_info("B", TensorProto.FLOAT, [4, 8])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 4])
node = helper.make_node("Gemm", ["A", "B"], ["Y"], transB=1)
graph = helper.make_graph([node], "gemm_dynamic_transB", [A, B], [Y])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "gemm/dynamic_transB", "gemm_dynamic_transB.onnx")
def gemm_dynamic_bias():
"""GEMM with runtime matrix operands and runtime bias."""
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, 8])
B = helper.make_tensor_value_info("B", TensorProto.FLOAT, [8, 4])
C = helper.make_tensor_value_info("C", TensorProto.FLOAT, [4])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 4])
node = helper.make_node("Gemm", ["A", "B", "C"], ["Y"])
graph = helper.make_graph([node], "gemm_dynamic_bias", [A, B, C], [Y])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "gemm/dynamic_bias", "gemm_dynamic_bias.onnx")
def gemm_dynamic_alpha():
"""GEMM with runtime matrix operands and runtime alpha scaling."""
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, 8])
B = helper.make_tensor_value_info("B", TensorProto.FLOAT, [8, 4])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 4])
node = helper.make_node("Gemm", ["A", "B"], ["Y"], alpha=0.5)
graph = helper.make_graph([node], "gemm_dynamic_alpha", [A, B], [Y])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "gemm/dynamic_alpha", "gemm_dynamic_alpha.onnx")
def gemm_dynamic_beta():
"""GEMM with runtime matrix operands, runtime bias, and runtime beta scaling."""
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, 8])
B = helper.make_tensor_value_info("B", TensorProto.FLOAT, [8, 4])
C = helper.make_tensor_value_info("C", TensorProto.FLOAT, [4])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 4])
node = helper.make_node("Gemm", ["A", "B", "C"], ["Y"], beta=2.0)
graph = helper.make_graph([node], "gemm_dynamic_beta", [A, B, C], [Y])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "gemm/dynamic_beta", "gemm_dynamic_beta.onnx")
def gemm_dynamic_bias_alpha_beta():
"""GEMM with runtime matrix operands, runtime bias, alpha, and beta."""
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, 8])
B = helper.make_tensor_value_info("B", TensorProto.FLOAT, [8, 4])
C = helper.make_tensor_value_info("C", TensorProto.FLOAT, [4])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 4])
node = helper.make_node("Gemm", ["A", "B", "C"], ["Y"], alpha=0.5, beta=2.0)
graph = helper.make_graph([node], "gemm_dynamic_bias_alpha_beta", [A, B, C], [Y])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "gemm/dynamic_bias_alpha_beta", "gemm_dynamic_bias_alpha_beta.onnx")
# ---------------------------------------------------------------------------
# MatMul tests
# ---------------------------------------------------------------------------
@@ -306,6 +387,28 @@ def matmul_basic():
save_model(model, "matmul/basic", "matmul_basic.onnx")
def matmul_left_constant():
"""Direct 2D MatMul with constant LHS."""
A = numpy_helper.from_array(np.random.default_rng(69).uniform(-1, 1, (2, 3)).astype(np.float32), name="A")
B = helper.make_tensor_value_info("B", TensorProto.FLOAT, [3, 4])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 4])
node = helper.make_node("MatMul", ["A", "B"], ["Y"])
graph = helper.make_graph([node], "matmul_left_constant", [B], [Y], initializer=[A])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "matmul/left_constant", "matmul_left_constant.onnx")
def matmul_dynamic():
"""Direct 2D MatMul with both operands provided at runtime."""
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, 3])
B = helper.make_tensor_value_info("B", TensorProto.FLOAT, [3, 4])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 4])
node = helper.make_node("MatMul", ["A", "B"], ["Y"])
graph = helper.make_graph([node], "matmul_dynamic", [A, B], [Y])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "matmul/dynamic", "matmul_dynamic.onnx")
def matmul_batched_3d():
"""Batched 3D MatMul with matching batch dimensions."""
rng = np.random.default_rng(50)
@@ -843,6 +946,12 @@ if __name__ == "__main__":
gemm_small()
gemm_large()
gemm_transB_with_bias()
gemm_dynamic()
gemm_dynamic_transB()
gemm_dynamic_bias()
gemm_dynamic_alpha()
gemm_dynamic_beta()
gemm_dynamic_bias_alpha_beta()
print("\nGenerating Conv tests:")
conv_3x3_kernel()
@@ -856,9 +965,12 @@ if __name__ == "__main__":
conv_large_spatial()
conv_grouped_two_groups()
conv_depthwise_grouped()
conv_dynamic()
print("\nGenerating MatMul tests:")
matmul_basic()
matmul_left_constant()
matmul_dynamic()
matmul_batched_3d()
print("\nGenerating Pooling tests:")