Files
Raptor/diff.txt
T
ilgeco 75fb70712f
Validate Operations / validate-operations (push) Has been cancelled
CodexWorkaround
2026-06-08 11:33:36 +02:00

255 lines
11 KiB
Plaintext

diff --git a/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt b/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt
index 0b7e8cc..32964aa 100644
--- a/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt
+++ b/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt
@@ -22,6 +22,7 @@ add_pim_library(OMONNXToSpatial
Patterns/Tensor/Gather.cpp
Patterns/Tensor/Resize.cpp
Patterns/Tensor/Reshape.cpp
+ Patterns/Tensor/Slice.cpp
Patterns/Tensor/Split.cpp
Patterns/Tensor/Transpose.cpp
ONNXToSpatialPass.cpp
diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp
index edf311e..c3d42f7 100644
--- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp
+++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp
@@ -138,6 +138,7 @@ void ONNXToSpatialPass::runOnOperation() {
target.addIllegalOp<ONNXGatherOp>();
target.addIllegalOp<ONNXReshapeOp>();
target.addIllegalOp<ONNXResizeOp>();
+ target.addIllegalOp<ONNXSliceOp>();
target.addIllegalOp<ONNXLRNOp>();
target.addIllegalOp<ONNXReduceMeanV13Op>();
target.addIllegalOp<ONNXSplitOp>();
diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns.cpp
index ffa0b1f..0a747e9 100644
--- a/src/PIM/Conversion/ONNXToSpatial/Patterns.cpp
+++ b/src/PIM/Conversion/ONNXToSpatial/Patterns.cpp
@@ -22,6 +22,7 @@ void populateConversionPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
populateGatherPatterns(patterns, ctx);
populateResizePatterns(patterns, ctx);
populateReshapePatterns(patterns, ctx);
+ populateSlicePatterns(patterns, ctx);
populateSplitPatterns(patterns, ctx);
populateTransposePatterns(patterns, ctx);
}
diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns.hpp b/src/PIM/Conversion/ONNXToSpatial/Patterns.hpp
index e58729e..c040536 100644
--- a/src/PIM/Conversion/ONNXToSpatial/Patterns.hpp
+++ b/src/PIM/Conversion/ONNXToSpatial/Patterns.hpp
@@ -29,6 +29,7 @@ void populateConcatPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext
void populateGatherPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateResizePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateReshapePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
+void populateSlicePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateSplitPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateTransposePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Slice.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Slice.cpp
new file mode 100644
index 0000000..3f8867f
--- /dev/null
+++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Slice.cpp
@@ -0,0 +1,200 @@
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+#include "llvm/ADT/SmallVector.h"
+
+#include <algorithm>
+#include <optional>
+
+#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
+#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
+#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
+#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
+#include "src/Dialect/ONNX/ONNXOps.hpp"
+
+using namespace mlir;
+
+namespace onnx_mlir {
+namespace {
+
+static DenseElementsAttr getDenseConstantAttr(Value value) {
+ if (auto constantOp = value.getDefiningOp<arith::ConstantOp>())
+ return dyn_cast<DenseElementsAttr>(constantOp.getValue());
+ if (auto constantOp = value.getDefiningOp<ONNXConstantOp>())
+ return dyn_cast_or_null<DenseElementsAttr>(constantOp.getValueAttr());
+ return nullptr;
+}
+
+static FailureOr<SmallVector<int64_t>> getConstantIntValues(Value value) {
+ auto denseAttr = dyn_cast_or_null<DenseIntElementsAttr>(getDenseConstantAttr(value));
+ if (!denseAttr)
+ return failure();
+ return SmallVector<int64_t>(denseAttr.getValues<int64_t>().begin(), denseAttr.getValues<int64_t>().end());
+}
+
+static bool isNoneValueLike(Value value) { return isa_and_nonnull<ONNXNoneOp>(value.getDefiningOp()); }
+
+static FailureOr<Value> buildSlice(Value data,
+ RankedTensorType dataType,
+ RankedTensorType resultType,
+ ArrayRef<int64_t> starts,
+ ArrayRef<int64_t> ends,
+ std::optional<ArrayRef<int64_t>> axes,
+ std::optional<ArrayRef<int64_t>> steps,
+ ConversionPatternRewriter& rewriter,
+ Location loc) {
+ int64_t rank = dataType.getRank();
+ if (!dataType.hasStaticShape() || !resultType.hasStaticShape() || resultType.getRank() != rank)
+ return failure();
+
+ if (starts.size() != ends.size())
+ return failure();
+ if (axes && axes->size() != starts.size())
+ return failure();
+ if (steps && steps->size() != starts.size())
+ return failure();
+
+ SmallVector<int64_t> normalizedAxes;
+ if (axes) {
+ SmallVector<bool> seenAxes(rank, false);
+ normalizedAxes.reserve(axes->size());
+ for (int64_t axis : *axes) {
+ auto normalizedAxis = normalizeAxisChecked(axis, rank);
+ if (failed(normalizedAxis))
+ return failure();
+ if (seenAxes[*normalizedAxis])
+ return failure();
+ seenAxes[*normalizedAxis] = true;
+ normalizedAxes.push_back(*normalizedAxis);
+ }
+ }
+ else {
+ if (starts.size() > static_cast<size_t>(rank))
+ return failure();
+ normalizedAxes.reserve(starts.size());
+ for (size_t i = 0; i < starts.size(); ++i)
+ normalizedAxes.push_back(static_cast<int64_t>(i));
+ }
+
+ SmallVector<int64_t> normalizedSteps;
+ if (steps)
+ normalizedSteps.assign(steps->begin(), steps->end());
+ else
+ normalizedSteps.assign(starts.size(), 1);
+
+ SmallVector<int64_t> computedShape(dataType.getShape().begin(), dataType.getShape().end());
+ SmallVector<OpFoldResult> offsets = getZeroOffsets(rewriter, rank);
+ SmallVector<OpFoldResult> sizes = getStaticSizes(rewriter, dataType.getShape());
+ SmallVector<OpFoldResult> strides = getUnitStrides(rewriter, rank);
+
+ for (auto [sliceIndex, axis] : llvm::enumerate(normalizedAxes)) {
+ int64_t step = normalizedSteps[sliceIndex];
+ if (step <= 0)
+ return failure();
+
+ int64_t dimSize = dataType.getShape()[axis];
+ int64_t start = starts[sliceIndex];
+ int64_t end = ends[sliceIndex];
+
+ if (start < 0)
+ start += dimSize;
+ if (end < 0)
+ end += dimSize;
+
+ start = std::clamp(start, int64_t {0}, dimSize);
+ end = std::clamp(end, int64_t {0}, dimSize);
+
+ int64_t extent = std::max(end - start, int64_t {0});
+ int64_t size = (extent + step - 1) / step;
+
+ offsets[axis] = rewriter.getIndexAttr(start);
+ sizes[axis] = rewriter.getIndexAttr(size);
+ strides[axis] = rewriter.getIndexAttr(step);
+ computedShape[axis] = size;
+ }
+
+ if (llvm::ArrayRef(computedShape) != resultType.getShape())
+ return failure();
+
+ return tensor::ExtractSliceOp::create(rewriter, loc, resultType, data, offsets, sizes, strides).getResult();
+}
+
+struct Slice final : OpConversionPattern<ONNXSliceOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult matchAndRewrite(ONNXSliceOp sliceOp,
+ ONNXSliceOpAdaptor adaptor,
+ ConversionPatternRewriter& rewriter) const override {
+ auto dataType = dyn_cast<RankedTensorType>(adaptor.getData().getType());
+ auto resultType = dyn_cast<RankedTensorType>(sliceOp.getResult().getType());
+ if (!dataType || !resultType || !dataType.hasStaticShape() || !resultType.hasStaticShape())
+ return failure();
+
+ auto starts = getConstantIntValues(adaptor.getStarts());
+ auto ends = getConstantIntValues(adaptor.getEnds());
+ if (failed(starts))
+ return rewriter.notifyMatchFailure(sliceOp, "requires compile-time constant starts");
+ if (failed(ends))
+ return rewriter.notifyMatchFailure(sliceOp, "requires compile-time constant ends");
+
+ std::optional<SmallVector<int64_t>> axes;
+ if (!isNoneValueLike(adaptor.getAxes())) {
+ auto parsedAxes = getConstantIntValues(adaptor.getAxes());
+ if (failed(parsedAxes))
+ return rewriter.notifyMatchFailure(sliceOp, "requires compile-time constant axes when present");
+ axes = std::move(*parsedAxes);
+ }
+
+ std::optional<SmallVector<int64_t>> steps;
+ if (!isNoneValueLike(adaptor.getSteps())) {
+ auto parsedSteps = getConstantIntValues(adaptor.getSteps());
+ if (failed(parsedSteps))
+ return rewriter.notifyMatchFailure(sliceOp, "requires compile-time constant steps when present");
+ steps = std::move(*parsedSteps);
+ if (llvm::any_of(*steps, [](int64_t step) { return step <= 0; }))
+ return rewriter.notifyMatchFailure(sliceOp, "supports only positive constant steps");
+ }
+
+ ArrayRef<int64_t> startsRef = *starts;
+ ArrayRef<int64_t> endsRef = *ends;
+ std::optional<ArrayRef<int64_t>> axesRef = axes ? std::optional<ArrayRef<int64_t>>(ArrayRef<int64_t>(*axes))
+ : std::nullopt;
+ std::optional<ArrayRef<int64_t>> stepsRef = steps ? std::optional<ArrayRef<int64_t>>(ArrayRef<int64_t>(*steps))
+ : std::nullopt;
+
+ Location loc = sliceOp.getLoc();
+ auto tryBuildSlice = [&](Value data) {
+ return buildSlice(data, dataType, resultType, startsRef, endsRef, axesRef, stepsRef, rewriter, loc);
+ };
+
+ if (isCompileTimeComputable(adaptor.getData())) {
+ auto sliced = tryBuildSlice(adaptor.getData());
+ if (failed(sliced))
+ return rewriter.notifyMatchFailure(sliceOp, "failed to normalize static slice parameters");
+ rewriter.replaceOp(sliceOp, *sliced);
+ return success();
+ }
+
+ auto computeOp =
+ createSpatCompute<1>(rewriter, loc, TypeRange {resultType}, {}, adaptor.getData(), [&](Value data) {
+ auto sliced = tryBuildSlice(data);
+ if (failed(sliced))
+ return failure();
+ spatial::SpatYieldOp::create(rewriter, loc, *sliced);
+ return success();
+ });
+ if (failed(computeOp))
+ return rewriter.notifyMatchFailure(sliceOp, "failed to build runtime tensor.extract_slice lowering");
+
+ rewriter.replaceOp(sliceOp, computeOp->getResults());
+ return success();
+ }
+};
+
+} // namespace
+
+void populateSlicePatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.add<Slice>(ctx); }
+
+} // namespace onnx_mlir