Refactor + ReduceMean batched
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
ilgeco
2026-05-29 15:57:13 +02:00
parent 832bd7f1f7
commit 819d8af0f7
27 changed files with 929 additions and 568 deletions
@@ -23,28 +23,10 @@ using namespace mlir;
namespace onnx_mlir {
namespace {
template <typename ArrayAttrT>
static int64_t getI64(ArrayAttrT arrayAttr, size_t index) {
return cast<IntegerAttr>(arrayAttr[index]).getInt();
}
template <typename ArrayAttrT>
static int64_t getOptionalI64(std::optional<ArrayAttrT> arrayAttr, size_t index, int64_t defaultValue) {
return arrayAttr ? getI64(*arrayAttr, index) : defaultValue;
}
static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Location loc, Value tile) {
auto tileType = cast<RankedTensorType>(tile.getType());
Value empty = tensor::EmptyOp::create(rewriter, loc, tileType.getShape(), tileType.getElementType());
SmallVector<OpFoldResult> offsets(tileType.getRank(), rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> sizes;
sizes.reserve(tileType.getRank());
for (int64_t dimSize : tileType.getShape())
sizes.push_back(rewriter.getIndexAttr(dimSize));
SmallVector<OpFoldResult> strides(tileType.getRank(), rewriter.getIndexAttr(1));
return tensor::InsertSliceOp::create(rewriter, loc, tile, empty, offsets, sizes, strides);
return insertStaticSlice(rewriter, loc, tile, empty, getZeroOffsets(rewriter, tileType.getRank()));
}
static Value
@@ -197,12 +179,12 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
const int64_t inputWidth = xType.getDimSize(3);
const int64_t outputHeight = outType.getDimSize(2);
const int64_t outputWidth = outType.getDimSize(3);
const int64_t kernelHeight = getI64(kernelAttr, 0);
const int64_t kernelWidth = getI64(kernelAttr, 1);
const int64_t strideHeight = getOptionalI64(poolOp.getStrides(), 0, 1);
const int64_t strideWidth = getOptionalI64(poolOp.getStrides(), 1, 1);
const int64_t dilationHeight = getOptionalI64(poolOp.getDilations(), 0, 1);
const int64_t dilationWidth = getOptionalI64(poolOp.getDilations(), 1, 1);
const int64_t kernelHeight = getI64Attr(kernelAttr, 0);
const int64_t kernelWidth = getI64Attr(kernelAttr, 1);
const int64_t strideHeight = getOptionalI64Attr(poolOp.getStrides(), 0, 1);
const int64_t strideWidth = getOptionalI64Attr(poolOp.getStrides(), 1, 1);
const int64_t dilationHeight = getOptionalI64Attr(poolOp.getDilations(), 0, 1);
const int64_t dilationWidth = getOptionalI64Attr(poolOp.getDilations(), 1, 1);
int64_t padTop = 0;
int64_t padLeft = 0;
@@ -212,10 +194,10 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
if (auto padsAttr = poolOp.getPads()) {
if (padsAttr->size() != 4)
return rewriter.notifyMatchFailure(poolOp, "pads must have four elements.");
padTop = getI64(*padsAttr, 0);
padLeft = getI64(*padsAttr, 1);
padBottom = getI64(*padsAttr, 2);
padRight = getI64(*padsAttr, 3);
padTop = getI64Attr(*padsAttr, 0);
padLeft = getI64Attr(*padsAttr, 1);
padBottom = getI64Attr(*padsAttr, 2);
padRight = getI64Attr(*padsAttr, 3);
}
else {
StringRef autoPad = poolOp.getAutoPad();