Refactor + ReduceMean batched
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
ilgeco
2026-05-29 15:57:13 +02:00
parent 832bd7f1f7
commit 819d8af0f7
27 changed files with 929 additions and 568 deletions
@@ -19,14 +19,6 @@ using namespace mlir;
namespace onnx_mlir {
namespace {
static bool haveStaticPositiveShape(ArrayRef<int64_t> shape) {
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
}
static int64_t getStaticShapeElementCount(ArrayRef<int64_t> shape) {
return std::accumulate(shape.begin(), shape.end(), int64_t {1}, std::multiplies<int64_t> {});
}
static FailureOr<SmallVector<int64_t>> inferSupportedBatchShape(ArrayRef<int64_t> lhsBatchShape,
ArrayRef<int64_t> rhsBatchShape) {
if (lhsBatchShape.empty())
@@ -54,15 +46,7 @@ collapseBatchDims(Value value, int64_t batchSize, int64_t rows, int64_t cols, Pa
auto buildCollapsed = [&](Value input) -> Value {
return tensor::CollapseShapeOp::create(rewriter, loc, collapsedType, input, reassociation);
};
if (isCompileTimeComputable(value))
return buildCollapsed(value);
auto collapseCompute =
createSpatCompute<1>(rewriter, loc, TypeRange {collapsedType}, {}, ValueRange {value}, [&](Value input) {
spatial::SpatYieldOp::create(rewriter, loc, buildCollapsed(input));
});
return collapseCompute.getResult(0);
return materializeOrComputeUnary(value, collapsedType, rewriter, loc, buildCollapsed);
}
static Value
@@ -76,12 +60,10 @@ expandBatchDims(Value value, RankedTensorType outputType, size_t batchRank, Patt
for (size_t dim = 0; dim < batchRank; ++dim)
reassociation.front().push_back(static_cast<int64_t>(dim));
auto expandCompute =
createSpatCompute<1>(rewriter, loc, TypeRange {outputType}, {}, ValueRange {value}, [&](Value input) {
Value expanded = tensor::ExpandShapeOp::create(rewriter, loc, outputType, input, reassociation);
spatial::SpatYieldOp::create(rewriter, loc, expanded);
});
return expandCompute.getResult(0);
auto buildExpanded = [&](Value input) -> Value {
return tensor::ExpandShapeOp::create(rewriter, loc, outputType, input, reassociation).getResult();
};
return materializeOrComputeUnary(value, outputType, rewriter, loc, buildExpanded);
}
static Value extractBatchMatrix(Value value,
@@ -100,7 +82,7 @@ static Value extractBatchMatrix(Value value,
rewriter.getIndexAttr(batchSize == 1 ? 0 : batchIndex), rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes = {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(rows), rewriter.getIndexAttr(cols)};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> strides = getUnitStrides(rewriter, 3);
auto matrixType = RankedTensorType::get({rows, cols}, type.getElementType());
auto buildMatrix = [&](Value input) -> Value {
Value slice = tensor::ExtractSliceOp::create(rewriter, loc, sliceType, input, offsets, sizes, strides);
@@ -114,14 +96,7 @@ static Value extractBatchMatrix(Value value,
});
};
if (isCompileTimeComputable(value))
return buildMatrix(value);
auto batchMatrixCompute =
createSpatCompute<1>(rewriter, loc, TypeRange {matrixType}, {}, ValueRange {value}, [&](Value input) {
spatial::SpatYieldOp::create(rewriter, loc, buildMatrix(input));
});
return batchMatrixCompute.getResult(0);
return materializeOrComputeUnary(value, matrixType, rewriter, loc, buildMatrix);
}
static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Location loc) {
@@ -138,18 +113,7 @@ static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Locati
perm = {0, 2, 1};
}
auto buildTranspose = [&](Value input) -> Value {
return ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm));
};
if (isCompileTimeComputable(value))
return buildTranspose(value);
auto transposeCompute =
createSpatCompute<1>(rewriter, loc, TypeRange {transposedType}, {}, ValueRange {value}, [&](Value input) {
spatial::SpatYieldOp::create(rewriter, loc, buildTranspose(input));
});
return transposeCompute.getResult(0);
return transposeMaybeInCompute(value, transposedType, perm, rewriter, loc);
}
static Value transposeLastTwoDimsInCompute(Value value, PatternRewriter& rewriter, Location loc) {
@@ -166,10 +130,11 @@ static Value transposeLastTwoDimsInCompute(Value value, PatternRewriter& rewrite
perm = {0, 2, 1};
}
auto transposeCompute = createSpatCompute<1>(rewriter, loc, transposedType, {}, ValueRange {value}, [&](Value input) {
Value transposed = ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm));
spatial::SpatYieldOp::create(rewriter, loc, transposed);
});
auto transposeCompute =
createSpatCompute<1>(rewriter, loc, transposedType, {}, ValueRange {value}, [&](Value input) {
Value transposed = ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm));
spatial::SpatYieldOp::create(rewriter, loc, transposed);
});
return transposeCompute.getResult(0);
}
@@ -203,8 +168,7 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
return failure();
if (lhsType.getRank() < 2 || rhsType.getRank() < 2 || outType.getRank() < 2)
return failure();
if (!haveStaticPositiveShape(lhsType.getShape()) || !haveStaticPositiveShape(rhsType.getShape())
|| !haveStaticPositiveShape(outType.getShape()))
if (!hasStaticPositiveShape(lhsType) || !hasStaticPositiveShape(rhsType) || !hasStaticPositiveShape(outType))
return failure();
SmallVector<int64_t> lhsBatchShape(lhsType.getShape().begin(), lhsType.getShape().end() - 2);