Added support for SliceOp
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
ilgeco
2026-06-05 17:36:51 +02:00
parent 90c4339808
commit 8ddbbcecfa
14 changed files with 315 additions and 0 deletions
@@ -22,6 +22,7 @@ add_pim_library(OMONNXToSpatial
Patterns/Tensor/Gather.cpp
Patterns/Tensor/Resize.cpp
Patterns/Tensor/Reshape.cpp
Patterns/Tensor/Slice.cpp
Patterns/Tensor/Split.cpp
Patterns/Tensor/Transpose.cpp
ONNXToSpatialPass.cpp
@@ -138,6 +138,7 @@ void ONNXToSpatialPass::runOnOperation() {
target.addIllegalOp<ONNXGatherOp>();
target.addIllegalOp<ONNXReshapeOp>();
target.addIllegalOp<ONNXResizeOp>();
target.addIllegalOp<ONNXSliceOp>();
target.addIllegalOp<ONNXLRNOp>();
target.addIllegalOp<ONNXReduceMeanV13Op>();
target.addIllegalOp<ONNXSplitOp>();
@@ -22,6 +22,7 @@ void populateConversionPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
populateGatherPatterns(patterns, ctx);
populateResizePatterns(patterns, ctx);
populateReshapePatterns(patterns, ctx);
populateSlicePatterns(patterns, ctx);
populateSplitPatterns(patterns, ctx);
populateTransposePatterns(patterns, ctx);
}
@@ -29,6 +29,7 @@ void populateConcatPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext
void populateGatherPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateResizePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateReshapePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateSlicePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateSplitPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateTransposePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
@@ -0,0 +1,189 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVector.h"
#include <algorithm>
#include <optional>
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static FailureOr<SmallVector<int64_t>> getConstantIntValues(Value value) {
auto denseAttr = dyn_cast_or_null<DenseIntElementsAttr>(getHostConstDenseElementsAttr(value));
if (!denseAttr)
return failure();
return SmallVector<int64_t>(denseAttr.getValues<int64_t>().begin(), denseAttr.getValues<int64_t>().end());
}
static bool isNoneValueLike(Value value) { return isa_and_nonnull<ONNXNoneOp>(value.getDefiningOp()); }
static FailureOr<Value> buildSlice(Value data,
RankedTensorType dataType,
RankedTensorType resultType,
ArrayRef<int64_t> starts,
ArrayRef<int64_t> ends,
std::optional<ArrayRef<int64_t>> axes,
std::optional<ArrayRef<int64_t>> steps,
ConversionPatternRewriter& rewriter,
Location loc) {
int64_t rank = dataType.getRank();
if (!dataType.hasStaticShape() || !resultType.hasStaticShape() || resultType.getRank() != rank)
return failure();
if (starts.size() != ends.size())
return failure();
if (axes && axes->size() != starts.size())
return failure();
if (steps && steps->size() != starts.size())
return failure();
SmallVector<int64_t> normalizedAxes;
if (axes) {
SmallVector<bool> seenAxes(rank, false);
normalizedAxes.reserve(axes->size());
for (int64_t axis : *axes) {
auto normalizedAxis = normalizeAxisChecked(axis, rank);
if (failed(normalizedAxis))
return failure();
if (seenAxes[*normalizedAxis])
return failure();
seenAxes[*normalizedAxis] = true;
normalizedAxes.push_back(*normalizedAxis);
}
}
else {
if (starts.size() > static_cast<size_t>(rank))
return failure();
normalizedAxes.reserve(starts.size());
for (size_t i = 0; i < starts.size(); ++i)
normalizedAxes.push_back(static_cast<int64_t>(i));
}
SmallVector<int64_t> normalizedSteps;
if (steps)
normalizedSteps.assign(steps->begin(), steps->end());
else
normalizedSteps.assign(starts.size(), 1);
SmallVector<int64_t> computedShape(dataType.getShape().begin(), dataType.getShape().end());
SmallVector<OpFoldResult> offsets = getZeroOffsets(rewriter, rank);
SmallVector<OpFoldResult> sizes = getStaticSizes(rewriter, dataType.getShape());
SmallVector<OpFoldResult> strides = getUnitStrides(rewriter, rank);
for (auto [sliceIndex, axis] : llvm::enumerate(normalizedAxes)) {
int64_t step = normalizedSteps[sliceIndex];
if (step <= 0)
return failure();
int64_t dimSize = dataType.getShape()[axis];
int64_t start = starts[sliceIndex];
int64_t end = ends[sliceIndex];
start = normalizeIndex(start, dimSize);
end = normalizeIndex(end, dimSize);
start = std::clamp(start, int64_t {0}, dimSize);
end = std::clamp(end, int64_t {0}, dimSize);
int64_t extent = std::max(end - start, int64_t {0});
int64_t size = (extent + step - 1) / step;
offsets[axis] = rewriter.getIndexAttr(start);
sizes[axis] = rewriter.getIndexAttr(size);
strides[axis] = rewriter.getIndexAttr(step);
computedShape[axis] = size;
}
if (llvm::ArrayRef(computedShape) != resultType.getShape())
return failure();
return tensor::ExtractSliceOp::create(rewriter, loc, resultType, data, offsets, sizes, strides).getResult();
}
struct Slice final : OpConversionPattern<ONNXSliceOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(ONNXSliceOp sliceOp,
ONNXSliceOpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
auto dataType = dyn_cast<RankedTensorType>(adaptor.getData().getType());
auto resultType = dyn_cast<RankedTensorType>(sliceOp.getResult().getType());
if (!dataType || !resultType || !dataType.hasStaticShape() || !resultType.hasStaticShape())
return failure();
auto starts = getConstantIntValues(adaptor.getStarts());
auto ends = getConstantIntValues(adaptor.getEnds());
if (failed(starts))
return rewriter.notifyMatchFailure(sliceOp, "requires compile-time constant starts");
if (failed(ends))
return rewriter.notifyMatchFailure(sliceOp, "requires compile-time constant ends");
std::optional<SmallVector<int64_t>> axes;
if (!isNoneValueLike(adaptor.getAxes())) {
auto parsedAxes = getConstantIntValues(adaptor.getAxes());
if (failed(parsedAxes))
return rewriter.notifyMatchFailure(sliceOp, "requires compile-time constant axes when present");
axes = std::move(*parsedAxes);
}
std::optional<SmallVector<int64_t>> steps;
if (!isNoneValueLike(adaptor.getSteps())) {
auto parsedSteps = getConstantIntValues(adaptor.getSteps());
if (failed(parsedSteps))
return rewriter.notifyMatchFailure(sliceOp, "requires compile-time constant steps when present");
steps = std::move(*parsedSteps);
if (llvm::any_of(*steps, [](int64_t step) { return step <= 0; }))
return rewriter.notifyMatchFailure(sliceOp, "supports only positive constant steps");
}
ArrayRef<int64_t> startsRef = *starts;
ArrayRef<int64_t> endsRef = *ends;
std::optional<ArrayRef<int64_t>> axesRef = axes ? std::optional<ArrayRef<int64_t>>(ArrayRef<int64_t>(*axes))
: std::nullopt;
std::optional<ArrayRef<int64_t>> stepsRef = steps ? std::optional<ArrayRef<int64_t>>(ArrayRef<int64_t>(*steps))
: std::nullopt;
Location loc = sliceOp.getLoc();
auto tryBuildSlice = [&](Value data) {
return buildSlice(data, dataType, resultType, startsRef, endsRef, axesRef, stepsRef, rewriter, loc);
};
if (isCompileTimeComputable(adaptor.getData())) {
auto sliced = tryBuildSlice(adaptor.getData());
if (failed(sliced))
return rewriter.notifyMatchFailure(sliceOp, "failed to normalize static slice parameters");
rewriter.replaceOp(sliceOp, *sliced);
return success();
}
auto computeOp =
createSpatCompute<1>(rewriter, loc, TypeRange {resultType}, {}, adaptor.getData(), [&](Value data) {
auto sliced = tryBuildSlice(data);
if (failed(sliced))
return failure();
spatial::SpatYieldOp::create(rewriter, loc, *sliced);
return success();
});
if (failed(computeOp))
return rewriter.notifyMatchFailure(sliceOp, "failed to build runtime tensor.extract_slice lowering");
rewriter.replaceOp(sliceOp, computeOp->getResults());
return success();
}
};
} // namespace
void populateSlicePatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.add<Slice>(ctx); }
} // namespace onnx_mlir