better reports (dcp merge and memory)
This commit is contained in:
@@ -1,9 +1,12 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/APFloat.h"
|
||||
#include "llvm/ADT/APInt.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <optional>
|
||||
@@ -30,16 +33,6 @@ static int64_t getOptionalI64(std::optional<ArrayAttrT> arrayAttr, size_t index,
|
||||
return arrayAttr ? getI64(*arrayAttr, index) : defaultValue;
|
||||
}
|
||||
|
||||
template <typename PoolOp>
|
||||
static FailureOr<Value> concatAlongAxis(
|
||||
ConversionPatternRewriter& rewriter, Location loc, PoolOp poolOp, int64_t axis, ArrayRef<Value> values) {
|
||||
if (values.empty()) {
|
||||
poolOp.emitOpError("failed to build pooled output because an intermediate concatenation input list was empty");
|
||||
return failure();
|
||||
}
|
||||
return createSpatConcat(rewriter, loc, axis, values);
|
||||
}
|
||||
|
||||
static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Location loc, Value tile) {
|
||||
auto tileType = cast<RankedTensorType>(tile.getType());
|
||||
Value empty = tensor::EmptyOp::create(rewriter, loc, tileType.getShape(), tileType.getElementType());
|
||||
@@ -54,34 +47,126 @@ static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Loca
|
||||
return tensor::InsertSliceOp::create(rewriter, loc, tile, empty, offsets, sizes, strides);
|
||||
}
|
||||
|
||||
template <typename ReduceOp>
|
||||
static FailureOr<Value>
|
||||
reduceWindowValues(ConversionPatternRewriter& rewriter, Location loc, Operation* op, ArrayRef<Value> windowValues) {
|
||||
if (windowValues.empty()) {
|
||||
op->emitOpError("pool window resolved to zero valid elements");
|
||||
return failure();
|
||||
static Value createPoolFillElement(
|
||||
ConversionPatternRewriter& rewriter, Location loc, Type elementType, bool useMinimumValue) {
|
||||
if (!useMinimumValue)
|
||||
return arith::ConstantOp::create(rewriter, loc, elementType, rewriter.getZeroAttr(elementType));
|
||||
|
||||
if (auto floatType = dyn_cast<FloatType>(elementType)) {
|
||||
auto minValue = llvm::APFloat::getInf(floatType.getFloatSemantics(), /*Negative=*/true);
|
||||
return arith::ConstantOp::create(rewriter, loc, elementType, rewriter.getFloatAttr(floatType, minValue));
|
||||
}
|
||||
|
||||
Value reduced = windowValues.front();
|
||||
for (Value value : windowValues.drop_front())
|
||||
reduced = ReduceOp::create(rewriter, loc, reduced.getType(), reduced, value);
|
||||
return reduced;
|
||||
if (auto integerType = dyn_cast<IntegerType>(elementType)) {
|
||||
auto minValue = llvm::APInt::getSignedMinValue(integerType.getWidth());
|
||||
return arith::ConstantOp::create(rewriter, loc, elementType, rewriter.getIntegerAttr(integerType, minValue));
|
||||
}
|
||||
|
||||
llvm_unreachable("unsupported pool element type");
|
||||
}
|
||||
|
||||
static FailureOr<Value> scaleAverageWindow(
|
||||
ConversionPatternRewriter& rewriter, Location loc, Operation* op, Value reducedWindow, int64_t divisor) {
|
||||
if (divisor <= 0) {
|
||||
op->emitOpError("AveragePool divisor must be positive");
|
||||
static Value createPoolFillTensor(
|
||||
ConversionPatternRewriter& rewriter, Location loc, RankedTensorType tensorType, bool useMinimumValue) {
|
||||
auto fillElement = createPoolFillElement(rewriter, loc, tensorType.getElementType(), useMinimumValue);
|
||||
return tensor::SplatOp::create(rewriter, loc, tensorType, fillElement);
|
||||
}
|
||||
|
||||
template <typename PoolOp>
|
||||
static Value createPaddedPoolInput(ConversionPatternRewriter& rewriter,
|
||||
Location loc,
|
||||
PoolOp poolOp,
|
||||
Value input,
|
||||
RankedTensorType inputType,
|
||||
int64_t padTop,
|
||||
int64_t padLeft,
|
||||
int64_t padBottom,
|
||||
int64_t padRight) {
|
||||
if (padTop == 0 && padLeft == 0 && padBottom == 0 && padRight == 0)
|
||||
return input;
|
||||
|
||||
auto paddedType = RankedTensorType::get({inputType.getDimSize(0),
|
||||
inputType.getDimSize(1),
|
||||
inputType.getDimSize(2) + padTop + padBottom,
|
||||
inputType.getDimSize(3) + padLeft + padRight},
|
||||
inputType.getElementType(),
|
||||
inputType.getEncoding());
|
||||
SmallVector<OpFoldResult> lowPads = {rewriter.getIndexAttr(0),
|
||||
rewriter.getIndexAttr(0),
|
||||
rewriter.getIndexAttr(padTop),
|
||||
rewriter.getIndexAttr(padLeft)};
|
||||
SmallVector<OpFoldResult> highPads = {rewriter.getIndexAttr(0),
|
||||
rewriter.getIndexAttr(0),
|
||||
rewriter.getIndexAttr(padBottom),
|
||||
rewriter.getIndexAttr(padRight)};
|
||||
auto padOp = tensor::PadOp::create(rewriter, loc, paddedType, input, lowPads, highPads);
|
||||
auto* padBlock = new Block();
|
||||
for (int index = 0; index < paddedType.getRank(); ++index)
|
||||
padBlock->addArgument(rewriter.getIndexType(), loc);
|
||||
padOp.getRegion().push_back(padBlock);
|
||||
rewriter.setInsertionPointToStart(padBlock);
|
||||
Value padValue = createPoolFillElement(
|
||||
rewriter, loc, inputType.getElementType(), std::is_same_v<PoolOp, ONNXMaxPoolSingleOutOp>);
|
||||
tensor::YieldOp::create(rewriter, loc, padValue);
|
||||
rewriter.setInsertionPointAfter(padOp);
|
||||
return padOp.getResult();
|
||||
}
|
||||
|
||||
static FailureOr<Value> createAverageScaleTensor(ConversionPatternRewriter& rewriter,
|
||||
Location loc,
|
||||
Operation* op,
|
||||
RankedTensorType outType,
|
||||
int64_t channels,
|
||||
int64_t inputHeight,
|
||||
int64_t inputWidth,
|
||||
int64_t outputHeight,
|
||||
int64_t outputWidth,
|
||||
int64_t kernelHeight,
|
||||
int64_t kernelWidth,
|
||||
int64_t strideHeight,
|
||||
int64_t strideWidth,
|
||||
int64_t dilationHeight,
|
||||
int64_t dilationWidth,
|
||||
int64_t padTop,
|
||||
int64_t padLeft,
|
||||
bool countIncludePad) {
|
||||
auto elemType = dyn_cast<FloatType>(outType.getElementType());
|
||||
if (!elemType) {
|
||||
op->emitOpError("AveragePool lowering requires a floating-point element type");
|
||||
return failure();
|
||||
}
|
||||
if (divisor == 1)
|
||||
return reducedWindow;
|
||||
|
||||
auto tileType = cast<RankedTensorType>(reducedWindow.getType());
|
||||
double scale = 1.0 / static_cast<double>(divisor);
|
||||
auto scaleAttr = DenseElementsAttr::get(tileType, rewriter.getFloatAttr(tileType.getElementType(), scale));
|
||||
Value scaleTensor = arith::ConstantOp::create(rewriter, loc, tileType, scaleAttr);
|
||||
return spatial::SpatVMulOp::create(rewriter, loc, tileType, reducedWindow, scaleTensor).getResult();
|
||||
auto scaleType = RankedTensorType::get({1, channels, outputHeight, outputWidth}, elemType, outType.getEncoding());
|
||||
SmallVector<Attribute> scaleValues;
|
||||
scaleValues.reserve(static_cast<size_t>(channels * outputHeight * outputWidth));
|
||||
for (int64_t channel = 0; channel < channels; ++channel) {
|
||||
(void) channel;
|
||||
for (int64_t outH = 0; outH < outputHeight; ++outH) {
|
||||
for (int64_t outW = 0; outW < outputWidth; ++outW) {
|
||||
int64_t validCount = 0;
|
||||
for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) {
|
||||
const int64_t inH = outH * strideHeight + kernelH * dilationHeight - padTop;
|
||||
if (inH < 0 || inH >= inputHeight)
|
||||
continue;
|
||||
for (int64_t kernelW = 0; kernelW < kernelWidth; ++kernelW) {
|
||||
const int64_t inW = outW * strideWidth + kernelW * dilationWidth - padLeft;
|
||||
if (inW < 0 || inW >= inputWidth)
|
||||
continue;
|
||||
++validCount;
|
||||
}
|
||||
}
|
||||
|
||||
const int64_t divisor = countIncludePad ? kernelHeight * kernelWidth : validCount;
|
||||
if (divisor <= 0) {
|
||||
op->emitOpError("AveragePool divisor must be positive");
|
||||
return failure();
|
||||
}
|
||||
scaleValues.push_back(rewriter.getFloatAttr(elemType, 1.0 / static_cast<double>(divisor)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto scaleAttr = DenseElementsAttr::get(scaleType, scaleValues);
|
||||
return arith::ConstantOp::create(rewriter, loc, scaleType, scaleAttr).getResult();
|
||||
}
|
||||
|
||||
template <typename PoolOp>
|
||||
@@ -159,106 +244,144 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
|
||||
}
|
||||
}
|
||||
|
||||
(void) padBottom;
|
||||
(void) padRight;
|
||||
|
||||
const int64_t xbarSize = static_cast<int64_t>(crossbarSize.getValue());
|
||||
const int64_t channelTileCount = (channels + xbarSize - 1) / xbarSize;
|
||||
const int64_t outputPatchCount = batchSize * outputHeight * outputWidth;
|
||||
const bool countIncludePad = [&]() {
|
||||
if constexpr (std::is_same_v<PoolOp, ONNXAveragePoolOp>)
|
||||
return poolOp.getCountIncludePad() == 1;
|
||||
return true;
|
||||
}();
|
||||
Value averageScaleTensor;
|
||||
if constexpr (std::is_same_v<PoolOp, ONNXAveragePoolOp>) {
|
||||
auto maybeAverageScaleTensor = createAverageScaleTensor(rewriter,
|
||||
loc,
|
||||
poolOp,
|
||||
outType,
|
||||
channels,
|
||||
inputHeight,
|
||||
inputWidth,
|
||||
outputHeight,
|
||||
outputWidth,
|
||||
kernelHeight,
|
||||
kernelWidth,
|
||||
strideHeight,
|
||||
strideWidth,
|
||||
dilationHeight,
|
||||
dilationWidth,
|
||||
padTop,
|
||||
padLeft,
|
||||
countIncludePad);
|
||||
if (failed(maybeAverageScaleTensor))
|
||||
return failure();
|
||||
averageScaleTensor = *maybeAverageScaleTensor;
|
||||
}
|
||||
constexpr size_t numInputs = 1;
|
||||
auto computeOp =
|
||||
createSpatCompute<numInputs>(rewriter, loc, outType, {}, ValueRange {x}, [&](Value xArg) -> LogicalResult {
|
||||
SmallVector<Value> batchResults;
|
||||
batchResults.reserve(batchSize);
|
||||
Value paddedInput = createPaddedPoolInput(rewriter, loc, poolOp, xArg, xType, padTop, padLeft, padBottom, padRight);
|
||||
Value pooledOutputInit = tensor::EmptyOp::create(rewriter, loc, outType.getShape(), outType.getElementType());
|
||||
|
||||
for (int64_t batch = 0; batch < batchSize; ++batch) {
|
||||
SmallVector<Value> rows;
|
||||
rows.reserve(outputHeight);
|
||||
Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
|
||||
Value c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
|
||||
Value cOutputPatchCount = arith::ConstantIndexOp::create(rewriter, loc, outputPatchCount);
|
||||
Value cOutputPixelsPerBatch = arith::ConstantIndexOp::create(rewriter, loc, outputHeight * outputWidth);
|
||||
Value cOutputWidth = arith::ConstantIndexOp::create(rewriter, loc, outputWidth);
|
||||
Value cStrideHeight = arith::ConstantIndexOp::create(rewriter, loc, strideHeight);
|
||||
Value cStrideWidth = arith::ConstantIndexOp::create(rewriter, loc, strideWidth);
|
||||
|
||||
for (int64_t outH = 0; outH < outputHeight; ++outH) {
|
||||
SmallVector<Value> rowPixels;
|
||||
rowPixels.reserve(outputWidth);
|
||||
auto outputLoop = scf::ForOp::create(rewriter, loc, c0, cOutputPatchCount, c1, ValueRange {pooledOutputInit});
|
||||
rewriter.setInsertionPointToStart(outputLoop.getBody());
|
||||
|
||||
for (int64_t outW = 0; outW < outputWidth; ++outW) {
|
||||
SmallVector<Value> outputChannelTiles;
|
||||
outputChannelTiles.reserve(channelTileCount);
|
||||
Value outputPatchIndex = outputLoop.getInductionVar();
|
||||
Value pooledOutputAcc = outputLoop.getRegionIterArgs().front();
|
||||
|
||||
for (int64_t channelTile = 0; channelTile < channelTileCount; ++channelTile) {
|
||||
const int64_t tileChannels = std::min<int64_t>(xbarSize, channels - channelTile * xbarSize);
|
||||
auto tileType = RankedTensorType::get({1, tileChannels, 1, 1}, outType.getElementType());
|
||||
Value batchIndex = arith::DivUIOp::create(rewriter, loc, outputPatchIndex, cOutputPixelsPerBatch);
|
||||
Value batchPatchIndex = arith::RemUIOp::create(rewriter, loc, outputPatchIndex, cOutputPixelsPerBatch);
|
||||
Value outHeightIndex = arith::DivUIOp::create(rewriter, loc, batchPatchIndex, cOutputWidth);
|
||||
Value outWidthIndex = arith::RemUIOp::create(rewriter, loc, batchPatchIndex, cOutputWidth);
|
||||
Value windowBaseH = arith::MulIOp::create(rewriter, loc, outHeightIndex, cStrideHeight);
|
||||
Value windowBaseW = arith::MulIOp::create(rewriter, loc, outWidthIndex, cStrideWidth);
|
||||
|
||||
SmallVector<Value> windowValues;
|
||||
windowValues.reserve(kernelHeight * kernelWidth);
|
||||
for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) {
|
||||
const int64_t inH = outH * strideHeight + kernelH * dilationHeight - padTop;
|
||||
if (inH < 0 || inH >= inputHeight)
|
||||
continue;
|
||||
Value updatedOutput = pooledOutputAcc;
|
||||
for (int64_t channelTile = 0; channelTile < channelTileCount; ++channelTile) {
|
||||
const int64_t tileChannels = std::min<int64_t>(xbarSize, channels - channelTile * xbarSize);
|
||||
auto tileType = RankedTensorType::get({1, tileChannels, 1, 1}, outType.getElementType());
|
||||
Value reducedWindow = createPoolFillTensor(
|
||||
rewriter, loc, tileType, std::is_same_v<PoolOp, ONNXMaxPoolSingleOutOp>);
|
||||
|
||||
for (int64_t kernelW = 0; kernelW < kernelWidth; ++kernelW) {
|
||||
const int64_t inW = outW * strideWidth + kernelW * dilationWidth - padLeft;
|
||||
if (inW < 0 || inW >= inputWidth)
|
||||
continue;
|
||||
|
||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(batch),
|
||||
rewriter.getIndexAttr(channelTile * xbarSize),
|
||||
rewriter.getIndexAttr(inH),
|
||||
rewriter.getIndexAttr(inW)};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(tileChannels),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1)};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1)};
|
||||
Value windowValue =
|
||||
tensor::ExtractSliceOp::create(rewriter, loc, tileType, xArg, offsets, sizes, strides);
|
||||
windowValue = materializeContiguousTile(rewriter, loc, windowValue);
|
||||
windowValues.push_back(windowValue);
|
||||
}
|
||||
}
|
||||
|
||||
if (windowValues.empty())
|
||||
return rewriter.notifyMatchFailure(poolOp, "pool window resolved to zero valid elements.");
|
||||
|
||||
auto reducedWindow = reduceWindowValues<ReduceOp>(rewriter, loc, poolOp, windowValues);
|
||||
if (failed(reducedWindow))
|
||||
return failure();
|
||||
Value reducedWindowValue = *reducedWindow;
|
||||
if constexpr (std::is_same_v<PoolOp, ONNXAveragePoolOp>) {
|
||||
const bool countIncludePad = poolOp.getCountIncludePad() == 1;
|
||||
const int64_t divisor =
|
||||
countIncludePad ? kernelHeight * kernelWidth : static_cast<int64_t>(windowValues.size());
|
||||
auto scaledWindow = scaleAverageWindow(rewriter, loc, poolOp, reducedWindowValue, divisor);
|
||||
if (failed(scaledWindow))
|
||||
return failure();
|
||||
reducedWindowValue = *scaledWindow;
|
||||
}
|
||||
|
||||
outputChannelTiles.push_back(reducedWindowValue);
|
||||
}
|
||||
|
||||
auto rowPixel = concatAlongAxis(rewriter, loc, poolOp, /*axis=*/1, outputChannelTiles);
|
||||
if (failed(rowPixel))
|
||||
return failure();
|
||||
rowPixels.push_back(*rowPixel);
|
||||
for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) {
|
||||
Value paddedInH = windowBaseH;
|
||||
if (kernelH * dilationHeight != 0) {
|
||||
Value kernelHOffset = arith::ConstantIndexOp::create(rewriter, loc, kernelH * dilationHeight);
|
||||
paddedInH = arith::AddIOp::create(rewriter, loc, paddedInH, kernelHOffset);
|
||||
}
|
||||
|
||||
auto row = concatAlongAxis(rewriter, loc, poolOp, /*axis=*/3, rowPixels);
|
||||
if (failed(row))
|
||||
return failure();
|
||||
rows.push_back(*row);
|
||||
for (int64_t kernelW = 0; kernelW < kernelWidth; ++kernelW) {
|
||||
Value paddedInW = windowBaseW;
|
||||
if (kernelW * dilationWidth != 0) {
|
||||
Value kernelWOffset = arith::ConstantIndexOp::create(rewriter, loc, kernelW * dilationWidth);
|
||||
paddedInW = arith::AddIOp::create(rewriter, loc, paddedInW, kernelWOffset);
|
||||
}
|
||||
|
||||
SmallVector<OpFoldResult> offsets = {batchIndex,
|
||||
rewriter.getIndexAttr(channelTile * xbarSize),
|
||||
paddedInH,
|
||||
paddedInW};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(tileChannels),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1)};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1)};
|
||||
Value windowValue =
|
||||
tensor::ExtractSliceOp::create(rewriter, loc, tileType, paddedInput, offsets, sizes, strides);
|
||||
windowValue = materializeContiguousTile(rewriter, loc, windowValue);
|
||||
reducedWindow = ReduceOp::create(rewriter, loc, tileType, reducedWindow, windowValue);
|
||||
}
|
||||
}
|
||||
|
||||
auto batchResult = concatAlongAxis(rewriter, loc, poolOp, /*axis=*/2, rows);
|
||||
if (failed(batchResult))
|
||||
return failure();
|
||||
batchResults.push_back(*batchResult);
|
||||
if constexpr (std::is_same_v<PoolOp, ONNXAveragePoolOp>) {
|
||||
SmallVector<OpFoldResult> scaleOffsets = {rewriter.getIndexAttr(0),
|
||||
rewriter.getIndexAttr(channelTile * xbarSize),
|
||||
outHeightIndex,
|
||||
outWidthIndex};
|
||||
SmallVector<OpFoldResult> scaleSizes = {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(tileChannels),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1)};
|
||||
SmallVector<OpFoldResult> scaleStrides = {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1)};
|
||||
Value scaleSlice = tensor::ExtractSliceOp::create(
|
||||
rewriter, loc, tileType, averageScaleTensor, scaleOffsets, scaleSizes, scaleStrides);
|
||||
scaleSlice = materializeContiguousTile(rewriter, loc, scaleSlice);
|
||||
reducedWindow = spatial::SpatVMulOp::create(rewriter, loc, tileType, reducedWindow, scaleSlice);
|
||||
}
|
||||
|
||||
SmallVector<OpFoldResult> outputOffsets = {batchIndex,
|
||||
rewriter.getIndexAttr(channelTile * xbarSize),
|
||||
outHeightIndex,
|
||||
outWidthIndex};
|
||||
SmallVector<OpFoldResult> outputSizes = {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(tileChannels),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1)};
|
||||
SmallVector<OpFoldResult> outputStrides = {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1)};
|
||||
updatedOutput = tensor::InsertSliceOp::create(
|
||||
rewriter, loc, reducedWindow, updatedOutput, outputOffsets, outputSizes, outputStrides);
|
||||
}
|
||||
|
||||
auto pooledOutput = concatAlongAxis(rewriter, loc, poolOp, /*axis=*/0, batchResults);
|
||||
if (failed(pooledOutput))
|
||||
return failure();
|
||||
spatial::SpatYieldOp::create(rewriter, loc, *pooledOutput);
|
||||
scf::YieldOp::create(rewriter, loc, updatedOutput);
|
||||
|
||||
rewriter.setInsertionPointAfter(outputLoop);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, outputLoop.getResult(0));
|
||||
return success();
|
||||
});
|
||||
if (failed(computeOp))
|
||||
|
||||
Reference in New Issue
Block a user