automatic code-reformat
All checks were successful
Validate Operations / validate-operations (push) Successful in 18m22s

This commit is contained in:
NiccoloN
2026-04-09 14:27:23 +02:00
parent 1a0192d1f9
commit 9e0d31af50
16 changed files with 88 additions and 103 deletions

View File

@@ -92,10 +92,8 @@ static FailureOr<Value> materializeBroadcastedConstantTensor(Value value,
return arith::ConstantOp::create(rewriter, loc, resultType, broadcastedAttr).getResult(); return arith::ConstantOp::create(rewriter, loc, resultType, broadcastedAttr).getResult();
} }
static FailureOr<Value> prepareElementwiseOperand(Value value, static FailureOr<Value>
RankedTensorType resultType, prepareElementwiseOperand(Value value, RankedTensorType resultType, ConversionPatternRewriter& rewriter, Location loc) {
ConversionPatternRewriter& rewriter,
Location loc) {
auto valueType = dyn_cast<RankedTensorType>(value.getType()); auto valueType = dyn_cast<RankedTensorType>(value.getType());
if (!valueType || !valueType.hasStaticShape()) if (!valueType || !valueType.hasStaticShape())
return failure(); return failure();

View File

@@ -280,8 +280,8 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
for (size_t aSliceId = 0; aSliceId < aHSlices[coreId].size(); aSliceId++) for (size_t aSliceId = 0; aSliceId < aHSlices[coreId].size(); aSliceId++)
weights.push_back(bTiles[outSliceId][coreId][aSliceId]); weights.push_back(bTiles[outSliceId][coreId][aSliceId]);
auto computeOp = auto computeOp = createSpatCompute(
createSpatCompute(rewriter, gemmLoc, currOutHSliceType, weights, aHSlices[coreId], [&](ValueRange aHSlicesArgs) { rewriter, gemmLoc, currOutHSliceType, weights, aHSlices[coreId], [&](ValueRange aHSlicesArgs) {
SmallVector<Value> vmmOutputs; SmallVector<Value> vmmOutputs;
vmmOutputs.reserve(aHSlicesArgs.size()); vmmOutputs.reserve(aHSlicesArgs.size());
for (auto [aHSliceId, computeArg] : llvm::enumerate(aHSlicesArgs)) for (auto [aHSliceId, computeArg] : llvm::enumerate(aHSlicesArgs))

View File

@@ -71,10 +71,8 @@ static SmallVector<ReassociationIndices> buildCollapseReassociation(ArrayRef<boo
return reassociation; return reassociation;
} }
static Value createAverageCompute(Value input, static Value
RankedTensorType resultType, createAverageCompute(Value input, RankedTensorType resultType, ConversionPatternRewriter& rewriter, Location loc) {
ConversionPatternRewriter& rewriter,
Location loc) {
constexpr size_t numInputs = 1; constexpr size_t numInputs = 1;
auto computeOp = createSpatCompute<numInputs>(rewriter, loc, resultType, {}, ValueRange {input}, [&](Value x) { auto computeOp = createSpatCompute<numInputs>(rewriter, loc, resultType, {}, ValueRange {input}, [&](Value x) {
auto avgOp = spatial::SpatVAvgOp::create(rewriter, loc, resultType, x); auto avgOp = spatial::SpatVAvgOp::create(rewriter, loc, resultType, x);
@@ -141,7 +139,8 @@ struct ReduceMeanToSpatialCompute : OpConversionPattern<ONNXReduceMeanV13Op> {
Location loc = reduceMeanOp.getLoc(); Location loc = reduceMeanOp.getLoc();
RankedTensorType leafType = getAllOnesType(inputType, resultType.getElementType()); RankedTensorType leafType = getAllOnesType(inputType, resultType.getElementType());
Value reducedKeepdims = buildReduceMeanKeepdims(adaptor.getData(), reducedAxes, /*axis=*/0, leafType, rewriter, loc); Value reducedKeepdims =
buildReduceMeanKeepdims(adaptor.getData(), reducedAxes, /*axis=*/0, leafType, rewriter, loc);
if (reduceMeanOp.getKeepdims() != 0) { if (reduceMeanOp.getKeepdims() != 0) {
rewriter.replaceOp(reduceMeanOp, reducedKeepdims); rewriter.replaceOp(reduceMeanOp, reducedKeepdims);

View File

@@ -24,18 +24,16 @@ static SmallVector<int64_t> permuteShape(ArrayRef<int64_t> shape, ArrayRef<int64
static Value createSoftmaxCompute(Value input, ConversionPatternRewriter& rewriter, Location loc) { static Value createSoftmaxCompute(Value input, ConversionPatternRewriter& rewriter, Location loc) {
auto inputType = cast<RankedTensorType>(input.getType()); auto inputType = cast<RankedTensorType>(input.getType());
constexpr size_t numInputs = 1; constexpr size_t numInputs = 1;
auto computeOp = createSpatCompute<numInputs>(rewriter, loc, TypeRange {inputType}, {}, ValueRange {input}, [&](Value x) { auto computeOp =
createSpatCompute<numInputs>(rewriter, loc, TypeRange {inputType}, {}, ValueRange {input}, [&](Value x) {
auto softmaxOp = spatial::SpatSoftmaxOp::create(rewriter, loc, inputType, x); auto softmaxOp = spatial::SpatSoftmaxOp::create(rewriter, loc, inputType, x);
spatial::SpatYieldOp::create(rewriter, loc, softmaxOp.getResult()); spatial::SpatYieldOp::create(rewriter, loc, softmaxOp.getResult());
}); });
return computeOp.getResult(0); return computeOp.getResult(0);
} }
static Value buildSoftmax(Value input, static Value
int64_t softmaxAxis, buildSoftmax(Value input, int64_t softmaxAxis, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
int64_t axis,
ConversionPatternRewriter& rewriter,
Location loc) {
auto inputType = cast<RankedTensorType>(input.getType()); auto inputType = cast<RankedTensorType>(input.getType());
if (axis == inputType.getRank()) if (axis == inputType.getRank())
return createSoftmaxCompute(input, rewriter, loc); return createSoftmaxCompute(input, rewriter, loc);
@@ -71,7 +69,8 @@ struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
Value result; Value result;
if (axis == inputType.getRank() - 1) { if (axis == inputType.getRank() - 1) {
result = buildSoftmax(input, axis, /*axis=*/0, rewriter, softmaxOp.getLoc()); result = buildSoftmax(input, axis, /*axis=*/0, rewriter, softmaxOp.getLoc());
} else { }
else {
SmallVector<int64_t> permutation; SmallVector<int64_t> permutation;
permutation.reserve(inputType.getRank()); permutation.reserve(inputType.getRank());
for (int64_t dim = 0; dim < inputType.getRank(); ++dim) for (int64_t dim = 0; dim < inputType.getRank(); ++dim)
@@ -85,14 +84,15 @@ struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
auto transposedType = RankedTensorType::get( auto transposedType = RankedTensorType::get(
permuteShape(inputType.getShape(), permutation), inputType.getElementType(), inputType.getEncoding()); permuteShape(inputType.getShape(), permutation), inputType.getElementType(), inputType.getEncoding());
auto preTransposeCompute = createSpatCompute<1>( auto preTransposeCompute =
rewriter, softmaxOp.getLoc(), TypeRange {transposedType}, {}, input, [&](Value x) { createSpatCompute<1>(rewriter, softmaxOp.getLoc(), TypeRange {transposedType}, {}, input, [&](Value x) {
Value transposed = Value transposed = ONNXTransposeOp::create(
ONNXTransposeOp::create(rewriter, softmaxOp.getLoc(), transposedType, x, rewriter.getI64ArrayAttr(permutation)); rewriter, softmaxOp.getLoc(), transposedType, x, rewriter.getI64ArrayAttr(permutation));
spatial::SpatYieldOp::create(rewriter, softmaxOp.getLoc(), transposed); spatial::SpatYieldOp::create(rewriter, softmaxOp.getLoc(), transposed);
}); });
Value transposedInput = preTransposeCompute.getResult(0); Value transposedInput = preTransposeCompute.getResult(0);
Value transposedResult = buildSoftmax(transposedInput, /*softmaxAxis=*/inputType.getRank() - 1, /*axis=*/0, rewriter, softmaxOp.getLoc()); Value transposedResult = buildSoftmax(
transposedInput, /*softmaxAxis=*/inputType.getRank() - 1, /*axis=*/0, rewriter, softmaxOp.getLoc());
result = ONNXTransposeOp::create( result = ONNXTransposeOp::create(
rewriter, softmaxOp.getLoc(), inputType, transposedResult, rewriter.getI64ArrayAttr(inversePermutation)); rewriter, softmaxOp.getLoc(), inputType, transposedResult, rewriter.getI64ArrayAttr(inversePermutation));
} }

View File

@@ -23,8 +23,6 @@ struct Concat : public OpConversionPattern<ONNXConcatOp> {
} }
}; };
void populateConcatPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { void populateConcatPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert<Concat>(ctx); }
patterns.insert<Concat>(ctx);
}
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -19,11 +19,8 @@ static int64_t normalizeAxis(int64_t axis, int64_t rank) { return axis >= 0 ? ax
static int64_t normalizeIndex(int64_t index, int64_t dimSize) { return index >= 0 ? index : dimSize + index; } static int64_t normalizeIndex(int64_t index, int64_t dimSize) { return index >= 0 ? index : dimSize + index; }
static Value extractSliceAt(Value input, static Value
int64_t axis, extractSliceAt(Value input, int64_t axis, int64_t offset, ConversionPatternRewriter& rewriter, Location loc) {
int64_t offset,
ConversionPatternRewriter& rewriter,
Location loc) {
auto inputType = cast<RankedTensorType>(input.getType()); auto inputType = cast<RankedTensorType>(input.getType());
SmallVector<OpFoldResult> offsets(inputType.getRank(), rewriter.getIndexAttr(0)); SmallVector<OpFoldResult> offsets(inputType.getRank(), rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> sizes; SmallVector<OpFoldResult> sizes;
@@ -110,12 +107,18 @@ struct Gather : OpConversionPattern<ONNXGatherOp> {
SmallVector<int64_t> flatIndices(indicesAttr.getValues<int64_t>().begin(), indicesAttr.getValues<int64_t>().end()); SmallVector<int64_t> flatIndices(indicesAttr.getValues<int64_t>().begin(), indicesAttr.getValues<int64_t>().end());
Location loc = gatherOp.getLoc(); Location loc = gatherOp.getLoc();
auto computeOp = createSpatCompute<1>( auto computeOp =
rewriter, loc, TypeRange {gatherOp.getResult().getType()}, {}, adaptor.getData(), [&](Value data) -> LogicalResult { createSpatCompute<1>(rewriter,
loc,
TypeRange {gatherOp.getResult().getType()},
{},
adaptor.getData(),
[&](Value data) -> LogicalResult {
Value result; Value result;
if (indicesType.getRank() == 1) { if (indicesType.getRank() == 1) {
result = concatGatherSlices(data, axis, flatIndices, axisDim, rewriter, loc); result = concatGatherSlices(data, axis, flatIndices, axisDim, rewriter, loc);
} else if (indicesType.getRank() == 2) { }
else if (indicesType.getRank() == 2) {
int64_t rowCount = indicesType.getShape()[0]; int64_t rowCount = indicesType.getShape()[0];
int64_t rowWidth = indicesType.getShape()[1]; int64_t rowWidth = indicesType.getShape()[1];
SmallVector<Value> rows; SmallVector<Value> rows;
@@ -127,9 +130,11 @@ struct Gather : OpConversionPattern<ONNXGatherOp> {
return failure(); return failure();
rows.push_back(addLeadingGatherDim(gatheredRow, axis, rewriter, loc)); rows.push_back(addLeadingGatherDim(gatheredRow, axis, rewriter, loc));
} }
result = result = rows.size() == 1
rows.size() == 1 ? rows.front() : tensor::ConcatOp::create(rewriter, loc, /*axis=*/axis, rows).getResult(); ? rows.front()
} else { : tensor::ConcatOp::create(rewriter, loc, /*axis=*/axis, rows).getResult();
}
else {
return failure(); return failure();
} }

View File

@@ -15,11 +15,8 @@ using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
namespace { namespace {
static Value extractSliceAt(Value input, static Value
int64_t axis, extractSliceAt(Value input, int64_t axis, int64_t offset, ConversionPatternRewriter& rewriter, Location loc) {
int64_t offset,
ConversionPatternRewriter& rewriter,
Location loc) {
auto inputType = cast<RankedTensorType>(input.getType()); auto inputType = cast<RankedTensorType>(input.getType());
SmallVector<OpFoldResult> offsets(inputType.getRank(), rewriter.getIndexAttr(0)); SmallVector<OpFoldResult> offsets(inputType.getRank(), rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> sizes; SmallVector<OpFoldResult> sizes;
@@ -67,8 +64,7 @@ struct Resize : OpConversionPattern<ONNXResizeOp> {
if (!inputType || !resultType || !inputType.hasStaticShape() || !resultType.hasStaticShape()) if (!inputType || !resultType || !inputType.hasStaticShape() || !resultType.hasStaticShape())
return failure(); return failure();
if (resizeOp.getMode() != "nearest" if (resizeOp.getMode() != "nearest" || resizeOp.getCoordinateTransformationMode() != "asymmetric"
|| resizeOp.getCoordinateTransformationMode() != "asymmetric"
|| resizeOp.getNearestMode() != "floor") || resizeOp.getNearestMode() != "floor")
return failure(); return failure();
@@ -76,9 +72,10 @@ struct Resize : OpConversionPattern<ONNXResizeOp> {
|| llvm::any_of(resultType.getShape(), [](int64_t dim) { return dim <= 0; })) || llvm::any_of(resultType.getShape(), [](int64_t dim) { return dim <= 0; }))
return failure(); return failure();
auto computeOp = createSpatCompute<1>( auto computeOp =
rewriter, resizeOp.getLoc(), TypeRange {resultType}, {}, adaptor.getX(), [&](Value x) { createSpatCompute<1>(rewriter, resizeOp.getLoc(), TypeRange {resultType}, {}, adaptor.getX(), [&](Value x) {
Value result = buildNearestResize(x, inputType.getShape(), resultType.getShape(), /*axis=*/0, rewriter, resizeOp.getLoc()); Value result =
buildNearestResize(x, inputType.getShape(), resultType.getShape(), /*axis=*/0, rewriter, resizeOp.getLoc());
spatial::SpatYieldOp::create(rewriter, resizeOp.getLoc(), result); spatial::SpatYieldOp::create(rewriter, resizeOp.getLoc(), result);
}); });
rewriter.replaceOp(resizeOp, computeOp.getResults()); rewriter.replaceOp(resizeOp, computeOp.getResults());

View File

@@ -12,12 +12,8 @@ namespace {
static int64_t normalizeAxis(int64_t axis, int64_t rank) { return axis >= 0 ? axis : rank + axis; } static int64_t normalizeAxis(int64_t axis, int64_t rank) { return axis >= 0 ? axis : rank + axis; }
static Value extractSliceAt(Value input, static Value extractSliceAt(
int64_t axis, Value input, int64_t axis, int64_t offset, int64_t size, ConversionPatternRewriter& rewriter, Location loc) {
int64_t offset,
int64_t size,
ConversionPatternRewriter& rewriter,
Location loc) {
auto inputType = cast<RankedTensorType>(input.getType()); auto inputType = cast<RankedTensorType>(input.getType());
SmallVector<OpFoldResult> offsets(inputType.getRank(), rewriter.getIndexAttr(0)); SmallVector<OpFoldResult> offsets(inputType.getRank(), rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> sizes; SmallVector<OpFoldResult> sizes;
@@ -33,9 +29,8 @@ static Value extractSliceAt(Value input,
struct Split : OpConversionPattern<ONNXSplitOp> { struct Split : OpConversionPattern<ONNXSplitOp> {
using OpConversionPattern::OpConversionPattern; using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(ONNXSplitOp splitOp, LogicalResult
ONNXSplitOpAdaptor adaptor, matchAndRewrite(ONNXSplitOp splitOp, ONNXSplitOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override {
ConversionPatternRewriter& rewriter) const override {
auto inputType = dyn_cast<RankedTensorType>(adaptor.getInput().getType()); auto inputType = dyn_cast<RankedTensorType>(adaptor.getInput().getType());
if (!inputType || !inputType.hasStaticShape()) if (!inputType || !inputType.hasStaticShape())
return failure(); return failure();

View File

@@ -44,6 +44,4 @@ bool TaskDCP::hasDescendent(TaskDCP* child) {
} }
// TODO fare qualcosa di sensato // TODO fare qualcosa di sensato
int TaskDCP::computeWeight(GraphDCP* graph, CPU cpu) { int TaskDCP::computeWeight(GraphDCP* graph, CPU cpu) { return orig_weight; }
return orig_weight;
}

View File

@@ -71,12 +71,7 @@ public:
return true; return true;
} }
auto begin() { auto begin() { return storage.begin(); }
return storage.begin();
}
auto end() {
return storage.end();
}
auto end() { return storage.end(); }
}; };

View File

@@ -1,7 +1,9 @@
#pragma once #pragma once
#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypeInterfaces.h"
#include "llvm/Support/Casting.h" #include "llvm/Support/Casting.h"
#include <algorithm> #include <algorithm>
#include <cstdint> #include <cstdint>
#include <utility> #include <utility>
@@ -50,11 +52,10 @@ inline int64_t getSpatWeightCompute(onnx_mlir::spatial::SpatWeightedCompute spat
int64_t tot = 0; int64_t tot = 0;
for (auto& region : spatWeightedCompute.getBody()) { for (auto& region : spatWeightedCompute.getBody()) {
for (auto& inst : region) { for (auto& inst : region) {
for(auto result : inst.getResults()){ for (auto result : inst.getResults())
if (auto element = llvm::dyn_cast<mlir::ShapedType>(result.getType())) if (auto element = llvm::dyn_cast<mlir::ShapedType>(result.getType()))
tot += onnx_mlir::getSizeInBytes(element); tot += onnx_mlir::getSizeInBytes(element);
} }
} }
}
return tot; return tot;
} }

View File

@@ -243,8 +243,7 @@ void populateConstantFoldingSubviewPatterns(RewritePatternSet& patterns) {
patterns.add<RewriteCoreSubviewCopyPattern, patterns.add<RewriteCoreSubviewCopyPattern,
RewriteHostSubviewLoadPattern, RewriteHostSubviewLoadPattern,
RewriteHostSubviewStorePattern, RewriteHostSubviewStorePattern,
FoldConstantCoreSubviewPattern>( FoldConstantCoreSubviewPattern>(patterns.getContext());
patterns.getContext());
} }
} // namespace onnx_mlir } // namespace onnx_mlir