From 568529ea5f61d61d262ca86f91e04cb6cc690dcf Mon Sep 17 00:00:00 2001 From: NiccoloN Date: Fri, 20 Mar 2026 22:00:46 +0100 Subject: [PATCH] fix batched conv --- .../Conversion/ONNXToSpatial/Math/Conv.cpp | 65 ++++++++++--------- src/PIM/Dialect/Spatial/CMakeLists.txt | 3 +- 2 files changed, 36 insertions(+), 32 deletions(-) diff --git a/src/PIM/Conversion/ONNXToSpatial/Math/Conv.cpp b/src/PIM/Conversion/ONNXToSpatial/Math/Conv.cpp index cf8de30..a9d4dfb 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Math/Conv.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Math/Conv.cpp @@ -38,11 +38,11 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp, assert("Only support static shapes" && xType.hasStaticShape() && wType.hasStaticShape() && outType.hasStaticShape()); assert("Only support 2D convolution" && xType.getRank() == 4); - assert("Only support batch size 1 for input" && xType.getDimSize(0) == 1); // We need to understand what is group assert("Only support group=1" && convOp.getGroup() == 1); + const int64_t batchSize = xType.getDimSize(0); const int64_t numChannelsIn = xType.getDimSize(1); const int64_t xHeight = xType.getDimSize(2); const int64_t xWidth = xType.getDimSize(3); @@ -107,7 +107,8 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp, // B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns // Gemm output: [numPatches, cOut] const int64_t patchSize = numChannelsIn * wHeight * wWidth; - const int64_t numPatches = outHeight * outWidth; + const int64_t numPatchesPerBatch = outHeight * outWidth; + const int64_t numPatches = batchSize * numPatchesPerBatch; auto elemType = xType.getElementType(); auto im2colType = RankedTensorType::get({numPatches, patchSize}, elemType); @@ -115,7 +116,7 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp, auto wFlatType = RankedTensorType::get({numChannelsOut, patchSize}, wType.getElementType()); auto wTransType = RankedTensorType::get({patchSize, numChannelsOut}, wType.getElementType()); auto gemmOutType = RankedTensorType::get({numPatches, numChannelsOut}, outType.getElementType()); - auto nhwcType = RankedTensorType::get({1, outHeight, outWidth, numChannelsOut}, outType.getElementType()); + auto nhwcType = RankedTensorType::get({batchSize, outHeight, outWidth, numChannelsOut}, outType.getElementType()); // Prepare weight matrix W for crossbar storage: // W: [numChannelsOut, numChannelsIn, wHeight, wWidth] -> [numChannelsOut, patchSize] -> [patchSize, numChannelsOut] @@ -160,7 +161,7 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp, if (padHeightBegin || padHeightEnd || padWidthBegin || padWidthEnd) { const int64_t paddedHeight = xHeight + padHeightBegin + padHeightEnd; const int64_t paddedWidth = xWidth + padWidthBegin + padWidthEnd; - auto paddedType = RankedTensorType::get({1, numChannelsIn, paddedHeight, paddedWidth}, elemType); + auto paddedType = RankedTensorType::get({batchSize, numChannelsIn, paddedHeight, paddedWidth}, elemType); SmallVector lowPads = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0), rewriter.getIndexAttr(padHeightBegin), @@ -182,36 +183,38 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp, } // Build im2col [numPatches, patchSize]: - // For each output position (oh, ow), extract the patch from x + // For each batch/output position (n, oh, ow), extract the patch from x SmallVector im2colRows; im2colRows.reserve(numPatches); - for (int64_t oh = 0; oh < outHeight; oh++) { - for (int64_t ow = 0; ow < outWidth; ow++) { - SmallVector offsets = {rewriter.getIndexAttr(0), - rewriter.getIndexAttr(0), - rewriter.getIndexAttr(oh * strideHeight), - rewriter.getIndexAttr(ow * strideWidth)}; - SmallVector sizes = {rewriter.getIndexAttr(1), - rewriter.getIndexAttr(numChannelsIn), - rewriter.getIndexAttr(wHeight), - rewriter.getIndexAttr(wWidth)}; - SmallVector strides = {rewriter.getIndexAttr(1), - rewriter.getIndexAttr(1), - rewriter.getIndexAttr(dilationHeight), - rewriter.getIndexAttr(dilationWidth)}; - auto patchType = RankedTensorType::get({1, numChannelsIn, wHeight, wWidth}, elemType); - Value patch = tensor::ExtractSliceOp::create(rewriter, loc, patchType, paddedInput, offsets, sizes, strides); + for (int64_t n = 0; n < batchSize; n++) { + for (int64_t oh = 0; oh < outHeight; oh++) { + for (int64_t ow = 0; ow < outWidth; ow++) { + SmallVector offsets = {rewriter.getIndexAttr(n), + rewriter.getIndexAttr(0), + rewriter.getIndexAttr(oh * strideHeight), + rewriter.getIndexAttr(ow * strideWidth)}; + SmallVector sizes = {rewriter.getIndexAttr(1), + rewriter.getIndexAttr(numChannelsIn), + rewriter.getIndexAttr(wHeight), + rewriter.getIndexAttr(wWidth)}; + SmallVector strides = {rewriter.getIndexAttr(1), + rewriter.getIndexAttr(1), + rewriter.getIndexAttr(dilationHeight), + rewriter.getIndexAttr(dilationWidth)}; + auto patchType = RankedTensorType::get({1, numChannelsIn, wHeight, wWidth}, elemType); + Value patch = tensor::ExtractSliceOp::create(rewriter, loc, patchType, paddedInput, offsets, sizes, strides); - // Flatten [1, numChannelsIn, wHeight, wWidth] -> [1, patchSize] - Value row = tensor::CollapseShapeOp::create(rewriter, - loc, - rowType, - patch, - SmallVector { - {0}, - {1, 2, 3} - }); - im2colRows.push_back(row); + // Flatten [1, numChannelsIn, wHeight, wWidth] -> [1, patchSize] + Value row = tensor::CollapseShapeOp::create(rewriter, + loc, + rowType, + patch, + SmallVector { + {0}, + {1, 2, 3} + }); + im2colRows.push_back(row); + } } } diff --git a/src/PIM/Dialect/Spatial/CMakeLists.txt b/src/PIM/Dialect/Spatial/CMakeLists.txt index 2ff4c97..21266c4 100644 --- a/src/PIM/Dialect/Spatial/CMakeLists.txt +++ b/src/PIM/Dialect/Spatial/CMakeLists.txt @@ -7,9 +7,10 @@ add_onnx_mlir_library(SpatialOps Transforms/SpatialBufferizableOpInterface.cpp DEPENDS + OMONNXIncGen OMSpatialIncGen LINK_LIBS PUBLIC MLIRIR OMMlirDialects -) \ No newline at end of file +)