This commit is contained in:
@@ -1340,6 +1340,118 @@ def split_uneven_channel_axis_4d():
|
||||
save_model(model, "split/uneven_channel_axis_4d", "split_uneven_channel_axis_4d.onnx")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Slice tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def slice_2d_basic():
|
||||
"""Slice a 2D tensor with explicit axes and unit steps."""
|
||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [4, 6])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 3])
|
||||
starts = make_int64_initializer("starts", [1, 2])
|
||||
ends = make_int64_initializer("ends", [3, 5])
|
||||
axes = make_int64_initializer("axes", [0, 1])
|
||||
steps = make_int64_initializer("steps", [1, 1])
|
||||
node = helper.make_node("Slice", ["X", "starts", "ends", "axes", "steps"], ["Y"])
|
||||
graph = helper.make_graph([node], "slice_2d_basic", [X], [Y], initializer=[starts, ends, axes, steps])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "slice/2d_basic", "slice_2d_basic.onnx")
|
||||
|
||||
|
||||
def slice_default_axes():
|
||||
"""Slice with omitted axes and steps using default positional axes."""
|
||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 3, 4])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 2, 4])
|
||||
starts = make_int64_initializer("starts", [0, 1, 0])
|
||||
ends = make_int64_initializer("ends", [2, 3, 4])
|
||||
node = helper.make_node("Slice", ["X", "starts", "ends"], ["Y"])
|
||||
graph = helper.make_graph([node], "slice_default_axes", [X], [Y], initializer=[starts, ends])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "slice/default_axes", "slice_default_axes.onnx")
|
||||
|
||||
|
||||
def slice_negative_axis():
|
||||
"""Slice using a negative axis."""
|
||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 3, 5])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 3, 3])
|
||||
starts = make_int64_initializer("starts", [1])
|
||||
ends = make_int64_initializer("ends", [4])
|
||||
axes = make_int64_initializer("axes", [-1])
|
||||
node = helper.make_node("Slice", ["X", "starts", "ends", "axes"], ["Y"])
|
||||
graph = helper.make_graph([node], "slice_negative_axis", [X], [Y], initializer=[starts, ends, axes])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "slice/negative_axis", "slice_negative_axis.onnx")
|
||||
|
||||
|
||||
def slice_negative_indices():
|
||||
"""Slice with negative indices along one axis."""
|
||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [3, 5])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [3, 3])
|
||||
starts = make_int64_initializer("starts", [-4])
|
||||
ends = make_int64_initializer("ends", [-1])
|
||||
axes = make_int64_initializer("axes", [1])
|
||||
node = helper.make_node("Slice", ["X", "starts", "ends", "axes"], ["Y"])
|
||||
graph = helper.make_graph([node], "slice_negative_indices", [X], [Y], initializer=[starts, ends, axes])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "slice/negative_indices", "slice_negative_indices.onnx")
|
||||
|
||||
|
||||
def slice_step2():
|
||||
"""Slice with a positive step greater than one."""
|
||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 3, 8])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 3, 4])
|
||||
starts = make_int64_initializer("starts", [0])
|
||||
ends = make_int64_initializer("ends", [8])
|
||||
axes = make_int64_initializer("axes", [2])
|
||||
steps = make_int64_initializer("steps", [2])
|
||||
node = helper.make_node("Slice", ["X", "starts", "ends", "axes", "steps"], ["Y"])
|
||||
graph = helper.make_graph([node], "slice_step2", [X], [Y], initializer=[starts, ends, axes, steps])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "slice/step2", "slice_step2.onnx")
|
||||
|
||||
|
||||
def slice_nchw_spatial_crop():
|
||||
"""Slice an NCHW tensor across the spatial axes."""
|
||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 8, 8])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 3, 4, 6])
|
||||
starts = make_int64_initializer("starts", [2, 1])
|
||||
ends = make_int64_initializer("ends", [6, 7])
|
||||
axes = make_int64_initializer("axes", [2, 3])
|
||||
node = helper.make_node("Slice", ["X", "starts", "ends", "axes"], ["Y"])
|
||||
graph = helper.make_graph([node], "slice_nchw_spatial_crop", [X], [Y], initializer=[starts, ends, axes])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "slice/nchw_spatial_crop", "slice_nchw_spatial_crop.onnx")
|
||||
|
||||
|
||||
def slice_after_conv():
|
||||
"""Conv followed by a spatial crop using Slice."""
|
||||
rng = np.random.default_rng(108)
|
||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 8, 8])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 4, 4, 4])
|
||||
W = numpy_helper.from_array(rng.uniform(-1, 1, (4, 3, 3, 3)).astype(np.float32), name="W")
|
||||
starts = make_int64_initializer("starts", [1, 1])
|
||||
ends = make_int64_initializer("ends", [5, 5])
|
||||
axes = make_int64_initializer("axes", [2, 3])
|
||||
conv = helper.make_node("Conv", ["X", "W"], ["C"], kernel_shape=[3, 3], strides=[1, 1], pads=[0, 0, 0, 0])
|
||||
slice_node = helper.make_node("Slice", ["C", "starts", "ends", "axes"], ["Y"])
|
||||
graph = helper.make_graph([conv, slice_node], "slice_after_conv", [X], [Y], initializer=[W, starts, ends, axes])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "slice/after_conv", "slice_after_conv.onnx")
|
||||
|
||||
|
||||
def slice_large_channel_1024():
|
||||
"""Slice a large channel range out of a 1024-channel tensor."""
|
||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1024, 1, 1])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 512, 1, 1])
|
||||
starts = make_int64_initializer("starts", [128])
|
||||
ends = make_int64_initializer("ends", [640])
|
||||
axes = make_int64_initializer("axes", [1])
|
||||
node = helper.make_node("Slice", ["X", "starts", "ends", "axes"], ["Y"])
|
||||
graph = helper.make_graph([node], "slice_large_channel_1024", [X], [Y], initializer=[starts, ends, axes])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "slice/large_channel_1024", "slice_large_channel_1024.onnx")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Gather tests
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -1880,6 +1992,16 @@ if __name__ == "__main__":
|
||||
split_negative_axis()
|
||||
split_uneven_channel_axis_4d()
|
||||
|
||||
print("\nGenerating Slice tests:")
|
||||
slice_2d_basic()
|
||||
slice_default_axes()
|
||||
slice_negative_axis()
|
||||
slice_negative_indices()
|
||||
slice_step2()
|
||||
slice_nchw_spatial_crop()
|
||||
slice_after_conv()
|
||||
slice_large_channel_1024()
|
||||
|
||||
print("\nGenerating Softmax tests:")
|
||||
softmax_basic()
|
||||
softmax_3d_last_axis()
|
||||
|
||||
Reference in New Issue
Block a user