automatic code-reformat
All checks were successful
Validate Operations / validate-operations (push) Successful in 18m22s

This commit is contained in:
NiccoloN
2026-04-09 14:27:23 +02:00
parent 1a0192d1f9
commit 9e0d31af50
16 changed files with 88 additions and 103 deletions

View File

@@ -175,7 +175,7 @@ bool encapsulateConcat(IRRewriter& rewriter, Location loc, Operation* inst) {
auto newCompute = spatial::SpatWeightedCompute::create(rewriter, loc, inst->getResultTypes().front(), sources); auto newCompute = spatial::SpatWeightedCompute::create(rewriter, loc, inst->getResultTypes().front(), sources);
llvm::SmallVector<Type> sourceTypes; llvm::SmallVector<Type> sourceTypes;
llvm::SmallVector<Location> sourceLoc; llvm::SmallVector<Location> sourceLoc;
for (auto source : sources){ for (auto source : sources) {
sourceTypes.push_back(source.getType()); sourceTypes.push_back(source.getType());
sourceLoc.push_back(loc); sourceLoc.push_back(loc);
} }
@@ -183,7 +183,7 @@ bool encapsulateConcat(IRRewriter& rewriter, Location loc, Operation* inst) {
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) sources.size()}); newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) sources.size()});
rewriter.setInsertionPointToEnd(BB); rewriter.setInsertionPointToEnd(BB);
IRMapping mapper; IRMapping mapper;
for(auto [source,bbArg] : llvm::zip(sources, BB->getArguments())) for (auto [source, bbArg] : llvm::zip(sources, BB->getArguments()))
mapper.map(source, bbArg); mapper.map(source, bbArg);
auto newConcat = rewriter.clone(*inst, mapper); auto newConcat = rewriter.clone(*inst, mapper);
spatial::SpatYieldOp::create(rewriter, loc, newConcat->getResult(0)); spatial::SpatYieldOp::create(rewriter, loc, newConcat->getResult(0));

View File

