This commit is contained in:
@@ -28,8 +28,6 @@ struct ConvToGemm : OpConversionPattern<ONNXConvOp> {
|
||||
ConversionPatternRewriter& rewriter) const override;
|
||||
};
|
||||
|
||||
static int64_t getI64FromArrayAttr(ArrayAttr arr, size_t idx) { return cast<IntegerAttr>(arr[idx]).getInt(); }
|
||||
|
||||
static Value expandBiasIfNeeded(Value bias, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
auto biasType = cast<RankedTensorType>(bias.getType());
|
||||
if (biasType.getRank() != 1)
|
||||
@@ -615,10 +613,10 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
return failure();
|
||||
}
|
||||
|
||||
const int64_t strideHeight = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 0) : 1;
|
||||
const int64_t strideWidth = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 1) : 1;
|
||||
const int64_t dilationHeight = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 0) : 1;
|
||||
const int64_t dilationWidth = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 1) : 1;
|
||||
const int64_t strideHeight = getOptionalI64Attr(stridesAttr, 0, 1);
|
||||
const int64_t strideWidth = getOptionalI64Attr(stridesAttr, 1, 1);
|
||||
const int64_t dilationHeight = getOptionalI64Attr(dilationsAttr, 0, 1);
|
||||
const int64_t dilationWidth = getOptionalI64Attr(dilationsAttr, 1, 1);
|
||||
|
||||
int64_t padHeightBegin = 0;
|
||||
int64_t padHeightEnd = 0;
|
||||
@@ -626,10 +624,10 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
int64_t padWidthEnd = 0;
|
||||
|
||||
if (padsAttr) {
|
||||
padHeightBegin = getI64FromArrayAttr(*padsAttr, 0);
|
||||
padWidthBegin = getI64FromArrayAttr(*padsAttr, 1);
|
||||
padHeightEnd = getI64FromArrayAttr(*padsAttr, 2);
|
||||
padWidthEnd = getI64FromArrayAttr(*padsAttr, 3);
|
||||
padHeightBegin = getI64Attr(*padsAttr, 0);
|
||||
padWidthBegin = getI64Attr(*padsAttr, 1);
|
||||
padHeightEnd = getI64Attr(*padsAttr, 2);
|
||||
padWidthEnd = getI64Attr(*padsAttr, 3);
|
||||
}
|
||||
else {
|
||||
// Compute padding from auto_pad attribute
|
||||
|
||||
Reference in New Issue
Block a user