fix much stuff

This commit is contained in:
NiccoloN
2026-05-22 18:53:38 +02:00
parent 8337a11ce9
commit 2c1da813b5
18 changed files with 502 additions and 191 deletions
+1 -1
View File
@@ -30,7 +30,7 @@ python3 validation/operations/gen_tests.py
| Test | Directory | A (input) | W (weight) | Output | transB | alpha | beta | Bias | Notes |
|---------------|-------------------------|-----------|------------|----------|--------|-------|------|-------|------------------------------|
| Default | `gemm/` | [10,132] | [132,132] | [10,132] | no | 1 | 1 | no | Hand-crafted, square weights |
| Simple | `gemm/simple` | [10,132] | [132,132] | [10,132] | no | 1 | 1 | no | Square weights |
| Non-square | `gemm/non_square` | [4,128] | [128,64] | [4,64] | no | 1 | 1 | no | K != N |
| With bias | `gemm/with_bias` | [4,128] | [128,128] | [4,128] | no | 1 | 1 | [128] | Bias vector |
| transB | `gemm/transB` | [4,128] | [64,128] | [4,64] | yes | 1 | 1 | no | Transposed weight |
+13
View File
@@ -185,6 +185,18 @@ def conv_depthwise_grouped():
# GEMM tests
# ---------------------------------------------------------------------------
def gemm_simple():
"""Simple GEMM with square weights: [10, 132] @ [132, 132]."""
B, K, N = 10, 132, 132
W = numpy_helper.from_array(np.random.default_rng(41).uniform(-1, 1, (K, N)).astype(np.float32), name="W")
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [B, K])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [B, N])
node = helper.make_node("Gemm", ["A", "W"], ["Y"])
graph = helper.make_graph([node], "gemm_simple", [A], [Y], initializer=[W])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "gemm/simple", "gemm_simple.onnx")
def gemm_non_square():
"""GEMM with non-square weight matrix: [B, K] @ [K, N], K != N."""
B, K, N = 4, 128, 64
@@ -823,6 +835,7 @@ def div_after_gemm():
if __name__ == "__main__":
print("Generating GEMM tests:")
gemm_simple()
gemm_non_square()
gemm_with_bias()
gemm_transB()