Refactor + ReduceMean batched
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
ilgeco
2026-05-29 15:57:13 +02:00
parent 832bd7f1f7
commit 819d8af0f7
27 changed files with 929 additions and 568 deletions
@@ -28,8 +28,6 @@ struct ConvToGemm : OpConversionPattern<ONNXConvOp> {
ConversionPatternRewriter& rewriter) const override;
};
static int64_t getI64FromArrayAttr(ArrayAttr arr, size_t idx) { return cast<IntegerAttr>(arr[idx]).getInt(); }
static Value expandBiasIfNeeded(Value bias, ConversionPatternRewriter& rewriter, Location loc) {
auto biasType = cast<RankedTensorType>(bias.getType());
if (biasType.getRank() != 1)
@@ -615,10 +613,10 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
return failure();
}
const int64_t strideHeight = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 0) : 1;
const int64_t strideWidth = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 1) : 1;
const int64_t dilationHeight = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 0) : 1;
const int64_t dilationWidth = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 1) : 1;
const int64_t strideHeight = getOptionalI64Attr(stridesAttr, 0, 1);
const int64_t strideWidth = getOptionalI64Attr(stridesAttr, 1, 1);
const int64_t dilationHeight = getOptionalI64Attr(dilationsAttr, 0, 1);
const int64_t dilationWidth = getOptionalI64Attr(dilationsAttr, 1, 1);
int64_t padHeightBegin = 0;
int64_t padHeightEnd = 0;
@@ -626,10 +624,10 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
int64_t padWidthEnd = 0;
if (padsAttr) {
padHeightBegin = getI64FromArrayAttr(*padsAttr, 0);
padWidthBegin = getI64FromArrayAttr(*padsAttr, 1);
padHeightEnd = getI64FromArrayAttr(*padsAttr, 2);
padWidthEnd = getI64FromArrayAttr(*padsAttr, 3);
padHeightBegin = getI64Attr(*padsAttr, 0);
padWidthBegin = getI64Attr(*padsAttr, 1);
padHeightEnd = getI64Attr(*padsAttr, 2);
padWidthEnd = getI64Attr(*padsAttr, 3);
}
else {
// Compute padding from auto_pad attribute