better transpose pattern and cleanup
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
NiccoloN
2026-06-03 12:26:31 +02:00
parent 636310d0cb
commit 0a5e73c3ea
8 changed files with 75 additions and 165 deletions
@@ -123,39 +123,16 @@ static Value extractBatchMatrix(Value value,
static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Location loc) {
auto type = cast<RankedTensorType>(value.getType());
auto shape = type.getShape();
RankedTensorType transposedType;
SmallVector<int64_t> perm;
auto createONNXTranspose = [&](RankedTensorType resultType, ArrayRef<int64_t> permutation) {
return ONNXTransposeOp::create(rewriter, loc, resultType, value, rewriter.getI64ArrayAttr(permutation)).getResult();
};
if (type.getRank() == 2) {
transposedType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType());
perm = {1, 0};
}
else {
transposedType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType());
perm = {0, 2, 1};
auto resultType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType(), type.getEncoding());
return createONNXTranspose(resultType, {1, 0});
}
return transposeMaybeInCompute(value, transposedType, perm, rewriter, loc);
}
static Value transposeLastTwoDimsInCompute(Value value, PatternRewriter& rewriter, Location loc) {
auto type = cast<RankedTensorType>(value.getType());
auto shape = type.getShape();
RankedTensorType transposedType;
SmallVector<int64_t> perm;
if (type.getRank() == 2) {
transposedType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType());
perm = {1, 0};
}
else {
transposedType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType());
perm = {0, 2, 1};
}
auto transposeCompute = createSpatCompute<1>(rewriter, loc, transposedType, {}, ValueRange {value}, [&](Value input) {
Value transposed = ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm));
spatial::SpatYieldOp::create(rewriter, loc, transposed);
});
return transposeCompute.getResult(0);
auto resultType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType(), type.getEncoding());
return createONNXTranspose(resultType, {0, 2, 1});
}
static Value createZeroPaddedTensor(Value value, RankedTensorType resultType, PatternRewriter& rewriter, Location loc) {
@@ -372,32 +349,6 @@ static Value extractDynamicBatchedBColumn(Value matrix,
.getResult();
}
static Value extractDynamicBatchedBRow(Value matrix,
int64_t sourceBatchCount,
Value batch,
Value row,
RankedTensorType vectorType,
PatternRewriter& rewriter,
Location loc) {
auto rowSliceType = RankedTensorType::get({1, 1, vectorType.getDimSize(1)}, vectorType.getElementType());
SmallVector<OpFoldResult> offsets {sourceBatchCount == 1 ? OpFoldResult(rewriter.getIndexAttr(0))
: OpFoldResult(batch),
row,
rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(vectorType.getDimSize(1))};
auto rowSlice =
tensor::ExtractSliceOp::create(rewriter, loc, rowSliceType, matrix, offsets, sizes, getUnitStrides(rewriter, 3));
return tensor::CollapseShapeOp::create(rewriter,
loc,
vectorType,
rowSlice,
SmallVector<ReassociationIndices> {
{0, 1},
{2}
});
}
static Value extractDynamicBatchedRowVector(Value matrix,
int64_t sourceBatchCount,
Value batch,
@@ -432,7 +383,6 @@ static FailureOr<spatial::SpatComputeBatch> createBatchedVvdmulBatch(Value a,
RankedTensorType bType,
RankedTensorType scalarPiecesType,
RankedTensorType outType,
bool bAlreadyTransposed,
PatternRewriter& rewriter,
Location loc) {
const int64_t numBatches = outType.getDimSize(0);
@@ -459,9 +409,7 @@ static FailureOr<spatial::SpatComputeBatch> createBatchedVvdmulBatch(Value a,
Value aVector =
extractDynamicBatchedRowVector(args.inputs[0], aBatchCount, batch, row, vectorType, rewriter, loc);
Value bVector =
bAlreadyTransposed
? extractDynamicBatchedBRow(args.inputs[1], bBatchCount, batch, column, vectorType, rewriter, loc)
: extractDynamicBatchedBColumn(args.inputs[1], bBatchCount, batch, column, vectorType, rewriter, loc);
extractDynamicBatchedBColumn(args.inputs[1], bBatchCount, batch, column, vectorType, rewriter, loc);
Value scalar = spatial::SpatVVDMulOp::create(rewriter, loc, scalarType, aVector, bVector).getResult();
SmallVector<OpFoldResult> outputOffsets {args.lane, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
@@ -537,15 +485,6 @@ static FailureOr<Value> createBatchedDynamicOutputCompute(Value scalarPieces,
return computeOp->getResult(0);
}
static Value transposeBatchedOutput(Value value, RankedTensorType outputType, PatternRewriter& rewriter, Location loc) {
auto transposeCompute =
createSpatCompute<1>(rewriter, loc, TypeRange {outputType}, {}, ValueRange {value}, [&](Value input) {
Value transposed = ONNXTransposeOp::create(rewriter, loc, outputType, input, rewriter.getI64ArrayAttr({0, 2, 1}));
spatial::SpatYieldOp::create(rewriter, loc, transposed);
});
return transposeCompute.getResult(0);
}
static Value extractBatchedReductionPiece(Value partialPiecesArg,
Value batch,
Value hSlice,
@@ -764,7 +703,7 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
int64_t gemmK = shapeInfo->k;
int64_t gemmN = shapeInfo->n;
if (useTransposedForm) {
lhs = transposeLastTwoDimsInCompute(matmulOp.getB(), rewriter, loc);
lhs = transposeLastTwoDims(matmulOp.getB(), rewriter, loc);
lhsBatchForGemm = shapeInfo->rhsBatch;
rhs = transposeLastTwoDims(matmulOp.getA(), rewriter, loc);
rhsBatchForGemm = shapeInfo->lhsBatch;
@@ -787,15 +726,10 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
rewriter.getBoolAttr(false),
rewriter.getBoolAttr(false))
.getY();
if (useTransposedForm) {
auto transposeCompute =
createSpatCompute<1>(rewriter, loc, TypeRange {shapeInfo->outType}, {}, gemmResult, [&](Value input) {
Value transposed =
ONNXTransposeOp::create(rewriter, loc, shapeInfo->outType, input, rewriter.getI64ArrayAttr({1, 0}));
spatial::SpatYieldOp::create(rewriter, loc, transposed);
});
gemmResult = transposeCompute.getResult(0);
}
if (useTransposedForm)
gemmResult =
ONNXTransposeOp::create(rewriter, loc, shapeInfo->outType, gemmResult, rewriter.getI64ArrayAttr({1, 0}))
.getResult();
rewriter.replaceOp(matmulOp, gemmResult);
return success();
}
@@ -822,7 +756,7 @@ struct MatMulBatchedToSpatialComputes : OpRewritePattern<ONNXMatMulOp> {
int64_t gemmK = shapeInfo->k;
int64_t gemmN = shapeInfo->n;
if (useTransposedForm) {
lhs = transposeLastTwoDimsInCompute(matmulOp.getB(), rewriter, loc);
lhs = transposeLastTwoDims(matmulOp.getB(), rewriter, loc);
lhsBatchForGemm = shapeInfo->rhsBatch;
rhs = transposeLastTwoDims(matmulOp.getA(), rewriter, loc);
rhsBatchForGemm = shapeInfo->lhsBatch;
@@ -880,12 +814,14 @@ struct MatMulBatchedToSpatialComputes : OpRewritePattern<ONNXMatMulOp> {
if (failed(result))
return failure();
Value finalResult = *result;
if (useTransposedForm)
finalResult = transposeBatchedOutput(
finalResult,
RankedTensorType::get({shapeInfo->batch, shapeInfo->m, shapeInfo->n}, shapeInfo->outType.getElementType()),
rewriter,
loc);
if (useTransposedForm) {
auto transposedOutType = RankedTensorType::get({shapeInfo->batch, shapeInfo->m, shapeInfo->n},
shapeInfo->outType.getElementType(),
shapeInfo->outType.getEncoding());
finalResult =
ONNXTransposeOp::create(rewriter, loc, transposedOutType, finalResult, rewriter.getI64ArrayAttr({0, 2, 1}))
.getResult();
}
finalResult = expandBatchDims(finalResult, shapeInfo->outType, shapeInfo->batchShape.size(), rewriter, loc);
rewriter.replaceOp(matmulOp, finalResult);
return success();
@@ -901,7 +837,6 @@ struct MatMulBatchedToSpatialComputes : OpRewritePattern<ONNXMatMulOp> {
rhsBatchedType,
scalarPiecesType,
directOutType,
false,
rewriter,
loc);
if (failed(batchOp))
@@ -911,12 +846,14 @@ struct MatMulBatchedToSpatialComputes : OpRewritePattern<ONNXMatMulOp> {
if (failed(result))
return failure();
Value finalResult = *result;
if (useTransposedForm)
finalResult = transposeBatchedOutput(
finalResult,
RankedTensorType::get({shapeInfo->batch, shapeInfo->m, shapeInfo->n}, shapeInfo->outType.getElementType()),
rewriter,
loc);
if (useTransposedForm) {
auto transposedOutType = RankedTensorType::get({shapeInfo->batch, shapeInfo->m, shapeInfo->n},
shapeInfo->outType.getElementType(),
shapeInfo->outType.getEncoding());
finalResult =
ONNXTransposeOp::create(rewriter, loc, transposedOutType, finalResult, rewriter.getI64ArrayAttr({0, 2, 1}))
.getResult();
}
finalResult = expandBatchDims(finalResult, shapeInfo->outType, shapeInfo->batchShape.size(), rewriter, loc);
rewriter.replaceOp(matmulOp, finalResult);
return success();