This commit is contained in:
@@ -51,8 +51,8 @@ static Value createPaddedRows(Value tensorValue,
|
||||
if (tensorType.getDimSize(0) == paddedRows)
|
||||
return tensorValue;
|
||||
|
||||
auto paddedType =
|
||||
RankedTensorType::get({paddedRows, tensorType.getDimSize(1)}, tensorType.getElementType(), tensorType.getEncoding());
|
||||
auto paddedType = RankedTensorType::get(
|
||||
{paddedRows, tensorType.getDimSize(1)}, tensorType.getElementType(), tensorType.getEncoding());
|
||||
SmallVector<OpFoldResult> lowPads = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> highPads = {rewriter.getIndexAttr(paddedRows - tensorType.getDimSize(0)),
|
||||
rewriter.getIndexAttr(0)};
|
||||
@@ -62,20 +62,15 @@ static Value createPaddedRows(Value tensorValue,
|
||||
padBlock->addArgument(rewriter.getIndexType(), loc);
|
||||
padOp.getRegion().push_back(padBlock);
|
||||
rewriter.setInsertionPointToStart(padBlock);
|
||||
auto zero = getOrCreateConstant(rewriter,
|
||||
padOp.getOperation(),
|
||||
rewriter.getZeroAttr(tensorType.getElementType()),
|
||||
tensorType.getElementType());
|
||||
auto zero = getOrCreateConstant(
|
||||
rewriter, padOp.getOperation(), rewriter.getZeroAttr(tensorType.getElementType()), tensorType.getElementType());
|
||||
tensor::YieldOp::create(rewriter, loc, zero);
|
||||
rewriter.setInsertionPointAfter(padOp);
|
||||
return padOp.getResult();
|
||||
}
|
||||
|
||||
static Value packRowsForParallelGemm(Value rows,
|
||||
RankedTensorType rowsType,
|
||||
int64_t packFactor,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
static Value packRowsForParallelGemm(
|
||||
Value rows, RankedTensorType rowsType, int64_t packFactor, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
if (packFactor == 1)
|
||||
return rows;
|
||||
|
||||
@@ -118,10 +113,8 @@ static Value unpackRowsFromParallelGemm(Value packedRows,
|
||||
|
||||
const int64_t packedNumRows = packedRowsType.getDimSize(0);
|
||||
const int64_t paddedNumRows = packedNumRows * packFactor;
|
||||
auto expandedType =
|
||||
RankedTensorType::get({packedNumRows, packFactor, rowWidth},
|
||||
packedRowsType.getElementType(),
|
||||
packedRowsType.getEncoding());
|
||||
auto expandedType = RankedTensorType::get(
|
||||
{packedNumRows, packFactor, rowWidth}, packedRowsType.getElementType(), packedRowsType.getEncoding());
|
||||
auto paddedType =
|
||||
RankedTensorType::get({paddedNumRows, rowWidth}, packedRowsType.getElementType(), packedRowsType.getEncoding());
|
||||
auto unpackedType =
|
||||
@@ -193,11 +186,8 @@ static Value buildPackedWeight(DenseElementsAttr wDenseAttr,
|
||||
return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), packedAttr, packedWeightType);
|
||||
}
|
||||
|
||||
static Value createConvWeightMatrix(Value w,
|
||||
RankedTensorType wFlatType,
|
||||
RankedTensorType wTransType,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
static Value createConvWeightMatrix(
|
||||
Value w, RankedTensorType wFlatType, RankedTensorType wTransType, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
auto buildWeightMatrix = [&](Value weight) -> Value {
|
||||
Value wFlat = tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
@@ -360,9 +350,8 @@ static Value createIm2colRowComputes(Value x,
|
||||
Value im2col = im2colLoop.getResult(0);
|
||||
|
||||
Value gemmInputRows = im2col;
|
||||
if (packFactor != 1) {
|
||||
if (packFactor != 1)
|
||||
gemmInputRows = packRowsForParallelGemm(im2col, im2colType, packFactor, rewriter, loc);
|
||||
}
|
||||
|
||||
spatial::SpatYieldOp::create(rewriter, loc, gemmInputRows);
|
||||
});
|
||||
@@ -387,8 +376,13 @@ static Value createCollectedConvOutput(ValueRange gemmRows,
|
||||
}
|
||||
else {
|
||||
Value packedOutput = createSpatConcat(rewriter, loc, /*axis=*/0, gemmRowArgs);
|
||||
gemmOut = unpackRowsFromParallelGemm(
|
||||
packedOutput, cast<RankedTensorType>(packedOutput.getType()), numPatches, numChannelsOut, packFactor, rewriter, loc);
|
||||
gemmOut = unpackRowsFromParallelGemm(packedOutput,
|
||||
cast<RankedTensorType>(packedOutput.getType()),
|
||||
numPatches,
|
||||
numChannelsOut,
|
||||
packFactor,
|
||||
rewriter,
|
||||
loc);
|
||||
}
|
||||
|
||||
// Restore to NCHW layout:
|
||||
|
||||
Reference in New Issue
Block a user