87 lines
1.9 KiB
C++
87 lines
1.9 KiB
C++
#pragma once
|
|
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
#include "mlir/IR/Value.h"
|
|
|
|
#include <cstdint>
|
|
|
|
namespace onnx_mlir {
|
|
|
|
struct ConvLoweringState {
|
|
mlir::Value x;
|
|
mlir::Value w;
|
|
mlir::Value b;
|
|
mlir::RankedTensorType xType;
|
|
mlir::RankedTensorType wType;
|
|
mlir::RankedTensorType outType;
|
|
int64_t batchSize;
|
|
int64_t numChannelsIn;
|
|
int64_t xHeight;
|
|
int64_t xWidth;
|
|
int64_t numChannelsOut;
|
|
int64_t wHeight;
|
|
int64_t wWidth;
|
|
int64_t outHeight;
|
|
int64_t outWidth;
|
|
int64_t group;
|
|
int64_t numChannelsInPerGroup;
|
|
int64_t numChannelsOutPerGroup;
|
|
int64_t padHeightBegin;
|
|
int64_t padHeightEnd;
|
|
int64_t padWidthBegin;
|
|
int64_t padWidthEnd;
|
|
int64_t strideHeight;
|
|
int64_t strideWidth;
|
|
int64_t dilationHeight;
|
|
int64_t dilationWidth;
|
|
bool hasBias;
|
|
};
|
|
|
|
struct ConvGeometry {
|
|
int64_t batchSize;
|
|
int64_t numChannelsIn;
|
|
int64_t xHeight;
|
|
int64_t xWidth;
|
|
int64_t numChannelsOut;
|
|
int64_t wHeight;
|
|
int64_t wWidth;
|
|
int64_t outHeight;
|
|
int64_t outWidth;
|
|
int64_t group;
|
|
int64_t numChannelsInPerGroup;
|
|
int64_t numChannelsOutPerGroup;
|
|
int64_t k;
|
|
int64_t c;
|
|
int64_t p;
|
|
int64_t xbarSize;
|
|
int64_t pack;
|
|
uint64_t im2colElements;
|
|
bool hasBias;
|
|
bool isDepthwise;
|
|
};
|
|
|
|
struct RowInterval {
|
|
int64_t begin = 0;
|
|
int64_t end = 0;
|
|
};
|
|
|
|
struct ConvRowDemand {
|
|
RowInterval outputRows;
|
|
RowInterval neededInputRows;
|
|
RowInterval acquiredInputRows;
|
|
int64_t topHaloRows = 0;
|
|
int64_t bottomHaloRows = 0;
|
|
};
|
|
|
|
bool isDepthwiseConv(int64_t group, int64_t numChannelsIn, int64_t numChannelsOut, int64_t numChannelsInPerGroup);
|
|
|
|
ConvGeometry buildConvGeometry(const ConvLoweringState& state);
|
|
|
|
uint64_t chooseStreamChunkPositions(const ConvGeometry& geo, int64_t packFactor);
|
|
|
|
RowInterval computeConvInputRowsForOutputRows(RowInterval outputRows, const ConvLoweringState& state);
|
|
|
|
ConvRowDemand buildConvRowDemand(RowInterval outputRows, const ConvLoweringState& state);
|
|
|
|
} // namespace onnx_mlir
|