|
|
|
@@ -0,0 +1,189 @@
|
|
|
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
|
|
|
#include "mlir/IR/BuiltinAttributes.h"
|
|
|
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
|
|
|
|
|
|
|
|
#include "llvm/ADT/SmallVector.h"
|
|
|
|
|
|
|
|
|
|
#include <algorithm>
|
|
|
|
|
#include <optional>
|
|
|
|
|
|
|
|
|
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
|
|
|
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
|
|
|
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
|
|
|
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
|
|
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
|
|
|
|
|
|
|
|
|
using namespace mlir;
|
|
|
|
|
|
|
|
|
|
namespace onnx_mlir {
|
|
|
|
|
namespace {
|
|
|
|
|
|
|
|
|
|
static FailureOr<SmallVector<int64_t>> getConstantIntValues(Value value) {
|
|
|
|
|
auto denseAttr = dyn_cast_or_null<DenseIntElementsAttr>(getHostConstDenseElementsAttr(value));
|
|
|
|
|
if (!denseAttr)
|
|
|
|
|
return failure();
|
|
|
|
|
return SmallVector<int64_t>(denseAttr.getValues<int64_t>().begin(), denseAttr.getValues<int64_t>().end());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static bool isNoneValueLike(Value value) { return isa_and_nonnull<ONNXNoneOp>(value.getDefiningOp()); }
|
|
|
|
|
|
|
|
|
|
static FailureOr<Value> buildSlice(Value data,
|
|
|
|
|
RankedTensorType dataType,
|
|
|
|
|
RankedTensorType resultType,
|
|
|
|
|
ArrayRef<int64_t> starts,
|
|
|
|
|
ArrayRef<int64_t> ends,
|
|
|
|
|
std::optional<ArrayRef<int64_t>> axes,
|
|
|
|
|
std::optional<ArrayRef<int64_t>> steps,
|
|
|
|
|
ConversionPatternRewriter& rewriter,
|
|
|
|
|
Location loc) {
|
|
|
|
|
int64_t rank = dataType.getRank();
|
|
|
|
|
if (!dataType.hasStaticShape() || !resultType.hasStaticShape() || resultType.getRank() != rank)
|
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
|
|
if (starts.size() != ends.size())
|
|
|
|
|
return failure();
|
|
|
|
|
if (axes && axes->size() != starts.size())
|
|
|
|
|
return failure();
|
|
|
|
|
if (steps && steps->size() != starts.size())
|
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
|
|
SmallVector<int64_t> normalizedAxes;
|
|
|
|
|
if (axes) {
|
|
|
|
|
SmallVector<bool> seenAxes(rank, false);
|
|
|
|
|
normalizedAxes.reserve(axes->size());
|
|
|
|
|
for (int64_t axis : *axes) {
|
|
|
|
|
auto normalizedAxis = normalizeAxisChecked(axis, rank);
|
|
|
|
|
if (failed(normalizedAxis))
|
|
|
|
|
return failure();
|
|
|
|
|
if (seenAxes[*normalizedAxis])
|
|
|
|
|
return failure();
|
|
|
|
|
seenAxes[*normalizedAxis] = true;
|
|
|
|
|
normalizedAxes.push_back(*normalizedAxis);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
else {
|
|
|
|
|
if (starts.size() > static_cast<size_t>(rank))
|
|
|
|
|
return failure();
|
|
|
|
|
normalizedAxes.reserve(starts.size());
|
|
|
|
|
for (size_t i = 0; i < starts.size(); ++i)
|
|
|
|
|
normalizedAxes.push_back(static_cast<int64_t>(i));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SmallVector<int64_t> normalizedSteps;
|
|
|
|
|
if (steps)
|
|
|
|
|
normalizedSteps.assign(steps->begin(), steps->end());
|
|
|
|
|
else
|
|
|
|
|
normalizedSteps.assign(starts.size(), 1);
|
|
|
|
|
|
|
|
|
|
SmallVector<int64_t> computedShape(dataType.getShape().begin(), dataType.getShape().end());
|
|
|
|
|
SmallVector<OpFoldResult> offsets = getZeroOffsets(rewriter, rank);
|
|
|
|
|
SmallVector<OpFoldResult> sizes = getStaticSizes(rewriter, dataType.getShape());
|
|
|
|
|
SmallVector<OpFoldResult> strides = getUnitStrides(rewriter, rank);
|
|
|
|
|
|
|
|
|
|
for (auto [sliceIndex, axis] : llvm::enumerate(normalizedAxes)) {
|
|
|
|
|
int64_t step = normalizedSteps[sliceIndex];
|
|
|
|
|
if (step <= 0)
|
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
|
|
int64_t dimSize = dataType.getShape()[axis];
|
|
|
|
|
int64_t start = starts[sliceIndex];
|
|
|
|
|
int64_t end = ends[sliceIndex];
|
|
|
|
|
|
|
|
|
|
start = normalizeIndex(start, dimSize);
|
|
|
|
|
end = normalizeIndex(end, dimSize);
|
|
|
|
|
|
|
|
|
|
start = std::clamp(start, int64_t {0}, dimSize);
|
|
|
|
|
end = std::clamp(end, int64_t {0}, dimSize);
|
|
|
|
|
|
|
|
|
|
int64_t extent = std::max(end - start, int64_t {0});
|
|
|
|
|
int64_t size = (extent + step - 1) / step;
|
|
|
|
|
|
|
|
|
|
offsets[axis] = rewriter.getIndexAttr(start);
|
|
|
|
|
sizes[axis] = rewriter.getIndexAttr(size);
|
|
|
|
|
strides[axis] = rewriter.getIndexAttr(step);
|
|
|
|
|
computedShape[axis] = size;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (llvm::ArrayRef(computedShape) != resultType.getShape())
|
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
|
|
return tensor::ExtractSliceOp::create(rewriter, loc, resultType, data, offsets, sizes, strides).getResult();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
struct Slice final : OpConversionPattern<ONNXSliceOp> {
|
|
|
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
|
|
|
|
|
|
LogicalResult matchAndRewrite(ONNXSliceOp sliceOp,
|
|
|
|
|
ONNXSliceOpAdaptor adaptor,
|
|
|
|
|
ConversionPatternRewriter& rewriter) const override {
|
|
|
|
|
auto dataType = dyn_cast<RankedTensorType>(adaptor.getData().getType());
|
|
|
|
|
auto resultType = dyn_cast<RankedTensorType>(sliceOp.getResult().getType());
|
|
|
|
|
if (!dataType || !resultType || !dataType.hasStaticShape() || !resultType.hasStaticShape())
|
|
|
|
|
return failure();
|
|
|
|
|
|
|
|
|
|
auto starts = getConstantIntValues(adaptor.getStarts());
|
|
|
|
|
auto ends = getConstantIntValues(adaptor.getEnds());
|
|
|
|
|
if (failed(starts))
|
|
|
|
|
return rewriter.notifyMatchFailure(sliceOp, "requires compile-time constant starts");
|
|
|
|
|
if (failed(ends))
|
|
|
|
|
return rewriter.notifyMatchFailure(sliceOp, "requires compile-time constant ends");
|
|
|
|
|
|
|
|
|
|
std::optional<SmallVector<int64_t>> axes;
|
|
|
|
|
if (!isNoneValueLike(adaptor.getAxes())) {
|
|
|
|
|
auto parsedAxes = getConstantIntValues(adaptor.getAxes());
|
|
|
|
|
if (failed(parsedAxes))
|
|
|
|
|
return rewriter.notifyMatchFailure(sliceOp, "requires compile-time constant axes when present");
|
|
|
|
|
axes = std::move(*parsedAxes);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::optional<SmallVector<int64_t>> steps;
|
|
|
|
|
if (!isNoneValueLike(adaptor.getSteps())) {
|
|
|
|
|
auto parsedSteps = getConstantIntValues(adaptor.getSteps());
|
|
|
|
|
if (failed(parsedSteps))
|
|
|
|
|
return rewriter.notifyMatchFailure(sliceOp, "requires compile-time constant steps when present");
|
|
|
|
|
steps = std::move(*parsedSteps);
|
|
|
|
|
if (llvm::any_of(*steps, [](int64_t step) { return step <= 0; }))
|
|
|
|
|
return rewriter.notifyMatchFailure(sliceOp, "supports only positive constant steps");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ArrayRef<int64_t> startsRef = *starts;
|
|
|
|
|
ArrayRef<int64_t> endsRef = *ends;
|
|
|
|
|
std::optional<ArrayRef<int64_t>> axesRef = axes ? std::optional<ArrayRef<int64_t>>(ArrayRef<int64_t>(*axes))
|
|
|
|
|
: std::nullopt;
|
|
|
|
|
std::optional<ArrayRef<int64_t>> stepsRef = steps ? std::optional<ArrayRef<int64_t>>(ArrayRef<int64_t>(*steps))
|
|
|
|
|
: std::nullopt;
|
|
|
|
|
|
|
|
|
|
Location loc = sliceOp.getLoc();
|
|
|
|
|
auto tryBuildSlice = [&](Value data) {
|
|
|
|
|
return buildSlice(data, dataType, resultType, startsRef, endsRef, axesRef, stepsRef, rewriter, loc);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
if (isCompileTimeComputable(adaptor.getData())) {
|
|
|
|
|
auto sliced = tryBuildSlice(adaptor.getData());
|
|
|
|
|
if (failed(sliced))
|
|
|
|
|
return rewriter.notifyMatchFailure(sliceOp, "failed to normalize static slice parameters");
|
|
|
|
|
rewriter.replaceOp(sliceOp, *sliced);
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto computeOp =
|
|
|
|
|
createSpatCompute<1>(rewriter, loc, TypeRange {resultType}, {}, adaptor.getData(), [&](Value data) {
|
|
|
|
|
auto sliced = tryBuildSlice(data);
|
|
|
|
|
if (failed(sliced))
|
|
|
|
|
return failure();
|
|
|
|
|
spatial::SpatYieldOp::create(rewriter, loc, *sliced);
|
|
|
|
|
return success();
|
|
|
|
|
});
|
|
|
|
|
if (failed(computeOp))
|
|
|
|
|
return rewriter.notifyMatchFailure(sliceOp, "failed to build runtime tensor.extract_slice lowering");
|
|
|
|
|
|
|
|
|
|
rewriter.replaceOp(sliceOp, computeOp->getResults());
|
|
|
|
|
return success();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
void populateSlicePatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.add<Slice>(ctx); }
|
|
|
|
|
|
|
|
|
|
} // namespace onnx_mlir
|