Refactor + ReduceMean batched
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
ilgeco
2026-05-29 15:57:13 +02:00
parent 832bd7f1f7
commit 819d8af0f7
27 changed files with 929 additions and 568 deletions
@@ -13,16 +13,6 @@ using namespace mlir;
namespace onnx_mlir {
namespace {
static int64_t normalizeAxis(int64_t axis, int64_t rank) { return axis >= 0 ? axis : rank + axis; }
static SmallVector<int64_t> permuteShape(ArrayRef<int64_t> shape, ArrayRef<int64_t> permutation) {
SmallVector<int64_t> permutedShape;
permutedShape.reserve(permutation.size());
for (int64_t axis : permutation)
permutedShape.push_back(shape[axis]);
return permutedShape;
}
static Value buildLoopSoftmaxSlice(Value input,
Value accumulator,
RankedTensorType inputType,
@@ -36,7 +26,7 @@ static Value buildLoopSoftmaxSlice(Value input,
SmallVector<OpFoldResult> offsets;
SmallVector<OpFoldResult> sizes;
SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> strides = getUnitStrides(rewriter, rank);
offsets.reserve(rank);
sizes.reserve(rank);
@@ -110,44 +100,31 @@ struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
if (!inputType || !inputType.hasStaticShape())
return failure();
int64_t axis = normalizeAxis(softmaxOp.getAxis(), inputType.getRank());
if (axis < 0 || axis >= inputType.getRank())
auto axis = normalizeAxisChecked(softmaxOp.getAxis(), inputType.getRank());
if (failed(axis))
return failure();
Value input = adaptor.getInput();
Value result;
if (axis == inputType.getRank() - 1) {
if (*axis == inputType.getRank() - 1) {
result = createLoopSoftmaxCompute(input, rewriter, softmaxOp.getLoc());
}
else {
SmallVector<int64_t> permutation;
permutation.reserve(inputType.getRank());
for (int64_t dim = 0; dim < inputType.getRank(); ++dim)
if (dim != axis)
if (dim != *axis)
permutation.push_back(dim);
permutation.push_back(axis);
SmallVector<int64_t> inversePermutation(inputType.getRank());
for (auto [newIndex, oldIndex] : llvm::enumerate(permutation))
inversePermutation[oldIndex] = static_cast<int64_t>(newIndex);
permutation.push_back(*axis);
SmallVector<int64_t> inversePermutation = invertPermutation(permutation);
auto transposedType = RankedTensorType::get(
permuteShape(inputType.getShape(), permutation), inputType.getElementType(), inputType.getEncoding());
auto preTransposeCompute =
createSpatCompute<1>(rewriter, softmaxOp.getLoc(), TypeRange {transposedType}, {}, input, [&](Value x) {
Value transposed = ONNXTransposeOp::create(
rewriter, softmaxOp.getLoc(), transposedType, x, rewriter.getI64ArrayAttr(permutation));
spatial::SpatYieldOp::create(rewriter, softmaxOp.getLoc(), transposed);
});
Value transposedInput = preTransposeCompute.getResult(0);
Value transposedInput =
transposeMaybeInCompute(input, transposedType, permutation, rewriter, softmaxOp.getLoc());
Value transposedResult = createLoopSoftmaxCompute(transposedInput, rewriter, softmaxOp.getLoc());
auto postTransposeCompute =
createSpatCompute<1>(rewriter, softmaxOp.getLoc(), TypeRange {inputType}, {}, transposedResult, [&](Value x) {
Value transposed = ONNXTransposeOp::create(
rewriter, softmaxOp.getLoc(), inputType, x, rewriter.getI64ArrayAttr(inversePermutation));
spatial::SpatYieldOp::create(rewriter, softmaxOp.getLoc(), transposed);
});
result = postTransposeCompute.getResult(0);
result = transposeMaybeInCompute(
transposedResult, inputType, inversePermutation, rewriter, softmaxOp.getLoc());
}
rewriter.replaceOp(softmaxOp, result);