add conv helpers new validation tests for matmul
This commit is contained in:
@@ -51,23 +51,107 @@ static Value createPaddedRows(Value tensorValue,
|
||||
if (tensorType.getDimSize(0) == paddedRows)
|
||||
return tensorValue;
|
||||
|
||||
auto paddedType = RankedTensorType::get({paddedRows, tensorType.getDimSize(1)}, tensorType.getElementType());
|
||||
auto paddedType =
|
||||
RankedTensorType::get({paddedRows, tensorType.getDimSize(1)}, tensorType.getElementType(), tensorType.getEncoding());
|
||||
SmallVector<OpFoldResult> lowPads = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> highPads = {rewriter.getIndexAttr(paddedRows - tensorType.getDimSize(0)),
|
||||
rewriter.getIndexAttr(0)};
|
||||
auto padOp = tensor::PadOp::create(rewriter, loc, paddedType, tensorValue, lowPads, highPads);
|
||||
auto* padBlock = new Block();
|
||||
for (int i = 0; i < 2; i++)
|
||||
for (int i = 0; i < 2; ++i)
|
||||
padBlock->addArgument(rewriter.getIndexType(), loc);
|
||||
padOp.getRegion().push_back(padBlock);
|
||||
rewriter.setInsertionPointToStart(padBlock);
|
||||
auto zero = getOrCreateConstant(rewriter, padOp.getOperation(), rewriter.getZeroAttr(tensorType.getElementType()),
|
||||
auto zero = getOrCreateConstant(rewriter,
|
||||
padOp.getOperation(),
|
||||
rewriter.getZeroAttr(tensorType.getElementType()),
|
||||
tensorType.getElementType());
|
||||
tensor::YieldOp::create(rewriter, loc, zero);
|
||||
rewriter.setInsertionPointAfter(padOp);
|
||||
return padOp.getResult();
|
||||
}
|
||||
|
||||
static Value packRowsForParallelGemm(Value rows,
|
||||
RankedTensorType rowsType,
|
||||
int64_t packFactor,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
if (packFactor == 1)
|
||||
return rows;
|
||||
|
||||
const int64_t packedNumRows = ceilIntegerDivide(rowsType.getDimSize(0), packFactor);
|
||||
const int64_t paddedNumRows = packedNumRows * packFactor;
|
||||
const int64_t rowWidth = rowsType.getDimSize(1);
|
||||
auto groupedType =
|
||||
RankedTensorType::get({packedNumRows, packFactor, rowWidth}, rowsType.getElementType(), rowsType.getEncoding());
|
||||
auto packedType =
|
||||
RankedTensorType::get({packedNumRows, packFactor * rowWidth}, rowsType.getElementType(), rowsType.getEncoding());
|
||||
|
||||
Value paddedRows = createPaddedRows(rows, rowsType, paddedNumRows, rewriter, loc);
|
||||
Value groupedRows = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
groupedType,
|
||||
paddedRows,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1},
|
||||
{2}
|
||||
});
|
||||
return tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
packedType,
|
||||
groupedRows,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0},
|
||||
{1, 2}
|
||||
});
|
||||
}
|
||||
|
||||
static Value unpackRowsFromParallelGemm(Value packedRows,
|
||||
RankedTensorType packedRowsType,
|
||||
int64_t unpackedRows,
|
||||
int64_t rowWidth,
|
||||
int64_t packFactor,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
if (packFactor == 1)
|
||||
return packedRows;
|
||||
|
||||
const int64_t packedNumRows = packedRowsType.getDimSize(0);
|
||||
const int64_t paddedNumRows = packedNumRows * packFactor;
|
||||
auto expandedType =
|
||||
RankedTensorType::get({packedNumRows, packFactor, rowWidth},
|
||||
packedRowsType.getElementType(),
|
||||
packedRowsType.getEncoding());
|
||||
auto paddedType =
|
||||
RankedTensorType::get({paddedNumRows, rowWidth}, packedRowsType.getElementType(), packedRowsType.getEncoding());
|
||||
auto unpackedType =
|
||||
RankedTensorType::get({unpackedRows, rowWidth}, packedRowsType.getElementType(), packedRowsType.getEncoding());
|
||||
|
||||
Value expandedRows = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
expandedType,
|
||||
packedRows,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0},
|
||||
{1, 2}
|
||||
});
|
||||
Value paddedRows = tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
paddedType,
|
||||
expandedRows,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1},
|
||||
{2}
|
||||
});
|
||||
if (paddedNumRows == unpackedRows)
|
||||
return paddedRows;
|
||||
|
||||
SmallVector<OpFoldResult> offsets {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> sizes {rewriter.getIndexAttr(unpackedRows), rewriter.getIndexAttr(rowWidth)};
|
||||
SmallVector<OpFoldResult> strides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
return tensor::ExtractSliceOp::create(rewriter, loc, unpackedType, paddedRows, offsets, sizes, strides);
|
||||
}
|
||||
|
||||
static Value buildPackedWeight(DenseElementsAttr wDenseAttr,
|
||||
Value wTrans,
|
||||
RankedTensorType wType,
|
||||
@@ -189,7 +273,6 @@ static Value createIm2colRowComputes(Value x,
|
||||
Location loc) {
|
||||
auto elemType = xType.getElementType();
|
||||
constexpr size_t numInputs = 1;
|
||||
const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor);
|
||||
auto im2colComputeOp =
|
||||
createSpatCompute<numInputs>(rewriter, loc, TypeRange {gemmInputRowsType}, {}, x, [&](Value xArg) {
|
||||
Value paddedInput = xArg;
|
||||
@@ -278,26 +361,7 @@ static Value createIm2colRowComputes(Value x,
|
||||
|
||||
Value gemmInputRows = im2col;
|
||||
if (packFactor != 1) {
|
||||
const int64_t paddedNumPatches = packedNumRows * packFactor;
|
||||
auto groupedType = RankedTensorType::get({packedNumRows, packFactor, patchSize}, elemType);
|
||||
auto packedType = RankedTensorType::get({packedNumRows, packFactor * patchSize}, elemType);
|
||||
Value paddedIm2col = createPaddedRows(im2col, im2colType, paddedNumPatches, rewriter, loc);
|
||||
Value groupedIm2col = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
groupedType,
|
||||
paddedIm2col,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1},
|
||||
{2}
|
||||
});
|
||||
gemmInputRows = tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
packedType,
|
||||
groupedIm2col,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0},
|
||||
{1, 2}
|
||||
});
|
||||
gemmInputRows = packRowsForParallelGemm(im2col, im2colType, packFactor, rewriter, loc);
|
||||
}
|
||||
|
||||
spatial::SpatYieldOp::create(rewriter, loc, gemmInputRows);
|
||||
@@ -316,41 +380,15 @@ static Value createCollectedConvOutput(ValueRange gemmRows,
|
||||
int64_t packFactor,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor);
|
||||
const int64_t paddedNumPatches = packedNumRows * packFactor;
|
||||
auto collectComputeOp = createSpatCompute(rewriter, loc, convType, {}, gemmRows, [&](ValueRange gemmRowArgs) {
|
||||
Value gemmOut;
|
||||
if (packFactor == 1) {
|
||||
gemmOut = createSpatConcat(rewriter, loc, /*axis=*/0, gemmRowArgs);
|
||||
}
|
||||
else {
|
||||
auto expandedType = RankedTensorType::get({packedNumRows, packFactor, numChannelsOut}, outType.getElementType());
|
||||
auto paddedType = RankedTensorType::get({paddedNumPatches, numChannelsOut}, outType.getElementType());
|
||||
Value packedOutput = createSpatConcat(rewriter, loc, /*axis=*/0, gemmRowArgs);
|
||||
Value expandedOutput = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
expandedType,
|
||||
packedOutput,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0},
|
||||
{1, 2}
|
||||
});
|
||||
Value paddedOutput = tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
paddedType,
|
||||
expandedOutput,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1},
|
||||
{2}
|
||||
});
|
||||
|
||||
gemmOut = paddedOutput;
|
||||
if (paddedNumPatches != numPatches) {
|
||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(numPatches), rewriter.getIndexAttr(numChannelsOut)};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
gemmOut = tensor::ExtractSliceOp::create(rewriter, loc, gemmOutType, paddedOutput, offsets, sizes, strides);
|
||||
}
|
||||
gemmOut = unpackRowsFromParallelGemm(
|
||||
packedOutput, cast<RankedTensorType>(packedOutput.getType()), numPatches, numChannelsOut, packFactor, rewriter, loc);
|
||||
}
|
||||
|
||||
// Restore to NCHW layout:
|
||||
|
||||
Reference in New Issue
Block a user