batched matmul pattern
Validate Operations / validate-operations (push) Has been cancelled

add conv helpers
new validation tests for matmul
This commit is contained in:
NiccoloN
2026-05-29 19:07:24 +02:00
parent 8bb0babf1b
commit a41f694cf0
18 changed files with 877 additions and 192 deletions
@@ -51,23 +51,107 @@ static Value createPaddedRows(Value tensorValue,
if (tensorType.getDimSize(0) == paddedRows)
return tensorValue;
auto paddedType = RankedTensorType::get({paddedRows, tensorType.getDimSize(1)}, tensorType.getElementType());
auto paddedType =
RankedTensorType::get({paddedRows, tensorType.getDimSize(1)}, tensorType.getElementType(), tensorType.getEncoding());
SmallVector<OpFoldResult> lowPads = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> highPads = {rewriter.getIndexAttr(paddedRows - tensorType.getDimSize(0)),
rewriter.getIndexAttr(0)};
auto padOp = tensor::PadOp::create(rewriter, loc, paddedType, tensorValue, lowPads, highPads);
auto* padBlock = new Block();
for (int i = 0; i < 2; i++)
for (int i = 0; i < 2; ++i)
padBlock->addArgument(rewriter.getIndexType(), loc);
padOp.getRegion().push_back(padBlock);
rewriter.setInsertionPointToStart(padBlock);
auto zero = getOrCreateConstant(rewriter, padOp.getOperation(), rewriter.getZeroAttr(tensorType.getElementType()),
auto zero = getOrCreateConstant(rewriter,
padOp.getOperation(),
rewriter.getZeroAttr(tensorType.getElementType()),
tensorType.getElementType());
tensor::YieldOp::create(rewriter, loc, zero);
rewriter.setInsertionPointAfter(padOp);
return padOp.getResult();
}
static Value packRowsForParallelGemm(Value rows,
RankedTensorType rowsType,
int64_t packFactor,
ConversionPatternRewriter& rewriter,
Location loc) {
if (packFactor == 1)
return rows;
const int64_t packedNumRows = ceilIntegerDivide(rowsType.getDimSize(0), packFactor);
const int64_t paddedNumRows = packedNumRows * packFactor;
const int64_t rowWidth = rowsType.getDimSize(1);
auto groupedType =
RankedTensorType::get({packedNumRows, packFactor, rowWidth}, rowsType.getElementType(), rowsType.getEncoding());
auto packedType =
RankedTensorType::get({packedNumRows, packFactor * rowWidth}, rowsType.getElementType(), rowsType.getEncoding());
Value paddedRows = createPaddedRows(rows, rowsType, paddedNumRows, rewriter, loc);
Value groupedRows = tensor::ExpandShapeOp::create(rewriter,
loc,
groupedType,
paddedRows,
SmallVector<ReassociationIndices> {
{0, 1},
{2}
});
return tensor::CollapseShapeOp::create(rewriter,
loc,
packedType,
groupedRows,
SmallVector<ReassociationIndices> {
{0},
{1, 2}
});
}
static Value unpackRowsFromParallelGemm(Value packedRows,
RankedTensorType packedRowsType,
int64_t unpackedRows,
int64_t rowWidth,
int64_t packFactor,
ConversionPatternRewriter& rewriter,
Location loc) {
if (packFactor == 1)
return packedRows;
const int64_t packedNumRows = packedRowsType.getDimSize(0);
const int64_t paddedNumRows = packedNumRows * packFactor;
auto expandedType =
RankedTensorType::get({packedNumRows, packFactor, rowWidth},
packedRowsType.getElementType(),
packedRowsType.getEncoding());
auto paddedType =
RankedTensorType::get({paddedNumRows, rowWidth}, packedRowsType.getElementType(), packedRowsType.getEncoding());
auto unpackedType =
RankedTensorType::get({unpackedRows, rowWidth}, packedRowsType.getElementType(), packedRowsType.getEncoding());
Value expandedRows = tensor::ExpandShapeOp::create(rewriter,
loc,
expandedType,
packedRows,
SmallVector<ReassociationIndices> {
{0},
{1, 2}
});
Value paddedRows = tensor::CollapseShapeOp::create(rewriter,
loc,
paddedType,
expandedRows,
SmallVector<ReassociationIndices> {
{0, 1},
{2}
});
if (paddedNumRows == unpackedRows)
return paddedRows;
SmallVector<OpFoldResult> offsets {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes {rewriter.getIndexAttr(unpackedRows), rewriter.getIndexAttr(rowWidth)};
SmallVector<OpFoldResult> strides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
return tensor::ExtractSliceOp::create(rewriter, loc, unpackedType, paddedRows, offsets, sizes, strides);
}
static Value buildPackedWeight(DenseElementsAttr wDenseAttr,
Value wTrans,
RankedTensorType wType,
@@ -189,7 +273,6 @@ static Value createIm2colRowComputes(Value x,
Location loc) {
auto elemType = xType.getElementType();
constexpr size_t numInputs = 1;
const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor);
auto im2colComputeOp =
createSpatCompute<numInputs>(rewriter, loc, TypeRange {gemmInputRowsType}, {}, x, [&](Value xArg) {
Value paddedInput = xArg;
@@ -278,26 +361,7 @@ static Value createIm2colRowComputes(Value x,
Value gemmInputRows = im2col;
if (packFactor != 1) {
const int64_t paddedNumPatches = packedNumRows * packFactor;
auto groupedType = RankedTensorType::get({packedNumRows, packFactor, patchSize}, elemType);
auto packedType = RankedTensorType::get({packedNumRows, packFactor * patchSize}, elemType);
Value paddedIm2col = createPaddedRows(im2col, im2colType, paddedNumPatches, rewriter, loc);
Value groupedIm2col = tensor::ExpandShapeOp::create(rewriter,
loc,
groupedType,
paddedIm2col,
SmallVector<ReassociationIndices> {
{0, 1},
{2}
});
gemmInputRows = tensor::CollapseShapeOp::create(rewriter,
loc,
packedType,
groupedIm2col,
SmallVector<ReassociationIndices> {
{0},
{1, 2}
});
gemmInputRows = packRowsForParallelGemm(im2col, im2colType, packFactor, rewriter, loc);
}
spatial::SpatYieldOp::create(rewriter, loc, gemmInputRows);
@@ -316,41 +380,15 @@ static Value createCollectedConvOutput(ValueRange gemmRows,
int64_t packFactor,
ConversionPatternRewriter& rewriter,
Location loc) {
const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor);
const int64_t paddedNumPatches = packedNumRows * packFactor;
auto collectComputeOp = createSpatCompute(rewriter, loc, convType, {}, gemmRows, [&](ValueRange gemmRowArgs) {
Value gemmOut;
if (packFactor == 1) {
gemmOut = createSpatConcat(rewriter, loc, /*axis=*/0, gemmRowArgs);
}
else {
auto expandedType = RankedTensorType::get({packedNumRows, packFactor, numChannelsOut}, outType.getElementType());
auto paddedType = RankedTensorType::get({paddedNumPatches, numChannelsOut}, outType.getElementType());
Value packedOutput = createSpatConcat(rewriter, loc, /*axis=*/0, gemmRowArgs);
Value expandedOutput = tensor::ExpandShapeOp::create(rewriter,
loc,
expandedType,
packedOutput,
SmallVector<ReassociationIndices> {
{0},
{1, 2}
});
Value paddedOutput = tensor::CollapseShapeOp::create(rewriter,
loc,
paddedType,
expandedOutput,
SmallVector<ReassociationIndices> {
{0, 1},
{2}
});
gemmOut = paddedOutput;
if (paddedNumPatches != numPatches) {
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(numPatches), rewriter.getIndexAttr(numChannelsOut)};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
gemmOut = tensor::ExtractSliceOp::create(rewriter, loc, gemmOutType, paddedOutput, offsets, sizes, strides);
}
gemmOut = unpackRowsFromParallelGemm(
packedOutput, cast<RankedTensorType>(packedOutput.getType()), numPatches, numChannelsOut, packFactor, rewriter, loc);
}
// Restore to NCHW layout: