255 lines
11 KiB
Plaintext
255 lines
11 KiB
Plaintext
diff --git a/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt b/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt
|
|
index 0b7e8cc..32964aa 100644
|
|
--- a/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt
|
|
+++ b/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt
|
|
@@ -22,6 +22,7 @@ add_pim_library(OMONNXToSpatial
|
|
Patterns/Tensor/Gather.cpp
|
|
Patterns/Tensor/Resize.cpp
|
|
Patterns/Tensor/Reshape.cpp
|
|
+ Patterns/Tensor/Slice.cpp
|
|
Patterns/Tensor/Split.cpp
|
|
Patterns/Tensor/Transpose.cpp
|
|
ONNXToSpatialPass.cpp
|
|
diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp
|
|
index edf311e..c3d42f7 100644
|
|
--- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp
|
|
+++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp
|
|
@@ -138,6 +138,7 @@ void ONNXToSpatialPass::runOnOperation() {
|
|
target.addIllegalOp<ONNXGatherOp>();
|
|
target.addIllegalOp<ONNXReshapeOp>();
|
|
target.addIllegalOp<ONNXResizeOp>();
|
|
+ target.addIllegalOp<ONNXSliceOp>();
|
|
target.addIllegalOp<ONNXLRNOp>();
|
|
target.addIllegalOp<ONNXReduceMeanV13Op>();
|
|
target.addIllegalOp<ONNXSplitOp>();
|
|
diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns.cpp
|
|
index ffa0b1f..0a747e9 100644
|
|
--- a/src/PIM/Conversion/ONNXToSpatial/Patterns.cpp
|
|
+++ b/src/PIM/Conversion/ONNXToSpatial/Patterns.cpp
|
|
@@ -22,6 +22,7 @@ void populateConversionPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
|
populateGatherPatterns(patterns, ctx);
|
|
populateResizePatterns(patterns, ctx);
|
|
populateReshapePatterns(patterns, ctx);
|
|
+ populateSlicePatterns(patterns, ctx);
|
|
populateSplitPatterns(patterns, ctx);
|
|
populateTransposePatterns(patterns, ctx);
|
|
}
|
|
diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns.hpp b/src/PIM/Conversion/ONNXToSpatial/Patterns.hpp
|
|
index e58729e..c040536 100644
|
|
--- a/src/PIM/Conversion/ONNXToSpatial/Patterns.hpp
|
|
+++ b/src/PIM/Conversion/ONNXToSpatial/Patterns.hpp
|
|
@@ -29,6 +29,7 @@ void populateConcatPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext
|
|
void populateGatherPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
|
void populateResizePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
|
void populateReshapePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
|
+void populateSlicePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
|
void populateSplitPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
|
void populateTransposePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
|
|
|
diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Slice.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Slice.cpp
|
|
new file mode 100644
|
|
index 0000000..3f8867f
|
|
--- /dev/null
|
|
+++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Slice.cpp
|
|
@@ -0,0 +1,200 @@
|
|
+#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
+#include "mlir/IR/BuiltinAttributes.h"
|
|
+#include "mlir/Transforms/DialectConversion.h"
|
|
+
|
|
+#include "llvm/ADT/SmallVector.h"
|
|
+
|
|
+#include <algorithm>
|
|
+#include <optional>
|
|
+
|
|
+#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
|
+#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
|
+#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
|
+#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
+#include "src/Dialect/ONNX/ONNXOps.hpp"
|
|
+
|
|
+using namespace mlir;
|
|
+
|
|
+namespace onnx_mlir {
|
|
+namespace {
|
|
+
|
|
+static DenseElementsAttr getDenseConstantAttr(Value value) {
|
|
+ if (auto constantOp = value.getDefiningOp<arith::ConstantOp>())
|
|
+ return dyn_cast<DenseElementsAttr>(constantOp.getValue());
|
|
+ if (auto constantOp = value.getDefiningOp<ONNXConstantOp>())
|
|
+ return dyn_cast_or_null<DenseElementsAttr>(constantOp.getValueAttr());
|
|
+ return nullptr;
|
|
+}
|
|
+
|
|
+static FailureOr<SmallVector<int64_t>> getConstantIntValues(Value value) {
|
|
+ auto denseAttr = dyn_cast_or_null<DenseIntElementsAttr>(getDenseConstantAttr(value));
|
|
+ if (!denseAttr)
|
|
+ return failure();
|
|
+ return SmallVector<int64_t>(denseAttr.getValues<int64_t>().begin(), denseAttr.getValues<int64_t>().end());
|
|
+}
|
|
+
|
|
+static bool isNoneValueLike(Value value) { return isa_and_nonnull<ONNXNoneOp>(value.getDefiningOp()); }
|
|
+
|
|
+static FailureOr<Value> buildSlice(Value data,
|
|
+ RankedTensorType dataType,
|
|
+ RankedTensorType resultType,
|
|
+ ArrayRef<int64_t> starts,
|
|
+ ArrayRef<int64_t> ends,
|
|
+ std::optional<ArrayRef<int64_t>> axes,
|
|
+ std::optional<ArrayRef<int64_t>> steps,
|
|
+ ConversionPatternRewriter& rewriter,
|
|
+ Location loc) {
|
|
+ int64_t rank = dataType.getRank();
|
|
+ if (!dataType.hasStaticShape() || !resultType.hasStaticShape() || resultType.getRank() != rank)
|
|
+ return failure();
|
|
+
|
|
+ if (starts.size() != ends.size())
|
|
+ return failure();
|
|
+ if (axes && axes->size() != starts.size())
|
|
+ return failure();
|
|
+ if (steps && steps->size() != starts.size())
|
|
+ return failure();
|
|
+
|
|
+ SmallVector<int64_t> normalizedAxes;
|
|
+ if (axes) {
|
|
+ SmallVector<bool> seenAxes(rank, false);
|
|
+ normalizedAxes.reserve(axes->size());
|
|
+ for (int64_t axis : *axes) {
|
|
+ auto normalizedAxis = normalizeAxisChecked(axis, rank);
|
|
+ if (failed(normalizedAxis))
|
|
+ return failure();
|
|
+ if (seenAxes[*normalizedAxis])
|
|
+ return failure();
|
|
+ seenAxes[*normalizedAxis] = true;
|
|
+ normalizedAxes.push_back(*normalizedAxis);
|
|
+ }
|
|
+ }
|
|
+ else {
|
|
+ if (starts.size() > static_cast<size_t>(rank))
|
|
+ return failure();
|
|
+ normalizedAxes.reserve(starts.size());
|
|
+ for (size_t i = 0; i < starts.size(); ++i)
|
|
+ normalizedAxes.push_back(static_cast<int64_t>(i));
|
|
+ }
|
|
+
|
|
+ SmallVector<int64_t> normalizedSteps;
|
|
+ if (steps)
|
|
+ normalizedSteps.assign(steps->begin(), steps->end());
|
|
+ else
|
|
+ normalizedSteps.assign(starts.size(), 1);
|
|
+
|
|
+ SmallVector<int64_t> computedShape(dataType.getShape().begin(), dataType.getShape().end());
|
|
+ SmallVector<OpFoldResult> offsets = getZeroOffsets(rewriter, rank);
|
|
+ SmallVector<OpFoldResult> sizes = getStaticSizes(rewriter, dataType.getShape());
|
|
+ SmallVector<OpFoldResult> strides = getUnitStrides(rewriter, rank);
|
|
+
|
|
+ for (auto [sliceIndex, axis] : llvm::enumerate(normalizedAxes)) {
|
|
+ int64_t step = normalizedSteps[sliceIndex];
|
|
+ if (step <= 0)
|
|
+ return failure();
|
|
+
|
|
+ int64_t dimSize = dataType.getShape()[axis];
|
|
+ int64_t start = starts[sliceIndex];
|
|
+ int64_t end = ends[sliceIndex];
|
|
+
|
|
+ if (start < 0)
|
|
+ start += dimSize;
|
|
+ if (end < 0)
|
|
+ end += dimSize;
|
|
+
|
|
+ start = std::clamp(start, int64_t {0}, dimSize);
|
|
+ end = std::clamp(end, int64_t {0}, dimSize);
|
|
+
|
|
+ int64_t extent = std::max(end - start, int64_t {0});
|
|
+ int64_t size = (extent + step - 1) / step;
|
|
+
|
|
+ offsets[axis] = rewriter.getIndexAttr(start);
|
|
+ sizes[axis] = rewriter.getIndexAttr(size);
|
|
+ strides[axis] = rewriter.getIndexAttr(step);
|
|
+ computedShape[axis] = size;
|
|
+ }
|
|
+
|
|
+ if (llvm::ArrayRef(computedShape) != resultType.getShape())
|
|
+ return failure();
|
|
+
|
|
+ return tensor::ExtractSliceOp::create(rewriter, loc, resultType, data, offsets, sizes, strides).getResult();
|
|
+}
|
|
+
|
|
+struct Slice final : OpConversionPattern<ONNXSliceOp> {
|
|
+ using OpConversionPattern::OpConversionPattern;
|
|
+
|
|
+ LogicalResult matchAndRewrite(ONNXSliceOp sliceOp,
|
|
+ ONNXSliceOpAdaptor adaptor,
|
|
+ ConversionPatternRewriter& rewriter) const override {
|
|
+ auto dataType = dyn_cast<RankedTensorType>(adaptor.getData().getType());
|
|
+ auto resultType = dyn_cast<RankedTensorType>(sliceOp.getResult().getType());
|
|
+ if (!dataType || !resultType || !dataType.hasStaticShape() || !resultType.hasStaticShape())
|
|
+ return failure();
|
|
+
|
|
+ auto starts = getConstantIntValues(adaptor.getStarts());
|
|
+ auto ends = getConstantIntValues(adaptor.getEnds());
|
|
+ if (failed(starts))
|
|
+ return rewriter.notifyMatchFailure(sliceOp, "requires compile-time constant starts");
|
|
+ if (failed(ends))
|
|
+ return rewriter.notifyMatchFailure(sliceOp, "requires compile-time constant ends");
|
|
+
|
|
+ std::optional<SmallVector<int64_t>> axes;
|
|
+ if (!isNoneValueLike(adaptor.getAxes())) {
|
|
+ auto parsedAxes = getConstantIntValues(adaptor.getAxes());
|
|
+ if (failed(parsedAxes))
|
|
+ return rewriter.notifyMatchFailure(sliceOp, "requires compile-time constant axes when present");
|
|
+ axes = std::move(*parsedAxes);
|
|
+ }
|
|
+
|
|
+ std::optional<SmallVector<int64_t>> steps;
|
|
+ if (!isNoneValueLike(adaptor.getSteps())) {
|
|
+ auto parsedSteps = getConstantIntValues(adaptor.getSteps());
|
|
+ if (failed(parsedSteps))
|
|
+ return rewriter.notifyMatchFailure(sliceOp, "requires compile-time constant steps when present");
|
|
+ steps = std::move(*parsedSteps);
|
|
+ if (llvm::any_of(*steps, [](int64_t step) { return step <= 0; }))
|
|
+ return rewriter.notifyMatchFailure(sliceOp, "supports only positive constant steps");
|
|
+ }
|
|
+
|
|
+ ArrayRef<int64_t> startsRef = *starts;
|
|
+ ArrayRef<int64_t> endsRef = *ends;
|
|
+ std::optional<ArrayRef<int64_t>> axesRef = axes ? std::optional<ArrayRef<int64_t>>(ArrayRef<int64_t>(*axes))
|
|
+ : std::nullopt;
|
|
+ std::optional<ArrayRef<int64_t>> stepsRef = steps ? std::optional<ArrayRef<int64_t>>(ArrayRef<int64_t>(*steps))
|
|
+ : std::nullopt;
|
|
+
|
|
+ Location loc = sliceOp.getLoc();
|
|
+ auto tryBuildSlice = [&](Value data) {
|
|
+ return buildSlice(data, dataType, resultType, startsRef, endsRef, axesRef, stepsRef, rewriter, loc);
|
|
+ };
|
|
+
|
|
+ if (isCompileTimeComputable(adaptor.getData())) {
|
|
+ auto sliced = tryBuildSlice(adaptor.getData());
|
|
+ if (failed(sliced))
|
|
+ return rewriter.notifyMatchFailure(sliceOp, "failed to normalize static slice parameters");
|
|
+ rewriter.replaceOp(sliceOp, *sliced);
|
|
+ return success();
|
|
+ }
|
|
+
|
|
+ auto computeOp =
|
|
+ createSpatCompute<1>(rewriter, loc, TypeRange {resultType}, {}, adaptor.getData(), [&](Value data) {
|
|
+ auto sliced = tryBuildSlice(data);
|
|
+ if (failed(sliced))
|
|
+ return failure();
|
|
+ spatial::SpatYieldOp::create(rewriter, loc, *sliced);
|
|
+ return success();
|
|
+ });
|
|
+ if (failed(computeOp))
|
|
+ return rewriter.notifyMatchFailure(sliceOp, "failed to build runtime tensor.extract_slice lowering");
|
|
+
|
|
+ rewriter.replaceOp(sliceOp, computeOp->getResults());
|
|
+ return success();
|
|
+ }
|
|
+};
|
|
+
|
|
+} // namespace
|
|
+
|
|
+void populateSlicePatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.add<Slice>(ctx); }
|
|
+
|
|
+} // namespace onnx_mlir
|