diff --git a/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt b/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt index 0b7e8cc..32964aa 100644 --- a/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt +++ b/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt @@ -22,6 +22,7 @@ add_pim_library(OMONNXToSpatial Patterns/Tensor/Gather.cpp Patterns/Tensor/Resize.cpp Patterns/Tensor/Reshape.cpp + Patterns/Tensor/Slice.cpp Patterns/Tensor/Split.cpp Patterns/Tensor/Transpose.cpp ONNXToSpatialPass.cpp diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp index edf311e..c3d42f7 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp @@ -138,6 +138,7 @@ void ONNXToSpatialPass::runOnOperation() { target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns.cpp index ffa0b1f..0a747e9 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns.cpp @@ -22,6 +22,7 @@ void populateConversionPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { populateGatherPatterns(patterns, ctx); populateResizePatterns(patterns, ctx); populateReshapePatterns(patterns, ctx); + populateSlicePatterns(patterns, ctx); populateSplitPatterns(patterns, ctx); populateTransposePatterns(patterns, ctx); } diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns.hpp b/src/PIM/Conversion/ONNXToSpatial/Patterns.hpp index e58729e..c040536 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns.hpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns.hpp @@ -29,6 +29,7 @@ void populateConcatPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext void populateGatherPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateResizePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateReshapePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); +void populateSlicePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateSplitPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateTransposePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Slice.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Slice.cpp new file mode 100644 index 0000000..3f8867f --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Slice.cpp @@ -0,0 +1,200 @@ +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "llvm/ADT/SmallVector.h" + +#include +#include + +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" + +using namespace mlir; + +namespace onnx_mlir { +namespace { + +static DenseElementsAttr getDenseConstantAttr(Value value) { + if (auto constantOp = value.getDefiningOp()) + return dyn_cast(constantOp.getValue()); + if (auto constantOp = value.getDefiningOp()) + return dyn_cast_or_null(constantOp.getValueAttr()); + return nullptr; +} + +static FailureOr> getConstantIntValues(Value value) { + auto denseAttr = dyn_cast_or_null(getDenseConstantAttr(value)); + if (!denseAttr) + return failure(); + return SmallVector(denseAttr.getValues().begin(), denseAttr.getValues().end()); +} + +static bool isNoneValueLike(Value value) { return isa_and_nonnull(value.getDefiningOp()); } + +static FailureOr buildSlice(Value data, + RankedTensorType dataType, + RankedTensorType resultType, + ArrayRef starts, + ArrayRef ends, + std::optional> axes, + std::optional> steps, + ConversionPatternRewriter& rewriter, + Location loc) { + int64_t rank = dataType.getRank(); + if (!dataType.hasStaticShape() || !resultType.hasStaticShape() || resultType.getRank() != rank) + return failure(); + + if (starts.size() != ends.size()) + return failure(); + if (axes && axes->size() != starts.size()) + return failure(); + if (steps && steps->size() != starts.size()) + return failure(); + + SmallVector normalizedAxes; + if (axes) { + SmallVector seenAxes(rank, false); + normalizedAxes.reserve(axes->size()); + for (int64_t axis : *axes) { + auto normalizedAxis = normalizeAxisChecked(axis, rank); + if (failed(normalizedAxis)) + return failure(); + if (seenAxes[*normalizedAxis]) + return failure(); + seenAxes[*normalizedAxis] = true; + normalizedAxes.push_back(*normalizedAxis); + } + } + else { + if (starts.size() > static_cast(rank)) + return failure(); + normalizedAxes.reserve(starts.size()); + for (size_t i = 0; i < starts.size(); ++i) + normalizedAxes.push_back(static_cast(i)); + } + + SmallVector normalizedSteps; + if (steps) + normalizedSteps.assign(steps->begin(), steps->end()); + else + normalizedSteps.assign(starts.size(), 1); + + SmallVector computedShape(dataType.getShape().begin(), dataType.getShape().end()); + SmallVector offsets = getZeroOffsets(rewriter, rank); + SmallVector sizes = getStaticSizes(rewriter, dataType.getShape()); + SmallVector strides = getUnitStrides(rewriter, rank); + + for (auto [sliceIndex, axis] : llvm::enumerate(normalizedAxes)) { + int64_t step = normalizedSteps[sliceIndex]; + if (step <= 0) + return failure(); + + int64_t dimSize = dataType.getShape()[axis]; + int64_t start = starts[sliceIndex]; + int64_t end = ends[sliceIndex]; + + if (start < 0) + start += dimSize; + if (end < 0) + end += dimSize; + + start = std::clamp(start, int64_t {0}, dimSize); + end = std::clamp(end, int64_t {0}, dimSize); + + int64_t extent = std::max(end - start, int64_t {0}); + int64_t size = (extent + step - 1) / step; + + offsets[axis] = rewriter.getIndexAttr(start); + sizes[axis] = rewriter.getIndexAttr(size); + strides[axis] = rewriter.getIndexAttr(step); + computedShape[axis] = size; + } + + if (llvm::ArrayRef(computedShape) != resultType.getShape()) + return failure(); + + return tensor::ExtractSliceOp::create(rewriter, loc, resultType, data, offsets, sizes, strides).getResult(); +} + +struct Slice final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(ONNXSliceOp sliceOp, + ONNXSliceOpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + auto dataType = dyn_cast(adaptor.getData().getType()); + auto resultType = dyn_cast(sliceOp.getResult().getType()); + if (!dataType || !resultType || !dataType.hasStaticShape() || !resultType.hasStaticShape()) + return failure(); + + auto starts = getConstantIntValues(adaptor.getStarts()); + auto ends = getConstantIntValues(adaptor.getEnds()); + if (failed(starts)) + return rewriter.notifyMatchFailure(sliceOp, "requires compile-time constant starts"); + if (failed(ends)) + return rewriter.notifyMatchFailure(sliceOp, "requires compile-time constant ends"); + + std::optional> axes; + if (!isNoneValueLike(adaptor.getAxes())) { + auto parsedAxes = getConstantIntValues(adaptor.getAxes()); + if (failed(parsedAxes)) + return rewriter.notifyMatchFailure(sliceOp, "requires compile-time constant axes when present"); + axes = std::move(*parsedAxes); + } + + std::optional> steps; + if (!isNoneValueLike(adaptor.getSteps())) { + auto parsedSteps = getConstantIntValues(adaptor.getSteps()); + if (failed(parsedSteps)) + return rewriter.notifyMatchFailure(sliceOp, "requires compile-time constant steps when present"); + steps = std::move(*parsedSteps); + if (llvm::any_of(*steps, [](int64_t step) { return step <= 0; })) + return rewriter.notifyMatchFailure(sliceOp, "supports only positive constant steps"); + } + + ArrayRef startsRef = *starts; + ArrayRef endsRef = *ends; + std::optional> axesRef = axes ? std::optional>(ArrayRef(*axes)) + : std::nullopt; + std::optional> stepsRef = steps ? std::optional>(ArrayRef(*steps)) + : std::nullopt; + + Location loc = sliceOp.getLoc(); + auto tryBuildSlice = [&](Value data) { + return buildSlice(data, dataType, resultType, startsRef, endsRef, axesRef, stepsRef, rewriter, loc); + }; + + if (isCompileTimeComputable(adaptor.getData())) { + auto sliced = tryBuildSlice(adaptor.getData()); + if (failed(sliced)) + return rewriter.notifyMatchFailure(sliceOp, "failed to normalize static slice parameters"); + rewriter.replaceOp(sliceOp, *sliced); + return success(); + } + + auto computeOp = + createSpatCompute<1>(rewriter, loc, TypeRange {resultType}, {}, adaptor.getData(), [&](Value data) { + auto sliced = tryBuildSlice(data); + if (failed(sliced)) + return failure(); + spatial::SpatYieldOp::create(rewriter, loc, *sliced); + return success(); + }); + if (failed(computeOp)) + return rewriter.notifyMatchFailure(sliceOp, "failed to build runtime tensor.extract_slice lowering"); + + rewriter.replaceOp(sliceOp, computeOp->getResults()); + return success(); + } +}; + +} // namespace + +void populateSlicePatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.add(ctx); } + +} // namespace onnx_mlir