automatic code-reformat
Validate Operations / validate-operations (push) Successful in 18m22s

This commit is contained in:
NiccoloN
2026-04-09 14:27:23 +02:00
parent 1a0192d1f9
commit 9e0d31af50
16 changed files with 88 additions and 103 deletions
@@ -92,10 +92,8 @@ static FailureOr<Value> materializeBroadcastedConstantTensor(Value value,
return arith::ConstantOp::create(rewriter, loc, resultType, broadcastedAttr).getResult();
}
static FailureOr<Value> prepareElementwiseOperand(Value value,
RankedTensorType resultType,
ConversionPatternRewriter& rewriter,
Location loc) {
static FailureOr<Value>
prepareElementwiseOperand(Value value, RankedTensorType resultType, ConversionPatternRewriter& rewriter, Location loc) {
auto valueType = dyn_cast<RankedTensorType>(value.getType());
if (!valueType || !valueType.hasStaticShape())
return failure();
@@ -280,8 +280,8 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
for (size_t aSliceId = 0; aSliceId < aHSlices[coreId].size(); aSliceId++)
weights.push_back(bTiles[outSliceId][coreId][aSliceId]);
auto computeOp =
createSpatCompute(rewriter, gemmLoc, currOutHSliceType, weights, aHSlices[coreId], [&](ValueRange aHSlicesArgs) {
auto computeOp = createSpatCompute(
rewriter, gemmLoc, currOutHSliceType, weights, aHSlices[coreId], [&](ValueRange aHSlicesArgs) {
SmallVector<Value> vmmOutputs;
vmmOutputs.reserve(aHSlicesArgs.size());
for (auto [aHSliceId, computeArg] : llvm::enumerate(aHSlicesArgs))
@@ -71,10 +71,8 @@ static SmallVector<ReassociationIndices> buildCollapseReassociation(ArrayRef<boo
return reassociation;
}
static Value createAverageCompute(Value input,
RankedTensorType resultType,
ConversionPatternRewriter& rewriter,
Location loc) {
static Value
createAverageCompute(Value input, RankedTensorType resultType, ConversionPatternRewriter& rewriter, Location loc) {
constexpr size_t numInputs = 1;
auto computeOp = createSpatCompute<numInputs>(rewriter, loc, resultType, {}, ValueRange {input}, [&](Value x) {
auto avgOp = spatial::SpatVAvgOp::create(rewriter, loc, resultType, x);
@@ -141,7 +139,8 @@ struct ReduceMeanToSpatialCompute : OpConversionPattern<ONNXReduceMeanV13Op> {
Location loc = reduceMeanOp.getLoc();
RankedTensorType leafType = getAllOnesType(inputType, resultType.getElementType());
Value reducedKeepdims = buildReduceMeanKeepdims(adaptor.getData(), reducedAxes, /*axis=*/0, leafType, rewriter, loc);
Value reducedKeepdims =
buildReduceMeanKeepdims(adaptor.getData(), reducedAxes, /*axis=*/0, leafType, rewriter, loc);
if (reduceMeanOp.getKeepdims() != 0) {
rewriter.replaceOp(reduceMeanOp, reducedKeepdims);