diff --git a/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt b/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt index 0b7e8cc..32964aa 100644 --- a/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt +++ b/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt @@ -22,6 +22,7 @@ add_pim_library(OMONNXToSpatial Patterns/Tensor/Gather.cpp Patterns/Tensor/Resize.cpp Patterns/Tensor/Reshape.cpp + Patterns/Tensor/Slice.cpp Patterns/Tensor/Split.cpp Patterns/Tensor/Transpose.cpp ONNXToSpatialPass.cpp diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp index edf311e..c3d42f7 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp @@ -138,6 +138,7 @@ void ONNXToSpatialPass::runOnOperation() { target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns.cpp index ffa0b1f..0a747e9 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns.cpp @@ -22,6 +22,7 @@ void populateConversionPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { populateGatherPatterns(patterns, ctx); populateResizePatterns(patterns, ctx); populateReshapePatterns(patterns, ctx); + populateSlicePatterns(patterns, ctx); populateSplitPatterns(patterns, ctx); populateTransposePatterns(patterns, ctx); } diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns.hpp b/src/PIM/Conversion/ONNXToSpatial/Patterns.hpp index e58729e..c040536 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns.hpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns.hpp @@ -29,6 +29,7 @@ void populateConcatPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext void populateGatherPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateResizePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateReshapePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); +void populateSlicePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateSplitPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateTransposePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Slice.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Slice.cpp new file mode 100644 index 0000000..9d0eae1 --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Slice.cpp @@ -0,0 +1,189 @@ +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "llvm/ADT/SmallVector.h" + +#include +#include + +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" + +using namespace mlir; + +namespace onnx_mlir { +namespace { + +static FailureOr> getConstantIntValues(Value value) { + auto denseAttr = dyn_cast_or_null(getHostConstDenseElementsAttr(value)); + if (!denseAttr) + return failure(); + return SmallVector(denseAttr.getValues().begin(), denseAttr.getValues().end()); +} + +static bool isNoneValueLike(Value value) { return isa_and_nonnull(value.getDefiningOp()); } + +static FailureOr buildSlice(Value data, + RankedTensorType dataType, + RankedTensorType resultType, + ArrayRef starts, + ArrayRef ends, + std::optional> axes, + std::optional> steps, + ConversionPatternRewriter& rewriter, + Location loc) { + int64_t rank = dataType.getRank(); + if (!dataType.hasStaticShape() || !resultType.hasStaticShape() || resultType.getRank() != rank) + return failure(); + + if (starts.size() != ends.size()) + return failure(); + if (axes && axes->size() != starts.size()) + return failure(); + if (steps && steps->size() != starts.size()) + return failure(); + + SmallVector normalizedAxes; + if (axes) { + SmallVector seenAxes(rank, false); + normalizedAxes.reserve(axes->size()); + for (int64_t axis : *axes) { + auto normalizedAxis = normalizeAxisChecked(axis, rank); + if (failed(normalizedAxis)) + return failure(); + if (seenAxes[*normalizedAxis]) + return failure(); + seenAxes[*normalizedAxis] = true; + normalizedAxes.push_back(*normalizedAxis); + } + } + else { + if (starts.size() > static_cast(rank)) + return failure(); + normalizedAxes.reserve(starts.size()); + for (size_t i = 0; i < starts.size(); ++i) + normalizedAxes.push_back(static_cast(i)); + } + + SmallVector normalizedSteps; + if (steps) + normalizedSteps.assign(steps->begin(), steps->end()); + else + normalizedSteps.assign(starts.size(), 1); + + SmallVector computedShape(dataType.getShape().begin(), dataType.getShape().end()); + SmallVector offsets = getZeroOffsets(rewriter, rank); + SmallVector sizes = getStaticSizes(rewriter, dataType.getShape()); + SmallVector strides = getUnitStrides(rewriter, rank); + + for (auto [sliceIndex, axis] : llvm::enumerate(normalizedAxes)) { + int64_t step = normalizedSteps[sliceIndex]; + if (step <= 0) + return failure(); + + int64_t dimSize = dataType.getShape()[axis]; + int64_t start = starts[sliceIndex]; + int64_t end = ends[sliceIndex]; + + start = normalizeIndex(start, dimSize); + end = normalizeIndex(end, dimSize); + + start = std::clamp(start, int64_t {0}, dimSize); + end = std::clamp(end, int64_t {0}, dimSize); + + int64_t extent = std::max(end - start, int64_t {0}); + int64_t size = (extent + step - 1) / step; + + offsets[axis] = rewriter.getIndexAttr(start); + sizes[axis] = rewriter.getIndexAttr(size); + strides[axis] = rewriter.getIndexAttr(step); + computedShape[axis] = size; + } + + if (llvm::ArrayRef(computedShape) != resultType.getShape()) + return failure(); + + return tensor::ExtractSliceOp::create(rewriter, loc, resultType, data, offsets, sizes, strides).getResult(); +} + +struct Slice final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(ONNXSliceOp sliceOp, + ONNXSliceOpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + auto dataType = dyn_cast(adaptor.getData().getType()); + auto resultType = dyn_cast(sliceOp.getResult().getType()); + if (!dataType || !resultType || !dataType.hasStaticShape() || !resultType.hasStaticShape()) + return failure(); + + auto starts = getConstantIntValues(adaptor.getStarts()); + auto ends = getConstantIntValues(adaptor.getEnds()); + if (failed(starts)) + return rewriter.notifyMatchFailure(sliceOp, "requires compile-time constant starts"); + if (failed(ends)) + return rewriter.notifyMatchFailure(sliceOp, "requires compile-time constant ends"); + + std::optional> axes; + if (!isNoneValueLike(adaptor.getAxes())) { + auto parsedAxes = getConstantIntValues(adaptor.getAxes()); + if (failed(parsedAxes)) + return rewriter.notifyMatchFailure(sliceOp, "requires compile-time constant axes when present"); + axes = std::move(*parsedAxes); + } + + std::optional> steps; + if (!isNoneValueLike(adaptor.getSteps())) { + auto parsedSteps = getConstantIntValues(adaptor.getSteps()); + if (failed(parsedSteps)) + return rewriter.notifyMatchFailure(sliceOp, "requires compile-time constant steps when present"); + steps = std::move(*parsedSteps); + if (llvm::any_of(*steps, [](int64_t step) { return step <= 0; })) + return rewriter.notifyMatchFailure(sliceOp, "supports only positive constant steps"); + } + + ArrayRef startsRef = *starts; + ArrayRef endsRef = *ends; + std::optional> axesRef = axes ? std::optional>(ArrayRef(*axes)) + : std::nullopt; + std::optional> stepsRef = steps ? std::optional>(ArrayRef(*steps)) + : std::nullopt; + + Location loc = sliceOp.getLoc(); + auto tryBuildSlice = [&](Value data) { + return buildSlice(data, dataType, resultType, startsRef, endsRef, axesRef, stepsRef, rewriter, loc); + }; + + if (isCompileTimeComputable(adaptor.getData())) { + auto sliced = tryBuildSlice(adaptor.getData()); + if (failed(sliced)) + return rewriter.notifyMatchFailure(sliceOp, "failed to normalize static slice parameters"); + rewriter.replaceOp(sliceOp, *sliced); + return success(); + } + + auto computeOp = + createSpatCompute<1>(rewriter, loc, TypeRange {resultType}, {}, adaptor.getData(), [&](Value data) { + auto sliced = tryBuildSlice(data); + if (failed(sliced)) + return failure(); + spatial::SpatYieldOp::create(rewriter, loc, *sliced); + return success(); + }); + if (failed(computeOp)) + return rewriter.notifyMatchFailure(sliceOp, "failed to build runtime tensor.extract_slice lowering"); + + rewriter.replaceOp(sliceOp, computeOp->getResults()); + return success(); + } +}; + +} // namespace + +void populateSlicePatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.add(ctx); } + +} // namespace onnx_mlir diff --git a/validation/operations/gen_tests.py b/validation/operations/gen_tests.py index 1d5adce..ad107b3 100644 --- a/validation/operations/gen_tests.py +++ b/validation/operations/gen_tests.py @@ -1340,6 +1340,118 @@ def split_uneven_channel_axis_4d(): save_model(model, "split/uneven_channel_axis_4d", "split_uneven_channel_axis_4d.onnx") +# --------------------------------------------------------------------------- +# Slice tests +# --------------------------------------------------------------------------- + +def slice_2d_basic(): + """Slice a 2D tensor with explicit axes and unit steps.""" + X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [4, 6]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 3]) + starts = make_int64_initializer("starts", [1, 2]) + ends = make_int64_initializer("ends", [3, 5]) + axes = make_int64_initializer("axes", [0, 1]) + steps = make_int64_initializer("steps", [1, 1]) + node = helper.make_node("Slice", ["X", "starts", "ends", "axes", "steps"], ["Y"]) + graph = helper.make_graph([node], "slice_2d_basic", [X], [Y], initializer=[starts, ends, axes, steps]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "slice/2d_basic", "slice_2d_basic.onnx") + + +def slice_default_axes(): + """Slice with omitted axes and steps using default positional axes.""" + X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 3, 4]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 2, 4]) + starts = make_int64_initializer("starts", [0, 1, 0]) + ends = make_int64_initializer("ends", [2, 3, 4]) + node = helper.make_node("Slice", ["X", "starts", "ends"], ["Y"]) + graph = helper.make_graph([node], "slice_default_axes", [X], [Y], initializer=[starts, ends]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "slice/default_axes", "slice_default_axes.onnx") + + +def slice_negative_axis(): + """Slice using a negative axis.""" + X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 3, 5]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 3, 3]) + starts = make_int64_initializer("starts", [1]) + ends = make_int64_initializer("ends", [4]) + axes = make_int64_initializer("axes", [-1]) + node = helper.make_node("Slice", ["X", "starts", "ends", "axes"], ["Y"]) + graph = helper.make_graph([node], "slice_negative_axis", [X], [Y], initializer=[starts, ends, axes]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "slice/negative_axis", "slice_negative_axis.onnx") + + +def slice_negative_indices(): + """Slice with negative indices along one axis.""" + X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [3, 5]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [3, 3]) + starts = make_int64_initializer("starts", [-4]) + ends = make_int64_initializer("ends", [-1]) + axes = make_int64_initializer("axes", [1]) + node = helper.make_node("Slice", ["X", "starts", "ends", "axes"], ["Y"]) + graph = helper.make_graph([node], "slice_negative_indices", [X], [Y], initializer=[starts, ends, axes]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "slice/negative_indices", "slice_negative_indices.onnx") + + +def slice_step2(): + """Slice with a positive step greater than one.""" + X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 3, 8]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 3, 4]) + starts = make_int64_initializer("starts", [0]) + ends = make_int64_initializer("ends", [8]) + axes = make_int64_initializer("axes", [2]) + steps = make_int64_initializer("steps", [2]) + node = helper.make_node("Slice", ["X", "starts", "ends", "axes", "steps"], ["Y"]) + graph = helper.make_graph([node], "slice_step2", [X], [Y], initializer=[starts, ends, axes, steps]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "slice/step2", "slice_step2.onnx") + + +def slice_nchw_spatial_crop(): + """Slice an NCHW tensor across the spatial axes.""" + X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 8, 8]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 3, 4, 6]) + starts = make_int64_initializer("starts", [2, 1]) + ends = make_int64_initializer("ends", [6, 7]) + axes = make_int64_initializer("axes", [2, 3]) + node = helper.make_node("Slice", ["X", "starts", "ends", "axes"], ["Y"]) + graph = helper.make_graph([node], "slice_nchw_spatial_crop", [X], [Y], initializer=[starts, ends, axes]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "slice/nchw_spatial_crop", "slice_nchw_spatial_crop.onnx") + + +def slice_after_conv(): + """Conv followed by a spatial crop using Slice.""" + rng = np.random.default_rng(108) + X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 8, 8]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 4, 4, 4]) + W = numpy_helper.from_array(rng.uniform(-1, 1, (4, 3, 3, 3)).astype(np.float32), name="W") + starts = make_int64_initializer("starts", [1, 1]) + ends = make_int64_initializer("ends", [5, 5]) + axes = make_int64_initializer("axes", [2, 3]) + conv = helper.make_node("Conv", ["X", "W"], ["C"], kernel_shape=[3, 3], strides=[1, 1], pads=[0, 0, 0, 0]) + slice_node = helper.make_node("Slice", ["C", "starts", "ends", "axes"], ["Y"]) + graph = helper.make_graph([conv, slice_node], "slice_after_conv", [X], [Y], initializer=[W, starts, ends, axes]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "slice/after_conv", "slice_after_conv.onnx") + + +def slice_large_channel_1024(): + """Slice a large channel range out of a 1024-channel tensor.""" + X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1024, 1, 1]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 512, 1, 1]) + starts = make_int64_initializer("starts", [128]) + ends = make_int64_initializer("ends", [640]) + axes = make_int64_initializer("axes", [1]) + node = helper.make_node("Slice", ["X", "starts", "ends", "axes"], ["Y"]) + graph = helper.make_graph([node], "slice_large_channel_1024", [X], [Y], initializer=[starts, ends, axes]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "slice/large_channel_1024", "slice_large_channel_1024.onnx") + + # --------------------------------------------------------------------------- # Gather tests # --------------------------------------------------------------------------- @@ -1880,6 +1992,16 @@ if __name__ == "__main__": split_negative_axis() split_uneven_channel_axis_4d() + print("\nGenerating Slice tests:") + slice_2d_basic() + slice_default_axes() + slice_negative_axis() + slice_negative_indices() + slice_step2() + slice_nchw_spatial_crop() + slice_after_conv() + slice_large_channel_1024() + print("\nGenerating Softmax tests:") softmax_basic() softmax_3d_last_axis() diff --git a/validation/operations/slice/2d_basic/slice_2d_basic.onnx b/validation/operations/slice/2d_basic/slice_2d_basic.onnx new file mode 100644 index 0000000..6409387 Binary files /dev/null and b/validation/operations/slice/2d_basic/slice_2d_basic.onnx differ diff --git a/validation/operations/slice/after_conv/slice_after_conv.onnx b/validation/operations/slice/after_conv/slice_after_conv.onnx new file mode 100644 index 0000000..dbcbed7 Binary files /dev/null and b/validation/operations/slice/after_conv/slice_after_conv.onnx differ diff --git a/validation/operations/slice/default_axes/slice_default_axes.onnx b/validation/operations/slice/default_axes/slice_default_axes.onnx new file mode 100644 index 0000000..0dcf32f Binary files /dev/null and b/validation/operations/slice/default_axes/slice_default_axes.onnx differ diff --git a/validation/operations/slice/large_channel_1024/slice_large_channel_1024.onnx b/validation/operations/slice/large_channel_1024/slice_large_channel_1024.onnx new file mode 100644 index 0000000..9cc9201 Binary files /dev/null and b/validation/operations/slice/large_channel_1024/slice_large_channel_1024.onnx differ diff --git a/validation/operations/slice/nchw_spatial_crop/slice_nchw_spatial_crop.onnx b/validation/operations/slice/nchw_spatial_crop/slice_nchw_spatial_crop.onnx new file mode 100644 index 0000000..3edfb7a Binary files /dev/null and b/validation/operations/slice/nchw_spatial_crop/slice_nchw_spatial_crop.onnx differ diff --git a/validation/operations/slice/negative_axis/slice_negative_axis.onnx b/validation/operations/slice/negative_axis/slice_negative_axis.onnx new file mode 100644 index 0000000..2fc969b Binary files /dev/null and b/validation/operations/slice/negative_axis/slice_negative_axis.onnx differ diff --git a/validation/operations/slice/negative_indices/slice_negative_indices.onnx b/validation/operations/slice/negative_indices/slice_negative_indices.onnx new file mode 100644 index 0000000..2a8d3d8 Binary files /dev/null and b/validation/operations/slice/negative_indices/slice_negative_indices.onnx differ diff --git a/validation/operations/slice/step2/slice_step2.onnx b/validation/operations/slice/step2/slice_step2.onnx new file mode 100644 index 0000000..be0a876 Binary files /dev/null and b/validation/operations/slice/step2/slice_step2.onnx differ