Added support for SliceOp
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
ilgeco
2026-06-05 17:36:51 +02:00
parent 90c4339808
commit 8ddbbcecfa
14 changed files with 315 additions and 0 deletions
@@ -22,6 +22,7 @@ add_pim_library(OMONNXToSpatial
Patterns/Tensor/Gather.cpp
Patterns/Tensor/Resize.cpp
Patterns/Tensor/Reshape.cpp
Patterns/Tensor/Slice.cpp
Patterns/Tensor/Split.cpp
Patterns/Tensor/Transpose.cpp
ONNXToSpatialPass.cpp
@@ -138,6 +138,7 @@ void ONNXToSpatialPass::runOnOperation() {
target.addIllegalOp<ONNXGatherOp>();
target.addIllegalOp<ONNXReshapeOp>();
target.addIllegalOp<ONNXResizeOp>();
target.addIllegalOp<ONNXSliceOp>();
target.addIllegalOp<ONNXLRNOp>();
target.addIllegalOp<ONNXReduceMeanV13Op>();
target.addIllegalOp<ONNXSplitOp>();
@@ -22,6 +22,7 @@ void populateConversionPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
populateGatherPatterns(patterns, ctx);
populateResizePatterns(patterns, ctx);
populateReshapePatterns(patterns, ctx);
populateSlicePatterns(patterns, ctx);
populateSplitPatterns(patterns, ctx);
populateTransposePatterns(patterns, ctx);
}
@@ -29,6 +29,7 @@ void populateConcatPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext
void populateGatherPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateResizePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateReshapePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateSlicePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateSplitPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateTransposePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
@@ -0,0 +1,189 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVector.h"
#include <algorithm>
#include <optional>
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static FailureOr<SmallVector<int64_t>> getConstantIntValues(Value value) {
auto denseAttr = dyn_cast_or_null<DenseIntElementsAttr>(getHostConstDenseElementsAttr(value));
if (!denseAttr)
return failure();
return SmallVector<int64_t>(denseAttr.getValues<int64_t>().begin(), denseAttr.getValues<int64_t>().end());
}
static bool isNoneValueLike(Value value) { return isa_and_nonnull<ONNXNoneOp>(value.getDefiningOp()); }
static FailureOr<Value> buildSlice(Value data,
RankedTensorType dataType,
RankedTensorType resultType,
ArrayRef<int64_t> starts,
ArrayRef<int64_t> ends,
std::optional<ArrayRef<int64_t>> axes,
std::optional<ArrayRef<int64_t>> steps,
ConversionPatternRewriter& rewriter,
Location loc) {
int64_t rank = dataType.getRank();
if (!dataType.hasStaticShape() || !resultType.hasStaticShape() || resultType.getRank() != rank)
return failure();
if (starts.size() != ends.size())
return failure();
if (axes && axes->size() != starts.size())
return failure();
if (steps && steps->size() != starts.size())
return failure();
SmallVector<int64_t> normalizedAxes;
if (axes) {
SmallVector<bool> seenAxes(rank, false);
normalizedAxes.reserve(axes->size());
for (int64_t axis : *axes) {
auto normalizedAxis = normalizeAxisChecked(axis, rank);
if (failed(normalizedAxis))
return failure();
if (seenAxes[*normalizedAxis])
return failure();
seenAxes[*normalizedAxis] = true;
normalizedAxes.push_back(*normalizedAxis);
}
}
else {
if (starts.size() > static_cast<size_t>(rank))
return failure();
normalizedAxes.reserve(starts.size());
for (size_t i = 0; i < starts.size(); ++i)
normalizedAxes.push_back(static_cast<int64_t>(i));
}
SmallVector<int64_t> normalizedSteps;
if (steps)
normalizedSteps.assign(steps->begin(), steps->end());
else
normalizedSteps.assign(starts.size(), 1);
SmallVector<int64_t> computedShape(dataType.getShape().begin(), dataType.getShape().end());
SmallVector<OpFoldResult> offsets = getZeroOffsets(rewriter, rank);
SmallVector<OpFoldResult> sizes = getStaticSizes(rewriter, dataType.getShape());
SmallVector<OpFoldResult> strides = getUnitStrides(rewriter, rank);
for (auto [sliceIndex, axis] : llvm::enumerate(normalizedAxes)) {
int64_t step = normalizedSteps[sliceIndex];
if (step <= 0)
return failure();
int64_t dimSize = dataType.getShape()[axis];
int64_t start = starts[sliceIndex];
int64_t end = ends[sliceIndex];
start = normalizeIndex(start, dimSize);
end = normalizeIndex(end, dimSize);
start = std::clamp(start, int64_t {0}, dimSize);
end = std::clamp(end, int64_t {0}, dimSize);
int64_t extent = std::max(end - start, int64_t {0});
int64_t size = (extent + step - 1) / step;
offsets[axis] = rewriter.getIndexAttr(start);
sizes[axis] = rewriter.getIndexAttr(size);
strides[axis] = rewriter.getIndexAttr(step);
computedShape[axis] = size;
}
if (llvm::ArrayRef(computedShape) != resultType.getShape())
return failure();
return tensor::ExtractSliceOp::create(rewriter, loc, resultType, data, offsets, sizes, strides).getResult();
}
struct Slice final : OpConversionPattern<ONNXSliceOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(ONNXSliceOp sliceOp,
ONNXSliceOpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
auto dataType = dyn_cast<RankedTensorType>(adaptor.getData().getType());
auto resultType = dyn_cast<RankedTensorType>(sliceOp.getResult().getType());
if (!dataType || !resultType || !dataType.hasStaticShape() || !resultType.hasStaticShape())
return failure();
auto starts = getConstantIntValues(adaptor.getStarts());
auto ends = getConstantIntValues(adaptor.getEnds());
if (failed(starts))
return rewriter.notifyMatchFailure(sliceOp, "requires compile-time constant starts");
if (failed(ends))
return rewriter.notifyMatchFailure(sliceOp, "requires compile-time constant ends");
std::optional<SmallVector<int64_t>> axes;
if (!isNoneValueLike(adaptor.getAxes())) {
auto parsedAxes = getConstantIntValues(adaptor.getAxes());
if (failed(parsedAxes))
return rewriter.notifyMatchFailure(sliceOp, "requires compile-time constant axes when present");
axes = std::move(*parsedAxes);
}
std::optional<SmallVector<int64_t>> steps;
if (!isNoneValueLike(adaptor.getSteps())) {
auto parsedSteps = getConstantIntValues(adaptor.getSteps());
if (failed(parsedSteps))
return rewriter.notifyMatchFailure(sliceOp, "requires compile-time constant steps when present");
steps = std::move(*parsedSteps);
if (llvm::any_of(*steps, [](int64_t step) { return step <= 0; }))
return rewriter.notifyMatchFailure(sliceOp, "supports only positive constant steps");
}
ArrayRef<int64_t> startsRef = *starts;
ArrayRef<int64_t> endsRef = *ends;
std::optional<ArrayRef<int64_t>> axesRef = axes ? std::optional<ArrayRef<int64_t>>(ArrayRef<int64_t>(*axes))
: std::nullopt;
std::optional<ArrayRef<int64_t>> stepsRef = steps ? std::optional<ArrayRef<int64_t>>(ArrayRef<int64_t>(*steps))
: std::nullopt;
Location loc = sliceOp.getLoc();
auto tryBuildSlice = [&](Value data) {
return buildSlice(data, dataType, resultType, startsRef, endsRef, axesRef, stepsRef, rewriter, loc);
};
if (isCompileTimeComputable(adaptor.getData())) {
auto sliced = tryBuildSlice(adaptor.getData());
if (failed(sliced))
return rewriter.notifyMatchFailure(sliceOp, "failed to normalize static slice parameters");
rewriter.replaceOp(sliceOp, *sliced);
return success();
}
auto computeOp =
createSpatCompute<1>(rewriter, loc, TypeRange {resultType}, {}, adaptor.getData(), [&](Value data) {
auto sliced = tryBuildSlice(data);
if (failed(sliced))
return failure();
spatial::SpatYieldOp::create(rewriter, loc, *sliced);
return success();
});
if (failed(computeOp))
return rewriter.notifyMatchFailure(sliceOp, "failed to build runtime tensor.extract_slice lowering");
rewriter.replaceOp(sliceOp, computeOp->getResults());
return success();
}
};
} // namespace
void populateSlicePatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.add<Slice>(ctx); }
} // namespace onnx_mlir
+122
View File
@@ -1340,6 +1340,118 @@ def split_uneven_channel_axis_4d():
save_model(model, "split/uneven_channel_axis_4d", "split_uneven_channel_axis_4d.onnx")
# ---------------------------------------------------------------------------
# Slice tests
# ---------------------------------------------------------------------------
def slice_2d_basic():
"""Slice a 2D tensor with explicit axes and unit steps."""
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [4, 6])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 3])
starts = make_int64_initializer("starts", [1, 2])
ends = make_int64_initializer("ends", [3, 5])
axes = make_int64_initializer("axes", [0, 1])
steps = make_int64_initializer("steps", [1, 1])
node = helper.make_node("Slice", ["X", "starts", "ends", "axes", "steps"], ["Y"])
graph = helper.make_graph([node], "slice_2d_basic", [X], [Y], initializer=[starts, ends, axes, steps])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "slice/2d_basic", "slice_2d_basic.onnx")
def slice_default_axes():
"""Slice with omitted axes and steps using default positional axes."""
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 3, 4])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 2, 4])
starts = make_int64_initializer("starts", [0, 1, 0])
ends = make_int64_initializer("ends", [2, 3, 4])
node = helper.make_node("Slice", ["X", "starts", "ends"], ["Y"])
graph = helper.make_graph([node], "slice_default_axes", [X], [Y], initializer=[starts, ends])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "slice/default_axes", "slice_default_axes.onnx")
def slice_negative_axis():
"""Slice using a negative axis."""
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 3, 5])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 3, 3])
starts = make_int64_initializer("starts", [1])
ends = make_int64_initializer("ends", [4])
axes = make_int64_initializer("axes", [-1])
node = helper.make_node("Slice", ["X", "starts", "ends", "axes"], ["Y"])
graph = helper.make_graph([node], "slice_negative_axis", [X], [Y], initializer=[starts, ends, axes])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "slice/negative_axis", "slice_negative_axis.onnx")
def slice_negative_indices():
"""Slice with negative indices along one axis."""
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [3, 5])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [3, 3])
starts = make_int64_initializer("starts", [-4])
ends = make_int64_initializer("ends", [-1])
axes = make_int64_initializer("axes", [1])
node = helper.make_node("Slice", ["X", "starts", "ends", "axes"], ["Y"])
graph = helper.make_graph([node], "slice_negative_indices", [X], [Y], initializer=[starts, ends, axes])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "slice/negative_indices", "slice_negative_indices.onnx")
def slice_step2():
"""Slice with a positive step greater than one."""
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 3, 8])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 3, 4])
starts = make_int64_initializer("starts", [0])
ends = make_int64_initializer("ends", [8])
axes = make_int64_initializer("axes", [2])
steps = make_int64_initializer("steps", [2])
node = helper.make_node("Slice", ["X", "starts", "ends", "axes", "steps"], ["Y"])
graph = helper.make_graph([node], "slice_step2", [X], [Y], initializer=[starts, ends, axes, steps])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "slice/step2", "slice_step2.onnx")
def slice_nchw_spatial_crop():
"""Slice an NCHW tensor across the spatial axes."""
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 8, 8])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 3, 4, 6])
starts = make_int64_initializer("starts", [2, 1])
ends = make_int64_initializer("ends", [6, 7])
axes = make_int64_initializer("axes", [2, 3])
node = helper.make_node("Slice", ["X", "starts", "ends", "axes"], ["Y"])
graph = helper.make_graph([node], "slice_nchw_spatial_crop", [X], [Y], initializer=[starts, ends, axes])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "slice/nchw_spatial_crop", "slice_nchw_spatial_crop.onnx")
def slice_after_conv():
"""Conv followed by a spatial crop using Slice."""
rng = np.random.default_rng(108)
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 8, 8])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 4, 4, 4])
W = numpy_helper.from_array(rng.uniform(-1, 1, (4, 3, 3, 3)).astype(np.float32), name="W")
starts = make_int64_initializer("starts", [1, 1])
ends = make_int64_initializer("ends", [5, 5])
axes = make_int64_initializer("axes", [2, 3])
conv = helper.make_node("Conv", ["X", "W"], ["C"], kernel_shape=[3, 3], strides=[1, 1], pads=[0, 0, 0, 0])
slice_node = helper.make_node("Slice", ["C", "starts", "ends", "axes"], ["Y"])
graph = helper.make_graph([conv, slice_node], "slice_after_conv", [X], [Y], initializer=[W, starts, ends, axes])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "slice/after_conv", "slice_after_conv.onnx")
def slice_large_channel_1024():
"""Slice a large channel range out of a 1024-channel tensor."""
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1024, 1, 1])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 512, 1, 1])
starts = make_int64_initializer("starts", [128])
ends = make_int64_initializer("ends", [640])
axes = make_int64_initializer("axes", [1])
node = helper.make_node("Slice", ["X", "starts", "ends", "axes"], ["Y"])
graph = helper.make_graph([node], "slice_large_channel_1024", [X], [Y], initializer=[starts, ends, axes])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
save_model(model, "slice/large_channel_1024", "slice_large_channel_1024.onnx")
# ---------------------------------------------------------------------------
# Gather tests
# ---------------------------------------------------------------------------
@@ -1880,6 +1992,16 @@ if __name__ == "__main__":
split_negative_axis()
split_uneven_channel_axis_4d()
print("\nGenerating Slice tests:")
slice_2d_basic()
slice_default_axes()
slice_negative_axis()
slice_negative_indices()
slice_step2()
slice_nchw_spatial_crop()
slice_after_conv()
slice_large_channel_1024()
print("\nGenerating Softmax tests:")
softmax_basic()
softmax_3d_last_axis()
Binary file not shown.