automatic code reformat
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
NiccoloN
2026-05-29 19:21:37 +02:00
parent a41f694cf0
commit 2d5b03c08f
26 changed files with 183 additions and 168 deletions
@@ -51,8 +51,8 @@ static Value createPaddedRows(Value tensorValue,
if (tensorType.getDimSize(0) == paddedRows)
return tensorValue;
auto paddedType =
RankedTensorType::get({paddedRows, tensorType.getDimSize(1)}, tensorType.getElementType(), tensorType.getEncoding());
auto paddedType = RankedTensorType::get(
{paddedRows, tensorType.getDimSize(1)}, tensorType.getElementType(), tensorType.getEncoding());
SmallVector<OpFoldResult> lowPads = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> highPads = {rewriter.getIndexAttr(paddedRows - tensorType.getDimSize(0)),
rewriter.getIndexAttr(0)};
@@ -62,20 +62,15 @@ static Value createPaddedRows(Value tensorValue,
padBlock->addArgument(rewriter.getIndexType(), loc);
padOp.getRegion().push_back(padBlock);
rewriter.setInsertionPointToStart(padBlock);
auto zero = getOrCreateConstant(rewriter,
padOp.getOperation(),
rewriter.getZeroAttr(tensorType.getElementType()),
tensorType.getElementType());
auto zero = getOrCreateConstant(
rewriter, padOp.getOperation(), rewriter.getZeroAttr(tensorType.getElementType()), tensorType.getElementType());
tensor::YieldOp::create(rewriter, loc, zero);
rewriter.setInsertionPointAfter(padOp);
return padOp.getResult();
}
static Value packRowsForParallelGemm(Value rows,
RankedTensorType rowsType,
int64_t packFactor,
ConversionPatternRewriter& rewriter,
Location loc) {
static Value packRowsForParallelGemm(
Value rows, RankedTensorType rowsType, int64_t packFactor, ConversionPatternRewriter& rewriter, Location loc) {
if (packFactor == 1)
return rows;
@@ -118,10 +113,8 @@ static Value unpackRowsFromParallelGemm(Value packedRows,
const int64_t packedNumRows = packedRowsType.getDimSize(0);
const int64_t paddedNumRows = packedNumRows * packFactor;
auto expandedType =
RankedTensorType::get({packedNumRows, packFactor, rowWidth},
packedRowsType.getElementType(),
packedRowsType.getEncoding());
auto expandedType = RankedTensorType::get(
{packedNumRows, packFactor, rowWidth}, packedRowsType.getElementType(), packedRowsType.getEncoding());
auto paddedType =
RankedTensorType::get({paddedNumRows, rowWidth}, packedRowsType.getElementType(), packedRowsType.getEncoding());
auto unpackedType =
@@ -193,11 +186,8 @@ static Value buildPackedWeight(DenseElementsAttr wDenseAttr,
return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), packedAttr, packedWeightType);
}
static Value createConvWeightMatrix(Value w,
RankedTensorType wFlatType,
RankedTensorType wTransType,
ConversionPatternRewriter& rewriter,
Location loc) {
static Value createConvWeightMatrix(
Value w, RankedTensorType wFlatType, RankedTensorType wTransType, ConversionPatternRewriter& rewriter, Location loc) {
auto buildWeightMatrix = [&](Value weight) -> Value {
Value wFlat = tensor::CollapseShapeOp::create(rewriter,
loc,
@@ -360,9 +350,8 @@ static Value createIm2colRowComputes(Value x,
Value im2col = im2colLoop.getResult(0);
Value gemmInputRows = im2col;
if (packFactor != 1) {
if (packFactor != 1)
gemmInputRows = packRowsForParallelGemm(im2col, im2colType, packFactor, rewriter, loc);
}
spatial::SpatYieldOp::create(rewriter, loc, gemmInputRows);
});
@@ -387,8 +376,13 @@ static Value createCollectedConvOutput(ValueRange gemmRows,
}
else {
Value packedOutput = createSpatConcat(rewriter, loc, /*axis=*/0, gemmRowArgs);
gemmOut = unpackRowsFromParallelGemm(
packedOutput, cast<RankedTensorType>(packedOutput.getType()), numPatches, numChannelsOut, packFactor, rewriter, loc);
gemmOut = unpackRowsFromParallelGemm(packedOutput,
cast<RankedTensorType>(packedOutput.getType()),
numPatches,
numChannelsOut,
packFactor,
rewriter,
loc);
}
// Restore to NCHW layout:
@@ -252,7 +252,13 @@ static spatial::SpatComputeBatch createVmmBatch(Value a,
Location loc) {
const int64_t laneCount = partialPiecesType.getDimSize(0);
auto batchOp = createSpatComputeBatch(
rewriter, loc, TypeRange {partialPiecesType}, laneCount, ValueRange {b}, ValueRange {a}, [&](detail::SpatComputeBatchBodyArgs args) {
rewriter,
loc,
TypeRange {partialPiecesType},
laneCount,
ValueRange {b},
ValueRange {a},
[&](detail::SpatComputeBatchBodyArgs args) {
Value row = onnx_mlir::modIndexByConstant(rewriter, loc, args.lane, numOutRows);
Value kOffset = createGemmBatchKOffset(args.lane, numOutRows, numKSlices, rewriter, loc);
Value hOffset = createGemmBatchHOffset(args.lane, numOutRows, numKSlices, numOutHSlices, rewriter, loc);
@@ -284,8 +290,8 @@ static spatial::SpatComputeBatch createVmmBatch(Value a,
return *batchOp;
}
static Value createDynamicGemmBatchRow(
Value lane, int64_t numOutCols, ConversionPatternRewriter& rewriter, Location loc) {
static Value
createDynamicGemmBatchRow(Value lane, int64_t numOutCols, ConversionPatternRewriter& rewriter, Location loc) {
if (numOutCols == 1)
return lane;
@@ -294,17 +300,21 @@ static Value createDynamicGemmBatchRow(
return createAffineApplyOrFoldedConstant(rewriter, loc, d0.floorDiv(numOutCols), ValueRange {lane});
}
static Value
extractDynamicGemmBColumn(Value matrix, Value column, RankedTensorType vectorType, ConversionPatternRewriter& rewriter, Location loc) {
static Value extractDynamicGemmBColumn(
Value matrix, Value column, RankedTensorType vectorType, ConversionPatternRewriter& rewriter, Location loc) {
SmallVector<OpFoldResult> offsets {rewriter.getIndexAttr(0), column};
SmallVector<OpFoldResult> strides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
auto columnSliceType = RankedTensorType::get({vectorType.getDimSize(1), 1}, vectorType.getElementType());
Value columnSlice = materializeContiguousTensorSlice(matrix, columnSliceType, offsets, strides, rewriter, loc);
SmallVector<ReassociationIndices> collapseReassociation {ReassociationIndices {0, 1}};
SmallVector<ReassociationIndices> collapseReassociation {
ReassociationIndices {0, 1}
};
auto collapsedType = RankedTensorType::get({vectorType.getDimSize(1)}, vectorType.getElementType());
Value collapsed =
tensor::CollapseShapeOp::create(rewriter, loc, collapsedType, columnSlice, collapseReassociation).getResult();
SmallVector<ReassociationIndices> expandReassociation {ReassociationIndices {0, 1}};
SmallVector<ReassociationIndices> expandReassociation {
ReassociationIndices {0, 1}
};
return tensor::ExpandShapeOp::create(rewriter, loc, vectorType, collapsed, expandReassociation).getResult();
}
@@ -371,13 +381,15 @@ static Value createBroadcastedBiasScalar(Value bias,
Location loc) {
SmallVector<OpFoldResult> unitStrides(biasType.getRank(), rewriter.getIndexAttr(1));
if (biasType.getRank() == 1) {
SmallVector<OpFoldResult> offsets {
biasType.getDimSize(0) == 1 ? OpFoldResult(rewriter.getIndexAttr(0)) : OpFoldResult(column)};
SmallVector<OpFoldResult> offsets {biasType.getDimSize(0) == 1 ? OpFoldResult(rewriter.getIndexAttr(0))
: OpFoldResult(column)};
SmallVector<OpFoldResult> sizes {rewriter.getIndexAttr(1)};
auto vectorType = RankedTensorType::get({1}, scalarType.getElementType());
Value vector = tensor::ExtractSliceOp::create(rewriter, loc, vectorType, bias, offsets, sizes, unitStrides)
.getResult();
SmallVector<ReassociationIndices> reassociation {ReassociationIndices {0, 1}};
Value vector =
tensor::ExtractSliceOp::create(rewriter, loc, vectorType, bias, offsets, sizes, unitStrides).getResult();
SmallVector<ReassociationIndices> reassociation {
ReassociationIndices {0, 1}
};
return tensor::ExpandShapeOp::create(rewriter, loc, scalarType, vector, reassociation).getResult();
}
@@ -407,16 +419,21 @@ static spatial::SpatComputeBatch createVvdmulBatch(Value a,
const int64_t reductionSize = aType.getDimSize(1);
const int64_t laneCount = numOutRows * numOutCols;
auto batchOp = createSpatComputeBatch(
rewriter, loc, TypeRange {scalarPiecesType}, laneCount, ValueRange {}, ValueRange {a, b}, [&](detail::SpatComputeBatchBodyArgs args) {
rewriter,
loc,
TypeRange {scalarPiecesType},
laneCount,
ValueRange {},
ValueRange {a, b},
[&](detail::SpatComputeBatchBodyArgs args) {
Value row = createDynamicGemmBatchRow(args.lane, numOutCols, rewriter, loc);
Value column = onnx_mlir::modIndexByConstant(rewriter, loc, args.lane, numOutCols);
auto vectorType = RankedTensorType::get({1, reductionSize}, aType.getElementType());
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
Value aVector = extractDynamicGemmRowVector(args.inputs[0], row, vectorType, rewriter, loc);
Value bVector = bAlreadyTransposed
? extractTransposedBRow(args.inputs[1], column, vectorType, rewriter, loc)
: extractDynamicGemmBColumn(args.inputs[1], column, vectorType, rewriter, loc);
Value bVector = bAlreadyTransposed ? extractTransposedBRow(args.inputs[1], column, vectorType, rewriter, loc)
: extractDynamicGemmBColumn(args.inputs[1], column, vectorType, rewriter, loc);
Value scalar = spatial::SpatVVDMulOp::create(rewriter, loc, scalarType, aVector, bVector).getResult();
SmallVector<OpFoldResult> outputOffsets {args.lane, rewriter.getIndexAttr(0)};
@@ -578,9 +595,8 @@ static spatial::SpatCompute createReductionCompute(Value partialPieces,
auto buildOutputSlice = [&](Value outputAcc, Value hSlice) -> Value {
Value reduced =
reducePartialPiecesForHSlice(partialPiecesArg, hSlice, pieceType, numKSlices, numOutRows, rewriter, loc);
Value hOffset =
onnx_mlir::multiplyIndexByConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), hSlice,
crossbarSize.getValue());
Value hOffset = onnx_mlir::multiplyIndexByConstant(
rewriter, rewriter.getInsertionBlock()->getParentOp(), hSlice, crossbarSize.getValue());
if (biasArg) {
SmallVector<OpFoldResult> biasOffsets {rewriter.getIndexAttr(0), hOffset};
Value biasSlice =
@@ -721,8 +737,8 @@ LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp,
}
auto scalarPiecesType = RankedTensorType::get({laneCount64, 1}, outType.getElementType());
auto batchOp = createVvdmulBatch(
a, b, aType, bType, scalarPiecesType, outType, gemmOpAdaptor.getTransB(), rewriter, loc);
auto batchOp =
createVvdmulBatch(a, b, aType, bType, scalarPiecesType, outType, gemmOpAdaptor.getTransB(), rewriter, loc);
auto outputCompute = createDynamicGemmOutputCompute(
batchOp.getResult(0), hasC ? c : Value(), scalarPiecesType, biasType, outType, alpha, beta, rewriter, loc);
rewriter.replaceOp(gemmOp, outputCompute.getResults());
@@ -70,11 +70,8 @@ static SmallVector<int64_t> getKeptAxes(ArrayRef<bool> reducedAxes) {
return keptAxes;
}
static Value computeLaneIndex(Value lane,
int64_t stride,
int64_t dimSize,
ConversionPatternRewriter& rewriter,
Location loc) {
static Value
computeLaneIndex(Value lane, int64_t stride, int64_t dimSize, ConversionPatternRewriter& rewriter, Location loc) {
if (dimSize == 1)
return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
@@ -119,35 +116,41 @@ static FailureOr<Value> buildReduceMeanKeepdimsBatch(Value input,
sliceSizes.reserve(inputType.getRank());
insertOffsets.reserve(inputType.getRank());
auto batchOp = createSpatComputeBatch(
rewriter, loc, TypeRange {batchType}, laneCount, {}, ValueRange {input}, [&](detail::SpatComputeBatchBodyArgs args) {
size_t keptAxisIndex = 0;
sliceOffsets.clear();
sliceSizes.clear();
insertOffsets.clear();
for (auto [axis, isReduced] : llvm::enumerate(reducedAxes)) {
if (isReduced) {
sliceOffsets.push_back(rewriter.getIndexAttr(0));
sliceSizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(axis)));
continue;
}
auto batchOp =
createSpatComputeBatch(rewriter,
loc,
TypeRange {batchType},
laneCount,
{},
ValueRange {input},
[&](detail::SpatComputeBatchBodyArgs args) {
size_t keptAxisIndex = 0;
sliceOffsets.clear();
sliceSizes.clear();
insertOffsets.clear();
for (auto [axis, isReduced] : llvm::enumerate(reducedAxes)) {
if (isReduced) {
sliceOffsets.push_back(rewriter.getIndexAttr(0));
sliceSizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(axis)));
continue;
}
Value axisIndex =
computeLaneIndex(args.lane, keptAxisStrides[keptAxisIndex], inputType.getDimSize(axis), rewriter, loc);
++keptAxisIndex;
sliceOffsets.push_back(axisIndex);
sliceSizes.push_back(rewriter.getIndexAttr(1));
}
Value axisIndex = computeLaneIndex(
args.lane, keptAxisStrides[keptAxisIndex], inputType.getDimSize(axis), rewriter, loc);
++keptAxisIndex;
sliceOffsets.push_back(axisIndex);
sliceSizes.push_back(rewriter.getIndexAttr(1));
}
insertOffsets.push_back(args.lane);
insertOffsets.append(inputType.getRank() - 1, rewriter.getIndexAttr(0));
insertOffsets.push_back(args.lane);
insertOffsets.append(inputType.getRank() - 1, rewriter.getIndexAttr(0));
Value slice =
tensor::ExtractSliceOp::create(rewriter, loc, sliceType, args.inputs.front(), sliceOffsets, sliceSizes, unitStrides);
Value reduced = spatial::SpatVAvgOp::create(rewriter, loc, leafType, slice).getResult();
createParallelInsertSliceIntoBatchOutput(
rewriter, loc, reduced, args.outputs.front(), insertOffsets, insertSizes, unitStrides);
});
Value slice = tensor::ExtractSliceOp::create(
rewriter, loc, sliceType, args.inputs.front(), sliceOffsets, sliceSizes, unitStrides);
Value reduced = spatial::SpatVAvgOp::create(rewriter, loc, leafType, slice).getResult();
createParallelInsertSliceIntoBatchOutput(
rewriter, loc, reduced, args.outputs.front(), insertOffsets, insertSizes, unitStrides);
});
if (failed(batchOp))
return failure();
return (*batchOp).getResult(0);
@@ -193,15 +196,15 @@ static Value buildKeepdimsFromLanePackedBatch(Value batchValue,
auto reshapeCompute =
createSpatCompute<1>(rewriter, loc, TypeRange {keepdimsType}, {}, ValueRange {batchValue}, [&](Value input) {
auto flatType = RankedTensorType::get({batchType.getDimSize(0)}, batchType.getElementType(), batchType.getEncoding());
auto flatType =
RankedTensorType::get({batchType.getDimSize(0)}, batchType.getElementType(), batchType.getEncoding());
Value flat = tensor::CollapseShapeOp::create(rewriter, loc, flatType, input, collapseToFlat);
Value compact = flat;
if (compactKeptType != flatType)
compact = tensor::ExpandShapeOp::create(rewriter, loc, compactKeptType, flat, expandFlatToCompact);
Value keepdims = compact;
if (keepdimsType != compactKeptType)
keepdims =
tensor::ExpandShapeOp::create(rewriter, loc, keepdimsType, compact, expandCompactToKeepdims);
keepdims = tensor::ExpandShapeOp::create(rewriter, loc, keepdimsType, compact, expandCompactToKeepdims);
spatial::SpatYieldOp::create(rewriter, loc, keepdims);
});
return reshapeCompute.getResult(0);