This commit is contained in:
@@ -22,6 +22,7 @@ add_pim_library(OMONNXToSpatial
|
||||
Patterns/Tensor/Gather.cpp
|
||||
Patterns/Tensor/Resize.cpp
|
||||
Patterns/Tensor/Reshape.cpp
|
||||
Patterns/Tensor/Slice.cpp
|
||||
Patterns/Tensor/Split.cpp
|
||||
Patterns/Tensor/Transpose.cpp
|
||||
ONNXToSpatialPass.cpp
|
||||
|
||||
@@ -138,6 +138,7 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
target.addIllegalOp<ONNXGatherOp>();
|
||||
target.addIllegalOp<ONNXReshapeOp>();
|
||||
target.addIllegalOp<ONNXResizeOp>();
|
||||
target.addIllegalOp<ONNXSliceOp>();
|
||||
target.addIllegalOp<ONNXLRNOp>();
|
||||
target.addIllegalOp<ONNXReduceMeanV13Op>();
|
||||
target.addIllegalOp<ONNXSplitOp>();
|
||||
|
||||
@@ -22,6 +22,7 @@ void populateConversionPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
populateGatherPatterns(patterns, ctx);
|
||||
populateResizePatterns(patterns, ctx);
|
||||
populateReshapePatterns(patterns, ctx);
|
||||
populateSlicePatterns(patterns, ctx);
|
||||
populateSplitPatterns(patterns, ctx);
|
||||
populateTransposePatterns(patterns, ctx);
|
||||
}
|
||||
|
||||
@@ -29,6 +29,7 @@ void populateConcatPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext
|
||||
void populateGatherPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
void populateResizePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
void populateReshapePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
void populateSlicePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
void populateSplitPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
void populateTransposePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
|
||||
@@ -0,0 +1,189 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <optional>
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
static FailureOr<SmallVector<int64_t>> getConstantIntValues(Value value) {
|
||||
auto denseAttr = dyn_cast_or_null<DenseIntElementsAttr>(getHostConstDenseElementsAttr(value));
|
||||
if (!denseAttr)
|
||||
return failure();
|
||||
return SmallVector<int64_t>(denseAttr.getValues<int64_t>().begin(), denseAttr.getValues<int64_t>().end());
|
||||
}
|
||||
|
||||
static bool isNoneValueLike(Value value) { return isa_and_nonnull<ONNXNoneOp>(value.getDefiningOp()); }
|
||||
|
||||
static FailureOr<Value> buildSlice(Value data,
|
||||
RankedTensorType dataType,
|
||||
RankedTensorType resultType,
|
||||
ArrayRef<int64_t> starts,
|
||||
ArrayRef<int64_t> ends,
|
||||
std::optional<ArrayRef<int64_t>> axes,
|
||||
std::optional<ArrayRef<int64_t>> steps,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
int64_t rank = dataType.getRank();
|
||||
if (!dataType.hasStaticShape() || !resultType.hasStaticShape() || resultType.getRank() != rank)
|
||||
return failure();
|
||||
|
||||
if (starts.size() != ends.size())
|
||||
return failure();
|
||||
if (axes && axes->size() != starts.size())
|
||||
return failure();
|
||||
if (steps && steps->size() != starts.size())
|
||||
return failure();
|
||||
|
||||
SmallVector<int64_t> normalizedAxes;
|
||||
if (axes) {
|
||||
SmallVector<bool> seenAxes(rank, false);
|
||||
normalizedAxes.reserve(axes->size());
|
||||
for (int64_t axis : *axes) {
|
||||
auto normalizedAxis = normalizeAxisChecked(axis, rank);
|
||||
if (failed(normalizedAxis))
|
||||
return failure();
|
||||
if (seenAxes[*normalizedAxis])
|
||||
return failure();
|
||||
seenAxes[*normalizedAxis] = true;
|
||||
normalizedAxes.push_back(*normalizedAxis);
|
||||
}
|
||||
}
|
||||
else {
|
||||
if (starts.size() > static_cast<size_t>(rank))
|
||||
return failure();
|
||||
normalizedAxes.reserve(starts.size());
|
||||
for (size_t i = 0; i < starts.size(); ++i)
|
||||
normalizedAxes.push_back(static_cast<int64_t>(i));
|
||||
}
|
||||
|
||||
SmallVector<int64_t> normalizedSteps;
|
||||
if (steps)
|
||||
normalizedSteps.assign(steps->begin(), steps->end());
|
||||
else
|
||||
normalizedSteps.assign(starts.size(), 1);
|
||||
|
||||
SmallVector<int64_t> computedShape(dataType.getShape().begin(), dataType.getShape().end());
|
||||
SmallVector<OpFoldResult> offsets = getZeroOffsets(rewriter, rank);
|
||||
SmallVector<OpFoldResult> sizes = getStaticSizes(rewriter, dataType.getShape());
|
||||
SmallVector<OpFoldResult> strides = getUnitStrides(rewriter, rank);
|
||||
|
||||
for (auto [sliceIndex, axis] : llvm::enumerate(normalizedAxes)) {
|
||||
int64_t step = normalizedSteps[sliceIndex];
|
||||
if (step <= 0)
|
||||
return failure();
|
||||
|
||||
int64_t dimSize = dataType.getShape()[axis];
|
||||
int64_t start = starts[sliceIndex];
|
||||
int64_t end = ends[sliceIndex];
|
||||
|
||||
start = normalizeIndex(start, dimSize);
|
||||
end = normalizeIndex(end, dimSize);
|
||||
|
||||
start = std::clamp(start, int64_t {0}, dimSize);
|
||||
end = std::clamp(end, int64_t {0}, dimSize);
|
||||
|
||||
int64_t extent = std::max(end - start, int64_t {0});
|
||||
int64_t size = (extent + step - 1) / step;
|
||||
|
||||
offsets[axis] = rewriter.getIndexAttr(start);
|
||||
sizes[axis] = rewriter.getIndexAttr(size);
|
||||
strides[axis] = rewriter.getIndexAttr(step);
|
||||
computedShape[axis] = size;
|
||||
}
|
||||
|
||||
if (llvm::ArrayRef(computedShape) != resultType.getShape())
|
||||
return failure();
|
||||
|
||||
return tensor::ExtractSliceOp::create(rewriter, loc, resultType, data, offsets, sizes, strides).getResult();
|
||||
}
|
||||
|
||||
struct Slice final : OpConversionPattern<ONNXSliceOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ONNXSliceOp sliceOp,
|
||||
ONNXSliceOpAdaptor adaptor,
|
||||
ConversionPatternRewriter& rewriter) const override {
|
||||
auto dataType = dyn_cast<RankedTensorType>(adaptor.getData().getType());
|
||||
auto resultType = dyn_cast<RankedTensorType>(sliceOp.getResult().getType());
|
||||
if (!dataType || !resultType || !dataType.hasStaticShape() || !resultType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
auto starts = getConstantIntValues(adaptor.getStarts());
|
||||
auto ends = getConstantIntValues(adaptor.getEnds());
|
||||
if (failed(starts))
|
||||
return rewriter.notifyMatchFailure(sliceOp, "requires compile-time constant starts");
|
||||
if (failed(ends))
|
||||
return rewriter.notifyMatchFailure(sliceOp, "requires compile-time constant ends");
|
||||
|
||||
std::optional<SmallVector<int64_t>> axes;
|
||||
if (!isNoneValueLike(adaptor.getAxes())) {
|
||||
auto parsedAxes = getConstantIntValues(adaptor.getAxes());
|
||||
if (failed(parsedAxes))
|
||||
return rewriter.notifyMatchFailure(sliceOp, "requires compile-time constant axes when present");
|
||||
axes = std::move(*parsedAxes);
|
||||
}
|
||||
|
||||
std::optional<SmallVector<int64_t>> steps;
|
||||
if (!isNoneValueLike(adaptor.getSteps())) {
|
||||
auto parsedSteps = getConstantIntValues(adaptor.getSteps());
|
||||
if (failed(parsedSteps))
|
||||
return rewriter.notifyMatchFailure(sliceOp, "requires compile-time constant steps when present");
|
||||
steps = std::move(*parsedSteps);
|
||||
if (llvm::any_of(*steps, [](int64_t step) { return step <= 0; }))
|
||||
return rewriter.notifyMatchFailure(sliceOp, "supports only positive constant steps");
|
||||
}
|
||||
|
||||
ArrayRef<int64_t> startsRef = *starts;
|
||||
ArrayRef<int64_t> endsRef = *ends;
|
||||
std::optional<ArrayRef<int64_t>> axesRef = axes ? std::optional<ArrayRef<int64_t>>(ArrayRef<int64_t>(*axes))
|
||||
: std::nullopt;
|
||||
std::optional<ArrayRef<int64_t>> stepsRef = steps ? std::optional<ArrayRef<int64_t>>(ArrayRef<int64_t>(*steps))
|
||||
: std::nullopt;
|
||||
|
||||
Location loc = sliceOp.getLoc();
|
||||
auto tryBuildSlice = [&](Value data) {
|
||||
return buildSlice(data, dataType, resultType, startsRef, endsRef, axesRef, stepsRef, rewriter, loc);
|
||||
};
|
||||
|
||||
if (isCompileTimeComputable(adaptor.getData())) {
|
||||
auto sliced = tryBuildSlice(adaptor.getData());
|
||||
if (failed(sliced))
|
||||
return rewriter.notifyMatchFailure(sliceOp, "failed to normalize static slice parameters");
|
||||
rewriter.replaceOp(sliceOp, *sliced);
|
||||
return success();
|
||||
}
|
||||
|
||||
auto computeOp =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {resultType}, {}, adaptor.getData(), [&](Value data) {
|
||||
auto sliced = tryBuildSlice(data);
|
||||
if (failed(sliced))
|
||||
return failure();
|
||||
spatial::SpatYieldOp::create(rewriter, loc, *sliced);
|
||||
return success();
|
||||
});
|
||||
if (failed(computeOp))
|
||||
return rewriter.notifyMatchFailure(sliceOp, "failed to build runtime tensor.extract_slice lowering");
|
||||
|
||||
rewriter.replaceOp(sliceOp, computeOp->getResults());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void populateSlicePatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.add<Slice>(ctx); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -1340,6 +1340,118 @@ def split_uneven_channel_axis_4d():
|
||||
save_model(model, "split/uneven_channel_axis_4d", "split_uneven_channel_axis_4d.onnx")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Slice tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def slice_2d_basic():
|
||||
"""Slice a 2D tensor with explicit axes and unit steps."""
|
||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [4, 6])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 3])
|
||||
starts = make_int64_initializer("starts", [1, 2])
|
||||
ends = make_int64_initializer("ends", [3, 5])
|
||||
axes = make_int64_initializer("axes", [0, 1])
|
||||
steps = make_int64_initializer("steps", [1, 1])
|
||||
node = helper.make_node("Slice", ["X", "starts", "ends", "axes", "steps"], ["Y"])
|
||||
graph = helper.make_graph([node], "slice_2d_basic", [X], [Y], initializer=[starts, ends, axes, steps])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "slice/2d_basic", "slice_2d_basic.onnx")
|
||||
|
||||
|
||||
def slice_default_axes():
|
||||
"""Slice with omitted axes and steps using default positional axes."""
|
||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 3, 4])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 2, 4])
|
||||
starts = make_int64_initializer("starts", [0, 1, 0])
|
||||
ends = make_int64_initializer("ends", [2, 3, 4])
|
||||
node = helper.make_node("Slice", ["X", "starts", "ends"], ["Y"])
|
||||
graph = helper.make_graph([node], "slice_default_axes", [X], [Y], initializer=[starts, ends])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "slice/default_axes", "slice_default_axes.onnx")
|
||||
|
||||
|
||||
def slice_negative_axis():
|
||||
"""Slice using a negative axis."""
|
||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 3, 5])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 3, 3])
|
||||
starts = make_int64_initializer("starts", [1])
|
||||
ends = make_int64_initializer("ends", [4])
|
||||
axes = make_int64_initializer("axes", [-1])
|
||||
node = helper.make_node("Slice", ["X", "starts", "ends", "axes"], ["Y"])
|
||||
graph = helper.make_graph([node], "slice_negative_axis", [X], [Y], initializer=[starts, ends, axes])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "slice/negative_axis", "slice_negative_axis.onnx")
|
||||
|
||||
|
||||
def slice_negative_indices():
|
||||
"""Slice with negative indices along one axis."""
|
||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [3, 5])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [3, 3])
|
||||
starts = make_int64_initializer("starts", [-4])
|
||||
ends = make_int64_initializer("ends", [-1])
|
||||
axes = make_int64_initializer("axes", [1])
|
||||
node = helper.make_node("Slice", ["X", "starts", "ends", "axes"], ["Y"])
|
||||
graph = helper.make_graph([node], "slice_negative_indices", [X], [Y], initializer=[starts, ends, axes])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "slice/negative_indices", "slice_negative_indices.onnx")
|
||||
|
||||
|
||||
def slice_step2():
|
||||
"""Slice with a positive step greater than one."""
|
||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 3, 8])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 3, 4])
|
||||
starts = make_int64_initializer("starts", [0])
|
||||
ends = make_int64_initializer("ends", [8])
|
||||
axes = make_int64_initializer("axes", [2])
|
||||
steps = make_int64_initializer("steps", [2])
|
||||
node = helper.make_node("Slice", ["X", "starts", "ends", "axes", "steps"], ["Y"])
|
||||
graph = helper.make_graph([node], "slice_step2", [X], [Y], initializer=[starts, ends, axes, steps])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "slice/step2", "slice_step2.onnx")
|
||||
|
||||
|
||||
def slice_nchw_spatial_crop():
|
||||
"""Slice an NCHW tensor across the spatial axes."""
|
||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 8, 8])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 3, 4, 6])
|
||||
starts = make_int64_initializer("starts", [2, 1])
|
||||
ends = make_int64_initializer("ends", [6, 7])
|
||||
axes = make_int64_initializer("axes", [2, 3])
|
||||
node = helper.make_node("Slice", ["X", "starts", "ends", "axes"], ["Y"])
|
||||
graph = helper.make_graph([node], "slice_nchw_spatial_crop", [X], [Y], initializer=[starts, ends, axes])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "slice/nchw_spatial_crop", "slice_nchw_spatial_crop.onnx")
|
||||
|
||||
|
||||
def slice_after_conv():
|
||||
"""Conv followed by a spatial crop using Slice."""
|
||||
rng = np.random.default_rng(108)
|
||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3, 8, 8])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 4, 4, 4])
|
||||
W = numpy_helper.from_array(rng.uniform(-1, 1, (4, 3, 3, 3)).astype(np.float32), name="W")
|
||||
starts = make_int64_initializer("starts", [1, 1])
|
||||
ends = make_int64_initializer("ends", [5, 5])
|
||||
axes = make_int64_initializer("axes", [2, 3])
|
||||
conv = helper.make_node("Conv", ["X", "W"], ["C"], kernel_shape=[3, 3], strides=[1, 1], pads=[0, 0, 0, 0])
|
||||
slice_node = helper.make_node("Slice", ["C", "starts", "ends", "axes"], ["Y"])
|
||||
graph = helper.make_graph([conv, slice_node], "slice_after_conv", [X], [Y], initializer=[W, starts, ends, axes])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "slice/after_conv", "slice_after_conv.onnx")
|
||||
|
||||
|
||||
def slice_large_channel_1024():
|
||||
"""Slice a large channel range out of a 1024-channel tensor."""
|
||||
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1024, 1, 1])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 512, 1, 1])
|
||||
starts = make_int64_initializer("starts", [128])
|
||||
ends = make_int64_initializer("ends", [640])
|
||||
axes = make_int64_initializer("axes", [1])
|
||||
node = helper.make_node("Slice", ["X", "starts", "ends", "axes"], ["Y"])
|
||||
graph = helper.make_graph([node], "slice_large_channel_1024", [X], [Y], initializer=[starts, ends, axes])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "slice/large_channel_1024", "slice_large_channel_1024.onnx")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Gather tests
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -1880,6 +1992,16 @@ if __name__ == "__main__":
|
||||
split_negative_axis()
|
||||
split_uneven_channel_axis_4d()
|
||||
|
||||
print("\nGenerating Slice tests:")
|
||||
slice_2d_basic()
|
||||
slice_default_axes()
|
||||
slice_negative_axis()
|
||||
slice_negative_indices()
|
||||
slice_step2()
|
||||
slice_nchw_spatial_crop()
|
||||
slice_after_conv()
|
||||
slice_large_channel_1024()
|
||||
|
||||
print("\nGenerating Softmax tests:")
|
||||
softmax_basic()
|
||||
softmax_3d_last_axis()
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Reference in New Issue
Block a user