@@ -92,10 +92,8 @@ static FailureOr<Value> materializeBroadcastedConstantTensor(Value value,
return arith::ConstantOp::create(rewriter, loc, resultType, broadcastedAttr).getResult(); return arith::ConstantOp::create(rewriter, loc, resultType, broadcastedAttr).getResult();
} }
static FailureOr<Value> prepareElementwiseOperand(Value value, static FailureOr<Value>
RankedTensorType resultType, prepareElementwiseOperand(Value value, RankedTensorType resultType, ConversionPatternRewriter& rewriter, Location loc) {
ConversionPatternRewriter& rewriter,
Location loc) {
auto valueType = dyn_cast<RankedTensorType>(value.getType()); auto valueType = dyn_cast<RankedTensorType>(value.getType());
if (!valueType || !valueType.hasStaticShape()) if (!valueType || !valueType.hasStaticShape())
return failure(); return failure();

View File

@@ -280,8 +280,8 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
for (size_t aSliceId = 0; aSliceId < aHSlices[coreId].size(); aSliceId++) for (size_t aSliceId = 0; aSliceId < aHSlices[coreId].size(); aSliceId++)
weights.push_back(bTiles[outSliceId][coreId][aSliceId]); weights.push_back(bTiles[outSliceId][coreId][aSliceId]);
auto computeOp = auto computeOp = createSpatCompute(
createSpatCompute(rewriter, gemmLoc, currOutHSliceType, weights, aHSlices[coreId], [&](ValueRange aHSlicesArgs) { rewriter, gemmLoc, currOutHSliceType, weights, aHSlices[coreId], [&](ValueRange aHSlicesArgs) {
SmallVector<Value> vmmOutputs; SmallVector<Value> vmmOutputs;
vmmOutputs.reserve(aHSlicesArgs.size()); vmmOutputs.reserve(aHSlicesArgs.size());
for (auto [aHSliceId, computeArg] : llvm::enumerate(aHSlicesArgs)) for (auto [aHSliceId, computeArg] : llvm::enumerate(aHSlicesArgs))

View File

@@ -71,10 +71,8 @@ static SmallVector<ReassociationIndices> buildCollapseReassociation(ArrayRef<boo
return reassociation; return reassociation;
} }
static Value createAverageCompute(Value input, static Value
RankedTensorType resultType, createAverageCompute(Value input, RankedTensorType resultType, ConversionPatternRewriter& rewriter, Location loc) {
ConversionPatternRewriter& rewriter,
Location loc) {
constexpr size_t numInputs = 1; constexpr size_t numInputs = 1;
auto computeOp = createSpatCompute<numInputs>(rewriter, loc, resultType, {}, ValueRange {input}, [&](Value x) { auto computeOp = createSpatCompute<numInputs>(rewriter, loc, resultType, {}, ValueRange {input}, [&](Value x) {
auto avgOp = spatial::SpatVAvgOp::create(rewriter, loc, resultType, x); auto avgOp = spatial::SpatVAvgOp::create(rewriter, loc, resultType, x);
@@ -141,7 +139,8 @@ struct ReduceMeanToSpatialCompute : OpConversionPattern<ONNXReduceMeanV13Op> {
Location loc = reduceMeanOp.getLoc(); Location loc = reduceMeanOp.getLoc();
RankedTensorType leafType = getAllOnesType(inputType, resultType.getElementType()); RankedTensorType leafType = getAllOnesType(inputType, resultType.getElementType());
Value reducedKeepdims = buildReduceMeanKeepdims(adaptor.getData(), reducedAxes, /*axis=*/0, leafType, rewriter, loc); Value reducedKeepdims =
buildReduceMeanKeepdims(adaptor.getData(), reducedAxes, /*axis=*/0, leafType, rewriter, loc);
if (reduceMeanOp.getKeepdims() != 0) { if (reduceMeanOp.getKeepdims() != 0) {
rewriter.replaceOp(reduceMeanOp, reducedKeepdims); rewriter.replaceOp(reduceMeanOp, reducedKeepdims);

View File

@@ -24,18 +24,16 @@ static SmallVector<int64_t> permuteShape(ArrayRef<int64_t> shape, ArrayRef<int64
static Value createSoftmaxCompute(Value input, ConversionPatternRewriter& rewriter, Location loc) { static Value createSoftmaxCompute(Value input, ConversionPatternRewriter& rewriter, Location loc) {
auto inputType = cast<RankedTensorType>(input.getType()); auto inputType = cast<RankedTensorType>(input.getType());
constexpr size_t numInputs = 1; constexpr size_t numInputs = 1;
auto computeOp = createSpatCompute<numInputs>(rewriter, loc, TypeRange {inputType}, {}, ValueRange {input}, [&](Value x) { auto computeOp =
auto softmaxOp = spatial::SpatSoftmaxOp::create(rewriter, loc, inputType, x); createSpatCompute<numInputs>(rewriter, loc, TypeRange {inputType}, {}, ValueRange {input}, [&](Value x) {
spatial::SpatYieldOp::create(rewriter, loc, softmaxOp.getResult()); auto softmaxOp = spatial::SpatSoftmaxOp::create(rewriter, loc, inputType, x);
}); spatial::SpatYieldOp::create(rewriter, loc, softmaxOp.getResult());
});
return computeOp.getResult(0); return computeOp.getResult(0);
} }
static Value buildSoftmax(Value input, static Value
int64_t softmaxAxis, buildSoftmax(Value input, int64_t softmaxAxis, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
int64_t axis,
ConversionPatternRewriter& rewriter,
Location loc) {
auto inputType = cast<RankedTensorType>(input.getType()); auto inputType = cast<RankedTensorType>(input.getType());
if (axis == inputType.getRank()) if (axis == inputType.getRank())
return createSoftmaxCompute(input, rewriter, loc); return createSoftmaxCompute(input, rewriter, loc);
@@ -71,7 +69,8 @@ struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
Value result; Value result;
if (axis == inputType.getRank() - 1) { if (axis == inputType.getRank() - 1) {
result = buildSoftmax(input, axis, /*axis=*/0, rewriter, softmaxOp.getLoc()); result = buildSoftmax(input, axis, /*axis=*/0, rewriter, softmaxOp.getLoc());
} else { }
else {
SmallVector<int64_t> permutation; SmallVector<int64_t> permutation;
permutation.reserve(inputType.getRank()); permutation.reserve(inputType.getRank());
for (int64_t dim = 0; dim < inputType.getRank(); ++dim) for (int64_t dim = 0; dim < inputType.getRank(); ++dim)
@@ -85,14 +84,15 @@ struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
auto transposedType = RankedTensorType::get( auto transposedType = RankedTensorType::get(
permuteShape(inputType.getShape(), permutation), inputType.getElementType(), inputType.getEncoding()); permuteShape(inputType.getShape(), permutation), inputType.getElementType(), inputType.getEncoding());
auto preTransposeCompute = createSpatCompute<1>( auto preTransposeCompute =
rewriter, softmaxOp.getLoc(), TypeRange {transposedType}, {}, input, [&](Value x) { createSpatCompute<1>(rewriter, softmaxOp.getLoc(), TypeRange {transposedType}, {}, input, [&](Value x) {
Value transposed = Value transposed = ONNXTransposeOp::create(
ONNXTransposeOp::create(rewriter, softmaxOp.getLoc(), transposedType, x, rewriter.getI64ArrayAttr(permutation)); rewriter, softmaxOp.getLoc(), transposedType, x, rewriter.getI64ArrayAttr(permutation));
spatial::SpatYieldOp::create(rewriter, softmaxOp.getLoc(), transposed); spatial::SpatYieldOp::create(rewriter, softmaxOp.getLoc(), transposed);
}); });
Value transposedInput = preTransposeCompute.getResult(0); Value transposedInput = preTransposeCompute.getResult(0);
Value transposedResult = buildSoftmax(transposedInput, /*softmaxAxis=*/inputType.getRank() - 1, /*axis=*/0, rewriter, softmaxOp.getLoc()); Value transposedResult = buildSoftmax(
transposedInput, /*softmaxAxis=*/inputType.getRank() - 1, /*axis=*/0, rewriter, softmaxOp.getLoc());
result = ONNXTransposeOp::create( result = ONNXTransposeOp::create(
rewriter, softmaxOp.getLoc(), inputType, transposedResult, rewriter.getI64ArrayAttr(inversePermutation)); rewriter, softmaxOp.getLoc(), inputType, transposedResult, rewriter.getI64ArrayAttr(inversePermutation));
} }

View File

@@ -23,8 +23,6 @@ struct Concat : public OpConversionPattern<ONNXConcatOp> {
} }
}; };
void populateConcatPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { void populateConcatPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert<Concat>(ctx); }
patterns.insert<Concat>(ctx);
}
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -19,11 +19,8 @@ static int64_t normalizeAxis(int64_t axis, int64_t rank) { return axis >= 0 ? ax
static int64_t normalizeIndex(int64_t index, int64_t dimSize) { return index >= 0 ? index : dimSize + index; } static int64_t normalizeIndex(int64_t index, int64_t dimSize) { return index >= 0 ? index : dimSize + index; }
static Value extractSliceAt(Value input, static Value
int64_t axis, extractSliceAt(Value input, int64_t axis, int64_t offset, ConversionPatternRewriter& rewriter, Location loc) {
int64_t offset,
ConversionPatternRewriter& rewriter,
Location loc) {
auto inputType = cast<RankedTensorType>(input.getType()); auto inputType = cast<RankedTensorType>(input.getType());
SmallVector<OpFoldResult> offsets(inputType.getRank(), rewriter.getIndexAttr(0)); SmallVector<OpFoldResult> offsets(inputType.getRank(), rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> sizes; SmallVector<OpFoldResult> sizes;
@@ -110,34 +107,42 @@ struct Gather : OpConversionPattern<ONNXGatherOp> {
SmallVector<int64_t> flatIndices(indicesAttr.getValues<int64_t>().begin(), indicesAttr.getValues<int64_t>().end()); SmallVector<int64_t> flatIndices(indicesAttr.getValues<int64_t>().begin(), indicesAttr.getValues<int64_t>().end());
Location loc = gatherOp.getLoc(); Location loc = gatherOp.getLoc();
auto computeOp = createSpatCompute<1>( auto computeOp =
rewriter, loc, TypeRange {gatherOp.getResult().getType()}, {}, adaptor.getData(), [&](Value data) -> LogicalResult { createSpatCompute<1>(rewriter,
Value result; loc,
if (indicesType.getRank() == 1) { TypeRange {gatherOp.getResult().getType()},
result = concatGatherSlices(data, axis, flatIndices, axisDim, rewriter, loc); {},
} else if (indicesType.getRank() == 2) { adaptor.getData(),
int64_t rowCount = indicesType.getShape()[0]; [&](Value data) -> LogicalResult {
int64_t rowWidth = indicesType.getShape()[1]; Value result;
SmallVector<Value> rows; if (indicesType.getRank() == 1) {
rows.reserve(rowCount); result = concatGatherSlices(data, axis, flatIndices, axisDim, rewriter, loc);
for (int64_t row = 0; row < rowCount; ++row) { }
ArrayRef<int64_t> rowIndices(flatIndices.data() + row * rowWidth, rowWidth); else if (indicesType.getRank() == 2) {
Value gatheredRow = concatGatherSlices(data, axis, rowIndices, axisDim, rewriter, loc); int64_t rowCount = indicesType.getShape()[0];
if (!gatheredRow) int64_t rowWidth = indicesType.getShape()[1];
return failure(); SmallVector<Value> rows;
rows.push_back(addLeadingGatherDim(gatheredRow, axis, rewriter, loc)); rows.reserve(rowCount);
} for (int64_t row = 0; row < rowCount; ++row) {
result = ArrayRef<int64_t> rowIndices(flatIndices.data() + row * rowWidth, rowWidth);
rows.size() == 1 ? rows.front() : tensor::ConcatOp::create(rewriter, loc, /*axis=*/axis, rows).getResult(); Value gatheredRow = concatGatherSlices(data, axis, rowIndices, axisDim, rewriter, loc);
} else { if (!gatheredRow)
return failure(); return failure();
} rows.push_back(addLeadingGatherDim(gatheredRow, axis, rewriter, loc));
}
result = rows.size() == 1
? rows.front()
: tensor::ConcatOp::create(rewriter, loc, /*axis=*/axis, rows).getResult();
}
else {
return failure();
}
if (!result) if (!result)
return failure(); return failure();
spatial::SpatYieldOp::create(rewriter, loc, result); spatial::SpatYieldOp::create(rewriter, loc, result);
return success(); return success();
}); });
if (failed(computeOp)) if (failed(computeOp))
return failure(); return failure();
rewriter.replaceOp(gatherOp, computeOp->getResults()); rewriter.replaceOp(gatherOp, computeOp->getResults());

View File

@@ -15,11 +15,8 @@ using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
namespace { namespace {
static Value extractSliceAt(Value input, static Value
int64_t axis, extractSliceAt(Value input, int64_t axis, int64_t offset, ConversionPatternRewriter& rewriter, Location loc) {
int64_t offset,
ConversionPatternRewriter& rewriter,
Location loc) {
auto inputType = cast<RankedTensorType>(input.getType()); auto inputType = cast<RankedTensorType>(input.getType());
SmallVector<OpFoldResult> offsets(inputType.getRank(), rewriter.getIndexAttr(0)); SmallVector<OpFoldResult> offsets(inputType.getRank(), rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> sizes; SmallVector<OpFoldResult> sizes;
@@ -67,8 +64,7 @@ struct Resize : OpConversionPattern<ONNXResizeOp> {
if (!inputType || !resultType || !inputType.hasStaticShape() || !resultType.hasStaticShape()) if (!inputType || !resultType || !inputType.hasStaticShape() || !resultType.hasStaticShape())
return failure(); return failure();
if (resizeOp.getMode() != "nearest" if (resizeOp.getMode() != "nearest" || resizeOp.getCoordinateTransformationMode() != "asymmetric"
|| resizeOp.getCoordinateTransformationMode() != "asymmetric"
|| resizeOp.getNearestMode() != "floor") || resizeOp.getNearestMode() != "floor")
return failure(); return failure();
@@ -76,9 +72,10 @@ struct Resize : OpConversionPattern<ONNXResizeOp> {
|| llvm::any_of(resultType.getShape(), [](int64_t dim) { return dim <= 0; })) || llvm::any_of(resultType.getShape(), [](int64_t dim) { return dim <= 0; }))
return failure(); return failure();
auto computeOp = createSpatCompute<1>( auto computeOp =
rewriter, resizeOp.getLoc(), TypeRange {resultType}, {}, adaptor.getX(), [&](Value x) { createSpatCompute<1>(rewriter, resizeOp.getLoc(), TypeRange {resultType}, {}, adaptor.getX(), [&](Value x) {
Value result = buildNearestResize(x, inputType.getShape(), resultType.getShape(), /*axis=*/0, rewriter, resizeOp.getLoc()); Value result =
buildNearestResize(x, inputType.getShape(), resultType.getShape(), /*axis=*/0, rewriter, resizeOp.getLoc());
spatial::SpatYieldOp::create(rewriter, resizeOp.getLoc(), result); spatial::SpatYieldOp::create(rewriter, resizeOp.getLoc(), result);
}); });
rewriter.replaceOp(resizeOp, computeOp.getResults()); rewriter.replaceOp(resizeOp, computeOp.getResults());

View File

@@ -12,12 +12,8 @@ namespace {
static int64_t normalizeAxis(int64_t axis, int64_t rank) { return axis >= 0 ? axis : rank + axis; } static int64_t normalizeAxis(int64_t axis, int64_t rank) { return axis >= 0 ? axis : rank + axis; }
static Value extractSliceAt(Value input, static Value extractSliceAt(
int64_t axis, Value input, int64_t axis, int64_t offset, int64_t size, ConversionPatternRewriter& rewriter, Location loc) {
int64_t offset,
int64_t size,
ConversionPatternRewriter& rewriter,
Location loc) {
auto inputType = cast<RankedTensorType>(input.getType()); auto inputType = cast<RankedTensorType>(input.getType());
SmallVector<OpFoldResult> offsets(inputType.getRank(), rewriter.getIndexAttr(0)); SmallVector<OpFoldResult> offsets(inputType.getRank(), rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> sizes; SmallVector<OpFoldResult> sizes;
@@ -33,9 +29,8 @@ static Value extractSliceAt(Value input,
struct Split : OpConversionPattern<ONNXSplitOp> { struct Split : OpConversionPattern<ONNXSplitOp> {
using OpConversionPattern::OpConversionPattern; using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(ONNXSplitOp splitOp, LogicalResult
ONNXSplitOpAdaptor adaptor, matchAndRewrite(ONNXSplitOp splitOp, ONNXSplitOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override {
ConversionPatternRewriter& rewriter) const override {
auto inputType = dyn_cast<RankedTensorType>(adaptor.getInput().getType()); auto inputType = dyn_cast<RankedTensorType>(adaptor.getInput().getType());
if (!inputType || !inputType.hasStaticShape()) if (!inputType || !inputType.hasStaticShape())
return failure(); return failure();

View File

@@ -485,7 +485,7 @@ DCPAnalysisResult GraphDCP::getResult() {
size_t i = 0; size_t i = 0;
for (auto node : nodes) { for (auto node : nodes) {
ret.computeToCPUMap[node->getSpatWeightedCompute()] = cpu; ret.computeToCPUMap[node->getSpatWeightedCompute()] = cpu;
if (i++ == nodes.size() - 1){ if (i++ == nodes.size() - 1) {
ret.isLastComputeOfACpu.insert(node->getSpatWeightedCompute()); ret.isLastComputeOfACpu.insert(node->getSpatWeightedCompute());
ret.cpuToLastComputeMap[cpu] = node->getSpatWeightedCompute(); ret.cpuToLastComputeMap[cpu] = node->getSpatWeightedCompute();
} }

View File

@@ -43,7 +43,5 @@ bool TaskDCP::hasDescendent(TaskDCP* child) {
return false; return false;
} }
//TODO fare qualcosa di sensato // TODO fare qualcosa di sensato
int TaskDCP::computeWeight(GraphDCP* graph, CPU cpu) { int TaskDCP::computeWeight(GraphDCP* graph, CPU cpu) { return orig_weight; }
return orig_weight;
}

View File

@@ -75,11 +75,11 @@ public:
alst = val; alst = val;
} }
bool hasDescendent(TaskDCP* child); bool hasDescendent(TaskDCP* child);
int64_t Id() const { return (int64_t)spatWeightedCompute.getAsOpaquePointer(); } int64_t Id() const { return (int64_t) spatWeightedCompute.getAsOpaquePointer(); }
bool isCP() const { return alst == aest; } bool isCP() const { return alst == aest; }
bool isScheduled() const { return scheduledCPU.has_value(); } bool isScheduled() const { return scheduledCPU.has_value(); }
onnx_mlir::spatial::SpatWeightedCompute getSpatWeightedCompute(){return spatWeightedCompute;} onnx_mlir::spatial::SpatWeightedCompute getSpatWeightedCompute() { return spatWeightedCompute; }
friend std::optional<Edge_pair> addEdge(TaskDCP* parent, TaskDCP* child, Weight_t weight); friend std::optional<Edge_pair> addEdge(TaskDCP* parent, TaskDCP* child, Weight_t weight);
friend void removeEdge(TaskDCP* parent, TaskDCP* child); friend void removeEdge(TaskDCP* parent, TaskDCP* child);

View File

@@ -71,12 +71,7 @@ public:
return true; return true;
} }
auto begin() { auto begin() { return storage.begin(); }
return storage.begin();
}
auto end() {
return storage.end();
}
auto end() { return storage.end(); }
}; };

View File

@@ -1,7 +1,9 @@
#pragma once #pragma once
#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypeInterfaces.h"
#include "llvm/Support/Casting.h" #include "llvm/Support/Casting.h"
#include <algorithm> #include <algorithm>
#include <cstdint> #include <cstdint>
#include <utility> #include <utility>
@@ -50,10 +52,9 @@ inline int64_t getSpatWeightCompute(onnx_mlir::spatial::SpatWeightedCompute spat
int64_t tot = 0; int64_t tot = 0;
for (auto& region : spatWeightedCompute.getBody()) { for (auto& region : spatWeightedCompute.getBody()) {
for (auto& inst : region) { for (auto& inst : region) {
for(auto result : inst.getResults()){ for (auto result : inst.getResults())
if(auto element = llvm::dyn_cast<mlir::ShapedType>(result.getType())) if (auto element = llvm::dyn_cast<mlir::ShapedType>(result.getType()))
tot += onnx_mlir::getSizeInBytes(element); tot += onnx_mlir::getSizeInBytes(element);
}
} }
} }
return tot; return tot;

View File

@@ -413,7 +413,7 @@ struct ChannelBroadcastReceiveOpInterface
outputTensor, outputTensor,
rewriter.getI32IntegerAttr(numElements * elementSize), rewriter.getI32IntegerAttr(numElements * elementSize),
rewriter.getI32IntegerAttr(srcCoreId.value())) rewriter.getI32IntegerAttr(srcCoreId.value()))
.getOutput(); .getOutput();
replaceOpWithBufferizedValues(rewriter, op, newValue); replaceOpWithBufferizedValues(rewriter, op, newValue);

View File

@@ -243,8 +243,7 @@ void populateConstantFoldingSubviewPatterns(RewritePatternSet& patterns) {
patterns.add<RewriteCoreSubviewCopyPattern, patterns.add<RewriteCoreSubviewCopyPattern,
RewriteHostSubviewLoadPattern, RewriteHostSubviewLoadPattern,
RewriteHostSubviewStorePattern, RewriteHostSubviewStorePattern,
FoldConstantCoreSubviewPattern>( FoldConstantCoreSubviewPattern>(patterns.getContext());
patterns.getContext());
} }
} // namespace onnx_mlir } // namespace onnx_mlir