convolution uses crossbar size better

This commit is contained in:
NiccoloN
2026-04-14 11:06:35 +02:00
parent 0ac163e4b7
commit e866ec6f87

View File

@@ -1,12 +1,16 @@
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include <algorithm>
#include <cassert> #include <cassert>
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -24,122 +28,150 @@ struct ConvToGemm : OpConversionPattern<ONNXConvOp> {
ConversionPatternRewriter& rewriter) const override; ConversionPatternRewriter& rewriter) const override;
}; };
} // namespace static DenseElementsAttr getDenseConstantAttr(Value value) {
if (auto constantOp = value.getDefiningOp<arith::ConstantOp>())
return dyn_cast<DenseElementsAttr>(constantOp.getValue());
LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp, if (auto constantOp = value.getDefiningOp<ONNXConstantOp>())
ONNXConvOpAdaptor convOpAdaptor, return dyn_cast_or_null<DenseElementsAttr>(constantOp.getValueAttr());
ConversionPatternRewriter& rewriter) const {
Location loc = convOp.getLoc();
Value x = convOpAdaptor.getX();
Value w = convOpAdaptor.getW();
Value b = convOpAdaptor.getB();
auto xType = cast<RankedTensorType>(x.getType()); return nullptr;
auto wType = cast<RankedTensorType>(w.getType()); }
auto outType = cast<RankedTensorType>(convOp.getY().getType());
assert("Only support static shapes" && xType.hasStaticShape() && wType.hasStaticShape() && outType.hasStaticShape()); static int64_t getI64FromArrayAttr(ArrayAttr arr, size_t idx) { return cast<IntegerAttr>(arr[idx]).getInt(); }
assert("Only support 2D convolution" && xType.getRank() == 4);
// We need to understand what is group static Value expandBiasIfNeeded(Value bias, ConversionPatternRewriter& rewriter, Location loc) {
assert("Only support group=1" && convOp.getGroup() == 1); auto biasType = cast<RankedTensorType>(bias.getType());
if (biasType.getRank() != 1)
return bias;
const int64_t batchSize = xType.getDimSize(0); auto expandedBiasType = RankedTensorType::get({1, biasType.getDimSize(0)}, biasType.getElementType());
const int64_t numChannelsIn = xType.getDimSize(1); return tensor::ExpandShapeOp::create(rewriter,
const int64_t xHeight = xType.getDimSize(2); loc,
const int64_t xWidth = xType.getDimSize(3); expandedBiasType,
const int64_t numChannelsOut = wType.getDimSize(0); bias,
const int64_t wHeight = wType.getDimSize(2); SmallVector<ReassociationIndices> {
const int64_t wWidth = wType.getDimSize(3); {0, 1}
const int64_t outHeight = outType.getDimSize(2); });
const int64_t outWidth = outType.getDimSize(3); }
// Read optional conv attributes (ONNX defaults: stride=1, dilation=1, pad=0) static Value createPaddedRows(Value tensorValue,
auto getI64 = [](ArrayAttr arr, size_t idx) -> int64_t { return cast<IntegerAttr>(arr[idx]).getInt(); }; RankedTensorType tensorType,
int64_t paddedRows,
ConversionPatternRewriter& rewriter,
Location loc) {
if (tensorType.getDimSize(0) == paddedRows)
return tensorValue;
const auto stridesAttr = convOp.getStrides(); auto paddedType = RankedTensorType::get({paddedRows, tensorType.getDimSize(1)}, tensorType.getElementType());
const auto dilationsAttr = convOp.getDilations(); SmallVector<OpFoldResult> lowPads = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
const auto padsAttr = convOp.getPads(); SmallVector<OpFoldResult> highPads = {rewriter.getIndexAttr(paddedRows - tensorType.getDimSize(0)),
rewriter.getIndexAttr(0)};
auto padOp = tensor::PadOp::create(rewriter, loc, paddedType, tensorValue, lowPads, highPads);
auto* padBlock = new Block();
for (int i = 0; i < 2; i++)
padBlock->addArgument(rewriter.getIndexType(), loc);
padOp.getRegion().push_back(padBlock);
rewriter.setInsertionPointToStart(padBlock);
auto zero = arith::ConstantOp::create(
rewriter, loc, tensorType.getElementType(), rewriter.getZeroAttr(tensorType.getElementType()));
tensor::YieldOp::create(rewriter, loc, zero.getResult());
rewriter.setInsertionPointAfter(padOp);
return padOp.getResult();
}
const int64_t strideHeight = stridesAttr ? getI64(*stridesAttr, 0) : 1; static Value buildPackedWeight(DenseElementsAttr wDenseAttr,
const int64_t strideWidth = stridesAttr ? getI64(*stridesAttr, 1) : 1; Value wTrans,
const int64_t dilationHeight = dilationsAttr ? getI64(*dilationsAttr, 0) : 1; RankedTensorType wType,
const int64_t dilationWidth = dilationsAttr ? getI64(*dilationsAttr, 1) : 1; int64_t numChannelsIn,
int64_t numChannelsOut,
int64_t wHeight,
int64_t wWidth,
int64_t patchSize,
int64_t packFactor,
ConversionPatternRewriter& rewriter,
Location loc) {
if (packFactor == 1)
return wTrans;
int64_t padHeightBegin = 0; auto packedWeightType =
int64_t padHeightEnd = 0; RankedTensorType::get({packFactor * patchSize, packFactor * numChannelsOut}, wType.getElementType());
int64_t padWidthBegin = 0; SmallVector<Attribute> sourceValues(wDenseAttr.getValues<Attribute>());
int64_t padWidthEnd = 0; SmallVector<Attribute> packedValues(packedWeightType.getNumElements(),
cast<Attribute>(rewriter.getZeroAttr(wType.getElementType())));
if (padsAttr) { for (int64_t copyId = 0; copyId < packFactor; copyId++) {
padHeightBegin = getI64(*padsAttr, 0); for (int64_t outChannel = 0; outChannel < numChannelsOut; outChannel++) {
padWidthBegin = getI64(*padsAttr, 1); for (int64_t inChannel = 0; inChannel < numChannelsIn; inChannel++) {
padHeightEnd = getI64(*padsAttr, 2); for (int64_t kernelH = 0; kernelH < wHeight; kernelH++) {
padWidthEnd = getI64(*padsAttr, 3); for (int64_t kernelW = 0; kernelW < wWidth; kernelW++) {
} const int64_t sourceFlatIndex =
else { (((outChannel * numChannelsIn) + inChannel) * wHeight + kernelH) * wWidth + kernelW;
// Compute padding from auto_pad attribute const int64_t patchIndex = ((inChannel * wHeight) + kernelH) * wWidth + kernelW;
const auto autoPad = convOp.getAutoPad(); const int64_t targetRow = copyId * patchSize + patchIndex;
if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") { const int64_t targetCol = copyId * numChannelsOut + outChannel;
const int64_t effectiveKernelH = (wHeight - 1) * dilationHeight + 1; packedValues[targetRow * (packFactor * numChannelsOut) + targetCol] = sourceValues[sourceFlatIndex];
const int64_t effectiveKernelW = (wWidth - 1) * dilationWidth + 1; }
const int64_t totalPadH = }
std::max(static_cast<int64_t>(0), (outHeight - 1) * strideHeight + effectiveKernelH - xHeight);
const int64_t totalPadW =
std::max(static_cast<int64_t>(0), (outWidth - 1) * strideWidth + effectiveKernelW - xWidth);
if (autoPad == "SAME_UPPER") {
padHeightBegin = totalPadH / 2;
padHeightEnd = totalPadH - padHeightBegin;
padWidthBegin = totalPadW / 2;
padWidthEnd = totalPadW - padWidthBegin;
}
else { // SAME_LOWER
padHeightEnd = totalPadH / 2;
padHeightBegin = totalPadH - padHeightEnd;
padWidthEnd = totalPadW / 2;
padWidthBegin = totalPadW - padWidthEnd;
} }
} }
// "NOTSET" or "VALID" -> all pads stay 0
} }
// im2col layout (flipped with respect to the standard, so filters sit in B = crossbar): auto packedAttr = DenseElementsAttr::get(packedWeightType, packedValues);
// A (im2col): [numPatches, patchSize] -- one row per output spatial position return arith::ConstantOp::create(rewriter, loc, packedWeightType, packedAttr);
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns }
// Gemm output: [numPatches, cOut]
const int64_t patchSize = numChannelsIn * wHeight * wWidth;
const int64_t numPatchesPerBatch = outHeight * outWidth;
const int64_t numPatches = batchSize * numPatchesPerBatch;
static Value buildPackedBias(bool hasBias,
Value gemmBias,
Value biasMatrix,
DenseElementsAttr biasDenseAttr,
RankedTensorType outType,
int64_t numChannelsOut,
int64_t packFactor,
ConversionPatternRewriter& rewriter,
Location loc) {
if (!hasBias)
return gemmBias;
if (packFactor == 1)
return biasMatrix;
SmallVector<Attribute> sourceValues(biasDenseAttr.getValues<Attribute>());
SmallVector<Attribute> packedValues;
packedValues.reserve(packFactor * numChannelsOut);
for (int64_t copyId = 0; copyId < packFactor; copyId++)
packedValues.append(sourceValues.begin(), sourceValues.end());
auto packedBiasType = RankedTensorType::get({1, packFactor * numChannelsOut}, outType.getElementType());
auto packedBiasAttr = DenseElementsAttr::get(packedBiasType, packedValues);
return arith::ConstantOp::create(rewriter, loc, packedBiasType, packedBiasAttr).getResult();
}
static Value createIm2colCompute(Value x,
RankedTensorType xType,
RankedTensorType im2colType,
RankedTensorType rowType,
int64_t batchSize,
int64_t numChannelsIn,
int64_t xHeight,
int64_t xWidth,
int64_t wHeight,
int64_t wWidth,
int64_t padHeightBegin,
int64_t padHeightEnd,
int64_t padWidthBegin,
int64_t padWidthEnd,
int64_t strideHeight,
int64_t strideWidth,
int64_t dilationHeight,
int64_t dilationWidth,
int64_t outWidth,
int64_t patchSize,
int64_t numPatches,
int64_t numPatchesPerBatch,
ConversionPatternRewriter& rewriter,
Location loc) {
auto elemType = xType.getElementType(); auto elemType = xType.getElementType();
auto im2colType = RankedTensorType::get({numPatches, patchSize}, elemType);
auto rowType = RankedTensorType::get({1, patchSize}, elemType);
auto wFlatType = RankedTensorType::get({numChannelsOut, patchSize}, wType.getElementType());
auto wTransType = RankedTensorType::get({patchSize, numChannelsOut}, wType.getElementType());
auto gemmOutType = RankedTensorType::get({numPatches, numChannelsOut}, outType.getElementType());
auto nhwcType = RankedTensorType::get({batchSize, outHeight, outWidth, numChannelsOut}, outType.getElementType());
// Prepare weight matrix W for crossbar storage:
// W: [numChannelsOut, numChannelsIn, wHeight, wWidth] -> [numChannelsOut, patchSize] -> [patchSize, numChannelsOut]
Value wFlat = tensor::CollapseShapeOp::create(rewriter,
loc,
wFlatType,
w,
SmallVector<ReassociationIndices> {
{0},
{1, 2, 3}
});
Value wTrans = ONNXTransposeOp::create(rewriter, loc, wTransType, wFlat, rewriter.getI64ArrayAttr({1, 0}));
// Pass bias through directly; Gemm handles rank-1 C canonicalization.
bool hasB = !isa<ONNXNoneOp>(b.getDefiningOp());
Value gemmC;
if (hasB)
gemmC = b;
else
gemmC = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
constexpr size_t numInputs = 1; constexpr size_t numInputs = 1;
auto im2colComputeOp = createSpatCompute<numInputs>(rewriter, loc, im2colType, {}, x, [&](Value xArg) { auto im2colComputeOp = createSpatCompute<numInputs>(rewriter, loc, im2colType, {}, x, [&](Value xArg) {
Value paddedInput = xArg; Value paddedInput = xArg;
@@ -226,23 +258,104 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
Value im2col = im2colLoop.getResult(0); Value im2col = im2colLoop.getResult(0);
spatial::SpatYieldOp::create(rewriter, loc, im2col); spatial::SpatYieldOp::create(rewriter, loc, im2col);
}); });
return im2colComputeOp.getResult(0);
}
// Gemm: A @ B + C = im2col @ W^T + b static Value createPackedIm2colRows(Value im2col,
// [numPatches, patchSize] @ [patchSize, numChannelsOut] + [1, numChannelsOut] -> [numPatches, numChannelsOut] RankedTensorType im2colType,
auto gemmOp = ONNXGemmOp::create(rewriter, Type elemType,
loc, int64_t numPatches,
gemmOutType, int64_t patchSize,
im2colComputeOp.getResult(0), int64_t packFactor,
wTrans, ConversionPatternRewriter& rewriter,
gemmC, Location loc) {
rewriter.getF32FloatAttr(1.0f), if (packFactor == 1)
rewriter.getF32FloatAttr(1.0f), return im2col;
rewriter.getBoolAttr(false),
rewriter.getBoolAttr(false));
Value gemmOut = gemmOp.getY();
const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor);
const int64_t paddedNumPatches = packedNumRows * packFactor;
auto groupedType = RankedTensorType::get({packedNumRows, packFactor, patchSize}, elemType);
auto packedType = RankedTensorType::get({packedNumRows, packFactor * patchSize}, elemType);
auto packedComputeOp = createSpatCompute<1>(rewriter, loc, packedType, {}, im2col, [&](Value im2colArg) {
Value paddedIm2col = createPaddedRows(im2colArg, im2colType, paddedNumPatches, rewriter, loc);
Value groupedIm2col = tensor::ExpandShapeOp::create(rewriter,
loc,
groupedType,
paddedIm2col,
SmallVector<ReassociationIndices> {
{0, 1},
{2}
});
Value packedIm2col = tensor::CollapseShapeOp::create(rewriter,
loc,
packedType,
groupedIm2col,
SmallVector<ReassociationIndices> {
{0},
{1, 2}
});
spatial::SpatYieldOp::create(rewriter, loc, packedIm2col);
});
return packedComputeOp.getResult(0);
}
static Value createUnpackedOutput(Value packedOutput,
RankedTensorType gemmOutType,
RankedTensorType outType,
int64_t numPatches,
int64_t numChannelsOut,
int64_t packFactor,
ConversionPatternRewriter& rewriter,
Location loc) {
if (packFactor == 1)
return packedOutput;
const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor);
const int64_t paddedNumPatches = packedNumRows * packFactor;
auto expandedType = RankedTensorType::get({packedNumRows, packFactor, numChannelsOut}, outType.getElementType());
auto paddedType = RankedTensorType::get({paddedNumPatches, numChannelsOut}, outType.getElementType());
auto unpackComputeOp = createSpatCompute<1>(rewriter, loc, gemmOutType, {}, packedOutput, [&](Value packedOutputArg) {
Value expandedOutput = tensor::ExpandShapeOp::create(rewriter,
loc,
expandedType,
packedOutputArg,
SmallVector<ReassociationIndices> {
{0},
{1, 2}
});
Value paddedOutput = tensor::CollapseShapeOp::create(rewriter,
loc,
paddedType,
expandedOutput,
SmallVector<ReassociationIndices> {
{0, 1},
{2}
});
Value unpackedOutput = paddedOutput;
if (paddedNumPatches != numPatches) {
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(numPatches), rewriter.getIndexAttr(numChannelsOut)};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
unpackedOutput =
tensor::ExtractSliceOp::create(rewriter, loc, gemmOutType, paddedOutput, offsets, sizes, strides);
}
spatial::SpatYieldOp::create(rewriter, loc, unpackedOutput);
});
return unpackComputeOp.getResult(0);
}
static Value createCollectedConvOutput(Value gemmOut,
Type convType,
RankedTensorType nhwcType,
RankedTensorType outType,
ConversionPatternRewriter& rewriter,
Location loc) {
auto collectComputeOp = auto collectComputeOp =
createSpatCompute<numInputs>(rewriter, loc, convOp.getType(), {}, ValueRange {gemmOut}, [&](Value gemmOutArg) { createSpatCompute(rewriter, loc, convType, {}, ValueRange {gemmOut}, [&](ValueRange gemmOutArgs) {
Value gemmOutArg = gemmOutArgs.front();
// Restore to NCHW layout: // Restore to NCHW layout:
// [numPatches, numChannelsOut] // [numPatches, numChannelsOut]
// -> [1, outHeight, outWidth, numChannelsOut] // -> [1, outHeight, outWidth, numChannelsOut]
@@ -256,11 +369,225 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
{3} {3}
}); });
Value nchwOut = ONNXTransposeOp::create(rewriter, loc, outType, nhwcOut, rewriter.getI64ArrayAttr({0, 3, 1, 2})); Value nchwOut = ONNXTransposeOp::create(rewriter, loc, outType, nhwcOut, rewriter.getI64ArrayAttr({0, 3, 1, 2}));
spatial::SpatYieldOp::create(rewriter, loc, nchwOut); spatial::SpatYieldOp::create(rewriter, loc, nchwOut);
}); });
return collectComputeOp.getResult(0);
}
rewriter.replaceOp(convOp, collectComputeOp.getResult(0)); } // namespace
LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
ONNXConvOpAdaptor convOpAdaptor,
ConversionPatternRewriter& rewriter) const {
Location loc = convOp.getLoc();
Value x = convOpAdaptor.getX();
Value w = convOpAdaptor.getW();
Value b = convOpAdaptor.getB();
auto xType = cast<RankedTensorType>(x.getType());
auto wType = cast<RankedTensorType>(w.getType());
auto outType = cast<RankedTensorType>(convOp.getY().getType());
assert("Only support static shapes" && xType.hasStaticShape() && wType.hasStaticShape() && outType.hasStaticShape());
assert("Only support 2D convolution" && xType.getRank() == 4);
// We need to understand what is group
assert("Only support group=1" && convOp.getGroup() == 1);
const int64_t batchSize = xType.getDimSize(0);
const int64_t numChannelsIn = xType.getDimSize(1);
const int64_t xHeight = xType.getDimSize(2);
const int64_t xWidth = xType.getDimSize(3);
const int64_t numChannelsOut = wType.getDimSize(0);
const int64_t wHeight = wType.getDimSize(2);
const int64_t wWidth = wType.getDimSize(3);
const int64_t outHeight = outType.getDimSize(2);
const int64_t outWidth = outType.getDimSize(3);
// Read optional conv attributes (ONNX defaults: stride=1, dilation=1, pad=0)
const auto stridesAttr = convOp.getStrides();
const auto dilationsAttr = convOp.getDilations();
const auto padsAttr = convOp.getPads();
const int64_t strideHeight = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 0) : 1;
const int64_t strideWidth = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 1) : 1;
const int64_t dilationHeight = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 0) : 1;
const int64_t dilationWidth = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 1) : 1;
int64_t padHeightBegin = 0;
int64_t padHeightEnd = 0;
int64_t padWidthBegin = 0;
int64_t padWidthEnd = 0;
if (padsAttr) {
padHeightBegin = getI64FromArrayAttr(*padsAttr, 0);
padWidthBegin = getI64FromArrayAttr(*padsAttr, 1);
padHeightEnd = getI64FromArrayAttr(*padsAttr, 2);
padWidthEnd = getI64FromArrayAttr(*padsAttr, 3);
}
else {
// Compute padding from auto_pad attribute
const auto autoPad = convOp.getAutoPad();
if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
const int64_t effectiveKernelH = (wHeight - 1) * dilationHeight + 1;
const int64_t effectiveKernelW = (wWidth - 1) * dilationWidth + 1;
const int64_t totalPadH =
std::max(static_cast<int64_t>(0), (outHeight - 1) * strideHeight + effectiveKernelH - xHeight);
const int64_t totalPadW =
std::max(static_cast<int64_t>(0), (outWidth - 1) * strideWidth + effectiveKernelW - xWidth);
if (autoPad == "SAME_UPPER") {
padHeightBegin = totalPadH / 2;
padHeightEnd = totalPadH - padHeightBegin;
padWidthBegin = totalPadW / 2;
padWidthEnd = totalPadW - padWidthBegin;
}
else { // SAME_LOWER
padHeightEnd = totalPadH / 2;
padHeightBegin = totalPadH - padHeightEnd;
padWidthEnd = totalPadW / 2;
padWidthBegin = totalPadW - padWidthEnd;
}
}
// "NOTSET" or "VALID" -> all pads stay 0
}
// im2col layout (flipped with respect to the standard, so filters sit in B = crossbar):
// A (im2col): [numPatches, patchSize] -- one row per output spatial position
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
// Gemm output: [numPatches, cOut]
const int64_t patchSize = numChannelsIn * wHeight * wWidth;
const int64_t numPatchesPerBatch = outHeight * outWidth;
const int64_t numPatches = batchSize * numPatchesPerBatch;
auto elemType = xType.getElementType();
auto im2colType = RankedTensorType::get({numPatches, patchSize}, elemType);
auto rowType = RankedTensorType::get({1, patchSize}, elemType);
auto wFlatType = RankedTensorType::get({numChannelsOut, patchSize}, wType.getElementType());
auto wTransType = RankedTensorType::get({patchSize, numChannelsOut}, wType.getElementType());
auto gemmOutType = RankedTensorType::get({numPatches, numChannelsOut}, outType.getElementType());
auto nhwcType = RankedTensorType::get({batchSize, outHeight, outWidth, numChannelsOut}, outType.getElementType());
const int64_t xbarSize = static_cast<int64_t>(crossbarSize.getValue());
const int64_t wMaxDim = std::max(patchSize, numChannelsOut);
const int64_t maxParallelPixels = std::max<int64_t>(1, xbarSize / wMaxDim);
auto wDenseAttr = getDenseConstantAttr(w);
// Prepare weight matrix W for crossbar storage:
// W: [numChannelsOut, numChannelsIn, wHeight, wWidth] -> [numChannelsOut, patchSize] -> [patchSize, numChannelsOut]
Value wFlat = tensor::CollapseShapeOp::create(rewriter,
loc,
wFlatType,
w,
SmallVector<ReassociationIndices> {
{0},
{1, 2, 3}
});
Value wTrans = ONNXTransposeOp::create(rewriter, loc, wTransType, wFlat, rewriter.getI64ArrayAttr({1, 0}));
// Pass bias through directly; Gemm handles rank-1 C canonicalization.
bool hasB = !isa<ONNXNoneOp>(b.getDefiningOp());
Value gemmC = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
Value biasMatrix;
DenseElementsAttr biasDenseAttr;
if (hasB) {
gemmC = b;
biasDenseAttr = getDenseConstantAttr(b);
biasMatrix = expandBiasIfNeeded(b, rewriter, loc);
}
const bool canPackWeightsAsConstants = static_cast<bool>(wDenseAttr);
const bool canPackBiasAsConstants = !hasB || static_cast<bool>(biasDenseAttr);
const int64_t effectiveMaxParallelPixels =
(canPackWeightsAsConstants && canPackBiasAsConstants) ? maxParallelPixels : 1;
Value im2col = createIm2colCompute(x,
xType,
im2colType,
rowType,
batchSize,
numChannelsIn,
xHeight,
xWidth,
wHeight,
wWidth,
padHeightBegin,
padHeightEnd,
padWidthBegin,
padWidthEnd,
strideHeight,
strideWidth,
dilationHeight,
dilationWidth,
outWidth,
patchSize,
numPatches,
numPatchesPerBatch,
rewriter,
loc);
Value gemmOut;
if (effectiveMaxParallelPixels == 1) {
// Fallback to the plain im2col GEMM when a single crossbar cannot fit multiple pixels.
gemmOut = ONNXGemmOp::create(rewriter,
loc,
gemmOutType,
im2col,
wTrans,
gemmC,
rewriter.getF32FloatAttr(1.0f),
rewriter.getF32FloatAttr(1.0f),
rewriter.getBoolAttr(false),
rewriter.getBoolAttr(false))
.getY();
}
else {
// Keep the standard im2col view of convolution:
// A (im2col): [numPatches, patchSize] -- one row per output spatial position
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
// but repack several old rows into one new row so we use the available crossbar size better.
//
// We want to process N spatial pixels at the exact same time. Instead of doing N separate
// operations of (1 x patchSize) x (patchSize x cOut), we construct a block-diagonal weight matrix
// containing N copies of W^T and concatenate N im2col rows into one longer row:
// A_packed: [ceil(numPatches / N), N * patchSize]
// B_packed: [N * patchSize, N * cOut]
// Y_packed: [ceil(numPatches / N), N * cOut]
// The downstream GemmToManyGemv pass still splits by row, but now there are fewer, longer rows.
const int64_t packedNumRows = ceilIntegerDivide(numPatches, effectiveMaxParallelPixels);
auto packedOutType =
RankedTensorType::get({packedNumRows, effectiveMaxParallelPixels * numChannelsOut}, outType.getElementType());
Value packedA = createPackedIm2colRows(
im2col, im2colType, elemType, numPatches, patchSize, effectiveMaxParallelPixels, rewriter, loc);
Value packedB = buildPackedWeight(wDenseAttr,
wTrans,
wType,
numChannelsIn,
numChannelsOut,
wHeight,
wWidth,
patchSize,
effectiveMaxParallelPixels,
rewriter,
loc);
Value packedC = buildPackedBias(
hasB, gemmC, biasMatrix, biasDenseAttr, outType, numChannelsOut, effectiveMaxParallelPixels, rewriter, loc);
Value packedOut = ONNXGemmOp::create(rewriter,
loc,
packedOutType,
packedA,
packedB,
packedC,
rewriter.getF32FloatAttr(1.0f),
rewriter.getF32FloatAttr(1.0f),
rewriter.getBoolAttr(false),
rewriter.getBoolAttr(false))
.getY();
gemmOut = createUnpackedOutput(
packedOut, gemmOutType, outType, numPatches, numChannelsOut, effectiveMaxParallelPixels, rewriter, loc);
}
rewriter.replaceOp(convOp, createCollectedConvOutput(gemmOut, convOp.getType(), nhwcType, outType, rewriter, loc));
return success(); return success();
} }