This commit is contained in:
@@ -1340,6 +1340,118 @@ def split_uneven_channel_axis_4d():
|
||||
save_model(model, "split/uneven_channel_axis_4d", "split_uneven_channel_axis_4d.onnx")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Slice tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def slice_2d_basic():
|
||||
"""Slice a 2D tensor with explicit axes and unit steps."""
|
||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [4, 6])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 3])
|
||||
starts = make_int64_initializer("starts", [1, 2])
|
||||
ends = make_int64_initializer("ends", [3, 5])
|
||||
axes = make_int64_initializer("axes", [0, 1])
|
||||
steps = make_int64_initializer("steps", [1, 1])
|
||||
node = helper.make_node("Slice", ["X", "starts", "ends", "axes", "steps"], ["Y"])
|
||||
graph = helper.make_graph([node], "slice_2d_basic", [X], [Y], initializer=[starts, ends, axes, steps])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "slice/2d_basic", "slice_2d_basic.onnx")
|
||||
|
||||
|
||||
def slice_default_axes():
|
||||
"""Slice with omitted axes and steps using default positional axes."""
|
||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 3, 4])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 2, 4])
|
||||
starts = make_int64_initializer("starts", [0, 1, 0])
|
||||
ends = make_int64_initializer("ends", [2, 3, 4])
|
||||
node = helper.make_node("Slice", ["X", "starts", "ends"], ["Y"])
|
||||
graph = helper.make_graph([node], "slice_default_axes", [X], [Y], initializer=[starts, ends])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "slice/default_axes", "slice_default_axes.onnx")
|
||||
|
||||
|
||||
def slice_negative_axis():
|
||||
"""Slice using a negative axis."""
|
||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 3, 5])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 3, 3])
|
||||
starts = make_int64_initializer("starts", [1])
|
||||
ends = make_int64_initializer("ends", [4])
|
||||
axes = make_int64_initializer("axes", [-1])
|
||||
node = helper.make_node("Slice", ["X", "starts", "ends", "axes"], ["Y"])
|
||||
graph = helper.make_graph([node], "slice_negative_axis", [X], [Y], initializer=[starts, ends, axes])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "slice/negative_axis", "slice_negative_axis.onnx")
|
||||
|
||||
|
||||
def slice_negative_indices():
|
||||
"""Slice with negative indices along one axis."""
|
||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [3, 5])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [3, 3])
|
||||
starts = make_int64_initializer("starts", [-4])
|
||||
ends = make_int64_initializer("ends", [-1])
|
||||
axes = make_int64_initializer("axes", [1])
|
||||
node = helper.make_node("Slice", ["X", "starts", "ends", "axes"], ["Y"])
|
||||
graph = helper.make_graph([node], "slice_negative_indices", [X], [Y], initializer=[starts, ends, axes])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "slice/negative_indices", "slice_negative_indices.onnx")
|
||||
|
||||
|
||||
def slice_step2():
|
||||
"""Slice with a positive step greater than one."""
|
||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 3, 8])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 3, 4])
|
||||
starts = make_int64_initializer("starts", [0])
|
||||
ends = make_int64_initializer("ends", [8])
|
||||
axes = make_int64_initializer("axes", [2])
|
||||
steps = make_int64_initializer("steps", [2])
|
||||
node = helper.make_node("Slice", ["X", "starts", "ends", "axes", "steps"], ["Y"])
|
||||
graph = helper.make_graph([node], "slice_step2", [X], [Y], initializer=[starts, ends, axes, steps])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "slice/step2", "slice_step2.onnx")
|
||||
|
||||
|
||||
def slice_nchw_spatial_crop():
|
||||
"""Slice an NCHW tensor across the spatial axes."""
|
||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 8, 8])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 3, 4, 6])
|
||||
starts = make_int64_initializer("starts", [2, 1])
|
||||
ends = make_int64_initializer("ends", [6, 7])
|
||||
axes = make_int64_initializer("axes", [2, 3])
|
||||
node = helper.make_node("Slice", ["X", "starts", "ends", "axes"], ["Y"])
|
||||
graph = helper.make_graph([node], "slice_nchw_spatial_crop", [X], [Y], initializer=[starts, ends, axes])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "slice/nchw_spatial_crop", "slice_nchw_spatial_crop.onnx")
|
||||
|
||||
|
||||
def slice_after_conv():
|
||||
"""Conv followed by a spatial crop using Slice."""
|
||||
rng = np.random.default_rng(108)
|
||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 8, 8])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 4, 4, 4])
|
||||
W = numpy_helper.from_array(rng.uniform(-1, 1, (4, 3, 3, 3)).astype(np.float32), name="W")
|
||||
starts = make_int64_initializer("starts", [1, 1])
|
||||
ends = make_int64_initializer("ends", [5, 5])
|
||||
axes = make_int64_initializer("axes", [2, 3])
|
||||
conv = helper.make_node("Conv", ["X", "W"], ["C"], kernel_shape=[3, 3], strides=[1, 1], pads=[0, 0, 0, 0])
|
||||
slice_node = helper.make_node("Slice", ["C", "starts", "ends", "axes"], ["Y"])
|
||||
graph = helper.make_graph([conv, slice_node], "slice_after_conv", [X], [Y], initializer=[W, starts, ends, axes])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "slice/after_conv", "slice_after_conv.onnx")
|
||||
|
||||
|
||||
def slice_large_channel_1024():
|
||||
"""Slice a large channel range out of a 1024-channel tensor."""
|
||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1024, 1, 1])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 512, 1, 1])
|
||||
starts = make_int64_initializer("starts", [128])
|
||||
ends = make_int64_initializer("ends", [640])
|
||||
axes = make_int64_initializer("axes", [1])
|
||||
node = helper.make_node("Slice", ["X", "starts", "ends", "axes"], ["Y"])
|
||||
graph = helper.make_graph([node], "slice_large_channel_1024", [X], [Y], initializer=[starts, ends, axes])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "slice/large_channel_1024", "slice_large_channel_1024.onnx")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Gather tests
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -1880,6 +1992,16 @@ if __name__ == "__main__":
|
||||
split_negative_axis()
|
||||
split_uneven_channel_axis_4d()
|
||||
|
||||
print("\nGenerating Slice tests:")
|
||||
slice_2d_basic()
|
||||
slice_default_axes()
|
||||
slice_negative_axis()
|
||||
slice_negative_indices()
|
||||
slice_step2()
|
||||
slice_nchw_spatial_crop()
|
||||
slice_after_conv()
|
||||
slice_large_channel_1024()
|
||||
|
||||
print("\nGenerating Softmax tests:")
|
||||
softmax_basic()
|
||||
softmax_3d_last_axis()
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Reference in New Issue
Block a user