fix batched conv
This commit is contained in:
@@ -38,11 +38,11 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
|||||||
|
|
||||||
assert("Only support static shapes" && xType.hasStaticShape() && wType.hasStaticShape() && outType.hasStaticShape());
|
assert("Only support static shapes" && xType.hasStaticShape() && wType.hasStaticShape() && outType.hasStaticShape());
|
||||||
assert("Only support 2D convolution" && xType.getRank() == 4);
|
assert("Only support 2D convolution" && xType.getRank() == 4);
|
||||||
assert("Only support batch size 1 for input" && xType.getDimSize(0) == 1);
|
|
||||||
|
|
||||||
// We need to understand what is group
|
// We need to understand what is group
|
||||||
assert("Only support group=1" && convOp.getGroup() == 1);
|
assert("Only support group=1" && convOp.getGroup() == 1);
|
||||||
|
|
||||||
|
const int64_t batchSize = xType.getDimSize(0);
|
||||||
const int64_t numChannelsIn = xType.getDimSize(1);
|
const int64_t numChannelsIn = xType.getDimSize(1);
|
||||||
const int64_t xHeight = xType.getDimSize(2);
|
const int64_t xHeight = xType.getDimSize(2);
|
||||||
const int64_t xWidth = xType.getDimSize(3);
|
const int64_t xWidth = xType.getDimSize(3);
|
||||||
@@ -107,7 +107,8 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
|||||||
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
|
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
|
||||||
// Gemm output: [numPatches, cOut]
|
// Gemm output: [numPatches, cOut]
|
||||||
const int64_t patchSize = numChannelsIn * wHeight * wWidth;
|
const int64_t patchSize = numChannelsIn * wHeight * wWidth;
|
||||||
const int64_t numPatches = outHeight * outWidth;
|
const int64_t numPatchesPerBatch = outHeight * outWidth;
|
||||||
|
const int64_t numPatches = batchSize * numPatchesPerBatch;
|
||||||
|
|
||||||
auto elemType = xType.getElementType();
|
auto elemType = xType.getElementType();
|
||||||
auto im2colType = RankedTensorType::get({numPatches, patchSize}, elemType);
|
auto im2colType = RankedTensorType::get({numPatches, patchSize}, elemType);
|
||||||
@@ -115,7 +116,7 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
|||||||
auto wFlatType = RankedTensorType::get({numChannelsOut, patchSize}, wType.getElementType());
|
auto wFlatType = RankedTensorType::get({numChannelsOut, patchSize}, wType.getElementType());
|
||||||
auto wTransType = RankedTensorType::get({patchSize, numChannelsOut}, wType.getElementType());
|
auto wTransType = RankedTensorType::get({patchSize, numChannelsOut}, wType.getElementType());
|
||||||
auto gemmOutType = RankedTensorType::get({numPatches, numChannelsOut}, outType.getElementType());
|
auto gemmOutType = RankedTensorType::get({numPatches, numChannelsOut}, outType.getElementType());
|
||||||
auto nhwcType = RankedTensorType::get({1, outHeight, outWidth, numChannelsOut}, outType.getElementType());
|
auto nhwcType = RankedTensorType::get({batchSize, outHeight, outWidth, numChannelsOut}, outType.getElementType());
|
||||||
|
|
||||||
// Prepare weight matrix W for crossbar storage:
|
// Prepare weight matrix W for crossbar storage:
|
||||||
// W: [numChannelsOut, numChannelsIn, wHeight, wWidth] -> [numChannelsOut, patchSize] -> [patchSize, numChannelsOut]
|
// W: [numChannelsOut, numChannelsIn, wHeight, wWidth] -> [numChannelsOut, patchSize] -> [patchSize, numChannelsOut]
|
||||||
@@ -160,7 +161,7 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
|||||||
if (padHeightBegin || padHeightEnd || padWidthBegin || padWidthEnd) {
|
if (padHeightBegin || padHeightEnd || padWidthBegin || padWidthEnd) {
|
||||||
const int64_t paddedHeight = xHeight + padHeightBegin + padHeightEnd;
|
const int64_t paddedHeight = xHeight + padHeightBegin + padHeightEnd;
|
||||||
const int64_t paddedWidth = xWidth + padWidthBegin + padWidthEnd;
|
const int64_t paddedWidth = xWidth + padWidthBegin + padWidthEnd;
|
||||||
auto paddedType = RankedTensorType::get({1, numChannelsIn, paddedHeight, paddedWidth}, elemType);
|
auto paddedType = RankedTensorType::get({batchSize, numChannelsIn, paddedHeight, paddedWidth}, elemType);
|
||||||
SmallVector<OpFoldResult> lowPads = {rewriter.getIndexAttr(0),
|
SmallVector<OpFoldResult> lowPads = {rewriter.getIndexAttr(0),
|
||||||
rewriter.getIndexAttr(0),
|
rewriter.getIndexAttr(0),
|
||||||
rewriter.getIndexAttr(padHeightBegin),
|
rewriter.getIndexAttr(padHeightBegin),
|
||||||
@@ -182,36 +183,38 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Build im2col [numPatches, patchSize]:
|
// Build im2col [numPatches, patchSize]:
|
||||||
// For each output position (oh, ow), extract the patch from x
|
// For each batch/output position (n, oh, ow), extract the patch from x
|
||||||
SmallVector<Value> im2colRows;
|
SmallVector<Value> im2colRows;
|
||||||
im2colRows.reserve(numPatches);
|
im2colRows.reserve(numPatches);
|
||||||
for (int64_t oh = 0; oh < outHeight; oh++) {
|
for (int64_t n = 0; n < batchSize; n++) {
|
||||||
for (int64_t ow = 0; ow < outWidth; ow++) {
|
for (int64_t oh = 0; oh < outHeight; oh++) {
|
||||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(0),
|
for (int64_t ow = 0; ow < outWidth; ow++) {
|
||||||
rewriter.getIndexAttr(0),
|
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(n),
|
||||||
rewriter.getIndexAttr(oh * strideHeight),
|
rewriter.getIndexAttr(0),
|
||||||
rewriter.getIndexAttr(ow * strideWidth)};
|
rewriter.getIndexAttr(oh * strideHeight),
|
||||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
|
rewriter.getIndexAttr(ow * strideWidth)};
|
||||||
rewriter.getIndexAttr(numChannelsIn),
|
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
|
||||||
rewriter.getIndexAttr(wHeight),
|
rewriter.getIndexAttr(numChannelsIn),
|
||||||
rewriter.getIndexAttr(wWidth)};
|
rewriter.getIndexAttr(wHeight),
|
||||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1),
|
rewriter.getIndexAttr(wWidth)};
|
||||||
rewriter.getIndexAttr(1),
|
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1),
|
||||||
rewriter.getIndexAttr(dilationHeight),
|
rewriter.getIndexAttr(1),
|
||||||
rewriter.getIndexAttr(dilationWidth)};
|
rewriter.getIndexAttr(dilationHeight),
|
||||||
auto patchType = RankedTensorType::get({1, numChannelsIn, wHeight, wWidth}, elemType);
|
rewriter.getIndexAttr(dilationWidth)};
|
||||||
Value patch = tensor::ExtractSliceOp::create(rewriter, loc, patchType, paddedInput, offsets, sizes, strides);
|
auto patchType = RankedTensorType::get({1, numChannelsIn, wHeight, wWidth}, elemType);
|
||||||
|
Value patch = tensor::ExtractSliceOp::create(rewriter, loc, patchType, paddedInput, offsets, sizes, strides);
|
||||||
|
|
||||||
// Flatten [1, numChannelsIn, wHeight, wWidth] -> [1, patchSize]
|
// Flatten [1, numChannelsIn, wHeight, wWidth] -> [1, patchSize]
|
||||||
Value row = tensor::CollapseShapeOp::create(rewriter,
|
Value row = tensor::CollapseShapeOp::create(rewriter,
|
||||||
loc,
|
loc,
|
||||||
rowType,
|
rowType,
|
||||||
patch,
|
patch,
|
||||||
SmallVector<ReassociationIndices> {
|
SmallVector<ReassociationIndices> {
|
||||||
{0},
|
{0},
|
||||||
{1, 2, 3}
|
{1, 2, 3}
|
||||||
});
|
});
|
||||||
im2colRows.push_back(row);
|
im2colRows.push_back(row);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,9 +7,10 @@ add_onnx_mlir_library(SpatialOps
|
|||||||
Transforms/SpatialBufferizableOpInterface.cpp
|
Transforms/SpatialBufferizableOpInterface.cpp
|
||||||
|
|
||||||
DEPENDS
|
DEPENDS
|
||||||
|
OMONNXIncGen
|
||||||
OMSpatialIncGen
|
OMSpatialIncGen
|
||||||
|
|
||||||
LINK_LIBS PUBLIC
|
LINK_LIBS PUBLIC
|
||||||
MLIRIR
|
MLIRIR
|
||||||
OMMlirDialects
|
OMMlirDialects
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user