automatic code reformat

This commit is contained in:
NiccoloN
2026-05-22 15:23:48 +02:00
parent d136136d22
commit 8337a11ce9
37 changed files with 312 additions and 354 deletions
@@ -99,15 +99,17 @@ auto createSpatCompute(RewriterT& rewriter,
using BodyResult = detail::InvokeWithBlockArgsResultT<std::decay_t<BodyFn>, std::make_index_sequence<NumInputs>>;
if constexpr (std::is_same_v<BodyResult, void>) {
detail::invokeWithValues(
std::forward<BodyFn>(body), detail::getInputBlockArgs(block, weights.size()), std::make_index_sequence<NumInputs> {});
detail::invokeWithValues(std::forward<BodyFn>(body),
detail::getInputBlockArgs(block, weights.size()),
std::make_index_sequence<NumInputs> {});
rewriter.setInsertionPointAfter(computeOp);
return computeOp;
}
else {
auto bodyResult = detail::invokeWithValues(
std::forward<BodyFn>(body), detail::getInputBlockArgs(block, weights.size()), std::make_index_sequence<NumInputs> {});
auto bodyResult = detail::invokeWithValues(std::forward<BodyFn>(body),
detail::getInputBlockArgs(block, weights.size()),
std::make_index_sequence<NumInputs> {});
if (mlir::failed(bodyResult)) {
rewriter.setInsertionPointAfter(computeOp);
rewriter.eraseOp(computeOp);
@@ -423,8 +423,11 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
SmallVector<Value> vmmOutputs;
vmmOutputs.reserve(aHSlices[coreId].size());
for (auto aHSliceId : llvm::seq<size_t>(0, aHSlices[coreId].size()))
vmmOutputs.push_back(spatial::SpatVMMOp::create(
rewriter, gemmLoc, currOutHSliceType, computeOp.getWeightArgument(aHSliceId), computeOp.getInputArgument(aHSliceId)));
vmmOutputs.push_back(spatial::SpatVMMOp::create(rewriter,
gemmLoc,
currOutHSliceType,
computeOp.getWeightArgument(aHSliceId),
computeOp.getInputArgument(aHSliceId)));
if (vmmOutputs.empty()) {
gemmOp.emitOpError("requires at least one non-empty slice when lowering tiled Gemm to Spatial VMMs");
return failure();
@@ -579,8 +582,8 @@ LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp,
rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front());
SmallVector<OpFoldResult> outputOffsets {lane, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> outputSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(outType.getDimSize(1))};
tensor::ParallelInsertSliceOp::create(rewriter, loc, laneResult, packedOutput, outputOffsets, outputSizes,
unitStrides);
tensor::ParallelInsertSliceOp::create(
rewriter, loc, laneResult, packedOutput, outputOffsets, outputSizes, unitStrides);
rewriter.setInsertionPointAfter(batchOp);
rewriter.replaceOp(gemmOp, batchOp.getResults());
@@ -38,23 +38,16 @@ static FailureOr<SmallVector<int64_t>> inferSupportedBatchShape(ArrayRef<int64_t
return SmallVector<int64_t>(lhsBatchShape.begin(), lhsBatchShape.end());
}
static Value collapseBatchDims(Value value,
int64_t batchSize,
int64_t rows,
int64_t cols,
PatternRewriter& rewriter,
Location loc) {
static Value
collapseBatchDims(Value value, int64_t batchSize, int64_t rows, int64_t cols, PatternRewriter& rewriter, Location loc) {
auto type = cast<RankedTensorType>(value.getType());
if (type.getRank() == 2 || type.getRank() == 3)
return value;
auto collapsedType =
RankedTensorType::get({batchSize, rows, cols}, type.getElementType(), type.getEncoding());
SmallVector<ReassociationIndices> reassociation = {
ReassociationIndices {},
ReassociationIndices {static_cast<int64_t>(type.getRank() - 2)},
ReassociationIndices {static_cast<int64_t>(type.getRank() - 1)}
};
auto collapsedType = RankedTensorType::get({batchSize, rows, cols}, type.getElementType(), type.getEncoding());
SmallVector<ReassociationIndices> reassociation = {ReassociationIndices {},
ReassociationIndices {static_cast<int64_t>(type.getRank() - 2)},
ReassociationIndices {static_cast<int64_t>(type.getRank() - 1)}};
for (int64_t dim = 0; dim < type.getRank() - 2; ++dim)
reassociation.front().push_back(dim);
@@ -72,19 +65,14 @@ static Value collapseBatchDims(Value value,
return collapseCompute.getResult(0);
}
static Value expandBatchDims(Value value,
RankedTensorType outputType,
size_t batchRank,
PatternRewriter& rewriter,
Location loc) {
static Value
expandBatchDims(Value value, RankedTensorType outputType, size_t batchRank, PatternRewriter& rewriter, Location loc) {
if (cast<RankedTensorType>(value.getType()) == outputType)
return value;
SmallVector<ReassociationIndices> reassociation = {
ReassociationIndices {},
ReassociationIndices {static_cast<int64_t>(batchRank)},
ReassociationIndices {static_cast<int64_t>(batchRank + 1)}
};
SmallVector<ReassociationIndices> reassociation = {ReassociationIndices {},
ReassociationIndices {static_cast<int64_t>(batchRank)},
ReassociationIndices {static_cast<int64_t>(batchRank + 1)}};
for (size_t dim = 0; dim < batchRank; ++dim)
reassociation.front().push_back(static_cast<int64_t>(dim));
@@ -58,24 +58,21 @@ static Value buildNearestResizeLoop(Value input,
Value outputC = channelLoop.getInductionVar();
Value outputChannelAcc = channelLoop.getRegionIterArgs().front();
Value inputC =
buildNearestAsymmetricIndex(outputC, inputType.getDimSize(1), resultType.getDimSize(1), rewriter, loc);
Value inputC = buildNearestAsymmetricIndex(outputC, inputType.getDimSize(1), resultType.getDimSize(1), rewriter, loc);
auto heightLoop = scf::ForOp::create(rewriter, loc, c0, cOutputH, c1, ValueRange {outputChannelAcc});
rewriter.setInsertionPointToStart(heightLoop.getBody());
Value outputH = heightLoop.getInductionVar();
Value outputHeightAcc = heightLoop.getRegionIterArgs().front();
Value inputH =
buildNearestAsymmetricIndex(outputH, inputType.getDimSize(2), resultType.getDimSize(2), rewriter, loc);
Value inputH = buildNearestAsymmetricIndex(outputH, inputType.getDimSize(2), resultType.getDimSize(2), rewriter, loc);
auto widthLoop = scf::ForOp::create(rewriter, loc, c0, cOutputW, c1, ValueRange {outputHeightAcc});
rewriter.setInsertionPointToStart(widthLoop.getBody());
Value outputW = widthLoop.getInductionVar();
Value outputWidthAcc = widthLoop.getRegionIterArgs().front();
Value inputW =
buildNearestAsymmetricIndex(outputW, inputType.getDimSize(3), resultType.getDimSize(3), rewriter, loc);
Value inputW = buildNearestAsymmetricIndex(outputW, inputType.getDimSize(3), resultType.getDimSize(3), rewriter, loc);
SmallVector<OpFoldResult> inputOffsets = {inputN, inputC, inputH, inputW};
Value inputSlice =
@@ -114,8 +111,8 @@ struct Resize : OpConversionPattern<ONNXResizeOp> {
if (resizeOp.getMode() != "nearest" || resizeOp.getCoordinateTransformationMode() != "asymmetric"
|| resizeOp.getNearestMode() != "floor")
return rewriter.notifyMatchFailure(
resizeOp, "resize lowering currently supports only nearest + asymmetric + floor.");
return rewriter.notifyMatchFailure(resizeOp,
"resize lowering currently supports only nearest + asymmetric + floor.");
if (llvm::any_of(inputType.getShape(), [](int64_t dim) { return dim <= 0; })
|| llvm::any_of(resultType.getShape(), [](int64_t dim) { return dim <= 0; }))
@@ -94,8 +94,8 @@ struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCom
}
llvm::append_range(newBlockArgTypes, newInputTypes);
llvm::append_range(newBlockArgLocs, newInputLocs);
auto* newBlock =
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), TypeRange(newBlockArgTypes), newBlockArgLocs);
auto* newBlock = rewriter.createBlock(
&newCompute.getBody(), newCompute.getBody().end(), TypeRange(newBlockArgTypes), newBlockArgLocs);
newCompute.getProperties().setOperandSegmentSizes(
{static_cast<int>(newWeights.size()), static_cast<int>(newInputs.size())});
rewriter.setInsertionPointToStart(newBlock);
@@ -228,7 +228,8 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
mapper.map(oldArg, *clonedValue);
}
for (auto resultIndex : llvm::seq<size_t>(0, compute.getNumResults()))
mapper.map(compute.getOutputArgument(resultIndex), newBlock->getArgument(1 + newWeights.size() + newInputs.size() + resultIndex));
mapper.map(compute.getOutputArgument(resultIndex),
newBlock->getArgument(1 + newWeights.size() + newInputs.size() + resultIndex));
for (Operation& op : oldBlock)
rewriter.clone(op, mapper);