fix much stuff

This commit is contained in:
NiccoloN
2026-05-22 18:53:38 +02:00
parent 8337a11ce9
commit 2c1da813b5
18 changed files with 502 additions and 191 deletions
+13
View File
@@ -185,6 +185,18 @@ def conv_depthwise_grouped():
# GEMM tests
# ---------------------------------------------------------------------------
def gemm_simple():
"""Simple GEMM with square weights: [10, 132] @ [132, 132]."""
B, K, N = 10, 132, 132
W = numpy_helper.from_array(np.random.default_rng(41).uniform(-1, 1, (K, N)).astype(np.float32), name="W")
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [B, K])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [B, N])
node = helper.make_node("Gemm", ["A", "W"], ["Y"])
graph = helper.make_graph([node], "gemm_simple", [A], [Y], initializer=[W])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "gemm/simple", "gemm_simple.onnx")
def gemm_non_square():
"""GEMM with non-square weight matrix: [B, K] @ [K, N], K != N."""
B, K, N = 4, 128, 64
@@ -823,6 +835,7 @@ def div_after_gemm():
if __name__ == "__main__":
print("Generating GEMM tests:")
gemm_simple()
gemm_non_square()
gemm_with_bias()
gemm_transB()