fix much stuff
This commit is contained in:
@@ -30,7 +30,7 @@ python3 validation/operations/gen_tests.py
|
||||
|
||||
| Test | Directory | A (input) | W (weight) | Output | transB | alpha | beta | Bias | Notes |
|
||||
|---------------|-------------------------|-----------|------------|----------|--------|-------|------|-------|------------------------------|
|
||||
| Default | `gemm/` | [10,132] | [132,132] | [10,132] | no | 1 | 1 | no | Hand-crafted, square weights |
|
||||
| Simple | `gemm/simple` | [10,132] | [132,132] | [10,132] | no | 1 | 1 | no | Square weights |
|
||||
| Non-square | `gemm/non_square` | [4,128] | [128,64] | [4,64] | no | 1 | 1 | no | K != N |
|
||||
| With bias | `gemm/with_bias` | [4,128] | [128,128] | [4,128] | no | 1 | 1 | [128] | Bias vector |
|
||||
| transB | `gemm/transB` | [4,128] | [64,128] | [4,64] | yes | 1 | 1 | no | Transposed weight |
|
||||
|
||||
@@ -185,6 +185,18 @@ def conv_depthwise_grouped():
|
||||
# GEMM tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def gemm_simple():
|
||||
"""Simple GEMM with square weights: [10, 132] @ [132, 132]."""
|
||||
B, K, N = 10, 132, 132
|
||||
W = numpy_helper.from_array(np.random.default_rng(41).uniform(-1, 1, (K, N)).astype(np.float32), name="W")
|
||||
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [B, K])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [B, N])
|
||||
node = helper.make_node("Gemm", ["A", "W"], ["Y"])
|
||||
graph = helper.make_graph([node], "gemm_simple", [A], [Y], initializer=[W])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "gemm/simple", "gemm_simple.onnx")
|
||||
|
||||
|
||||
def gemm_non_square():
|
||||
"""GEMM with non-square weight matrix: [B, K] @ [K, N], K != N."""
|
||||
B, K, N = 4, 128, 64
|
||||
@@ -823,6 +835,7 @@ def div_after_gemm():
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Generating GEMM tests:")
|
||||
gemm_simple()
|
||||
gemm_non_square()
|
||||
gemm_with_bias()
|
||||
gemm_transB()
|
||||
|
||||
Reference in New Issue
Block a user