#pragma once #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Value.h" #include namespace onnx_mlir { struct ConvLoweringState { mlir::Value x; mlir::Value w; mlir::Value b; mlir::RankedTensorType xType; mlir::RankedTensorType wType; mlir::RankedTensorType outType; int64_t batchSize; int64_t numChannelsIn; int64_t xHeight; int64_t xWidth; int64_t numChannelsOut; int64_t wHeight; int64_t wWidth; int64_t outHeight; int64_t outWidth; int64_t group; int64_t numChannelsInPerGroup; int64_t numChannelsOutPerGroup; int64_t padHeightBegin; int64_t padHeightEnd; int64_t padWidthBegin; int64_t padWidthEnd; int64_t strideHeight; int64_t strideWidth; int64_t dilationHeight; int64_t dilationWidth; bool hasBias; }; struct ConvGeometry { int64_t batchSize; int64_t numChannelsIn; int64_t xHeight; int64_t xWidth; int64_t numChannelsOut; int64_t wHeight; int64_t wWidth; int64_t outHeight; int64_t outWidth; int64_t group; int64_t numChannelsInPerGroup; int64_t numChannelsOutPerGroup; int64_t k; int64_t c; int64_t p; int64_t xbarSize; int64_t pack; uint64_t im2colElements; bool hasBias; bool isDepthwise; }; struct RowInterval { int64_t begin = 0; int64_t end = 0; }; struct ConvRowDemand { RowInterval outputRows; RowInterval neededInputRows; RowInterval acquiredInputRows; int64_t topHaloRows = 0; int64_t bottomHaloRows = 0; }; bool isDepthwiseConv(int64_t group, int64_t numChannelsIn, int64_t numChannelsOut, int64_t numChannelsInPerGroup); ConvGeometry buildConvGeometry(const ConvLoweringState& state); uint64_t chooseStreamChunkPositions(const ConvGeometry& geo, int64_t packFactor); RowInterval computeConvInputRowsForOutputRows(RowInterval outputRows, const ConvLoweringState& state); ConvRowDemand buildConvRowDemand(RowInterval outputRows, const ConvLoweringState& state); } // namespace onnx_mlir