convolution uses crossbar size better
This commit is contained in:
@@ -1,12 +1,16 @@
|
|||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
#include "mlir/IR/BuiltinAttributes.h"
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
@@ -24,122 +28,150 @@ struct ConvToGemm : OpConversionPattern<ONNXConvOp> {
|
|||||||
ConversionPatternRewriter& rewriter) const override;
|
ConversionPatternRewriter& rewriter) const override;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace
|
static DenseElementsAttr getDenseConstantAttr(Value value) {
|
||||||
|
if (auto constantOp = value.getDefiningOp<arith::ConstantOp>())
|
||||||
|
return dyn_cast<DenseElementsAttr>(constantOp.getValue());
|
||||||
|
|
||||||
LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
if (auto constantOp = value.getDefiningOp<ONNXConstantOp>())
|
||||||
ONNXConvOpAdaptor convOpAdaptor,
|
return dyn_cast_or_null<DenseElementsAttr>(constantOp.getValueAttr());
|
||||||
ConversionPatternRewriter& rewriter) const {
|
|
||||||
Location loc = convOp.getLoc();
|
|
||||||
Value x = convOpAdaptor.getX();
|
|
||||||
Value w = convOpAdaptor.getW();
|
|
||||||
Value b = convOpAdaptor.getB();
|
|
||||||
|
|
||||||
auto xType = cast<RankedTensorType>(x.getType());
|
return nullptr;
|
||||||
auto wType = cast<RankedTensorType>(w.getType());
|
}
|
||||||
auto outType = cast<RankedTensorType>(convOp.getY().getType());
|
|
||||||
|
|
||||||
assert("Only support static shapes" && xType.hasStaticShape() && wType.hasStaticShape() && outType.hasStaticShape());
|
static int64_t getI64FromArrayAttr(ArrayAttr arr, size_t idx) { return cast<IntegerAttr>(arr[idx]).getInt(); }
|
||||||
assert("Only support 2D convolution" && xType.getRank() == 4);
|
|
||||||
|
|
||||||
// We need to understand what is group
|
static Value expandBiasIfNeeded(Value bias, ConversionPatternRewriter& rewriter, Location loc) {
|
||||||
assert("Only support group=1" && convOp.getGroup() == 1);
|
auto biasType = cast<RankedTensorType>(bias.getType());
|
||||||
|
if (biasType.getRank() != 1)
|
||||||
|
return bias;
|
||||||
|
|
||||||
const int64_t batchSize = xType.getDimSize(0);
|
auto expandedBiasType = RankedTensorType::get({1, biasType.getDimSize(0)}, biasType.getElementType());
|
||||||
const int64_t numChannelsIn = xType.getDimSize(1);
|
return tensor::ExpandShapeOp::create(rewriter,
|
||||||
const int64_t xHeight = xType.getDimSize(2);
|
loc,
|
||||||
const int64_t xWidth = xType.getDimSize(3);
|
expandedBiasType,
|
||||||
const int64_t numChannelsOut = wType.getDimSize(0);
|
bias,
|
||||||
const int64_t wHeight = wType.getDimSize(2);
|
SmallVector<ReassociationIndices> {
|
||||||
const int64_t wWidth = wType.getDimSize(3);
|
{0, 1}
|
||||||
const int64_t outHeight = outType.getDimSize(2);
|
});
|
||||||
const int64_t outWidth = outType.getDimSize(3);
|
}
|
||||||
|
|
||||||
// Read optional conv attributes (ONNX defaults: stride=1, dilation=1, pad=0)
|
static Value createPaddedRows(Value tensorValue,
|
||||||
auto getI64 = [](ArrayAttr arr, size_t idx) -> int64_t { return cast<IntegerAttr>(arr[idx]).getInt(); };
|
RankedTensorType tensorType,
|
||||||
|
int64_t paddedRows,
|
||||||
|
ConversionPatternRewriter& rewriter,
|
||||||
|
Location loc) {
|
||||||
|
if (tensorType.getDimSize(0) == paddedRows)
|
||||||
|
return tensorValue;
|
||||||
|
|
||||||
const auto stridesAttr = convOp.getStrides();
|
auto paddedType = RankedTensorType::get({paddedRows, tensorType.getDimSize(1)}, tensorType.getElementType());
|
||||||
const auto dilationsAttr = convOp.getDilations();
|
SmallVector<OpFoldResult> lowPads = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
|
||||||
const auto padsAttr = convOp.getPads();
|
SmallVector<OpFoldResult> highPads = {rewriter.getIndexAttr(paddedRows - tensorType.getDimSize(0)),
|
||||||
|
rewriter.getIndexAttr(0)};
|
||||||
|
auto padOp = tensor::PadOp::create(rewriter, loc, paddedType, tensorValue, lowPads, highPads);
|
||||||
|
auto* padBlock = new Block();
|
||||||
|
for (int i = 0; i < 2; i++)
|
||||||
|
padBlock->addArgument(rewriter.getIndexType(), loc);
|
||||||
|
padOp.getRegion().push_back(padBlock);
|
||||||
|
rewriter.setInsertionPointToStart(padBlock);
|
||||||
|
auto zero = arith::ConstantOp::create(
|
||||||
|
rewriter, loc, tensorType.getElementType(), rewriter.getZeroAttr(tensorType.getElementType()));
|
||||||
|
tensor::YieldOp::create(rewriter, loc, zero.getResult());
|
||||||
|
rewriter.setInsertionPointAfter(padOp);
|
||||||
|
return padOp.getResult();
|
||||||
|
}
|
||||||
|
|
||||||
const int64_t strideHeight = stridesAttr ? getI64(*stridesAttr, 0) : 1;
|
static Value buildPackedWeight(DenseElementsAttr wDenseAttr,
|
||||||
const int64_t strideWidth = stridesAttr ? getI64(*stridesAttr, 1) : 1;
|
Value wTrans,
|
||||||
const int64_t dilationHeight = dilationsAttr ? getI64(*dilationsAttr, 0) : 1;
|
RankedTensorType wType,
|
||||||
const int64_t dilationWidth = dilationsAttr ? getI64(*dilationsAttr, 1) : 1;
|
int64_t numChannelsIn,
|
||||||
|
int64_t numChannelsOut,
|
||||||
|
int64_t wHeight,
|
||||||
|
int64_t wWidth,
|
||||||
|
int64_t patchSize,
|
||||||
|
int64_t packFactor,
|
||||||
|
ConversionPatternRewriter& rewriter,
|
||||||
|
Location loc) {
|
||||||
|
if (packFactor == 1)
|
||||||
|
return wTrans;
|
||||||
|
|
||||||
int64_t padHeightBegin = 0;
|
auto packedWeightType =
|
||||||
int64_t padHeightEnd = 0;
|
RankedTensorType::get({packFactor * patchSize, packFactor * numChannelsOut}, wType.getElementType());
|
||||||
int64_t padWidthBegin = 0;
|
SmallVector<Attribute> sourceValues(wDenseAttr.getValues<Attribute>());
|
||||||
int64_t padWidthEnd = 0;
|
SmallVector<Attribute> packedValues(packedWeightType.getNumElements(),
|
||||||
|
cast<Attribute>(rewriter.getZeroAttr(wType.getElementType())));
|
||||||
|
|
||||||
if (padsAttr) {
|
for (int64_t copyId = 0; copyId < packFactor; copyId++) {
|
||||||
padHeightBegin = getI64(*padsAttr, 0);
|
for (int64_t outChannel = 0; outChannel < numChannelsOut; outChannel++) {
|
||||||
padWidthBegin = getI64(*padsAttr, 1);
|
for (int64_t inChannel = 0; inChannel < numChannelsIn; inChannel++) {
|
||||||
padHeightEnd = getI64(*padsAttr, 2);
|
for (int64_t kernelH = 0; kernelH < wHeight; kernelH++) {
|
||||||
padWidthEnd = getI64(*padsAttr, 3);
|
for (int64_t kernelW = 0; kernelW < wWidth; kernelW++) {
|
||||||
}
|
const int64_t sourceFlatIndex =
|
||||||
else {
|
(((outChannel * numChannelsIn) + inChannel) * wHeight + kernelH) * wWidth + kernelW;
|
||||||
// Compute padding from auto_pad attribute
|
const int64_t patchIndex = ((inChannel * wHeight) + kernelH) * wWidth + kernelW;
|
||||||
const auto autoPad = convOp.getAutoPad();
|
const int64_t targetRow = copyId * patchSize + patchIndex;
|
||||||
if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
|
const int64_t targetCol = copyId * numChannelsOut + outChannel;
|
||||||
const int64_t effectiveKernelH = (wHeight - 1) * dilationHeight + 1;
|
packedValues[targetRow * (packFactor * numChannelsOut) + targetCol] = sourceValues[sourceFlatIndex];
|
||||||
const int64_t effectiveKernelW = (wWidth - 1) * dilationWidth + 1;
|
}
|
||||||
const int64_t totalPadH =
|
}
|
||||||
std::max(static_cast<int64_t>(0), (outHeight - 1) * strideHeight + effectiveKernelH - xHeight);
|
|
||||||
const int64_t totalPadW =
|
|
||||||
std::max(static_cast<int64_t>(0), (outWidth - 1) * strideWidth + effectiveKernelW - xWidth);
|
|
||||||
|
|
||||||
if (autoPad == "SAME_UPPER") {
|
|
||||||
padHeightBegin = totalPadH / 2;
|
|
||||||
padHeightEnd = totalPadH - padHeightBegin;
|
|
||||||
padWidthBegin = totalPadW / 2;
|
|
||||||
padWidthEnd = totalPadW - padWidthBegin;
|
|
||||||
}
|
|
||||||
else { // SAME_LOWER
|
|
||||||
padHeightEnd = totalPadH / 2;
|
|
||||||
padHeightBegin = totalPadH - padHeightEnd;
|
|
||||||
padWidthEnd = totalPadW / 2;
|
|
||||||
padWidthBegin = totalPadW - padWidthEnd;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// "NOTSET" or "VALID" -> all pads stay 0
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// im2col layout (flipped with respect to the standard, so filters sit in B = crossbar):
|
auto packedAttr = DenseElementsAttr::get(packedWeightType, packedValues);
|
||||||
// A (im2col): [numPatches, patchSize] -- one row per output spatial position
|
return arith::ConstantOp::create(rewriter, loc, packedWeightType, packedAttr);
|
||||||
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
|
}
|
||||||
// Gemm output: [numPatches, cOut]
|
|
||||||
const int64_t patchSize = numChannelsIn * wHeight * wWidth;
|
|
||||||
const int64_t numPatchesPerBatch = outHeight * outWidth;
|
|
||||||
const int64_t numPatches = batchSize * numPatchesPerBatch;
|
|
||||||
|
|
||||||
|
static Value buildPackedBias(bool hasBias,
|
||||||
|
Value gemmBias,
|
||||||
|
Value biasMatrix,
|
||||||
|
DenseElementsAttr biasDenseAttr,
|
||||||
|
RankedTensorType outType,
|
||||||
|
int64_t numChannelsOut,
|
||||||
|
int64_t packFactor,
|
||||||
|
ConversionPatternRewriter& rewriter,
|
||||||
|
Location loc) {
|
||||||
|
if (!hasBias)
|
||||||
|
return gemmBias;
|
||||||
|
|
||||||
|
if (packFactor == 1)
|
||||||
|
return biasMatrix;
|
||||||
|
|
||||||
|
SmallVector<Attribute> sourceValues(biasDenseAttr.getValues<Attribute>());
|
||||||
|
SmallVector<Attribute> packedValues;
|
||||||
|
packedValues.reserve(packFactor * numChannelsOut);
|
||||||
|
for (int64_t copyId = 0; copyId < packFactor; copyId++)
|
||||||
|
packedValues.append(sourceValues.begin(), sourceValues.end());
|
||||||
|
|
||||||
|
auto packedBiasType = RankedTensorType::get({1, packFactor * numChannelsOut}, outType.getElementType());
|
||||||
|
auto packedBiasAttr = DenseElementsAttr::get(packedBiasType, packedValues);
|
||||||
|
return arith::ConstantOp::create(rewriter, loc, packedBiasType, packedBiasAttr).getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value createIm2colCompute(Value x,
|
||||||
|
RankedTensorType xType,
|
||||||
|
RankedTensorType im2colType,
|
||||||
|
RankedTensorType rowType,
|
||||||
|
int64_t batchSize,
|
||||||
|
int64_t numChannelsIn,
|
||||||
|
int64_t xHeight,
|
||||||
|
int64_t xWidth,
|
||||||
|
int64_t wHeight,
|
||||||
|
int64_t wWidth,
|
||||||
|
int64_t padHeightBegin,
|
||||||
|
int64_t padHeightEnd,
|
||||||
|
int64_t padWidthBegin,
|
||||||
|
int64_t padWidthEnd,
|
||||||
|
int64_t strideHeight,
|
||||||
|
int64_t strideWidth,
|
||||||
|
int64_t dilationHeight,
|
||||||
|
int64_t dilationWidth,
|
||||||
|
int64_t outWidth,
|
||||||
|
int64_t patchSize,
|
||||||
|
int64_t numPatches,
|
||||||
|
int64_t numPatchesPerBatch,
|
||||||
|
ConversionPatternRewriter& rewriter,
|
||||||
|
Location loc) {
|
||||||
auto elemType = xType.getElementType();
|
auto elemType = xType.getElementType();
|
||||||
auto im2colType = RankedTensorType::get({numPatches, patchSize}, elemType);
|
|
||||||
auto rowType = RankedTensorType::get({1, patchSize}, elemType);
|
|
||||||
auto wFlatType = RankedTensorType::get({numChannelsOut, patchSize}, wType.getElementType());
|
|
||||||
auto wTransType = RankedTensorType::get({patchSize, numChannelsOut}, wType.getElementType());
|
|
||||||
auto gemmOutType = RankedTensorType::get({numPatches, numChannelsOut}, outType.getElementType());
|
|
||||||
auto nhwcType = RankedTensorType::get({batchSize, outHeight, outWidth, numChannelsOut}, outType.getElementType());
|
|
||||||
|
|
||||||
// Prepare weight matrix W for crossbar storage:
|
|
||||||
// W: [numChannelsOut, numChannelsIn, wHeight, wWidth] -> [numChannelsOut, patchSize] -> [patchSize, numChannelsOut]
|
|
||||||
Value wFlat = tensor::CollapseShapeOp::create(rewriter,
|
|
||||||
loc,
|
|
||||||
wFlatType,
|
|
||||||
w,
|
|
||||||
SmallVector<ReassociationIndices> {
|
|
||||||
{0},
|
|
||||||
{1, 2, 3}
|
|
||||||
});
|
|
||||||
Value wTrans = ONNXTransposeOp::create(rewriter, loc, wTransType, wFlat, rewriter.getI64ArrayAttr({1, 0}));
|
|
||||||
|
|
||||||
// Pass bias through directly; Gemm handles rank-1 C canonicalization.
|
|
||||||
bool hasB = !isa<ONNXNoneOp>(b.getDefiningOp());
|
|
||||||
Value gemmC;
|
|
||||||
if (hasB)
|
|
||||||
gemmC = b;
|
|
||||||
else
|
|
||||||
gemmC = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
|
|
||||||
|
|
||||||
constexpr size_t numInputs = 1;
|
constexpr size_t numInputs = 1;
|
||||||
auto im2colComputeOp = createSpatCompute<numInputs>(rewriter, loc, im2colType, {}, x, [&](Value xArg) {
|
auto im2colComputeOp = createSpatCompute<numInputs>(rewriter, loc, im2colType, {}, x, [&](Value xArg) {
|
||||||
Value paddedInput = xArg;
|
Value paddedInput = xArg;
|
||||||
@@ -226,23 +258,104 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
|||||||
Value im2col = im2colLoop.getResult(0);
|
Value im2col = im2colLoop.getResult(0);
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, im2col);
|
spatial::SpatYieldOp::create(rewriter, loc, im2col);
|
||||||
});
|
});
|
||||||
|
return im2colComputeOp.getResult(0);
|
||||||
|
}
|
||||||
|
|
||||||
// Gemm: A @ B + C = im2col @ W^T + b
|
static Value createPackedIm2colRows(Value im2col,
|
||||||
// [numPatches, patchSize] @ [patchSize, numChannelsOut] + [1, numChannelsOut] -> [numPatches, numChannelsOut]
|
RankedTensorType im2colType,
|
||||||
auto gemmOp = ONNXGemmOp::create(rewriter,
|
Type elemType,
|
||||||
loc,
|
int64_t numPatches,
|
||||||
gemmOutType,
|
int64_t patchSize,
|
||||||
im2colComputeOp.getResult(0),
|
int64_t packFactor,
|
||||||
wTrans,
|
ConversionPatternRewriter& rewriter,
|
||||||
gemmC,
|
Location loc) {
|
||||||
rewriter.getF32FloatAttr(1.0f),
|
if (packFactor == 1)
|
||||||
rewriter.getF32FloatAttr(1.0f),
|
return im2col;
|
||||||
rewriter.getBoolAttr(false),
|
|
||||||
rewriter.getBoolAttr(false));
|
|
||||||
Value gemmOut = gemmOp.getY();
|
|
||||||
|
|
||||||
|
const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor);
|
||||||
|
const int64_t paddedNumPatches = packedNumRows * packFactor;
|
||||||
|
auto groupedType = RankedTensorType::get({packedNumRows, packFactor, patchSize}, elemType);
|
||||||
|
auto packedType = RankedTensorType::get({packedNumRows, packFactor * patchSize}, elemType);
|
||||||
|
auto packedComputeOp = createSpatCompute<1>(rewriter, loc, packedType, {}, im2col, [&](Value im2colArg) {
|
||||||
|
Value paddedIm2col = createPaddedRows(im2colArg, im2colType, paddedNumPatches, rewriter, loc);
|
||||||
|
Value groupedIm2col = tensor::ExpandShapeOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
groupedType,
|
||||||
|
paddedIm2col,
|
||||||
|
SmallVector<ReassociationIndices> {
|
||||||
|
{0, 1},
|
||||||
|
{2}
|
||||||
|
});
|
||||||
|
Value packedIm2col = tensor::CollapseShapeOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
packedType,
|
||||||
|
groupedIm2col,
|
||||||
|
SmallVector<ReassociationIndices> {
|
||||||
|
{0},
|
||||||
|
{1, 2}
|
||||||
|
});
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, packedIm2col);
|
||||||
|
});
|
||||||
|
return packedComputeOp.getResult(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value createUnpackedOutput(Value packedOutput,
|
||||||
|
RankedTensorType gemmOutType,
|
||||||
|
RankedTensorType outType,
|
||||||
|
int64_t numPatches,
|
||||||
|
int64_t numChannelsOut,
|
||||||
|
int64_t packFactor,
|
||||||
|
ConversionPatternRewriter& rewriter,
|
||||||
|
Location loc) {
|
||||||
|
if (packFactor == 1)
|
||||||
|
return packedOutput;
|
||||||
|
|
||||||
|
const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor);
|
||||||
|
const int64_t paddedNumPatches = packedNumRows * packFactor;
|
||||||
|
auto expandedType = RankedTensorType::get({packedNumRows, packFactor, numChannelsOut}, outType.getElementType());
|
||||||
|
auto paddedType = RankedTensorType::get({paddedNumPatches, numChannelsOut}, outType.getElementType());
|
||||||
|
auto unpackComputeOp = createSpatCompute<1>(rewriter, loc, gemmOutType, {}, packedOutput, [&](Value packedOutputArg) {
|
||||||
|
Value expandedOutput = tensor::ExpandShapeOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
expandedType,
|
||||||
|
packedOutputArg,
|
||||||
|
SmallVector<ReassociationIndices> {
|
||||||
|
{0},
|
||||||
|
{1, 2}
|
||||||
|
});
|
||||||
|
Value paddedOutput = tensor::CollapseShapeOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
paddedType,
|
||||||
|
expandedOutput,
|
||||||
|
SmallVector<ReassociationIndices> {
|
||||||
|
{0, 1},
|
||||||
|
{2}
|
||||||
|
});
|
||||||
|
|
||||||
|
Value unpackedOutput = paddedOutput;
|
||||||
|
if (paddedNumPatches != numPatches) {
|
||||||
|
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
|
||||||
|
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(numPatches), rewriter.getIndexAttr(numChannelsOut)};
|
||||||
|
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||||
|
unpackedOutput =
|
||||||
|
tensor::ExtractSliceOp::create(rewriter, loc, gemmOutType, paddedOutput, offsets, sizes, strides);
|
||||||
|
}
|
||||||
|
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, unpackedOutput);
|
||||||
|
});
|
||||||
|
return unpackComputeOp.getResult(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value createCollectedConvOutput(Value gemmOut,
|
||||||
|
Type convType,
|
||||||
|
RankedTensorType nhwcType,
|
||||||
|
RankedTensorType outType,
|
||||||
|
ConversionPatternRewriter& rewriter,
|
||||||
|
Location loc) {
|
||||||
auto collectComputeOp =
|
auto collectComputeOp =
|
||||||
createSpatCompute<numInputs>(rewriter, loc, convOp.getType(), {}, ValueRange {gemmOut}, [&](Value gemmOutArg) {
|
createSpatCompute(rewriter, loc, convType, {}, ValueRange {gemmOut}, [&](ValueRange gemmOutArgs) {
|
||||||
|
Value gemmOutArg = gemmOutArgs.front();
|
||||||
|
|
||||||
// Restore to NCHW layout:
|
// Restore to NCHW layout:
|
||||||
// [numPatches, numChannelsOut]
|
// [numPatches, numChannelsOut]
|
||||||
// -> [1, outHeight, outWidth, numChannelsOut]
|
// -> [1, outHeight, outWidth, numChannelsOut]
|
||||||
@@ -256,11 +369,225 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
|||||||
{3}
|
{3}
|
||||||
});
|
});
|
||||||
Value nchwOut = ONNXTransposeOp::create(rewriter, loc, outType, nhwcOut, rewriter.getI64ArrayAttr({0, 3, 1, 2}));
|
Value nchwOut = ONNXTransposeOp::create(rewriter, loc, outType, nhwcOut, rewriter.getI64ArrayAttr({0, 3, 1, 2}));
|
||||||
|
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, nchwOut);
|
spatial::SpatYieldOp::create(rewriter, loc, nchwOut);
|
||||||
});
|
});
|
||||||
|
return collectComputeOp.getResult(0);
|
||||||
|
}
|
||||||
|
|
||||||
rewriter.replaceOp(convOp, collectComputeOp.getResult(0));
|
} // namespace
|
||||||
|
|
||||||
|
LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||||
|
ONNXConvOpAdaptor convOpAdaptor,
|
||||||
|
ConversionPatternRewriter& rewriter) const {
|
||||||
|
Location loc = convOp.getLoc();
|
||||||
|
Value x = convOpAdaptor.getX();
|
||||||
|
Value w = convOpAdaptor.getW();
|
||||||
|
Value b = convOpAdaptor.getB();
|
||||||
|
|
||||||
|
auto xType = cast<RankedTensorType>(x.getType());
|
||||||
|
auto wType = cast<RankedTensorType>(w.getType());
|
||||||
|
auto outType = cast<RankedTensorType>(convOp.getY().getType());
|
||||||
|
|
||||||
|
assert("Only support static shapes" && xType.hasStaticShape() && wType.hasStaticShape() && outType.hasStaticShape());
|
||||||
|
assert("Only support 2D convolution" && xType.getRank() == 4);
|
||||||
|
|
||||||
|
// We need to understand what is group
|
||||||
|
assert("Only support group=1" && convOp.getGroup() == 1);
|
||||||
|
|
||||||
|
const int64_t batchSize = xType.getDimSize(0);
|
||||||
|
const int64_t numChannelsIn = xType.getDimSize(1);
|
||||||
|
const int64_t xHeight = xType.getDimSize(2);
|
||||||
|
const int64_t xWidth = xType.getDimSize(3);
|
||||||
|
const int64_t numChannelsOut = wType.getDimSize(0);
|
||||||
|
const int64_t wHeight = wType.getDimSize(2);
|
||||||
|
const int64_t wWidth = wType.getDimSize(3);
|
||||||
|
const int64_t outHeight = outType.getDimSize(2);
|
||||||
|
const int64_t outWidth = outType.getDimSize(3);
|
||||||
|
|
||||||
|
// Read optional conv attributes (ONNX defaults: stride=1, dilation=1, pad=0)
|
||||||
|
const auto stridesAttr = convOp.getStrides();
|
||||||
|
const auto dilationsAttr = convOp.getDilations();
|
||||||
|
const auto padsAttr = convOp.getPads();
|
||||||
|
|
||||||
|
const int64_t strideHeight = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 0) : 1;
|
||||||
|
const int64_t strideWidth = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 1) : 1;
|
||||||
|
const int64_t dilationHeight = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 0) : 1;
|
||||||
|
const int64_t dilationWidth = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 1) : 1;
|
||||||
|
|
||||||
|
int64_t padHeightBegin = 0;
|
||||||
|
int64_t padHeightEnd = 0;
|
||||||
|
int64_t padWidthBegin = 0;
|
||||||
|
int64_t padWidthEnd = 0;
|
||||||
|
|
||||||
|
if (padsAttr) {
|
||||||
|
padHeightBegin = getI64FromArrayAttr(*padsAttr, 0);
|
||||||
|
padWidthBegin = getI64FromArrayAttr(*padsAttr, 1);
|
||||||
|
padHeightEnd = getI64FromArrayAttr(*padsAttr, 2);
|
||||||
|
padWidthEnd = getI64FromArrayAttr(*padsAttr, 3);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
// Compute padding from auto_pad attribute
|
||||||
|
const auto autoPad = convOp.getAutoPad();
|
||||||
|
if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
|
||||||
|
const int64_t effectiveKernelH = (wHeight - 1) * dilationHeight + 1;
|
||||||
|
const int64_t effectiveKernelW = (wWidth - 1) * dilationWidth + 1;
|
||||||
|
const int64_t totalPadH =
|
||||||
|
std::max(static_cast<int64_t>(0), (outHeight - 1) * strideHeight + effectiveKernelH - xHeight);
|
||||||
|
const int64_t totalPadW =
|
||||||
|
std::max(static_cast<int64_t>(0), (outWidth - 1) * strideWidth + effectiveKernelW - xWidth);
|
||||||
|
|
||||||
|
if (autoPad == "SAME_UPPER") {
|
||||||
|
padHeightBegin = totalPadH / 2;
|
||||||
|
padHeightEnd = totalPadH - padHeightBegin;
|
||||||
|
padWidthBegin = totalPadW / 2;
|
||||||
|
padWidthEnd = totalPadW - padWidthBegin;
|
||||||
|
}
|
||||||
|
else { // SAME_LOWER
|
||||||
|
padHeightEnd = totalPadH / 2;
|
||||||
|
padHeightBegin = totalPadH - padHeightEnd;
|
||||||
|
padWidthEnd = totalPadW / 2;
|
||||||
|
padWidthBegin = totalPadW - padWidthEnd;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// "NOTSET" or "VALID" -> all pads stay 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// im2col layout (flipped with respect to the standard, so filters sit in B = crossbar):
|
||||||
|
// A (im2col): [numPatches, patchSize] -- one row per output spatial position
|
||||||
|
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
|
||||||
|
// Gemm output: [numPatches, cOut]
|
||||||
|
const int64_t patchSize = numChannelsIn * wHeight * wWidth;
|
||||||
|
const int64_t numPatchesPerBatch = outHeight * outWidth;
|
||||||
|
const int64_t numPatches = batchSize * numPatchesPerBatch;
|
||||||
|
|
||||||
|
auto elemType = xType.getElementType();
|
||||||
|
auto im2colType = RankedTensorType::get({numPatches, patchSize}, elemType);
|
||||||
|
auto rowType = RankedTensorType::get({1, patchSize}, elemType);
|
||||||
|
auto wFlatType = RankedTensorType::get({numChannelsOut, patchSize}, wType.getElementType());
|
||||||
|
auto wTransType = RankedTensorType::get({patchSize, numChannelsOut}, wType.getElementType());
|
||||||
|
auto gemmOutType = RankedTensorType::get({numPatches, numChannelsOut}, outType.getElementType());
|
||||||
|
auto nhwcType = RankedTensorType::get({batchSize, outHeight, outWidth, numChannelsOut}, outType.getElementType());
|
||||||
|
|
||||||
|
const int64_t xbarSize = static_cast<int64_t>(crossbarSize.getValue());
|
||||||
|
const int64_t wMaxDim = std::max(patchSize, numChannelsOut);
|
||||||
|
const int64_t maxParallelPixels = std::max<int64_t>(1, xbarSize / wMaxDim);
|
||||||
|
auto wDenseAttr = getDenseConstantAttr(w);
|
||||||
|
|
||||||
|
// Prepare weight matrix W for crossbar storage:
|
||||||
|
// W: [numChannelsOut, numChannelsIn, wHeight, wWidth] -> [numChannelsOut, patchSize] -> [patchSize, numChannelsOut]
|
||||||
|
Value wFlat = tensor::CollapseShapeOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
wFlatType,
|
||||||
|
w,
|
||||||
|
SmallVector<ReassociationIndices> {
|
||||||
|
{0},
|
||||||
|
{1, 2, 3}
|
||||||
|
});
|
||||||
|
Value wTrans = ONNXTransposeOp::create(rewriter, loc, wTransType, wFlat, rewriter.getI64ArrayAttr({1, 0}));
|
||||||
|
|
||||||
|
// Pass bias through directly; Gemm handles rank-1 C canonicalization.
|
||||||
|
bool hasB = !isa<ONNXNoneOp>(b.getDefiningOp());
|
||||||
|
Value gemmC = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
|
||||||
|
Value biasMatrix;
|
||||||
|
DenseElementsAttr biasDenseAttr;
|
||||||
|
if (hasB) {
|
||||||
|
gemmC = b;
|
||||||
|
biasDenseAttr = getDenseConstantAttr(b);
|
||||||
|
biasMatrix = expandBiasIfNeeded(b, rewriter, loc);
|
||||||
|
}
|
||||||
|
const bool canPackWeightsAsConstants = static_cast<bool>(wDenseAttr);
|
||||||
|
const bool canPackBiasAsConstants = !hasB || static_cast<bool>(biasDenseAttr);
|
||||||
|
const int64_t effectiveMaxParallelPixels =
|
||||||
|
(canPackWeightsAsConstants && canPackBiasAsConstants) ? maxParallelPixels : 1;
|
||||||
|
|
||||||
|
Value im2col = createIm2colCompute(x,
|
||||||
|
xType,
|
||||||
|
im2colType,
|
||||||
|
rowType,
|
||||||
|
batchSize,
|
||||||
|
numChannelsIn,
|
||||||
|
xHeight,
|
||||||
|
xWidth,
|
||||||
|
wHeight,
|
||||||
|
wWidth,
|
||||||
|
padHeightBegin,
|
||||||
|
padHeightEnd,
|
||||||
|
padWidthBegin,
|
||||||
|
padWidthEnd,
|
||||||
|
strideHeight,
|
||||||
|
strideWidth,
|
||||||
|
dilationHeight,
|
||||||
|
dilationWidth,
|
||||||
|
outWidth,
|
||||||
|
patchSize,
|
||||||
|
numPatches,
|
||||||
|
numPatchesPerBatch,
|
||||||
|
rewriter,
|
||||||
|
loc);
|
||||||
|
|
||||||
|
Value gemmOut;
|
||||||
|
if (effectiveMaxParallelPixels == 1) {
|
||||||
|
// Fallback to the plain im2col GEMM when a single crossbar cannot fit multiple pixels.
|
||||||
|
gemmOut = ONNXGemmOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
gemmOutType,
|
||||||
|
im2col,
|
||||||
|
wTrans,
|
||||||
|
gemmC,
|
||||||
|
rewriter.getF32FloatAttr(1.0f),
|
||||||
|
rewriter.getF32FloatAttr(1.0f),
|
||||||
|
rewriter.getBoolAttr(false),
|
||||||
|
rewriter.getBoolAttr(false))
|
||||||
|
.getY();
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
// Keep the standard im2col view of convolution:
|
||||||
|
// A (im2col): [numPatches, patchSize] -- one row per output spatial position
|
||||||
|
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
|
||||||
|
// but repack several old rows into one new row so we use the available crossbar size better.
|
||||||
|
//
|
||||||
|
// We want to process N spatial pixels at the exact same time. Instead of doing N separate
|
||||||
|
// operations of (1 x patchSize) x (patchSize x cOut), we construct a block-diagonal weight matrix
|
||||||
|
// containing N copies of W^T and concatenate N im2col rows into one longer row:
|
||||||
|
// A_packed: [ceil(numPatches / N), N * patchSize]
|
||||||
|
// B_packed: [N * patchSize, N * cOut]
|
||||||
|
// Y_packed: [ceil(numPatches / N), N * cOut]
|
||||||
|
// The downstream GemmToManyGemv pass still splits by row, but now there are fewer, longer rows.
|
||||||
|
const int64_t packedNumRows = ceilIntegerDivide(numPatches, effectiveMaxParallelPixels);
|
||||||
|
auto packedOutType =
|
||||||
|
RankedTensorType::get({packedNumRows, effectiveMaxParallelPixels * numChannelsOut}, outType.getElementType());
|
||||||
|
|
||||||
|
Value packedA = createPackedIm2colRows(
|
||||||
|
im2col, im2colType, elemType, numPatches, patchSize, effectiveMaxParallelPixels, rewriter, loc);
|
||||||
|
Value packedB = buildPackedWeight(wDenseAttr,
|
||||||
|
wTrans,
|
||||||
|
wType,
|
||||||
|
numChannelsIn,
|
||||||
|
numChannelsOut,
|
||||||
|
wHeight,
|
||||||
|
wWidth,
|
||||||
|
patchSize,
|
||||||
|
effectiveMaxParallelPixels,
|
||||||
|
rewriter,
|
||||||
|
loc);
|
||||||
|
Value packedC = buildPackedBias(
|
||||||
|
hasB, gemmC, biasMatrix, biasDenseAttr, outType, numChannelsOut, effectiveMaxParallelPixels, rewriter, loc);
|
||||||
|
Value packedOut = ONNXGemmOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
packedOutType,
|
||||||
|
packedA,
|
||||||
|
packedB,
|
||||||
|
packedC,
|
||||||
|
rewriter.getF32FloatAttr(1.0f),
|
||||||
|
rewriter.getF32FloatAttr(1.0f),
|
||||||
|
rewriter.getBoolAttr(false),
|
||||||
|
rewriter.getBoolAttr(false))
|
||||||
|
.getY();
|
||||||
|
gemmOut = createUnpackedOutput(
|
||||||
|
packedOut, gemmOutType, outType, numPatches, numChannelsOut, effectiveMaxParallelPixels, rewriter, loc);
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.replaceOp(convOp, createCollectedConvOutput(gemmOut, convOp.getType(), nhwcType, outType, rewriter, loc));
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user