automatic code reformat
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
NiccoloN
2026-05-29 19:21:37 +02:00
parent a41f694cf0
commit 2d5b03c08f
26 changed files with 183 additions and 168 deletions
@@ -51,8 +51,8 @@ static Value createPaddedRows(Value tensorValue,
if (tensorType.getDimSize(0) == paddedRows)
return tensorValue;
auto paddedType =
RankedTensorType::get({paddedRows, tensorType.getDimSize(1)}, tensorType.getElementType(), tensorType.getEncoding());
auto paddedType = RankedTensorType::get(
{paddedRows, tensorType.getDimSize(1)}, tensorType.getElementType(), tensorType.getEncoding());
SmallVector<OpFoldResult> lowPads = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> highPads = {rewriter.getIndexAttr(paddedRows - tensorType.getDimSize(0)),
rewriter.getIndexAttr(0)};
@@ -62,20 +62,15 @@ static Value createPaddedRows(Value tensorValue,
padBlock->addArgument(rewriter.getIndexType(), loc);
padOp.getRegion().push_back(padBlock);
rewriter.setInsertionPointToStart(padBlock);
auto zero = getOrCreateConstant(rewriter,
padOp.getOperation(),
rewriter.getZeroAttr(tensorType.getElementType()),
tensorType.getElementType());
auto zero = getOrCreateConstant(
rewriter, padOp.getOperation(), rewriter.getZeroAttr(tensorType.getElementType()), tensorType.getElementType());
tensor::YieldOp::create(rewriter, loc, zero);
rewriter.setInsertionPointAfter(padOp);
return padOp.getResult();
}
static Value packRowsForParallelGemm(Value rows,
RankedTensorType rowsType,
int64_t packFactor,
ConversionPatternRewriter& rewriter,
Location loc) {
static Value packRowsForParallelGemm(
Value rows, RankedTensorType rowsType, int64_t packFactor, ConversionPatternRewriter& rewriter, Location loc) {
if (packFactor == 1)
return rows;
@@ -118,10 +113,8 @@ static Value unpackRowsFromParallelGemm(Value packedRows,
const int64_t packedNumRows = packedRowsType.getDimSize(0);
const int64_t paddedNumRows = packedNumRows * packFactor;
auto expandedType =
RankedTensorType::get({packedNumRows, packFactor, rowWidth},
packedRowsType.getElementType(),
packedRowsType.getEncoding());
auto expandedType = RankedTensorType::get(
{packedNumRows, packFactor, rowWidth}, packedRowsType.getElementType(), packedRowsType.getEncoding());
auto paddedType =
RankedTensorType::get({paddedNumRows, rowWidth}, packedRowsType.getElementType(), packedRowsType.getEncoding());
auto unpackedType =
@@ -193,11 +186,8 @@ static Value buildPackedWeight(DenseElementsAttr wDenseAttr,
return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), packedAttr, packedWeightType);
}
static Value createConvWeightMatrix(Value w,
RankedTensorType wFlatType,
RankedTensorType wTransType,
ConversionPatternRewriter& rewriter,
Location loc) {
static Value createConvWeightMatrix(
Value w, RankedTensorType wFlatType, RankedTensorType wTransType, ConversionPatternRewriter& rewriter, Location loc) {
auto buildWeightMatrix = [&](Value weight) -> Value {
Value wFlat = tensor::CollapseShapeOp::create(rewriter,
loc,
@@ -360,9 +350,8 @@ static Value createIm2colRowComputes(Value x,
Value im2col = im2colLoop.getResult(0);
Value gemmInputRows = im2col;
if (packFactor != 1) {
if (packFactor != 1)
gemmInputRows = packRowsForParallelGemm(im2col, im2colType, packFactor, rewriter, loc);
}
spatial::SpatYieldOp::create(rewriter, loc, gemmInputRows);
});
@@ -387,8 +376,13 @@ static Value createCollectedConvOutput(ValueRange gemmRows,
}
else {
Value packedOutput = createSpatConcat(rewriter, loc, /*axis=*/0, gemmRowArgs);
gemmOut = unpackRowsFromParallelGemm(
packedOutput, cast<RankedTensorType>(packedOutput.getType()), numPatches, numChannelsOut, packFactor, rewriter, loc);
gemmOut = unpackRowsFromParallelGemm(packedOutput,
cast<RankedTensorType>(packedOutput.getType()),
numPatches,
numChannelsOut,
packFactor,
rewriter,
loc);
}
// Restore to NCHW layout: