This commit is contained in:
@@ -0,0 +1,86 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
struct ConvLoweringState {
|
||||
mlir::Value x;
|
||||
mlir::Value w;
|
||||
mlir::Value b;
|
||||
mlir::RankedTensorType xType;
|
||||
mlir::RankedTensorType wType;
|
||||
mlir::RankedTensorType outType;
|
||||
int64_t batchSize;
|
||||
int64_t numChannelsIn;
|
||||
int64_t xHeight;
|
||||
int64_t xWidth;
|
||||
int64_t numChannelsOut;
|
||||
int64_t wHeight;
|
||||
int64_t wWidth;
|
||||
int64_t outHeight;
|
||||
int64_t outWidth;
|
||||
int64_t group;
|
||||
int64_t numChannelsInPerGroup;
|
||||
int64_t numChannelsOutPerGroup;
|
||||
int64_t padHeightBegin;
|
||||
int64_t padHeightEnd;
|
||||
int64_t padWidthBegin;
|
||||
int64_t padWidthEnd;
|
||||
int64_t strideHeight;
|
||||
int64_t strideWidth;
|
||||
int64_t dilationHeight;
|
||||
int64_t dilationWidth;
|
||||
bool hasBias;
|
||||
};
|
||||
|
||||
struct ConvGeometry {
|
||||
int64_t batchSize;
|
||||
int64_t numChannelsIn;
|
||||
int64_t xHeight;
|
||||
int64_t xWidth;
|
||||
int64_t numChannelsOut;
|
||||
int64_t wHeight;
|
||||
int64_t wWidth;
|
||||
int64_t outHeight;
|
||||
int64_t outWidth;
|
||||
int64_t group;
|
||||
int64_t numChannelsInPerGroup;
|
||||
int64_t numChannelsOutPerGroup;
|
||||
int64_t k;
|
||||
int64_t c;
|
||||
int64_t p;
|
||||
int64_t xbarSize;
|
||||
int64_t pack;
|
||||
uint64_t im2colElements;
|
||||
bool hasBias;
|
||||
bool isDepthwise;
|
||||
};
|
||||
|
||||
struct RowInterval {
|
||||
int64_t begin = 0;
|
||||
int64_t end = 0;
|
||||
};
|
||||
|
||||
struct ConvRowDemand {
|
||||
RowInterval outputRows;
|
||||
RowInterval neededInputRows;
|
||||
RowInterval acquiredInputRows;
|
||||
int64_t topHaloRows = 0;
|
||||
int64_t bottomHaloRows = 0;
|
||||
};
|
||||
|
||||
bool isDepthwiseConv(int64_t group, int64_t numChannelsIn, int64_t numChannelsOut, int64_t numChannelsInPerGroup);
|
||||
|
||||
ConvGeometry buildConvGeometry(const ConvLoweringState& state);
|
||||
|
||||
uint64_t chooseStreamChunkPositions(const ConvGeometry& geo, int64_t packFactor);
|
||||
|
||||
RowInterval computeConvInputRowsForOutputRows(RowInterval outputRows, const ConvLoweringState& state);
|
||||
|
||||
ConvRowDemand buildConvRowDemand(RowInterval outputRows, const ConvLoweringState& state);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
Reference in New Issue
Block a user