570 lines
29 KiB
C++
570 lines
29 KiB
C++
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/IR/BuiltinAttributes.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
|
|
#include "llvm/ADT/SmallVector.h"
|
|
|
|
#include <algorithm>
|
|
#include <cassert>
|
|
|
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace onnx_mlir {
|
|
namespace {
|
|
|
|
struct ConvToGemm : OpConversionPattern<ONNXConvOp> {
|
|
using OpConversionPattern::OpConversionPattern;
|
|
|
|
LogicalResult matchAndRewrite(ONNXConvOp convOp,
|
|
ONNXConvOpAdaptor convOpAdaptor,
|
|
ConversionPatternRewriter& rewriter) const override;
|
|
};
|
|
|
|
static DenseElementsAttr getDenseConstantAttr(Value value) {
|
|
if (auto constantOp = value.getDefiningOp<arith::ConstantOp>())
|
|
return dyn_cast<DenseElementsAttr>(constantOp.getValue());
|
|
|
|
if (auto constantOp = value.getDefiningOp<ONNXConstantOp>())
|
|
return dyn_cast_or_null<DenseElementsAttr>(constantOp.getValueAttr());
|
|
|
|
return nullptr;
|
|
}
|
|
|
|
static int64_t getI64FromArrayAttr(ArrayAttr arr, size_t idx) { return cast<IntegerAttr>(arr[idx]).getInt(); }
|
|
|
|
static Value expandBiasIfNeeded(Value bias, ConversionPatternRewriter& rewriter, Location loc) {
|
|
auto biasType = cast<RankedTensorType>(bias.getType());
|
|
if (biasType.getRank() != 1)
|
|
return bias;
|
|
|
|
auto expandedBiasType = RankedTensorType::get({1, biasType.getDimSize(0)}, biasType.getElementType());
|
|
return tensor::ExpandShapeOp::create(rewriter,
|
|
loc,
|
|
expandedBiasType,
|
|
bias,
|
|
SmallVector<ReassociationIndices> {
|
|
{0, 1}
|
|
});
|
|
}
|
|
|
|
static Value createPaddedRows(Value tensorValue,
|
|
RankedTensorType tensorType,
|
|
int64_t paddedRows,
|
|
ConversionPatternRewriter& rewriter,
|
|
Location loc) {
|
|
if (tensorType.getDimSize(0) == paddedRows)
|
|
return tensorValue;
|
|
|
|
auto paddedType = RankedTensorType::get({paddedRows, tensorType.getDimSize(1)}, tensorType.getElementType());
|
|
SmallVector<OpFoldResult> lowPads = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
|
|
SmallVector<OpFoldResult> highPads = {rewriter.getIndexAttr(paddedRows - tensorType.getDimSize(0)),
|
|
rewriter.getIndexAttr(0)};
|
|
auto padOp = tensor::PadOp::create(rewriter, loc, paddedType, tensorValue, lowPads, highPads);
|
|
auto* padBlock = new Block();
|
|
for (int i = 0; i < 2; i++)
|
|
padBlock->addArgument(rewriter.getIndexType(), loc);
|
|
padOp.getRegion().push_back(padBlock);
|
|
rewriter.setInsertionPointToStart(padBlock);
|
|
auto zero = arith::ConstantOp::create(
|
|
rewriter, loc, tensorType.getElementType(), rewriter.getZeroAttr(tensorType.getElementType()));
|
|
tensor::YieldOp::create(rewriter, loc, zero.getResult());
|
|
rewriter.setInsertionPointAfter(padOp);
|
|
return padOp.getResult();
|
|
}
|
|
|
|
static Value buildPackedWeight(DenseElementsAttr wDenseAttr,
|
|
Value wTrans,
|
|
RankedTensorType wType,
|
|
int64_t numChannelsIn,
|
|
int64_t numChannelsOut,
|
|
int64_t wHeight,
|
|
int64_t wWidth,
|
|
int64_t patchSize,
|
|
int64_t packFactor,
|
|
ConversionPatternRewriter& rewriter,
|
|
Location loc) {
|
|
if (packFactor == 1)
|
|
return wTrans;
|
|
|
|
auto packedWeightType =
|
|
RankedTensorType::get({packFactor * patchSize, packFactor * numChannelsOut}, wType.getElementType());
|
|
SmallVector<Attribute> sourceValues(wDenseAttr.getValues<Attribute>());
|
|
SmallVector<Attribute> packedValues(packedWeightType.getNumElements(),
|
|
cast<Attribute>(rewriter.getZeroAttr(wType.getElementType())));
|
|
|
|
for (int64_t copyId = 0; copyId < packFactor; copyId++) {
|
|
for (int64_t outChannel = 0; outChannel < numChannelsOut; outChannel++) {
|
|
for (int64_t inChannel = 0; inChannel < numChannelsIn; inChannel++) {
|
|
for (int64_t kernelH = 0; kernelH < wHeight; kernelH++) {
|
|
for (int64_t kernelW = 0; kernelW < wWidth; kernelW++) {
|
|
const int64_t sourceFlatIndex =
|
|
(((outChannel * numChannelsIn) + inChannel) * wHeight + kernelH) * wWidth + kernelW;
|
|
const int64_t patchIndex = ((inChannel * wHeight) + kernelH) * wWidth + kernelW;
|
|
const int64_t targetRow = copyId * patchSize + patchIndex;
|
|
const int64_t targetCol = copyId * numChannelsOut + outChannel;
|
|
packedValues[targetRow * (packFactor * numChannelsOut) + targetCol] = sourceValues[sourceFlatIndex];
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
auto packedAttr = DenseElementsAttr::get(packedWeightType, packedValues);
|
|
return arith::ConstantOp::create(rewriter, loc, packedWeightType, packedAttr);
|
|
}
|
|
|
|
static Value buildPackedBias(bool hasBias,
|
|
Value gemmBias,
|
|
Value biasMatrix,
|
|
DenseElementsAttr biasDenseAttr,
|
|
RankedTensorType outType,
|
|
int64_t numChannelsOut,
|
|
int64_t packFactor,
|
|
ConversionPatternRewriter& rewriter,
|
|
Location loc) {
|
|
if (!hasBias)
|
|
return gemmBias;
|
|
|
|
if (packFactor == 1)
|
|
return biasMatrix;
|
|
|
|
SmallVector<Attribute> sourceValues(biasDenseAttr.getValues<Attribute>());
|
|
SmallVector<Attribute> packedValues;
|
|
packedValues.reserve(packFactor * numChannelsOut);
|
|
for (int64_t copyId = 0; copyId < packFactor; copyId++)
|
|
packedValues.append(sourceValues.begin(), sourceValues.end());
|
|
|
|
auto packedBiasType = RankedTensorType::get({1, packFactor * numChannelsOut}, outType.getElementType());
|
|
auto packedBiasAttr = DenseElementsAttr::get(packedBiasType, packedValues);
|
|
return arith::ConstantOp::create(rewriter, loc, packedBiasType, packedBiasAttr).getResult();
|
|
}
|
|
|
|
static Value createIm2colRowComputes(Value x,
|
|
RankedTensorType xType,
|
|
RankedTensorType im2colType,
|
|
RankedTensorType im2colRowType,
|
|
RankedTensorType gemmInputRowsType,
|
|
int64_t batchSize,
|
|
int64_t numChannelsIn,
|
|
int64_t xHeight,
|
|
int64_t xWidth,
|
|
int64_t wHeight,
|
|
int64_t wWidth,
|
|
int64_t padHeightBegin,
|
|
int64_t padHeightEnd,
|
|
int64_t padWidthBegin,
|
|
int64_t padWidthEnd,
|
|
int64_t strideHeight,
|
|
int64_t strideWidth,
|
|
int64_t dilationHeight,
|
|
int64_t dilationWidth,
|
|
int64_t outWidth,
|
|
int64_t patchSize,
|
|
int64_t numPatches,
|
|
int64_t numPatchesPerBatch,
|
|
int64_t packFactor,
|
|
ConversionPatternRewriter& rewriter,
|
|
Location loc) {
|
|
auto elemType = xType.getElementType();
|
|
constexpr size_t numInputs = 1;
|
|
const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor);
|
|
auto im2colComputeOp =
|
|
createSpatCompute<numInputs>(rewriter, loc, TypeRange {gemmInputRowsType}, {}, x, [&](Value xArg) {
|
|
Value paddedInput = xArg;
|
|
|
|
// Pad input with zeros if needed:
|
|
// [1, numChannelsIn, xHeight, xWidth] -> [1, numChannelsIn, xHeight+padHeight, xWidth+padWidth]
|
|
if (padHeightBegin || padHeightEnd || padWidthBegin || padWidthEnd) {
|
|
const int64_t paddedHeight = xHeight + padHeightBegin + padHeightEnd;
|
|
const int64_t paddedWidth = xWidth + padWidthBegin + padWidthEnd;
|
|
auto paddedType = RankedTensorType::get({batchSize, numChannelsIn, paddedHeight, paddedWidth}, elemType);
|
|
SmallVector<OpFoldResult> lowPads = {rewriter.getIndexAttr(0),
|
|
rewriter.getIndexAttr(0),
|
|
rewriter.getIndexAttr(padHeightBegin),
|
|
rewriter.getIndexAttr(padWidthBegin)};
|
|
SmallVector<OpFoldResult> highPads = {rewriter.getIndexAttr(0),
|
|
rewriter.getIndexAttr(0),
|
|
rewriter.getIndexAttr(padHeightEnd),
|
|
rewriter.getIndexAttr(padWidthEnd)};
|
|
auto padOp = tensor::PadOp::create(rewriter, loc, paddedType, paddedInput, lowPads, highPads);
|
|
auto* padBlock = new Block();
|
|
for (int i = 0; i < 4; i++)
|
|
padBlock->addArgument(rewriter.getIndexType(), loc);
|
|
padOp.getRegion().push_back(padBlock);
|
|
rewriter.setInsertionPointToStart(padBlock);
|
|
auto zero = arith::ConstantOp::create(rewriter, loc, elemType, rewriter.getFloatAttr(elemType, 0.0));
|
|
tensor::YieldOp::create(rewriter, loc, zero.getResult());
|
|
rewriter.setInsertionPointAfter(padOp);
|
|
paddedInput = padOp.getResult();
|
|
}
|
|
|
|
// Build im2col [numPatches, patchSize] incrementally to keep the IR small
|
|
// until the late PIM unrolling step.
|
|
Value im2colInit = tensor::EmptyOp::create(rewriter, loc, im2colType.getShape(), elemType);
|
|
auto c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
|
|
auto c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
|
|
auto cNumPatches = arith::ConstantIndexOp::create(rewriter, loc, numPatches);
|
|
auto cNumPatchesPerBatch = arith::ConstantIndexOp::create(rewriter, loc, numPatchesPerBatch);
|
|
auto cOutWidth = arith::ConstantIndexOp::create(rewriter, loc, outWidth);
|
|
auto cStrideHeight = arith::ConstantIndexOp::create(rewriter, loc, strideHeight);
|
|
auto cStrideWidth = arith::ConstantIndexOp::create(rewriter, loc, strideWidth);
|
|
|
|
auto im2colLoop = scf::ForOp::create(rewriter, loc, c0, cNumPatches, c1, ValueRange {im2colInit});
|
|
rewriter.setInsertionPointToStart(im2colLoop.getBody());
|
|
|
|
Value patchIndex = im2colLoop.getInductionVar();
|
|
Value im2colAcc = im2colLoop.getRegionIterArgs().front();
|
|
|
|
Value batchIndex = arith::DivUIOp::create(rewriter, loc, patchIndex, cNumPatchesPerBatch);
|
|
Value batchPatchIndex = arith::RemUIOp::create(rewriter, loc, patchIndex, cNumPatchesPerBatch);
|
|
Value outHeightIndex = arith::DivUIOp::create(rewriter, loc, batchPatchIndex, cOutWidth);
|
|
Value outWidthIndex = arith::RemUIOp::create(rewriter, loc, batchPatchIndex, cOutWidth);
|
|
Value inputHeightOffset = arith::MulIOp::create(rewriter, loc, outHeightIndex, cStrideHeight);
|
|
Value inputWidthOffset = arith::MulIOp::create(rewriter, loc, outWidthIndex, cStrideWidth);
|
|
|
|
SmallVector<OpFoldResult> offsets = {batchIndex, rewriter.getIndexAttr(0), inputHeightOffset, inputWidthOffset};
|
|
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
|
|
rewriter.getIndexAttr(numChannelsIn),
|
|
rewriter.getIndexAttr(wHeight),
|
|
rewriter.getIndexAttr(wWidth)};
|
|
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1),
|
|
rewriter.getIndexAttr(1),
|
|
rewriter.getIndexAttr(dilationHeight),
|
|
rewriter.getIndexAttr(dilationWidth)};
|
|
auto patchType = RankedTensorType::get({1, numChannelsIn, wHeight, wWidth}, elemType);
|
|
Value patch = tensor::ExtractSliceOp::create(rewriter, loc, patchType, paddedInput, offsets, sizes, strides);
|
|
|
|
Value row = tensor::CollapseShapeOp::create(rewriter,
|
|
loc,
|
|
im2colRowType,
|
|
patch,
|
|
SmallVector<ReassociationIndices> {
|
|
{0},
|
|
{1, 2, 3}
|
|
});
|
|
|
|
SmallVector<OpFoldResult> rowOffsets = {patchIndex, rewriter.getIndexAttr(0)};
|
|
SmallVector<OpFoldResult> rowSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(patchSize)};
|
|
SmallVector<OpFoldResult> rowStrides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
|
Value updatedIm2col =
|
|
tensor::InsertSliceOp::create(rewriter, loc, row, im2colAcc, rowOffsets, rowSizes, rowStrides);
|
|
scf::YieldOp::create(rewriter, loc, updatedIm2col);
|
|
|
|
rewriter.setInsertionPointAfter(im2colLoop);
|
|
Value im2col = im2colLoop.getResult(0);
|
|
|
|
Value gemmInputRows = im2col;
|
|
if (packFactor != 1) {
|
|
const int64_t paddedNumPatches = packedNumRows * packFactor;
|
|
auto groupedType = RankedTensorType::get({packedNumRows, packFactor, patchSize}, elemType);
|
|
auto packedType = RankedTensorType::get({packedNumRows, packFactor * patchSize}, elemType);
|
|
Value paddedIm2col = createPaddedRows(im2col, im2colType, paddedNumPatches, rewriter, loc);
|
|
Value groupedIm2col = tensor::ExpandShapeOp::create(rewriter,
|
|
loc,
|
|
groupedType,
|
|
paddedIm2col,
|
|
SmallVector<ReassociationIndices> {
|
|
{0, 1},
|
|
{2}
|
|
});
|
|
gemmInputRows = tensor::CollapseShapeOp::create(rewriter,
|
|
loc,
|
|
packedType,
|
|
groupedIm2col,
|
|
SmallVector<ReassociationIndices> {
|
|
{0},
|
|
{1, 2}
|
|
});
|
|
}
|
|
|
|
spatial::SpatYieldOp::create(rewriter, loc, gemmInputRows);
|
|
});
|
|
|
|
return im2colComputeOp.getResult(0);
|
|
}
|
|
|
|
static Value createCollectedConvOutput(ValueRange gemmRows,
|
|
Type convType,
|
|
RankedTensorType gemmOutType,
|
|
RankedTensorType nhwcType,
|
|
RankedTensorType outType,
|
|
int64_t numPatches,
|
|
int64_t numChannelsOut,
|
|
int64_t packFactor,
|
|
ConversionPatternRewriter& rewriter,
|
|
Location loc) {
|
|
const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor);
|
|
const int64_t paddedNumPatches = packedNumRows * packFactor;
|
|
auto collectComputeOp = createSpatCompute(rewriter, loc, convType, {}, gemmRows, [&](ValueRange gemmRowArgs) {
|
|
Value gemmOut;
|
|
if (packFactor == 1) {
|
|
gemmOut = createSpatConcat(rewriter, loc, /*axis=*/0, gemmRowArgs);
|
|
}
|
|
else {
|
|
auto expandedType = RankedTensorType::get({packedNumRows, packFactor, numChannelsOut}, outType.getElementType());
|
|
auto paddedType = RankedTensorType::get({paddedNumPatches, numChannelsOut}, outType.getElementType());
|
|
Value packedOutput = createSpatConcat(rewriter, loc, /*axis=*/0, gemmRowArgs);
|
|
Value expandedOutput = tensor::ExpandShapeOp::create(rewriter,
|
|
loc,
|
|
expandedType,
|
|
packedOutput,
|
|
SmallVector<ReassociationIndices> {
|
|
{0},
|
|
{1, 2}
|
|
});
|
|
Value paddedOutput = tensor::CollapseShapeOp::create(rewriter,
|
|
loc,
|
|
paddedType,
|
|
expandedOutput,
|
|
SmallVector<ReassociationIndices> {
|
|
{0, 1},
|
|
{2}
|
|
});
|
|
|
|
gemmOut = paddedOutput;
|
|
if (paddedNumPatches != numPatches) {
|
|
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
|
|
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(numPatches), rewriter.getIndexAttr(numChannelsOut)};
|
|
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
|
gemmOut = tensor::ExtractSliceOp::create(rewriter, loc, gemmOutType, paddedOutput, offsets, sizes, strides);
|
|
}
|
|
}
|
|
|
|
// Restore to NCHW layout:
|
|
// [numPatches, numChannelsOut]
|
|
// -> [1, outHeight, outWidth, numChannelsOut]
|
|
// -> [1, numChannelsOut, outHeight, outWidth]
|
|
Value nhwcOut = tensor::ExpandShapeOp::create(rewriter,
|
|
loc,
|
|
nhwcType,
|
|
gemmOut,
|
|
SmallVector<ReassociationIndices> {
|
|
{0, 1, 2},
|
|
{3}
|
|
});
|
|
Value nchwOut = ONNXTransposeOp::create(rewriter, loc, outType, nhwcOut, rewriter.getI64ArrayAttr({0, 3, 1, 2}));
|
|
spatial::SpatYieldOp::create(rewriter, loc, nchwOut);
|
|
});
|
|
return collectComputeOp.getResult(0);
|
|
}
|
|
|
|
} // namespace
|
|
|
|
LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
|
ONNXConvOpAdaptor convOpAdaptor,
|
|
ConversionPatternRewriter& rewriter) const {
|
|
Location loc = convOp.getLoc();
|
|
Value x = convOpAdaptor.getX();
|
|
Value w = convOpAdaptor.getW();
|
|
Value b = convOpAdaptor.getB();
|
|
|
|
auto xType = cast<RankedTensorType>(x.getType());
|
|
auto wType = cast<RankedTensorType>(w.getType());
|
|
auto outType = cast<RankedTensorType>(convOp.getY().getType());
|
|
|
|
assert("Only support static shapes" && xType.hasStaticShape() && wType.hasStaticShape() && outType.hasStaticShape());
|
|
assert("Only support 2D convolution" && xType.getRank() == 4);
|
|
|
|
// We need to understand what is group
|
|
assert("Only support group=1" && convOp.getGroup() == 1);
|
|
|
|
const int64_t batchSize = xType.getDimSize(0);
|
|
const int64_t numChannelsIn = xType.getDimSize(1);
|
|
const int64_t xHeight = xType.getDimSize(2);
|
|
const int64_t xWidth = xType.getDimSize(3);
|
|
const int64_t numChannelsOut = wType.getDimSize(0);
|
|
const int64_t wHeight = wType.getDimSize(2);
|
|
const int64_t wWidth = wType.getDimSize(3);
|
|
const int64_t outHeight = outType.getDimSize(2);
|
|
const int64_t outWidth = outType.getDimSize(3);
|
|
|
|
// Read optional conv attributes (ONNX defaults: stride=1, dilation=1, pad=0)
|
|
const auto stridesAttr = convOp.getStrides();
|
|
const auto dilationsAttr = convOp.getDilations();
|
|
const auto padsAttr = convOp.getPads();
|
|
|
|
const int64_t strideHeight = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 0) : 1;
|
|
const int64_t strideWidth = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 1) : 1;
|
|
const int64_t dilationHeight = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 0) : 1;
|
|
const int64_t dilationWidth = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 1) : 1;
|
|
|
|
int64_t padHeightBegin = 0;
|
|
int64_t padHeightEnd = 0;
|
|
int64_t padWidthBegin = 0;
|
|
int64_t padWidthEnd = 0;
|
|
|
|
if (padsAttr) {
|
|
padHeightBegin = getI64FromArrayAttr(*padsAttr, 0);
|
|
padWidthBegin = getI64FromArrayAttr(*padsAttr, 1);
|
|
padHeightEnd = getI64FromArrayAttr(*padsAttr, 2);
|
|
padWidthEnd = getI64FromArrayAttr(*padsAttr, 3);
|
|
}
|
|
else {
|
|
// Compute padding from auto_pad attribute
|
|
const auto autoPad = convOp.getAutoPad();
|
|
if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
|
|
const int64_t effectiveKernelH = (wHeight - 1) * dilationHeight + 1;
|
|
const int64_t effectiveKernelW = (wWidth - 1) * dilationWidth + 1;
|
|
const int64_t totalPadH =
|
|
std::max(static_cast<int64_t>(0), (outHeight - 1) * strideHeight + effectiveKernelH - xHeight);
|
|
const int64_t totalPadW =
|
|
std::max(static_cast<int64_t>(0), (outWidth - 1) * strideWidth + effectiveKernelW - xWidth);
|
|
|
|
if (autoPad == "SAME_UPPER") {
|
|
padHeightBegin = totalPadH / 2;
|
|
padHeightEnd = totalPadH - padHeightBegin;
|
|
padWidthBegin = totalPadW / 2;
|
|
padWidthEnd = totalPadW - padWidthBegin;
|
|
}
|
|
else { // SAME_LOWER
|
|
padHeightEnd = totalPadH / 2;
|
|
padHeightBegin = totalPadH - padHeightEnd;
|
|
padWidthEnd = totalPadW / 2;
|
|
padWidthBegin = totalPadW - padWidthEnd;
|
|
}
|
|
}
|
|
// "NOTSET" or "VALID" -> all pads stay 0
|
|
}
|
|
|
|
// im2col layout (flipped with respect to the standard, so filters sit in B = crossbar):
|
|
// A (im2col): [numPatches, patchSize] -- one row per output spatial position
|
|
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
|
|
// Gemm output: [numPatches, cOut]
|
|
const int64_t patchSize = numChannelsIn * wHeight * wWidth;
|
|
const int64_t numPatchesPerBatch = outHeight * outWidth;
|
|
const int64_t numPatches = batchSize * numPatchesPerBatch;
|
|
|
|
auto elemType = xType.getElementType();
|
|
auto im2colType = RankedTensorType::get({numPatches, patchSize}, elemType);
|
|
auto rowType = RankedTensorType::get({1, patchSize}, elemType);
|
|
auto wFlatType = RankedTensorType::get({numChannelsOut, patchSize}, wType.getElementType());
|
|
auto wTransType = RankedTensorType::get({patchSize, numChannelsOut}, wType.getElementType());
|
|
auto gemmOutType = RankedTensorType::get({numPatches, numChannelsOut}, outType.getElementType());
|
|
auto nhwcType = RankedTensorType::get({batchSize, outHeight, outWidth, numChannelsOut}, outType.getElementType());
|
|
|
|
const int64_t xbarSize = static_cast<int64_t>(crossbarSize.getValue());
|
|
const int64_t wMaxDim = std::max(patchSize, numChannelsOut);
|
|
const int64_t maxParallelPixels = std::max<int64_t>(1, xbarSize / wMaxDim);
|
|
auto wDenseAttr = getDenseConstantAttr(w);
|
|
|
|
// Prepare weight matrix W for crossbar storage:
|
|
// W: [numChannelsOut, numChannelsIn, wHeight, wWidth] -> [numChannelsOut, patchSize] -> [patchSize, numChannelsOut]
|
|
Value wFlat = tensor::CollapseShapeOp::create(rewriter,
|
|
loc,
|
|
wFlatType,
|
|
w,
|
|
SmallVector<ReassociationIndices> {
|
|
{0},
|
|
{1, 2, 3}
|
|
});
|
|
Value wTrans = ONNXTransposeOp::create(rewriter, loc, wTransType, wFlat, rewriter.getI64ArrayAttr({1, 0}));
|
|
|
|
// Pass bias through directly; Gemm handles rank-1 C canonicalization.
|
|
bool hasB = !isa<ONNXNoneOp>(b.getDefiningOp());
|
|
Value gemmBias = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
|
|
Value biasMatrix;
|
|
DenseElementsAttr biasDenseAttr;
|
|
if (hasB) {
|
|
gemmBias = b;
|
|
biasDenseAttr = getDenseConstantAttr(b);
|
|
biasMatrix = expandBiasIfNeeded(b, rewriter, loc);
|
|
}
|
|
const bool canPackWeightsAsConstants = static_cast<bool>(wDenseAttr);
|
|
const bool canPackBiasAsConstants = !hasB || static_cast<bool>(biasDenseAttr);
|
|
const int64_t effectiveMaxParallelPixels =
|
|
(canPackWeightsAsConstants && canPackBiasAsConstants) ? maxParallelPixels : 1;
|
|
|
|
// Keep the standard im2col view of convolution:
|
|
// A (im2col): [numPatches, patchSize] -- one row per output spatial position
|
|
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
|
|
// and optionally repack several old rows into one GEMM row to use the available crossbar size better.
|
|
//
|
|
// We want to process N pixels at the same time. Instead of doing N separate operations
|
|
// of (1 x patchSize) x (patchSize x cOut), we construct a block-diagonal weight matrix
|
|
// containing N copies of W^T and concatenate N im2col rows into one longer row:
|
|
// A_packed: [ceil(numPatches / N), N * patchSize]
|
|
// B_packed: [N * patchSize, N * cOut]
|
|
// Y_packed: [ceil(numPatches / N), N * cOut]
|
|
const int64_t packedNumRows = ceilIntegerDivide(numPatches, effectiveMaxParallelPixels);
|
|
auto gemmInputRowsType = RankedTensorType::get({packedNumRows, effectiveMaxParallelPixels * patchSize}, elemType);
|
|
auto gemmOutputRowsType =
|
|
RankedTensorType::get({packedNumRows, effectiveMaxParallelPixels * numChannelsOut}, outType.getElementType());
|
|
Value gemmInputRows = createIm2colRowComputes(x,
|
|
xType,
|
|
im2colType,
|
|
rowType,
|
|
gemmInputRowsType,
|
|
batchSize,
|
|
numChannelsIn,
|
|
xHeight,
|
|
xWidth,
|
|
wHeight,
|
|
wWidth,
|
|
padHeightBegin,
|
|
padHeightEnd,
|
|
padWidthBegin,
|
|
padWidthEnd,
|
|
strideHeight,
|
|
strideWidth,
|
|
dilationHeight,
|
|
dilationWidth,
|
|
outWidth,
|
|
patchSize,
|
|
numPatches,
|
|
numPatchesPerBatch,
|
|
effectiveMaxParallelPixels,
|
|
rewriter,
|
|
loc);
|
|
|
|
Value gemmB = buildPackedWeight(wDenseAttr,
|
|
wTrans,
|
|
wType,
|
|
numChannelsIn,
|
|
numChannelsOut,
|
|
wHeight,
|
|
wWidth,
|
|
patchSize,
|
|
effectiveMaxParallelPixels,
|
|
rewriter,
|
|
loc);
|
|
Value gemmC = buildPackedBias(
|
|
hasB, gemmBias, biasMatrix, biasDenseAttr, outType, numChannelsOut, effectiveMaxParallelPixels, rewriter, loc);
|
|
|
|
Value gemmRows = ONNXGemmOp::create(rewriter,
|
|
loc,
|
|
gemmOutputRowsType,
|
|
gemmInputRows,
|
|
gemmB,
|
|
gemmC,
|
|
rewriter.getF32FloatAttr(1.0f),
|
|
rewriter.getF32FloatAttr(1.0f),
|
|
rewriter.getBoolAttr(false),
|
|
rewriter.getBoolAttr(false))
|
|
.getY();
|
|
|
|
rewriter.replaceOp(convOp,
|
|
createCollectedConvOutput(ValueRange {gemmRows},
|
|
convOp.getType(),
|
|
gemmOutType,
|
|
nhwcType,
|
|
outType,
|
|
numPatches,
|
|
numChannelsOut,
|
|
effectiveMaxParallelPixels,
|
|
rewriter,
|
|
loc));
|
|
return success();
|
|
}
|
|
|
|
void populateConvPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert<ConvToGemm>(ctx); }
|
|
|
|
} // namespace onnx_mlir
|