better transpose pattern and cleanup
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
NiccoloN
2026-06-03 12:26:31 +02:00
parent 636310d0cb
commit 0a5e73c3ea
8 changed files with 75 additions and 165 deletions
@@ -331,14 +331,6 @@ static Value extractDynamicGemmBColumn(
return tensor::ExpandShapeOp::create(rewriter, loc, vectorType, collapsed, expandReassociation).getResult();
}
static Value extractTransposedBRow(
Value transposedB, Value row, RankedTensorType vectorType, ConversionPatternRewriter& rewriter, Location loc) {
SmallVector<OpFoldResult> offsets {row, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(vectorType.getDimSize(1))};
SmallVector<OpFoldResult> strides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
return tensor::ExtractSliceOp::create(rewriter, loc, vectorType, transposedB, offsets, sizes, strides).getResult();
}
static Value extractDynamicGemmRowVector(
Value matrix, Value row, RankedTensorType vectorType, ConversionPatternRewriter& rewriter, Location loc) {
SmallVector<OpFoldResult> offsets {row, rewriter.getIndexAttr(0)};
@@ -424,7 +416,6 @@ static FailureOr<spatial::SpatComputeBatch> createVvdmulBatch(Value a,
RankedTensorType bType,
RankedTensorType scalarPiecesType,
RankedTensorType outType,
bool bAlreadyTransposed,
ConversionPatternRewriter& rewriter,
Location loc) {
const int64_t numOutRows = outType.getDimSize(0);
@@ -446,8 +437,7 @@ static FailureOr<spatial::SpatComputeBatch> createVvdmulBatch(Value a,
auto vectorType = RankedTensorType::get({1, reductionSize}, aType.getElementType());
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
Value aVector = extractDynamicGemmRowVector(args.inputs[0], row, vectorType, rewriter, loc);
Value bVector = bAlreadyTransposed ? extractTransposedBRow(args.inputs[1], column, vectorType, rewriter, loc)
: extractDynamicGemmBColumn(args.inputs[1], column, vectorType, rewriter, loc);
Value bVector = extractDynamicGemmBColumn(args.inputs[1], column, vectorType, rewriter, loc);
Value scalar = spatial::SpatVVDMulOp::create(rewriter, loc, scalarType, aVector, bVector).getResult();
SmallVector<OpFoldResult> outputOffsets {args.lane, rewriter.getIndexAttr(0)};
@@ -739,6 +729,13 @@ LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp,
const int64_t numOutCols = outType.getDimSize(1);
const int64_t reductionSize = aType.getDimSize(1);
if (gemmOpAdaptor.getTransB()) {
auto bShape = bType.getShape();
auto transposedType = RankedTensorType::get({bShape[1], bShape[0]}, bType.getElementType(), bType.getEncoding());
b = ONNXTransposeOp::create(rewriter, loc, transposedType, b, rewriter.getI64ArrayAttr({1, 0})).getResult();
bType = transposedType;
}
if (!isCompileTimeComputable(b)) {
bool hasC = hasGemmBias(c);
float alpha = gemmOpAdaptor.getAlpha().convertToFloat();
@@ -758,10 +755,8 @@ LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp,
biasType = *verifiedBiasType;
}
const int64_t expectedBRows = gemmOpAdaptor.getTransB() ? numOutCols : reductionSize;
const int64_t expectedBCols = gemmOpAdaptor.getTransB() ? reductionSize : numOutCols;
if (aType.getDimSize(0) != numOutRows || bType.getDimSize(0) != expectedBRows
|| bType.getDimSize(1) != expectedBCols) {
if (aType.getDimSize(0) != numOutRows || bType.getDimSize(0) != reductionSize
|| bType.getDimSize(1) != numOutCols) {
gemmOp.emitOpError("has inconsistent A, B, and output shapes");
return failure();
}
@@ -773,8 +768,7 @@ LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp,
}
auto scalarPiecesType = RankedTensorType::get({laneCount64, 1}, outType.getElementType());
auto batchOp =
createVvdmulBatch(a, b, aType, bType, scalarPiecesType, outType, gemmOpAdaptor.getTransB(), rewriter, loc);
auto batchOp = createVvdmulBatch(a, b, aType, bType, scalarPiecesType, outType, rewriter, loc);
if (failed(batchOp))
return failure();
auto outputCompute = createDynamicGemmOutputCompute(
@@ -793,13 +787,6 @@ LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp,
b = *scaledB;
bType = cast<RankedTensorType>(b.getType());
if (gemmOpAdaptor.getTransB()) {
auto bShape = bType.getShape();
auto transposedType = RankedTensorType::get({bShape[1], bShape[0]}, bType.getElementType());
b = transposeMaybeInCompute(b, transposedType, {1, 0}, rewriter, loc);
bType = cast<RankedTensorType>(b.getType());
}
if (aType.getDimSize(0) != numOutRows || bType.getDimSize(0) != reductionSize || bType.getDimSize(1) != numOutCols) {
gemmOp.emitOpError("has inconsistent A, B, and output shapes after transpose handling");
return failure();