fix batched conv
Some checks failed
Validate Operations / config (push) Successful in 1m3s
Validate Operations / build-mlir-cache (push) Successful in 3m40s
Validate Operations / validate (push) Failing after 2m37s

This commit is contained in:
NiccoloN
2026-03-20 22:00:46 +01:00
parent ca2e1645bb
commit 568529ea5f
2 changed files with 36 additions and 32 deletions

View File

@@ -38,11 +38,11 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
assert("Only support static shapes" && xType.hasStaticShape() && wType.hasStaticShape() && outType.hasStaticShape()); assert("Only support static shapes" && xType.hasStaticShape() && wType.hasStaticShape() && outType.hasStaticShape());
assert("Only support 2D convolution" && xType.getRank() == 4); assert("Only support 2D convolution" && xType.getRank() == 4);
assert("Only support batch size 1 for input" && xType.getDimSize(0) == 1);
// We need to understand what is group // We need to understand what is group
assert("Only support group=1" && convOp.getGroup() == 1); assert("Only support group=1" && convOp.getGroup() == 1);
const int64_t batchSize = xType.getDimSize(0);
const int64_t numChannelsIn = xType.getDimSize(1); const int64_t numChannelsIn = xType.getDimSize(1);
const int64_t xHeight = xType.getDimSize(2); const int64_t xHeight = xType.getDimSize(2);
const int64_t xWidth = xType.getDimSize(3); const int64_t xWidth = xType.getDimSize(3);
@@ -107,7 +107,8 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns // B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
// Gemm output: [numPatches, cOut] // Gemm output: [numPatches, cOut]
const int64_t patchSize = numChannelsIn * wHeight * wWidth; const int64_t patchSize = numChannelsIn * wHeight * wWidth;
const int64_t numPatches = outHeight * outWidth; const int64_t numPatchesPerBatch = outHeight * outWidth;
const int64_t numPatches = batchSize * numPatchesPerBatch;
auto elemType = xType.getElementType(); auto elemType = xType.getElementType();
auto im2colType = RankedTensorType::get({numPatches, patchSize}, elemType); auto im2colType = RankedTensorType::get({numPatches, patchSize}, elemType);
@@ -115,7 +116,7 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
auto wFlatType = RankedTensorType::get({numChannelsOut, patchSize}, wType.getElementType()); auto wFlatType = RankedTensorType::get({numChannelsOut, patchSize}, wType.getElementType());
auto wTransType = RankedTensorType::get({patchSize, numChannelsOut}, wType.getElementType()); auto wTransType = RankedTensorType::get({patchSize, numChannelsOut}, wType.getElementType());
auto gemmOutType = RankedTensorType::get({numPatches, numChannelsOut}, outType.getElementType()); auto gemmOutType = RankedTensorType::get({numPatches, numChannelsOut}, outType.getElementType());
auto nhwcType = RankedTensorType::get({1, outHeight, outWidth, numChannelsOut}, outType.getElementType()); auto nhwcType = RankedTensorType::get({batchSize, outHeight, outWidth, numChannelsOut}, outType.getElementType());
// Prepare weight matrix W for crossbar storage: // Prepare weight matrix W for crossbar storage:
// W: [numChannelsOut, numChannelsIn, wHeight, wWidth] -> [numChannelsOut, patchSize] -> [patchSize, numChannelsOut] // W: [numChannelsOut, numChannelsIn, wHeight, wWidth] -> [numChannelsOut, patchSize] -> [patchSize, numChannelsOut]
@@ -160,7 +161,7 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
if (padHeightBegin || padHeightEnd || padWidthBegin || padWidthEnd) { if (padHeightBegin || padHeightEnd || padWidthBegin || padWidthEnd) {
const int64_t paddedHeight = xHeight + padHeightBegin + padHeightEnd; const int64_t paddedHeight = xHeight + padHeightBegin + padHeightEnd;
const int64_t paddedWidth = xWidth + padWidthBegin + padWidthEnd; const int64_t paddedWidth = xWidth + padWidthBegin + padWidthEnd;
auto paddedType = RankedTensorType::get({1, numChannelsIn, paddedHeight, paddedWidth}, elemType); auto paddedType = RankedTensorType::get({batchSize, numChannelsIn, paddedHeight, paddedWidth}, elemType);
SmallVector<OpFoldResult> lowPads = {rewriter.getIndexAttr(0), SmallVector<OpFoldResult> lowPads = {rewriter.getIndexAttr(0),
rewriter.getIndexAttr(0), rewriter.getIndexAttr(0),
rewriter.getIndexAttr(padHeightBegin), rewriter.getIndexAttr(padHeightBegin),
@@ -182,36 +183,38 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
} }
// Build im2col [numPatches, patchSize]: // Build im2col [numPatches, patchSize]:
// For each output position (oh, ow), extract the patch from x // For each batch/output position (n, oh, ow), extract the patch from x
SmallVector<Value> im2colRows; SmallVector<Value> im2colRows;
im2colRows.reserve(numPatches); im2colRows.reserve(numPatches);
for (int64_t oh = 0; oh < outHeight; oh++) { for (int64_t n = 0; n < batchSize; n++) {
for (int64_t ow = 0; ow < outWidth; ow++) { for (int64_t oh = 0; oh < outHeight; oh++) {
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(0), for (int64_t ow = 0; ow < outWidth; ow++) {
rewriter.getIndexAttr(0), SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(n),
rewriter.getIndexAttr(oh * strideHeight), rewriter.getIndexAttr(0),
rewriter.getIndexAttr(ow * strideWidth)}; rewriter.getIndexAttr(oh * strideHeight),
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(ow * strideWidth)};
rewriter.getIndexAttr(numChannelsIn), SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(wHeight), rewriter.getIndexAttr(numChannelsIn),
rewriter.getIndexAttr(wWidth)}; rewriter.getIndexAttr(wHeight),
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(wWidth)};
rewriter.getIndexAttr(1), SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(dilationHeight), rewriter.getIndexAttr(1),
rewriter.getIndexAttr(dilationWidth)}; rewriter.getIndexAttr(dilationHeight),
auto patchType = RankedTensorType::get({1, numChannelsIn, wHeight, wWidth}, elemType); rewriter.getIndexAttr(dilationWidth)};
Value patch = tensor::ExtractSliceOp::create(rewriter, loc, patchType, paddedInput, offsets, sizes, strides); auto patchType = RankedTensorType::get({1, numChannelsIn, wHeight, wWidth}, elemType);
Value patch = tensor::ExtractSliceOp::create(rewriter, loc, patchType, paddedInput, offsets, sizes, strides);
// Flatten [1, numChannelsIn, wHeight, wWidth] -> [1, patchSize] // Flatten [1, numChannelsIn, wHeight, wWidth] -> [1, patchSize]
Value row = tensor::CollapseShapeOp::create(rewriter, Value row = tensor::CollapseShapeOp::create(rewriter,
loc, loc,
rowType, rowType,
patch, patch,
SmallVector<ReassociationIndices> { SmallVector<ReassociationIndices> {
{0}, {0},
{1, 2, 3} {1, 2, 3}
}); });
im2colRows.push_back(row); im2colRows.push_back(row);
}
} }
} }

View File

@@ -7,9 +7,10 @@ add_onnx_mlir_library(SpatialOps
Transforms/SpatialBufferizableOpInterface.cpp Transforms/SpatialBufferizableOpInterface.cpp
DEPENDS DEPENDS
OMONNXIncGen
OMSpatialIncGen OMSpatialIncGen
LINK_LIBS PUBLIC LINK_LIBS PUBLIC
MLIRIR MLIRIR
OMMlirDialects OMMlirDialects
) )