Some checks failed
Validate Operations / validate-operations (push) Failing after 2h50m56s
add relu validation add spatial compute helper minor refactors
266 lines
11 KiB
C++
266 lines
11 KiB
C++
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/IR/BuiltinAttributes.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
|
|
#include "llvm/ADT/SmallVector.h"
|
|
|
|
#include <algorithm>
|
|
#include <cassert>
|
|
#include <optional>
|
|
#include <type_traits>
|
|
|
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace onnx_mlir {
|
|
namespace {
|
|
|
|
template <typename ArrayAttrT>
|
|
static int64_t getI64(ArrayAttrT arrayAttr, size_t index) {
|
|
return cast<IntegerAttr>(arrayAttr[index]).getInt();
|
|
}
|
|
|
|
template <typename ArrayAttrT>
|
|
static int64_t getOptionalI64(std::optional<ArrayAttrT> arrayAttr, size_t index, int64_t defaultValue) {
|
|
return arrayAttr ? getI64(*arrayAttr, index) : defaultValue;
|
|
}
|
|
|
|
static Value concatAlongAxis(ConversionPatternRewriter& rewriter, Location loc, int64_t axis, ArrayRef<Value> values) {
|
|
assert(!values.empty() && "Expected at least one value to concatenate.");
|
|
if (values.size() == 1)
|
|
return values.front();
|
|
return tensor::ConcatOp::create(rewriter, loc, axis, values);
|
|
}
|
|
|
|
static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Location loc, Value tile) {
|
|
auto tileType = cast<RankedTensorType>(tile.getType());
|
|
Value empty = tensor::EmptyOp::create(rewriter, loc, tileType.getShape(), tileType.getElementType());
|
|
|
|
SmallVector<OpFoldResult> offsets(tileType.getRank(), rewriter.getIndexAttr(0));
|
|
SmallVector<OpFoldResult> sizes;
|
|
sizes.reserve(tileType.getRank());
|
|
for (int64_t dimSize : tileType.getShape())
|
|
sizes.push_back(rewriter.getIndexAttr(dimSize));
|
|
SmallVector<OpFoldResult> strides(tileType.getRank(), rewriter.getIndexAttr(1));
|
|
|
|
return tensor::InsertSliceOp::create(rewriter, loc, tile, empty, offsets, sizes, strides);
|
|
}
|
|
|
|
template <typename ReduceOp>
|
|
static Value reduceWindowValues(ConversionPatternRewriter& rewriter, Location loc, ArrayRef<Value> windowValues) {
|
|
assert(!windowValues.empty() && "Expected at least one pool window value.");
|
|
|
|
Value reduced = windowValues.front();
|
|
for (Value value : windowValues.drop_front())
|
|
reduced = ReduceOp::create(rewriter, loc, reduced.getType(), reduced, value);
|
|
return reduced;
|
|
}
|
|
|
|
static Value
|
|
scaleAverageWindow(ConversionPatternRewriter& rewriter, Location loc, Value reducedWindow, int64_t divisor) {
|
|
assert(divisor > 0 && "AveragePool divisor must be positive.");
|
|
if (divisor == 1)
|
|
return reducedWindow;
|
|
|
|
auto tileType = cast<RankedTensorType>(reducedWindow.getType());
|
|
double scale = 1.0 / static_cast<double>(divisor);
|
|
auto scaleAttr = DenseElementsAttr::get(tileType, rewriter.getFloatAttr(tileType.getElementType(), scale));
|
|
Value scaleTensor = arith::ConstantOp::create(rewriter, loc, tileType, scaleAttr);
|
|
return spatial::SpatVMulOp::create(rewriter, loc, tileType, reducedWindow, scaleTensor);
|
|
}
|
|
|
|
template <typename PoolOp>
|
|
struct PoolToSpatialCompute;
|
|
|
|
template <typename PoolOp, typename PoolOpAdaptor, typename ReduceOp>
|
|
struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
|
|
using OpConversionPattern<PoolOp>::OpConversionPattern;
|
|
|
|
LogicalResult matchAndRewrite(PoolOp poolOp, PoolOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final {
|
|
Location loc = poolOp.getLoc();
|
|
Value x = adaptor.getX();
|
|
|
|
auto xType = dyn_cast<RankedTensorType>(x.getType());
|
|
auto outType = dyn_cast<RankedTensorType>(poolOp.getResult().getType());
|
|
if (!xType || !outType || !xType.hasStaticShape() || !outType.hasStaticShape())
|
|
return rewriter.notifyMatchFailure(poolOp, "pool lowering requires static ranked tensor types.");
|
|
if (xType.getRank() != 4 || outType.getRank() != 4)
|
|
return rewriter.notifyMatchFailure(poolOp, "only 2D NCHW pool is supported.");
|
|
|
|
ArrayAttr kernelAttr = poolOp.getKernelShape();
|
|
if (!kernelAttr || kernelAttr.size() != 2)
|
|
return rewriter.notifyMatchFailure(poolOp, "pool lowering expects a 2D kernel.");
|
|
|
|
const int64_t batchSize = xType.getDimSize(0);
|
|
const int64_t channels = xType.getDimSize(1);
|
|
const int64_t inputHeight = xType.getDimSize(2);
|
|
const int64_t inputWidth = xType.getDimSize(3);
|
|
const int64_t outputHeight = outType.getDimSize(2);
|
|
const int64_t outputWidth = outType.getDimSize(3);
|
|
const int64_t kernelHeight = getI64(kernelAttr, 0);
|
|
const int64_t kernelWidth = getI64(kernelAttr, 1);
|
|
const int64_t strideHeight = getOptionalI64(poolOp.getStrides(), 0, 1);
|
|
const int64_t strideWidth = getOptionalI64(poolOp.getStrides(), 1, 1);
|
|
const int64_t dilationHeight = getOptionalI64(poolOp.getDilations(), 0, 1);
|
|
const int64_t dilationWidth = getOptionalI64(poolOp.getDilations(), 1, 1);
|
|
|
|
int64_t padTop = 0;
|
|
int64_t padLeft = 0;
|
|
int64_t padBottom = 0;
|
|
int64_t padRight = 0;
|
|
|
|
if (auto padsAttr = poolOp.getPads()) {
|
|
if (padsAttr->size() != 4)
|
|
return rewriter.notifyMatchFailure(poolOp, "pads must have four elements.");
|
|
padTop = getI64(*padsAttr, 0);
|
|
padLeft = getI64(*padsAttr, 1);
|
|
padBottom = getI64(*padsAttr, 2);
|
|
padRight = getI64(*padsAttr, 3);
|
|
}
|
|
else {
|
|
StringRef autoPad = poolOp.getAutoPad();
|
|
if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
|
|
const int64_t effectiveKernelH = (kernelHeight - 1) * dilationHeight + 1;
|
|
const int64_t effectiveKernelW = (kernelWidth - 1) * dilationWidth + 1;
|
|
const int64_t totalPadH =
|
|
std::max<int64_t>(0, (outputHeight - 1) * strideHeight + effectiveKernelH - inputHeight);
|
|
const int64_t totalPadW = std::max<int64_t>(0, (outputWidth - 1) * strideWidth + effectiveKernelW - inputWidth);
|
|
|
|
if (autoPad == "SAME_UPPER") {
|
|
padTop = totalPadH / 2;
|
|
padBottom = totalPadH - padTop;
|
|
padLeft = totalPadW / 2;
|
|
padRight = totalPadW - padLeft;
|
|
}
|
|
else {
|
|
padBottom = totalPadH / 2;
|
|
padTop = totalPadH - padBottom;
|
|
padRight = totalPadW / 2;
|
|
padLeft = totalPadW - padRight;
|
|
}
|
|
}
|
|
else if (autoPad != "NOTSET" && autoPad != "VALID") {
|
|
return rewriter.notifyMatchFailure(poolOp, "unsupported auto_pad value.");
|
|
}
|
|
}
|
|
|
|
(void) padBottom;
|
|
(void) padRight;
|
|
|
|
const int64_t xbarSize = static_cast<int64_t>(crossbarSize.getValue());
|
|
const int64_t channelTileCount = (channels + xbarSize - 1) / xbarSize;
|
|
auto computeOp = spatial::SpatWeightedCompute::create(rewriter, loc, outType, SmallVector<Value>(), ValueRange {x});
|
|
|
|
auto* computeBlock = new Block();
|
|
computeBlock->addArgument(xType, loc);
|
|
computeOp.getBody().push_back(computeBlock);
|
|
rewriter.setInsertionPointToStart(computeBlock);
|
|
|
|
Value input = computeBlock->getArgument(0);
|
|
SmallVector<Value> batchResults;
|
|
batchResults.reserve(batchSize);
|
|
|
|
for (int64_t batch = 0; batch < batchSize; ++batch) {
|
|
SmallVector<Value> rows;
|
|
rows.reserve(outputHeight);
|
|
|
|
for (int64_t outH = 0; outH < outputHeight; ++outH) {
|
|
SmallVector<Value> rowPixels;
|
|
rowPixels.reserve(outputWidth);
|
|
|
|
for (int64_t outW = 0; outW < outputWidth; ++outW) {
|
|
SmallVector<Value> outputChannelTiles;
|
|
outputChannelTiles.reserve(channelTileCount);
|
|
|
|
for (int64_t channelTile = 0; channelTile < channelTileCount; ++channelTile) {
|
|
const int64_t tileChannels = std::min<int64_t>(xbarSize, channels - channelTile * xbarSize);
|
|
auto tileType = RankedTensorType::get({1, tileChannels, 1, 1}, outType.getElementType());
|
|
|
|
SmallVector<Value> windowValues;
|
|
windowValues.reserve(kernelHeight * kernelWidth);
|
|
for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) {
|
|
const int64_t inH = outH * strideHeight + kernelH * dilationHeight - padTop;
|
|
if (inH < 0 || inH >= inputHeight)
|
|
continue;
|
|
|
|
for (int64_t kernelW = 0; kernelW < kernelWidth; ++kernelW) {
|
|
const int64_t inW = outW * strideWidth + kernelW * dilationWidth - padLeft;
|
|
if (inW < 0 || inW >= inputWidth)
|
|
continue;
|
|
|
|
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(batch),
|
|
rewriter.getIndexAttr(channelTile * xbarSize),
|
|
rewriter.getIndexAttr(inH),
|
|
rewriter.getIndexAttr(inW)};
|
|
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
|
|
rewriter.getIndexAttr(tileChannels),
|
|
rewriter.getIndexAttr(1),
|
|
rewriter.getIndexAttr(1)};
|
|
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1),
|
|
rewriter.getIndexAttr(1),
|
|
rewriter.getIndexAttr(1),
|
|
rewriter.getIndexAttr(1)};
|
|
Value windowValue =
|
|
tensor::ExtractSliceOp::create(rewriter, loc, tileType, input, offsets, sizes, strides);
|
|
windowValue = materializeContiguousTile(rewriter, loc, windowValue);
|
|
windowValues.push_back(windowValue);
|
|
}
|
|
}
|
|
|
|
if (windowValues.empty())
|
|
return rewriter.notifyMatchFailure(poolOp, "pool window resolved to zero valid elements.");
|
|
|
|
Value reducedWindow = reduceWindowValues<ReduceOp>(rewriter, loc, windowValues);
|
|
if constexpr (std::is_same_v<PoolOp, ONNXAveragePoolOp>) {
|
|
const bool countIncludePad = poolOp.getCountIncludePad() == 1;
|
|
const int64_t divisor =
|
|
countIncludePad ? kernelHeight * kernelWidth : static_cast<int64_t>(windowValues.size());
|
|
reducedWindow = scaleAverageWindow(rewriter, loc, reducedWindow, divisor);
|
|
}
|
|
|
|
outputChannelTiles.push_back(reducedWindow);
|
|
}
|
|
|
|
rowPixels.push_back(concatAlongAxis(rewriter, loc, /*axis=*/1, outputChannelTiles));
|
|
}
|
|
|
|
rows.push_back(concatAlongAxis(rewriter, loc, /*axis=*/3, rowPixels));
|
|
}
|
|
|
|
batchResults.push_back(concatAlongAxis(rewriter, loc, /*axis=*/2, rows));
|
|
}
|
|
|
|
Value pooledOutput = concatAlongAxis(rewriter, loc, /*axis=*/0, batchResults);
|
|
spatial::SpatYieldOp::create(rewriter, loc, pooledOutput);
|
|
|
|
rewriter.replaceOp(poolOp, computeOp.getResult(0));
|
|
return success();
|
|
}
|
|
};
|
|
|
|
template <>
|
|
struct PoolToSpatialCompute<ONNXMaxPoolSingleOutOp>
|
|
: public PoolToSpatialComputeBase<ONNXMaxPoolSingleOutOp, ONNXMaxPoolSingleOutOpAdaptor, spatial::SpatVMaxOp> {
|
|
using PoolToSpatialComputeBase::PoolToSpatialComputeBase;
|
|
};
|
|
|
|
template <>
|
|
struct PoolToSpatialCompute<ONNXAveragePoolOp>
|
|
: public PoolToSpatialComputeBase<ONNXAveragePoolOp, ONNXAveragePoolOpAdaptor, spatial::SpatVAddOp> {
|
|
using PoolToSpatialComputeBase::PoolToSpatialComputeBase;
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void populatePoolPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
|
patterns.insert<PoolToSpatialCompute<ONNXMaxPoolSingleOutOp>>(ctx);
|
|
patterns.insert<PoolToSpatialCompute<ONNXAveragePoolOp>>(ctx);
|
|
}
|
|
|
|
} // namespace onnx_mlir
|