add relu lowering
Some checks failed
Validate Operations / validate-operations (push) Failing after 2h50m56s
Some checks failed
Validate Operations / validate-operations (push) Failing after 2h50m56s
add relu validation add spatial compute helper minor refactors
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Generate ONNX test models for validating GEMM, Conv, and Pooling implementations."""
|
||||
"""Generate ONNX test models for validating GEMM, Conv, Pooling, and Relu implementations."""
|
||||
|
||||
import numpy as np
|
||||
import onnx
|
||||
@@ -327,6 +327,60 @@ def maxpool_after_conv():
|
||||
save_model(model, "pool/max_after_conv", "maxpool_after_conv.onnx")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Relu tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def relu_basic():
|
||||
"""Standalone Relu on a simple 2D tensor."""
|
||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [4, 8])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [4, 8])
|
||||
node = helper.make_node("Relu", ["X"], ["Y"])
|
||||
graph = helper.make_graph([node], "relu_basic", [X], [Y])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "relu/basic", "relu_basic.onnx")
|
||||
|
||||
|
||||
def relu_4d():
|
||||
"""Standalone Relu on an NCHW tensor."""
|
||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 3, 4, 4])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 3, 4, 4])
|
||||
node = helper.make_node("Relu", ["X"], ["Y"])
|
||||
graph = helper.make_graph([node], "relu_4d", [X], [Y])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "relu/4d", "relu_4d.onnx")
|
||||
|
||||
|
||||
def relu_after_conv():
|
||||
"""Conv followed by Relu."""
|
||||
rng = np.random.default_rng(60)
|
||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 5, 5])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 2, 3, 3])
|
||||
W = numpy_helper.from_array(rng.uniform(-1, 1, (2, 3, 3, 3)).astype(np.float32), name="W")
|
||||
B = numpy_helper.from_array(rng.uniform(-1, 1, (2,)).astype(np.float32), name="B")
|
||||
conv = helper.make_node("Conv", ["X", "W", "B"], ["C"],
|
||||
kernel_shape=[3, 3], strides=[1, 1], pads=[0, 0, 0, 0])
|
||||
relu = helper.make_node("Relu", ["C"], ["Y"])
|
||||
graph = helper.make_graph([conv, relu], "relu_after_conv", [X], [Y], initializer=[W, B])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "relu/after_conv", "relu_after_conv.onnx")
|
||||
|
||||
|
||||
def relu_after_gemm():
|
||||
"""Gemm followed by Relu."""
|
||||
B, K, N = 4, 64, 32
|
||||
rng = np.random.default_rng(61)
|
||||
W = numpy_helper.from_array(rng.uniform(-1, 1, (K, N)).astype(np.float32), name="W")
|
||||
C = numpy_helper.from_array(rng.uniform(-1, 1, (N,)).astype(np.float32), name="C")
|
||||
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [B, K])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [B, N])
|
||||
gemm = helper.make_node("Gemm", ["A", "W", "C"], ["G"])
|
||||
relu = helper.make_node("Relu", ["G"], ["Y"])
|
||||
graph = helper.make_graph([gemm, relu], "relu_after_gemm", [A], [Y], initializer=[W, C])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "relu/after_gemm", "relu_after_gemm.onnx")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -361,4 +415,10 @@ if __name__ == "__main__":
|
||||
avgpool_include_pad()
|
||||
maxpool_after_conv()
|
||||
|
||||
print("\nGenerating Relu tests:")
|
||||
relu_basic()
|
||||
relu_4d()
|
||||
relu_after_conv()
|
||||
relu_after_gemm()
|
||||
|
||||
print("\nDone.")
|
||||
|
||||
Reference in New Issue
Block a user