4f3570520c
reuse code for subviews
404 lines
20 KiB
C++
404 lines
20 KiB
C++
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/IR/BuiltinAttributes.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
|
|
#include "llvm/ADT/APFloat.h"
|
|
#include "llvm/ADT/APInt.h"
|
|
#include "llvm/ADT/SmallVector.h"
|
|
|
|
#include <algorithm>
|
|
#include <optional>
|
|
#include <type_traits>
|
|
|
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace onnx_mlir {
|
|
namespace {
|
|
|
|
template <typename ArrayAttrT>
|
|
static int64_t getI64(ArrayAttrT arrayAttr, size_t index) {
|
|
return cast<IntegerAttr>(arrayAttr[index]).getInt();
|
|
}
|
|
|
|
template <typename ArrayAttrT>
|
|
static int64_t getOptionalI64(std::optional<ArrayAttrT> arrayAttr, size_t index, int64_t defaultValue) {
|
|
return arrayAttr ? getI64(*arrayAttr, index) : defaultValue;
|
|
}
|
|
|
|
static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Location loc, Value tile) {
|
|
auto tileType = cast<RankedTensorType>(tile.getType());
|
|
Value empty = tensor::EmptyOp::create(rewriter, loc, tileType.getShape(), tileType.getElementType());
|
|
|
|
SmallVector<OpFoldResult> offsets(tileType.getRank(), rewriter.getIndexAttr(0));
|
|
SmallVector<OpFoldResult> sizes;
|
|
sizes.reserve(tileType.getRank());
|
|
for (int64_t dimSize : tileType.getShape())
|
|
sizes.push_back(rewriter.getIndexAttr(dimSize));
|
|
SmallVector<OpFoldResult> strides(tileType.getRank(), rewriter.getIndexAttr(1));
|
|
|
|
return tensor::InsertSliceOp::create(rewriter, loc, tile, empty, offsets, sizes, strides);
|
|
}
|
|
|
|
static Value
|
|
createPoolFillElement(ConversionPatternRewriter& rewriter, Location loc, Type elementType, bool useMinimumValue) {
|
|
if (!useMinimumValue)
|
|
return arith::ConstantOp::create(rewriter, loc, elementType, rewriter.getZeroAttr(elementType));
|
|
|
|
if (auto floatType = dyn_cast<FloatType>(elementType)) {
|
|
auto minValue = llvm::APFloat::getInf(floatType.getFloatSemantics(), /*Negative=*/true);
|
|
return arith::ConstantOp::create(rewriter, loc, elementType, rewriter.getFloatAttr(floatType, minValue));
|
|
}
|
|
|
|
if (auto integerType = dyn_cast<IntegerType>(elementType)) {
|
|
auto minValue = llvm::APInt::getSignedMinValue(integerType.getWidth());
|
|
return arith::ConstantOp::create(rewriter, loc, elementType, rewriter.getIntegerAttr(integerType, minValue));
|
|
}
|
|
|
|
llvm_unreachable("unsupported pool element type");
|
|
}
|
|
|
|
static Value createPoolFillTensor(ConversionPatternRewriter& rewriter,
|
|
Location loc,
|
|
RankedTensorType tensorType,
|
|
bool useMinimumValue) {
|
|
auto fillElement = createPoolFillElement(rewriter, loc, tensorType.getElementType(), useMinimumValue);
|
|
return tensor::SplatOp::create(rewriter, loc, tensorType, fillElement);
|
|
}
|
|
|
|
template <typename PoolOp>
|
|
static Value createPaddedPoolInput(ConversionPatternRewriter& rewriter,
|
|
Location loc,
|
|
PoolOp poolOp,
|
|
Value input,
|
|
RankedTensorType inputType,
|
|
int64_t padTop,
|
|
int64_t padLeft,
|
|
int64_t padBottom,
|
|
int64_t padRight) {
|
|
if (padTop == 0 && padLeft == 0 && padBottom == 0 && padRight == 0)
|
|
return input;
|
|
|
|
auto paddedType = RankedTensorType::get({inputType.getDimSize(0),
|
|
inputType.getDimSize(1),
|
|
inputType.getDimSize(2) + padTop + padBottom,
|
|
inputType.getDimSize(3) + padLeft + padRight},
|
|
inputType.getElementType(),
|
|
inputType.getEncoding());
|
|
SmallVector<OpFoldResult> lowPads = {
|
|
rewriter.getIndexAttr(0), rewriter.getIndexAttr(0), rewriter.getIndexAttr(padTop), rewriter.getIndexAttr(padLeft)};
|
|
SmallVector<OpFoldResult> highPads = {rewriter.getIndexAttr(0),
|
|
rewriter.getIndexAttr(0),
|
|
rewriter.getIndexAttr(padBottom),
|
|
rewriter.getIndexAttr(padRight)};
|
|
auto padOp = tensor::PadOp::create(rewriter, loc, paddedType, input, lowPads, highPads);
|
|
auto* padBlock = new Block();
|
|
for (int index = 0; index < paddedType.getRank(); ++index)
|
|
padBlock->addArgument(rewriter.getIndexType(), loc);
|
|
padOp.getRegion().push_back(padBlock);
|
|
rewriter.setInsertionPointToStart(padBlock);
|
|
Value padValue =
|
|
createPoolFillElement(rewriter, loc, inputType.getElementType(), std::is_same_v<PoolOp, ONNXMaxPoolSingleOutOp>);
|
|
tensor::YieldOp::create(rewriter, loc, padValue);
|
|
rewriter.setInsertionPointAfter(padOp);
|
|
return padOp.getResult();
|
|
}
|
|
|
|
static FailureOr<Value> createAverageScaleTensor(ConversionPatternRewriter& rewriter,
|
|
Location loc,
|
|
Operation* op,
|
|
RankedTensorType outType,
|
|
int64_t channels,
|
|
int64_t inputHeight,
|
|
int64_t inputWidth,
|
|
int64_t outputHeight,
|
|
int64_t outputWidth,
|
|
int64_t kernelHeight,
|
|
int64_t kernelWidth,
|
|
int64_t strideHeight,
|
|
int64_t strideWidth,
|
|
int64_t dilationHeight,
|
|
int64_t dilationWidth,
|
|
int64_t padTop,
|
|
int64_t padLeft,
|
|
bool countIncludePad) {
|
|
auto elemType = dyn_cast<FloatType>(outType.getElementType());
|
|
if (!elemType) {
|
|
op->emitOpError("AveragePool lowering requires a floating-point element type");
|
|
return failure();
|
|
}
|
|
|
|
auto scaleType = RankedTensorType::get({1, channels, outputHeight, outputWidth}, elemType, outType.getEncoding());
|
|
SmallVector<Attribute> scaleValues;
|
|
scaleValues.reserve(static_cast<size_t>(channels * outputHeight * outputWidth));
|
|
for (int64_t channel = 0; channel < channels; ++channel) {
|
|
(void) channel;
|
|
for (int64_t outH = 0; outH < outputHeight; ++outH) {
|
|
for (int64_t outW = 0; outW < outputWidth; ++outW) {
|
|
int64_t validCount = 0;
|
|
for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) {
|
|
const int64_t inH = outH * strideHeight + kernelH * dilationHeight - padTop;
|
|
if (inH < 0 || inH >= inputHeight)
|
|
continue;
|
|
for (int64_t kernelW = 0; kernelW < kernelWidth; ++kernelW) {
|
|
const int64_t inW = outW * strideWidth + kernelW * dilationWidth - padLeft;
|
|
if (inW < 0 || inW >= inputWidth)
|
|
continue;
|
|
++validCount;
|
|
}
|
|
}
|
|
|
|
const int64_t divisor = countIncludePad ? kernelHeight * kernelWidth : validCount;
|
|
if (divisor <= 0) {
|
|
op->emitOpError("AveragePool divisor must be positive");
|
|
return failure();
|
|
}
|
|
scaleValues.push_back(rewriter.getFloatAttr(elemType, 1.0 / static_cast<double>(divisor)));
|
|
}
|
|
}
|
|
}
|
|
|
|
auto scaleAttr = DenseElementsAttr::get(scaleType, scaleValues);
|
|
return arith::ConstantOp::create(rewriter, loc, scaleType, scaleAttr).getResult();
|
|
}
|
|
|
|
template <typename PoolOp>
|
|
struct PoolToSpatialCompute;
|
|
|
|
template <typename PoolOp, typename PoolOpAdaptor, typename ReduceOp>
|
|
struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
|
|
using OpConversionPattern<PoolOp>::OpConversionPattern;
|
|
|
|
LogicalResult matchAndRewrite(PoolOp poolOp, PoolOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final {
|
|
Location loc = poolOp.getLoc();
|
|
Value x = adaptor.getX();
|
|
|
|
auto xType = dyn_cast<RankedTensorType>(x.getType());
|
|
auto outType = dyn_cast<RankedTensorType>(poolOp.getResult().getType());
|
|
if (!xType || !outType || !xType.hasStaticShape() || !outType.hasStaticShape())
|
|
return rewriter.notifyMatchFailure(poolOp, "pool lowering requires static ranked tensor types.");
|
|
if (xType.getRank() != 4 || outType.getRank() != 4)
|
|
return rewriter.notifyMatchFailure(poolOp, "only 2D NCHW pool is supported.");
|
|
|
|
ArrayAttr kernelAttr = poolOp.getKernelShape();
|
|
if (!kernelAttr || kernelAttr.size() != 2)
|
|
return rewriter.notifyMatchFailure(poolOp, "pool lowering expects a 2D kernel.");
|
|
|
|
const int64_t batchSize = xType.getDimSize(0);
|
|
const int64_t channels = xType.getDimSize(1);
|
|
const int64_t inputHeight = xType.getDimSize(2);
|
|
const int64_t inputWidth = xType.getDimSize(3);
|
|
const int64_t outputHeight = outType.getDimSize(2);
|
|
const int64_t outputWidth = outType.getDimSize(3);
|
|
const int64_t kernelHeight = getI64(kernelAttr, 0);
|
|
const int64_t kernelWidth = getI64(kernelAttr, 1);
|
|
const int64_t strideHeight = getOptionalI64(poolOp.getStrides(), 0, 1);
|
|
const int64_t strideWidth = getOptionalI64(poolOp.getStrides(), 1, 1);
|
|
const int64_t dilationHeight = getOptionalI64(poolOp.getDilations(), 0, 1);
|
|
const int64_t dilationWidth = getOptionalI64(poolOp.getDilations(), 1, 1);
|
|
|
|
int64_t padTop = 0;
|
|
int64_t padLeft = 0;
|
|
int64_t padBottom = 0;
|
|
int64_t padRight = 0;
|
|
|
|
if (auto padsAttr = poolOp.getPads()) {
|
|
if (padsAttr->size() != 4)
|
|
return rewriter.notifyMatchFailure(poolOp, "pads must have four elements.");
|
|
padTop = getI64(*padsAttr, 0);
|
|
padLeft = getI64(*padsAttr, 1);
|
|
padBottom = getI64(*padsAttr, 2);
|
|
padRight = getI64(*padsAttr, 3);
|
|
}
|
|
else {
|
|
StringRef autoPad = poolOp.getAutoPad();
|
|
if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
|
|
const int64_t effectiveKernelH = (kernelHeight - 1) * dilationHeight + 1;
|
|
const int64_t effectiveKernelW = (kernelWidth - 1) * dilationWidth + 1;
|
|
const int64_t totalPadH =
|
|
std::max<int64_t>(0, (outputHeight - 1) * strideHeight + effectiveKernelH - inputHeight);
|
|
const int64_t totalPadW = std::max<int64_t>(0, (outputWidth - 1) * strideWidth + effectiveKernelW - inputWidth);
|
|
|
|
if (autoPad == "SAME_UPPER") {
|
|
padTop = totalPadH / 2;
|
|
padBottom = totalPadH - padTop;
|
|
padLeft = totalPadW / 2;
|
|
padRight = totalPadW - padLeft;
|
|
}
|
|
else {
|
|
padBottom = totalPadH / 2;
|
|
padTop = totalPadH - padBottom;
|
|
padRight = totalPadW / 2;
|
|
padLeft = totalPadW - padRight;
|
|
}
|
|
}
|
|
else if (autoPad != "NOTSET" && autoPad != "VALID") {
|
|
return rewriter.notifyMatchFailure(poolOp, "unsupported auto_pad value.");
|
|
}
|
|
}
|
|
|
|
const int64_t xbarSize = static_cast<int64_t>(crossbarSize.getValue());
|
|
const int64_t channelTileCount = (channels + xbarSize - 1) / xbarSize;
|
|
const int64_t outputPatchCount = batchSize * outputHeight * outputWidth;
|
|
const bool countIncludePad = [&]() {
|
|
if constexpr (std::is_same_v<PoolOp, ONNXAveragePoolOp>)
|
|
return poolOp.getCountIncludePad() == 1;
|
|
return true;
|
|
}();
|
|
Value averageScaleTensor;
|
|
if constexpr (std::is_same_v<PoolOp, ONNXAveragePoolOp>) {
|
|
auto maybeAverageScaleTensor = createAverageScaleTensor(rewriter,
|
|
loc,
|
|
poolOp,
|
|
outType,
|
|
channels,
|
|
inputHeight,
|
|
inputWidth,
|
|
outputHeight,
|
|
outputWidth,
|
|
kernelHeight,
|
|
kernelWidth,
|
|
strideHeight,
|
|
strideWidth,
|
|
dilationHeight,
|
|
dilationWidth,
|
|
padTop,
|
|
padLeft,
|
|
countIncludePad);
|
|
if (failed(maybeAverageScaleTensor))
|
|
return failure();
|
|
averageScaleTensor = *maybeAverageScaleTensor;
|
|
}
|
|
constexpr size_t numInputs = 1;
|
|
auto computeOp =
|
|
createSpatCompute<numInputs>(rewriter, loc, outType, {}, ValueRange {x}, [&](Value xArg) -> LogicalResult {
|
|
Value paddedInput =
|
|
createPaddedPoolInput(rewriter, loc, poolOp, xArg, xType, padTop, padLeft, padBottom, padRight);
|
|
Value pooledOutputInit = tensor::EmptyOp::create(rewriter, loc, outType.getShape(), outType.getElementType());
|
|
|
|
Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
|
|
Value c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
|
|
Value cOutputPatchCount = arith::ConstantIndexOp::create(rewriter, loc, outputPatchCount);
|
|
Value cOutputPixelsPerBatch = arith::ConstantIndexOp::create(rewriter, loc, outputHeight * outputWidth);
|
|
Value cOutputWidth = arith::ConstantIndexOp::create(rewriter, loc, outputWidth);
|
|
Value cStrideHeight = arith::ConstantIndexOp::create(rewriter, loc, strideHeight);
|
|
Value cStrideWidth = arith::ConstantIndexOp::create(rewriter, loc, strideWidth);
|
|
|
|
auto outputLoop = scf::ForOp::create(rewriter, loc, c0, cOutputPatchCount, c1, ValueRange {pooledOutputInit});
|
|
rewriter.setInsertionPointToStart(outputLoop.getBody());
|
|
|
|
Value outputPatchIndex = outputLoop.getInductionVar();
|
|
Value pooledOutputAcc = outputLoop.getRegionIterArgs().front();
|
|
|
|
Value batchIndex = arith::DivUIOp::create(rewriter, loc, outputPatchIndex, cOutputPixelsPerBatch);
|
|
Value batchPatchIndex = arith::RemUIOp::create(rewriter, loc, outputPatchIndex, cOutputPixelsPerBatch);
|
|
Value outHeightIndex = arith::DivUIOp::create(rewriter, loc, batchPatchIndex, cOutputWidth);
|
|
Value outWidthIndex = arith::RemUIOp::create(rewriter, loc, batchPatchIndex, cOutputWidth);
|
|
Value windowBaseH = arith::MulIOp::create(rewriter, loc, outHeightIndex, cStrideHeight);
|
|
Value windowBaseW = arith::MulIOp::create(rewriter, loc, outWidthIndex, cStrideWidth);
|
|
|
|
Value updatedOutput = pooledOutputAcc;
|
|
for (int64_t channelTile = 0; channelTile < channelTileCount; ++channelTile) {
|
|
const int64_t tileChannels = std::min<int64_t>(xbarSize, channels - channelTile * xbarSize);
|
|
auto tileType = RankedTensorType::get({1, tileChannels, 1, 1}, outType.getElementType());
|
|
Value reducedWindow =
|
|
createPoolFillTensor(rewriter, loc, tileType, std::is_same_v<PoolOp, ONNXMaxPoolSingleOutOp>);
|
|
|
|
for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) {
|
|
Value paddedInH = windowBaseH;
|
|
if (kernelH * dilationHeight != 0) {
|
|
Value kernelHOffset = arith::ConstantIndexOp::create(rewriter, loc, kernelH * dilationHeight);
|
|
paddedInH = arith::AddIOp::create(rewriter, loc, paddedInH, kernelHOffset);
|
|
}
|
|
|
|
for (int64_t kernelW = 0; kernelW < kernelWidth; ++kernelW) {
|
|
Value paddedInW = windowBaseW;
|
|
if (kernelW * dilationWidth != 0) {
|
|
Value kernelWOffset = arith::ConstantIndexOp::create(rewriter, loc, kernelW * dilationWidth);
|
|
paddedInW = arith::AddIOp::create(rewriter, loc, paddedInW, kernelWOffset);
|
|
}
|
|
|
|
SmallVector<OpFoldResult> offsets = {
|
|
batchIndex, rewriter.getIndexAttr(channelTile * xbarSize), paddedInH, paddedInW};
|
|
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
|
|
rewriter.getIndexAttr(tileChannels),
|
|
rewriter.getIndexAttr(1),
|
|
rewriter.getIndexAttr(1)};
|
|
SmallVector<OpFoldResult> strides = {
|
|
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
|
Value windowValue =
|
|
tensor::ExtractSliceOp::create(rewriter, loc, tileType, paddedInput, offsets, sizes, strides);
|
|
windowValue = materializeContiguousTile(rewriter, loc, windowValue);
|
|
reducedWindow = ReduceOp::create(rewriter, loc, tileType, reducedWindow, windowValue);
|
|
}
|
|
}
|
|
|
|
if constexpr (std::is_same_v<PoolOp, ONNXAveragePoolOp>) {
|
|
SmallVector<OpFoldResult> scaleOffsets = {
|
|
rewriter.getIndexAttr(0), rewriter.getIndexAttr(channelTile * xbarSize), outHeightIndex, outWidthIndex};
|
|
SmallVector<OpFoldResult> scaleSizes = {rewriter.getIndexAttr(1),
|
|
rewriter.getIndexAttr(tileChannels),
|
|
rewriter.getIndexAttr(1),
|
|
rewriter.getIndexAttr(1)};
|
|
SmallVector<OpFoldResult> scaleStrides = {
|
|
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
|
Value scaleSlice = tensor::ExtractSliceOp::create(
|
|
rewriter, loc, tileType, averageScaleTensor, scaleOffsets, scaleSizes, scaleStrides);
|
|
scaleSlice = materializeContiguousTile(rewriter, loc, scaleSlice);
|
|
reducedWindow = spatial::SpatVMulOp::create(rewriter, loc, tileType, reducedWindow, scaleSlice);
|
|
}
|
|
|
|
SmallVector<OpFoldResult> outputOffsets = {
|
|
batchIndex, rewriter.getIndexAttr(channelTile * xbarSize), outHeightIndex, outWidthIndex};
|
|
SmallVector<OpFoldResult> outputSizes = {rewriter.getIndexAttr(1),
|
|
rewriter.getIndexAttr(tileChannels),
|
|
rewriter.getIndexAttr(1),
|
|
rewriter.getIndexAttr(1)};
|
|
SmallVector<OpFoldResult> outputStrides = {
|
|
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
|
updatedOutput = tensor::InsertSliceOp::create(
|
|
rewriter, loc, reducedWindow, updatedOutput, outputOffsets, outputSizes, outputStrides);
|
|
}
|
|
|
|
scf::YieldOp::create(rewriter, loc, updatedOutput);
|
|
|
|
rewriter.setInsertionPointAfter(outputLoop);
|
|
spatial::SpatYieldOp::create(rewriter, loc, outputLoop.getResult(0));
|
|
return success();
|
|
});
|
|
if (failed(computeOp))
|
|
return failure();
|
|
|
|
rewriter.replaceOp(poolOp, computeOp->getResult(0));
|
|
return success();
|
|
}
|
|
};
|
|
|
|
template <>
|
|
struct PoolToSpatialCompute<ONNXMaxPoolSingleOutOp>
|
|
: public PoolToSpatialComputeBase<ONNXMaxPoolSingleOutOp, ONNXMaxPoolSingleOutOpAdaptor, spatial::SpatVMaxOp> {
|
|
using PoolToSpatialComputeBase::PoolToSpatialComputeBase;
|
|
};
|
|
|
|
template <>
|
|
struct PoolToSpatialCompute<ONNXAveragePoolOp>
|
|
: public PoolToSpatialComputeBase<ONNXAveragePoolOp, ONNXAveragePoolOpAdaptor, spatial::SpatVAddOp> {
|
|
using PoolToSpatialComputeBase::PoolToSpatialComputeBase;
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void populatePoolPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
|
patterns.insert<PoolToSpatialCompute<ONNXMaxPoolSingleOutOp>>(ctx);
|
|
patterns.insert<PoolToSpatialCompute<ONNXAveragePoolOp>>(ctx);
|
|
}
|
|
|
|
} // namespace onnx_mlir
|