From e866ec6f873ba7d8ca0039ac9f4f439b2f49050e Mon Sep 17 00:00:00 2001 From: NiccoloN Date: Tue, 14 Apr 2026 11:06:35 +0200 Subject: [PATCH] convolution uses crossbar size better --- .../ONNXToSpatial/Patterns/Math/Conv.cpp | 559 ++++++++++++++---- 1 file changed, 443 insertions(+), 116 deletions(-) diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp index 7e56e6c..00c946f 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp @@ -1,12 +1,16 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "llvm/ADT/SmallVector.h" +#include #include +#include "src/Accelerators/PIM/Common/PimCommon.hpp" +#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" @@ -24,122 +28,150 @@ struct ConvToGemm : OpConversionPattern { ConversionPatternRewriter& rewriter) const override; }; -} // namespace +static DenseElementsAttr getDenseConstantAttr(Value value) { + if (auto constantOp = value.getDefiningOp()) + return dyn_cast(constantOp.getValue()); -LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp, - ONNXConvOpAdaptor convOpAdaptor, - ConversionPatternRewriter& rewriter) const { - Location loc = convOp.getLoc(); - Value x = convOpAdaptor.getX(); - Value w = convOpAdaptor.getW(); - Value b = convOpAdaptor.getB(); + if (auto constantOp = value.getDefiningOp()) + return dyn_cast_or_null(constantOp.getValueAttr()); - auto xType = cast(x.getType()); - auto wType = cast(w.getType()); - auto outType = cast(convOp.getY().getType()); + return nullptr; +} - assert("Only support static shapes" && xType.hasStaticShape() && wType.hasStaticShape() && outType.hasStaticShape()); - assert("Only support 2D convolution" && xType.getRank() == 4); +static int64_t getI64FromArrayAttr(ArrayAttr arr, size_t idx) { return cast(arr[idx]).getInt(); } - // We need to understand what is group - assert("Only support group=1" && convOp.getGroup() == 1); +static Value expandBiasIfNeeded(Value bias, ConversionPatternRewriter& rewriter, Location loc) { + auto biasType = cast(bias.getType()); + if (biasType.getRank() != 1) + return bias; - const int64_t batchSize = xType.getDimSize(0); - const int64_t numChannelsIn = xType.getDimSize(1); - const int64_t xHeight = xType.getDimSize(2); - const int64_t xWidth = xType.getDimSize(3); - const int64_t numChannelsOut = wType.getDimSize(0); - const int64_t wHeight = wType.getDimSize(2); - const int64_t wWidth = wType.getDimSize(3); - const int64_t outHeight = outType.getDimSize(2); - const int64_t outWidth = outType.getDimSize(3); + auto expandedBiasType = RankedTensorType::get({1, biasType.getDimSize(0)}, biasType.getElementType()); + return tensor::ExpandShapeOp::create(rewriter, + loc, + expandedBiasType, + bias, + SmallVector { + {0, 1} + }); +} - // Read optional conv attributes (ONNX defaults: stride=1, dilation=1, pad=0) - auto getI64 = [](ArrayAttr arr, size_t idx) -> int64_t { return cast(arr[idx]).getInt(); }; +static Value createPaddedRows(Value tensorValue, + RankedTensorType tensorType, + int64_t paddedRows, + ConversionPatternRewriter& rewriter, + Location loc) { + if (tensorType.getDimSize(0) == paddedRows) + return tensorValue; - const auto stridesAttr = convOp.getStrides(); - const auto dilationsAttr = convOp.getDilations(); - const auto padsAttr = convOp.getPads(); + auto paddedType = RankedTensorType::get({paddedRows, tensorType.getDimSize(1)}, tensorType.getElementType()); + SmallVector lowPads = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)}; + SmallVector highPads = {rewriter.getIndexAttr(paddedRows - tensorType.getDimSize(0)), + rewriter.getIndexAttr(0)}; + auto padOp = tensor::PadOp::create(rewriter, loc, paddedType, tensorValue, lowPads, highPads); + auto* padBlock = new Block(); + for (int i = 0; i < 2; i++) + padBlock->addArgument(rewriter.getIndexType(), loc); + padOp.getRegion().push_back(padBlock); + rewriter.setInsertionPointToStart(padBlock); + auto zero = arith::ConstantOp::create( + rewriter, loc, tensorType.getElementType(), rewriter.getZeroAttr(tensorType.getElementType())); + tensor::YieldOp::create(rewriter, loc, zero.getResult()); + rewriter.setInsertionPointAfter(padOp); + return padOp.getResult(); +} - const int64_t strideHeight = stridesAttr ? getI64(*stridesAttr, 0) : 1; - const int64_t strideWidth = stridesAttr ? getI64(*stridesAttr, 1) : 1; - const int64_t dilationHeight = dilationsAttr ? getI64(*dilationsAttr, 0) : 1; - const int64_t dilationWidth = dilationsAttr ? getI64(*dilationsAttr, 1) : 1; +static Value buildPackedWeight(DenseElementsAttr wDenseAttr, + Value wTrans, + RankedTensorType wType, + int64_t numChannelsIn, + int64_t numChannelsOut, + int64_t wHeight, + int64_t wWidth, + int64_t patchSize, + int64_t packFactor, + ConversionPatternRewriter& rewriter, + Location loc) { + if (packFactor == 1) + return wTrans; - int64_t padHeightBegin = 0; - int64_t padHeightEnd = 0; - int64_t padWidthBegin = 0; - int64_t padWidthEnd = 0; + auto packedWeightType = + RankedTensorType::get({packFactor * patchSize, packFactor * numChannelsOut}, wType.getElementType()); + SmallVector sourceValues(wDenseAttr.getValues()); + SmallVector packedValues(packedWeightType.getNumElements(), + cast(rewriter.getZeroAttr(wType.getElementType()))); - if (padsAttr) { - padHeightBegin = getI64(*padsAttr, 0); - padWidthBegin = getI64(*padsAttr, 1); - padHeightEnd = getI64(*padsAttr, 2); - padWidthEnd = getI64(*padsAttr, 3); - } - else { - // Compute padding from auto_pad attribute - const auto autoPad = convOp.getAutoPad(); - if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") { - const int64_t effectiveKernelH = (wHeight - 1) * dilationHeight + 1; - const int64_t effectiveKernelW = (wWidth - 1) * dilationWidth + 1; - const int64_t totalPadH = - std::max(static_cast(0), (outHeight - 1) * strideHeight + effectiveKernelH - xHeight); - const int64_t totalPadW = - std::max(static_cast(0), (outWidth - 1) * strideWidth + effectiveKernelW - xWidth); - - if (autoPad == "SAME_UPPER") { - padHeightBegin = totalPadH / 2; - padHeightEnd = totalPadH - padHeightBegin; - padWidthBegin = totalPadW / 2; - padWidthEnd = totalPadW - padWidthBegin; - } - else { // SAME_LOWER - padHeightEnd = totalPadH / 2; - padHeightBegin = totalPadH - padHeightEnd; - padWidthEnd = totalPadW / 2; - padWidthBegin = totalPadW - padWidthEnd; + for (int64_t copyId = 0; copyId < packFactor; copyId++) { + for (int64_t outChannel = 0; outChannel < numChannelsOut; outChannel++) { + for (int64_t inChannel = 0; inChannel < numChannelsIn; inChannel++) { + for (int64_t kernelH = 0; kernelH < wHeight; kernelH++) { + for (int64_t kernelW = 0; kernelW < wWidth; kernelW++) { + const int64_t sourceFlatIndex = + (((outChannel * numChannelsIn) + inChannel) * wHeight + kernelH) * wWidth + kernelW; + const int64_t patchIndex = ((inChannel * wHeight) + kernelH) * wWidth + kernelW; + const int64_t targetRow = copyId * patchSize + patchIndex; + const int64_t targetCol = copyId * numChannelsOut + outChannel; + packedValues[targetRow * (packFactor * numChannelsOut) + targetCol] = sourceValues[sourceFlatIndex]; + } + } } } - // "NOTSET" or "VALID" -> all pads stay 0 } - // im2col layout (flipped with respect to the standard, so filters sit in B = crossbar): - // A (im2col): [numPatches, patchSize] -- one row per output spatial position - // B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns - // Gemm output: [numPatches, cOut] - const int64_t patchSize = numChannelsIn * wHeight * wWidth; - const int64_t numPatchesPerBatch = outHeight * outWidth; - const int64_t numPatches = batchSize * numPatchesPerBatch; + auto packedAttr = DenseElementsAttr::get(packedWeightType, packedValues); + return arith::ConstantOp::create(rewriter, loc, packedWeightType, packedAttr); +} +static Value buildPackedBias(bool hasBias, + Value gemmBias, + Value biasMatrix, + DenseElementsAttr biasDenseAttr, + RankedTensorType outType, + int64_t numChannelsOut, + int64_t packFactor, + ConversionPatternRewriter& rewriter, + Location loc) { + if (!hasBias) + return gemmBias; + + if (packFactor == 1) + return biasMatrix; + + SmallVector sourceValues(biasDenseAttr.getValues()); + SmallVector packedValues; + packedValues.reserve(packFactor * numChannelsOut); + for (int64_t copyId = 0; copyId < packFactor; copyId++) + packedValues.append(sourceValues.begin(), sourceValues.end()); + + auto packedBiasType = RankedTensorType::get({1, packFactor * numChannelsOut}, outType.getElementType()); + auto packedBiasAttr = DenseElementsAttr::get(packedBiasType, packedValues); + return arith::ConstantOp::create(rewriter, loc, packedBiasType, packedBiasAttr).getResult(); +} + +static Value createIm2colCompute(Value x, + RankedTensorType xType, + RankedTensorType im2colType, + RankedTensorType rowType, + int64_t batchSize, + int64_t numChannelsIn, + int64_t xHeight, + int64_t xWidth, + int64_t wHeight, + int64_t wWidth, + int64_t padHeightBegin, + int64_t padHeightEnd, + int64_t padWidthBegin, + int64_t padWidthEnd, + int64_t strideHeight, + int64_t strideWidth, + int64_t dilationHeight, + int64_t dilationWidth, + int64_t outWidth, + int64_t patchSize, + int64_t numPatches, + int64_t numPatchesPerBatch, + ConversionPatternRewriter& rewriter, + Location loc) { auto elemType = xType.getElementType(); - auto im2colType = RankedTensorType::get({numPatches, patchSize}, elemType); - auto rowType = RankedTensorType::get({1, patchSize}, elemType); - auto wFlatType = RankedTensorType::get({numChannelsOut, patchSize}, wType.getElementType()); - auto wTransType = RankedTensorType::get({patchSize, numChannelsOut}, wType.getElementType()); - auto gemmOutType = RankedTensorType::get({numPatches, numChannelsOut}, outType.getElementType()); - auto nhwcType = RankedTensorType::get({batchSize, outHeight, outWidth, numChannelsOut}, outType.getElementType()); - - // Prepare weight matrix W for crossbar storage: - // W: [numChannelsOut, numChannelsIn, wHeight, wWidth] -> [numChannelsOut, patchSize] -> [patchSize, numChannelsOut] - Value wFlat = tensor::CollapseShapeOp::create(rewriter, - loc, - wFlatType, - w, - SmallVector { - {0}, - {1, 2, 3} - }); - Value wTrans = ONNXTransposeOp::create(rewriter, loc, wTransType, wFlat, rewriter.getI64ArrayAttr({1, 0})); - - // Pass bias through directly; Gemm handles rank-1 C canonicalization. - bool hasB = !isa(b.getDefiningOp()); - Value gemmC; - if (hasB) - gemmC = b; - else - gemmC = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType()); - constexpr size_t numInputs = 1; auto im2colComputeOp = createSpatCompute(rewriter, loc, im2colType, {}, x, [&](Value xArg) { Value paddedInput = xArg; @@ -226,23 +258,104 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp, Value im2col = im2colLoop.getResult(0); spatial::SpatYieldOp::create(rewriter, loc, im2col); }); + return im2colComputeOp.getResult(0); +} - // Gemm: A @ B + C = im2col @ W^T + b - // [numPatches, patchSize] @ [patchSize, numChannelsOut] + [1, numChannelsOut] -> [numPatches, numChannelsOut] - auto gemmOp = ONNXGemmOp::create(rewriter, - loc, - gemmOutType, - im2colComputeOp.getResult(0), - wTrans, - gemmC, - rewriter.getF32FloatAttr(1.0f), - rewriter.getF32FloatAttr(1.0f), - rewriter.getBoolAttr(false), - rewriter.getBoolAttr(false)); - Value gemmOut = gemmOp.getY(); +static Value createPackedIm2colRows(Value im2col, + RankedTensorType im2colType, + Type elemType, + int64_t numPatches, + int64_t patchSize, + int64_t packFactor, + ConversionPatternRewriter& rewriter, + Location loc) { + if (packFactor == 1) + return im2col; + const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor); + const int64_t paddedNumPatches = packedNumRows * packFactor; + auto groupedType = RankedTensorType::get({packedNumRows, packFactor, patchSize}, elemType); + auto packedType = RankedTensorType::get({packedNumRows, packFactor * patchSize}, elemType); + auto packedComputeOp = createSpatCompute<1>(rewriter, loc, packedType, {}, im2col, [&](Value im2colArg) { + Value paddedIm2col = createPaddedRows(im2colArg, im2colType, paddedNumPatches, rewriter, loc); + Value groupedIm2col = tensor::ExpandShapeOp::create(rewriter, + loc, + groupedType, + paddedIm2col, + SmallVector { + {0, 1}, + {2} + }); + Value packedIm2col = tensor::CollapseShapeOp::create(rewriter, + loc, + packedType, + groupedIm2col, + SmallVector { + {0}, + {1, 2} + }); + spatial::SpatYieldOp::create(rewriter, loc, packedIm2col); + }); + return packedComputeOp.getResult(0); +} + +static Value createUnpackedOutput(Value packedOutput, + RankedTensorType gemmOutType, + RankedTensorType outType, + int64_t numPatches, + int64_t numChannelsOut, + int64_t packFactor, + ConversionPatternRewriter& rewriter, + Location loc) { + if (packFactor == 1) + return packedOutput; + + const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor); + const int64_t paddedNumPatches = packedNumRows * packFactor; + auto expandedType = RankedTensorType::get({packedNumRows, packFactor, numChannelsOut}, outType.getElementType()); + auto paddedType = RankedTensorType::get({paddedNumPatches, numChannelsOut}, outType.getElementType()); + auto unpackComputeOp = createSpatCompute<1>(rewriter, loc, gemmOutType, {}, packedOutput, [&](Value packedOutputArg) { + Value expandedOutput = tensor::ExpandShapeOp::create(rewriter, + loc, + expandedType, + packedOutputArg, + SmallVector { + {0}, + {1, 2} + }); + Value paddedOutput = tensor::CollapseShapeOp::create(rewriter, + loc, + paddedType, + expandedOutput, + SmallVector { + {0, 1}, + {2} + }); + + Value unpackedOutput = paddedOutput; + if (paddedNumPatches != numPatches) { + SmallVector offsets = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)}; + SmallVector sizes = {rewriter.getIndexAttr(numPatches), rewriter.getIndexAttr(numChannelsOut)}; + SmallVector strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; + unpackedOutput = + tensor::ExtractSliceOp::create(rewriter, loc, gemmOutType, paddedOutput, offsets, sizes, strides); + } + + spatial::SpatYieldOp::create(rewriter, loc, unpackedOutput); + }); + return unpackComputeOp.getResult(0); +} + +static Value createCollectedConvOutput(Value gemmOut, + Type convType, + RankedTensorType nhwcType, + RankedTensorType outType, + ConversionPatternRewriter& rewriter, + Location loc) { auto collectComputeOp = - createSpatCompute(rewriter, loc, convOp.getType(), {}, ValueRange {gemmOut}, [&](Value gemmOutArg) { + createSpatCompute(rewriter, loc, convType, {}, ValueRange {gemmOut}, [&](ValueRange gemmOutArgs) { + Value gemmOutArg = gemmOutArgs.front(); + // Restore to NCHW layout: // [numPatches, numChannelsOut] // -> [1, outHeight, outWidth, numChannelsOut] @@ -256,11 +369,225 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp, {3} }); Value nchwOut = ONNXTransposeOp::create(rewriter, loc, outType, nhwcOut, rewriter.getI64ArrayAttr({0, 3, 1, 2})); - spatial::SpatYieldOp::create(rewriter, loc, nchwOut); }); + return collectComputeOp.getResult(0); +} - rewriter.replaceOp(convOp, collectComputeOp.getResult(0)); +} // namespace + +LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp, + ONNXConvOpAdaptor convOpAdaptor, + ConversionPatternRewriter& rewriter) const { + Location loc = convOp.getLoc(); + Value x = convOpAdaptor.getX(); + Value w = convOpAdaptor.getW(); + Value b = convOpAdaptor.getB(); + + auto xType = cast(x.getType()); + auto wType = cast(w.getType()); + auto outType = cast(convOp.getY().getType()); + + assert("Only support static shapes" && xType.hasStaticShape() && wType.hasStaticShape() && outType.hasStaticShape()); + assert("Only support 2D convolution" && xType.getRank() == 4); + + // We need to understand what is group + assert("Only support group=1" && convOp.getGroup() == 1); + + const int64_t batchSize = xType.getDimSize(0); + const int64_t numChannelsIn = xType.getDimSize(1); + const int64_t xHeight = xType.getDimSize(2); + const int64_t xWidth = xType.getDimSize(3); + const int64_t numChannelsOut = wType.getDimSize(0); + const int64_t wHeight = wType.getDimSize(2); + const int64_t wWidth = wType.getDimSize(3); + const int64_t outHeight = outType.getDimSize(2); + const int64_t outWidth = outType.getDimSize(3); + + // Read optional conv attributes (ONNX defaults: stride=1, dilation=1, pad=0) + const auto stridesAttr = convOp.getStrides(); + const auto dilationsAttr = convOp.getDilations(); + const auto padsAttr = convOp.getPads(); + + const int64_t strideHeight = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 0) : 1; + const int64_t strideWidth = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 1) : 1; + const int64_t dilationHeight = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 0) : 1; + const int64_t dilationWidth = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 1) : 1; + + int64_t padHeightBegin = 0; + int64_t padHeightEnd = 0; + int64_t padWidthBegin = 0; + int64_t padWidthEnd = 0; + + if (padsAttr) { + padHeightBegin = getI64FromArrayAttr(*padsAttr, 0); + padWidthBegin = getI64FromArrayAttr(*padsAttr, 1); + padHeightEnd = getI64FromArrayAttr(*padsAttr, 2); + padWidthEnd = getI64FromArrayAttr(*padsAttr, 3); + } + else { + // Compute padding from auto_pad attribute + const auto autoPad = convOp.getAutoPad(); + if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") { + const int64_t effectiveKernelH = (wHeight - 1) * dilationHeight + 1; + const int64_t effectiveKernelW = (wWidth - 1) * dilationWidth + 1; + const int64_t totalPadH = + std::max(static_cast(0), (outHeight - 1) * strideHeight + effectiveKernelH - xHeight); + const int64_t totalPadW = + std::max(static_cast(0), (outWidth - 1) * strideWidth + effectiveKernelW - xWidth); + + if (autoPad == "SAME_UPPER") { + padHeightBegin = totalPadH / 2; + padHeightEnd = totalPadH - padHeightBegin; + padWidthBegin = totalPadW / 2; + padWidthEnd = totalPadW - padWidthBegin; + } + else { // SAME_LOWER + padHeightEnd = totalPadH / 2; + padHeightBegin = totalPadH - padHeightEnd; + padWidthEnd = totalPadW / 2; + padWidthBegin = totalPadW - padWidthEnd; + } + } + // "NOTSET" or "VALID" -> all pads stay 0 + } + + // im2col layout (flipped with respect to the standard, so filters sit in B = crossbar): + // A (im2col): [numPatches, patchSize] -- one row per output spatial position + // B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns + // Gemm output: [numPatches, cOut] + const int64_t patchSize = numChannelsIn * wHeight * wWidth; + const int64_t numPatchesPerBatch = outHeight * outWidth; + const int64_t numPatches = batchSize * numPatchesPerBatch; + + auto elemType = xType.getElementType(); + auto im2colType = RankedTensorType::get({numPatches, patchSize}, elemType); + auto rowType = RankedTensorType::get({1, patchSize}, elemType); + auto wFlatType = RankedTensorType::get({numChannelsOut, patchSize}, wType.getElementType()); + auto wTransType = RankedTensorType::get({patchSize, numChannelsOut}, wType.getElementType()); + auto gemmOutType = RankedTensorType::get({numPatches, numChannelsOut}, outType.getElementType()); + auto nhwcType = RankedTensorType::get({batchSize, outHeight, outWidth, numChannelsOut}, outType.getElementType()); + + const int64_t xbarSize = static_cast(crossbarSize.getValue()); + const int64_t wMaxDim = std::max(patchSize, numChannelsOut); + const int64_t maxParallelPixels = std::max(1, xbarSize / wMaxDim); + auto wDenseAttr = getDenseConstantAttr(w); + + // Prepare weight matrix W for crossbar storage: + // W: [numChannelsOut, numChannelsIn, wHeight, wWidth] -> [numChannelsOut, patchSize] -> [patchSize, numChannelsOut] + Value wFlat = tensor::CollapseShapeOp::create(rewriter, + loc, + wFlatType, + w, + SmallVector { + {0}, + {1, 2, 3} + }); + Value wTrans = ONNXTransposeOp::create(rewriter, loc, wTransType, wFlat, rewriter.getI64ArrayAttr({1, 0})); + + // Pass bias through directly; Gemm handles rank-1 C canonicalization. + bool hasB = !isa(b.getDefiningOp()); + Value gemmC = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType()); + Value biasMatrix; + DenseElementsAttr biasDenseAttr; + if (hasB) { + gemmC = b; + biasDenseAttr = getDenseConstantAttr(b); + biasMatrix = expandBiasIfNeeded(b, rewriter, loc); + } + const bool canPackWeightsAsConstants = static_cast(wDenseAttr); + const bool canPackBiasAsConstants = !hasB || static_cast(biasDenseAttr); + const int64_t effectiveMaxParallelPixels = + (canPackWeightsAsConstants && canPackBiasAsConstants) ? maxParallelPixels : 1; + + Value im2col = createIm2colCompute(x, + xType, + im2colType, + rowType, + batchSize, + numChannelsIn, + xHeight, + xWidth, + wHeight, + wWidth, + padHeightBegin, + padHeightEnd, + padWidthBegin, + padWidthEnd, + strideHeight, + strideWidth, + dilationHeight, + dilationWidth, + outWidth, + patchSize, + numPatches, + numPatchesPerBatch, + rewriter, + loc); + + Value gemmOut; + if (effectiveMaxParallelPixels == 1) { + // Fallback to the plain im2col GEMM when a single crossbar cannot fit multiple pixels. + gemmOut = ONNXGemmOp::create(rewriter, + loc, + gemmOutType, + im2col, + wTrans, + gemmC, + rewriter.getF32FloatAttr(1.0f), + rewriter.getF32FloatAttr(1.0f), + rewriter.getBoolAttr(false), + rewriter.getBoolAttr(false)) + .getY(); + } + else { + // Keep the standard im2col view of convolution: + // A (im2col): [numPatches, patchSize] -- one row per output spatial position + // B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns + // but repack several old rows into one new row so we use the available crossbar size better. + // + // We want to process N spatial pixels at the exact same time. Instead of doing N separate + // operations of (1 x patchSize) x (patchSize x cOut), we construct a block-diagonal weight matrix + // containing N copies of W^T and concatenate N im2col rows into one longer row: + // A_packed: [ceil(numPatches / N), N * patchSize] + // B_packed: [N * patchSize, N * cOut] + // Y_packed: [ceil(numPatches / N), N * cOut] + // The downstream GemmToManyGemv pass still splits by row, but now there are fewer, longer rows. + const int64_t packedNumRows = ceilIntegerDivide(numPatches, effectiveMaxParallelPixels); + auto packedOutType = + RankedTensorType::get({packedNumRows, effectiveMaxParallelPixels * numChannelsOut}, outType.getElementType()); + + Value packedA = createPackedIm2colRows( + im2col, im2colType, elemType, numPatches, patchSize, effectiveMaxParallelPixels, rewriter, loc); + Value packedB = buildPackedWeight(wDenseAttr, + wTrans, + wType, + numChannelsIn, + numChannelsOut, + wHeight, + wWidth, + patchSize, + effectiveMaxParallelPixels, + rewriter, + loc); + Value packedC = buildPackedBias( + hasB, gemmC, biasMatrix, biasDenseAttr, outType, numChannelsOut, effectiveMaxParallelPixels, rewriter, loc); + Value packedOut = ONNXGemmOp::create(rewriter, + loc, + packedOutType, + packedA, + packedB, + packedC, + rewriter.getF32FloatAttr(1.0f), + rewriter.getF32FloatAttr(1.0f), + rewriter.getBoolAttr(false), + rewriter.getBoolAttr(false)) + .getY(); + gemmOut = createUnpackedOutput( + packedOut, gemmOutType, outType, numPatches, numChannelsOut, effectiveMaxParallelPixels, rewriter, loc); + } + + rewriter.replaceOp(convOp, createCollectedConvOutput(gemmOut, convOp.getType(), nhwcType, outType, rewriter, loc)); return success(); }