fix much stuff
This commit is contained in:
@@ -185,6 +185,18 @@ def conv_depthwise_grouped():
|
||||
# GEMM tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def gemm_simple():
|
||||
"""Simple GEMM with square weights: [10, 132] @ [132, 132]."""
|
||||
B, K, N = 10, 132, 132
|
||||
W = numpy_helper.from_array(np.random.default_rng(41).uniform(-1, 1, (K, N)).astype(np.float32), name="W")
|
||||
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [B, K])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [B, N])
|
||||
node = helper.make_node("Gemm", ["A", "W"], ["Y"])
|
||||
graph = helper.make_graph([node], "gemm_simple", [A], [Y], initializer=[W])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "gemm/simple", "gemm_simple.onnx")
|
||||
|
||||
|
||||
def gemm_non_square():
|
||||
"""GEMM with non-square weight matrix: [B, K] @ [K, N], K != N."""
|
||||
B, K, N = 4, 128, 64
|
||||
@@ -823,6 +835,7 @@ def div_after_gemm():
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Generating GEMM tests:")
|
||||
gemm_simple()
|
||||
gemm_non_square()
|
||||
gemm_with_bias()
|
||||
gemm_transB()
|
||||
|
||||
Reference in New Issue
Block a user