reimplement pool lowering
add pool validation align PIM ops/codegen/parser with the ISA move constant materialization to MLIR rename the PIM verification/materialization passes better folded-constant handling
This commit is contained in:
@@ -6,7 +6,7 @@ add_pim_library(OMONNXToSpatial
|
||||
Patterns/Math/Gemm.cpp
|
||||
Patterns/Math/Conv.cpp
|
||||
Patterns/Math/MatMul.cpp
|
||||
Patterns/NN/Pooling.cpp
|
||||
Patterns/NN/Pool.cpp
|
||||
Patterns/NN/ReduceMean.cpp
|
||||
Patterns/Tensor/Concat.cpp
|
||||
Patterns/Tensor/Reshape.cpp
|
||||
|
||||
@@ -93,7 +93,7 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
patterns.add<removeLRNPattern>(ctx);
|
||||
|
||||
populateConvOpPatterns(patterns, ctx);
|
||||
populatePoolingTilingPattern(patterns, ctx);
|
||||
populatePoolTilingPattern(patterns, ctx);
|
||||
populateOnnxGemmOpPatterns(patterns, ctx);
|
||||
populateReshapeConversionPattern(patterns, ctx);
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ void populateMatMulRewritePatterns(mlir::RewritePatternSet& patterns, mlir::MLIR
|
||||
|
||||
void populateOnnxGemmOpPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populatePoolingTilingPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
void populatePoolTilingPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateONNXConcatToTensorConcatPattern(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
|
||||
265
src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Pool.cpp
Normal file
265
src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Pool.cpp
Normal file
@@ -0,0 +1,265 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <optional>
|
||||
#include <type_traits>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
template <typename ArrayAttrT>
|
||||
static int64_t getI64(ArrayAttrT arrayAttr, size_t index) {
|
||||
return cast<IntegerAttr>(arrayAttr[index]).getInt();
|
||||
}
|
||||
|
||||
template <typename ArrayAttrT>
|
||||
static int64_t getOptionalI64(std::optional<ArrayAttrT> arrayAttr, size_t index, int64_t defaultValue) {
|
||||
return arrayAttr ? getI64(*arrayAttr, index) : defaultValue;
|
||||
}
|
||||
|
||||
static Value concatAlongAxis(ConversionPatternRewriter& rewriter, Location loc, int64_t axis, ArrayRef<Value> values) {
|
||||
assert(!values.empty() && "Expected at least one value to concatenate.");
|
||||
if (values.size() == 1)
|
||||
return values.front();
|
||||
return tensor::ConcatOp::create(rewriter, loc, axis, values);
|
||||
}
|
||||
|
||||
static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Location loc, Value tile) {
|
||||
auto tileType = cast<RankedTensorType>(tile.getType());
|
||||
Value empty = tensor::EmptyOp::create(rewriter, loc, tileType.getShape(), tileType.getElementType());
|
||||
|
||||
SmallVector<OpFoldResult> offsets(tileType.getRank(), rewriter.getIndexAttr(0));
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
sizes.reserve(tileType.getRank());
|
||||
for (int64_t dimSize : tileType.getShape())
|
||||
sizes.push_back(rewriter.getIndexAttr(dimSize));
|
||||
SmallVector<OpFoldResult> strides(tileType.getRank(), rewriter.getIndexAttr(1));
|
||||
|
||||
return tensor::InsertSliceOp::create(rewriter, loc, tile, empty, offsets, sizes, strides);
|
||||
}
|
||||
|
||||
template <typename ReduceOp>
|
||||
static Value reduceWindowValues(ConversionPatternRewriter& rewriter, Location loc, ArrayRef<Value> windowValues) {
|
||||
assert(!windowValues.empty() && "Expected at least one pool window value.");
|
||||
|
||||
Value reduced = windowValues.front();
|
||||
for (Value value : windowValues.drop_front())
|
||||
reduced = ReduceOp::create(rewriter, loc, reduced.getType(), reduced, value);
|
||||
return reduced;
|
||||
}
|
||||
|
||||
static Value
|
||||
scaleAverageWindow(ConversionPatternRewriter& rewriter, Location loc, Value reducedWindow, int64_t divisor) {
|
||||
assert(divisor > 0 && "AveragePool divisor must be positive.");
|
||||
if (divisor == 1)
|
||||
return reducedWindow;
|
||||
|
||||
auto tileType = cast<RankedTensorType>(reducedWindow.getType());
|
||||
double scale = 1.0 / static_cast<double>(divisor);
|
||||
auto scaleAttr = DenseElementsAttr::get(tileType, rewriter.getFloatAttr(tileType.getElementType(), scale));
|
||||
Value scaleTensor = arith::ConstantOp::create(rewriter, loc, tileType, scaleAttr);
|
||||
return spatial::SpatVMulOp::create(rewriter, loc, tileType, reducedWindow, scaleTensor);
|
||||
}
|
||||
|
||||
template <typename PoolOp>
|
||||
struct PoolToSpatialCompute;
|
||||
|
||||
template <typename PoolOp, typename PoolOpAdaptor, typename ReduceOp>
|
||||
struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
|
||||
using OpConversionPattern<PoolOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(PoolOp poolOp, PoolOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final {
|
||||
Location loc = poolOp.getLoc();
|
||||
Value x = adaptor.getX();
|
||||
|
||||
auto xType = dyn_cast<RankedTensorType>(x.getType());
|
||||
auto outType = dyn_cast<RankedTensorType>(poolOp.getResult().getType());
|
||||
if (!xType || !outType || !xType.hasStaticShape() || !outType.hasStaticShape())
|
||||
return rewriter.notifyMatchFailure(poolOp, "pool lowering requires static ranked tensor types.");
|
||||
if (xType.getRank() != 4 || outType.getRank() != 4)
|
||||
return rewriter.notifyMatchFailure(poolOp, "only 2D NCHW pool is supported.");
|
||||
|
||||
ArrayAttr kernelAttr = poolOp.getKernelShape();
|
||||
if (!kernelAttr || kernelAttr.size() != 2)
|
||||
return rewriter.notifyMatchFailure(poolOp, "pool lowering expects a 2D kernel.");
|
||||
|
||||
const int64_t batchSize = xType.getDimSize(0);
|
||||
const int64_t channels = xType.getDimSize(1);
|
||||
const int64_t inputHeight = xType.getDimSize(2);
|
||||
const int64_t inputWidth = xType.getDimSize(3);
|
||||
const int64_t outputHeight = outType.getDimSize(2);
|
||||
const int64_t outputWidth = outType.getDimSize(3);
|
||||
const int64_t kernelHeight = getI64(kernelAttr, 0);
|
||||
const int64_t kernelWidth = getI64(kernelAttr, 1);
|
||||
const int64_t strideHeight = getOptionalI64(poolOp.getStrides(), 0, 1);
|
||||
const int64_t strideWidth = getOptionalI64(poolOp.getStrides(), 1, 1);
|
||||
const int64_t dilationHeight = getOptionalI64(poolOp.getDilations(), 0, 1);
|
||||
const int64_t dilationWidth = getOptionalI64(poolOp.getDilations(), 1, 1);
|
||||
|
||||
int64_t padTop = 0;
|
||||
int64_t padLeft = 0;
|
||||
int64_t padBottom = 0;
|
||||
int64_t padRight = 0;
|
||||
|
||||
if (auto padsAttr = poolOp.getPads()) {
|
||||
if (padsAttr->size() != 4)
|
||||
return rewriter.notifyMatchFailure(poolOp, "pads must have four elements.");
|
||||
padTop = getI64(*padsAttr, 0);
|
||||
padLeft = getI64(*padsAttr, 1);
|
||||
padBottom = getI64(*padsAttr, 2);
|
||||
padRight = getI64(*padsAttr, 3);
|
||||
}
|
||||
else {
|
||||
StringRef autoPad = poolOp.getAutoPad();
|
||||
if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
|
||||
const int64_t effectiveKernelH = (kernelHeight - 1) * dilationHeight + 1;
|
||||
const int64_t effectiveKernelW = (kernelWidth - 1) * dilationWidth + 1;
|
||||
const int64_t totalPadH =
|
||||
std::max<int64_t>(0, (outputHeight - 1) * strideHeight + effectiveKernelH - inputHeight);
|
||||
const int64_t totalPadW = std::max<int64_t>(0, (outputWidth - 1) * strideWidth + effectiveKernelW - inputWidth);
|
||||
|
||||
if (autoPad == "SAME_UPPER") {
|
||||
padTop = totalPadH / 2;
|
||||
padBottom = totalPadH - padTop;
|
||||
padLeft = totalPadW / 2;
|
||||
padRight = totalPadW - padLeft;
|
||||
}
|
||||
else {
|
||||
padBottom = totalPadH / 2;
|
||||
padTop = totalPadH - padBottom;
|
||||
padRight = totalPadW / 2;
|
||||
padLeft = totalPadW - padRight;
|
||||
}
|
||||
}
|
||||
else if (autoPad != "NOTSET" && autoPad != "VALID") {
|
||||
return rewriter.notifyMatchFailure(poolOp, "unsupported auto_pad value.");
|
||||
}
|
||||
}
|
||||
|
||||
(void) padBottom;
|
||||
(void) padRight;
|
||||
|
||||
const int64_t xbarSize = static_cast<int64_t>(crossbarSize.getValue());
|
||||
const int64_t channelTileCount = (channels + xbarSize - 1) / xbarSize;
|
||||
auto computeOp = spatial::SpatWeightedCompute::create(rewriter, loc, outType, SmallVector<Value>(), ValueRange {x});
|
||||
|
||||
auto* computeBlock = new Block();
|
||||
computeBlock->addArgument(xType, loc);
|
||||
computeOp.getBody().push_back(computeBlock);
|
||||
rewriter.setInsertionPointToStart(computeBlock);
|
||||
|
||||
Value input = computeBlock->getArgument(0);
|
||||
SmallVector<Value> batchResults;
|
||||
batchResults.reserve(batchSize);
|
||||
|
||||
for (int64_t batch = 0; batch < batchSize; ++batch) {
|
||||
SmallVector<Value> rows;
|
||||
rows.reserve(outputHeight);
|
||||
|
||||
for (int64_t outH = 0; outH < outputHeight; ++outH) {
|
||||
SmallVector<Value> rowPixels;
|
||||
rowPixels.reserve(outputWidth);
|
||||
|
||||
for (int64_t outW = 0; outW < outputWidth; ++outW) {
|
||||
SmallVector<Value> outputChannelTiles;
|
||||
outputChannelTiles.reserve(channelTileCount);
|
||||
|
||||
for (int64_t channelTile = 0; channelTile < channelTileCount; ++channelTile) {
|
||||
const int64_t tileChannels = std::min<int64_t>(xbarSize, channels - channelTile * xbarSize);
|
||||
auto tileType = RankedTensorType::get({1, tileChannels, 1, 1}, outType.getElementType());
|
||||
|
||||
SmallVector<Value> windowValues;
|
||||
windowValues.reserve(kernelHeight * kernelWidth);
|
||||
for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) {
|
||||
const int64_t inH = outH * strideHeight + kernelH * dilationHeight - padTop;
|
||||
if (inH < 0 || inH >= inputHeight)
|
||||
continue;
|
||||
|
||||
for (int64_t kernelW = 0; kernelW < kernelWidth; ++kernelW) {
|
||||
const int64_t inW = outW * strideWidth + kernelW * dilationWidth - padLeft;
|
||||
if (inW < 0 || inW >= inputWidth)
|
||||
continue;
|
||||
|
||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(batch),
|
||||
rewriter.getIndexAttr(channelTile * xbarSize),
|
||||
rewriter.getIndexAttr(inH),
|
||||
rewriter.getIndexAttr(inW)};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(tileChannels),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1)};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1)};
|
||||
Value windowValue =
|
||||
tensor::ExtractSliceOp::create(rewriter, loc, tileType, input, offsets, sizes, strides);
|
||||
windowValue = materializeContiguousTile(rewriter, loc, windowValue);
|
||||
windowValues.push_back(windowValue);
|
||||
}
|
||||
}
|
||||
|
||||
if (windowValues.empty())
|
||||
return rewriter.notifyMatchFailure(poolOp, "pool window resolved to zero valid elements.");
|
||||
|
||||
Value reducedWindow = reduceWindowValues<ReduceOp>(rewriter, loc, windowValues);
|
||||
if constexpr (std::is_same_v<PoolOp, ONNXAveragePoolOp>) {
|
||||
const bool countIncludePad = poolOp.getCountIncludePad() == 1;
|
||||
const int64_t divisor =
|
||||
countIncludePad ? kernelHeight * kernelWidth : static_cast<int64_t>(windowValues.size());
|
||||
reducedWindow = scaleAverageWindow(rewriter, loc, reducedWindow, divisor);
|
||||
}
|
||||
|
||||
outputChannelTiles.push_back(reducedWindow);
|
||||
}
|
||||
|
||||
rowPixels.push_back(concatAlongAxis(rewriter, loc, /*axis=*/1, outputChannelTiles));
|
||||
}
|
||||
|
||||
rows.push_back(concatAlongAxis(rewriter, loc, /*axis=*/3, rowPixels));
|
||||
}
|
||||
|
||||
batchResults.push_back(concatAlongAxis(rewriter, loc, /*axis=*/2, rows));
|
||||
}
|
||||
|
||||
Value pooledOutput = concatAlongAxis(rewriter, loc, /*axis=*/0, batchResults);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, pooledOutput);
|
||||
|
||||
rewriter.replaceOp(poolOp, computeOp.getResult(0));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PoolToSpatialCompute<ONNXMaxPoolSingleOutOp>
|
||||
: public PoolToSpatialComputeBase<ONNXMaxPoolSingleOutOp, ONNXMaxPoolSingleOutOpAdaptor, spatial::SpatVMaxOp> {
|
||||
using PoolToSpatialComputeBase::PoolToSpatialComputeBase;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PoolToSpatialCompute<ONNXAveragePoolOp>
|
||||
: public PoolToSpatialComputeBase<ONNXAveragePoolOp, ONNXAveragePoolOpAdaptor, spatial::SpatVAddOp> {
|
||||
using PoolToSpatialComputeBase::PoolToSpatialComputeBase;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void populatePoolTilingPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.insert<PoolToSpatialCompute<ONNXMaxPoolSingleOutOp>>(ctx);
|
||||
patterns.insert<PoolToSpatialCompute<ONNXAveragePoolOp>>(ctx);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -1,427 +0,0 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/IR/ValueRange.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
#include <cstddef>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
Value applyReducePatternNew(SmallVector<Value>& valuesToReduce,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
std::function<Value(const Value&, const Value&)> reduce,
|
||||
std::function<Value(const Value&)> preprocess,
|
||||
std::function<Value(const Value&)> postprocess) {
|
||||
// Simple case: if we have only one input, just return it
|
||||
if (valuesToReduce.size() == 1)
|
||||
return valuesToReduce[0];
|
||||
|
||||
if (preprocess) {
|
||||
for (auto& valToReduce : valuesToReduce) {
|
||||
rewriter.setInsertionPointAfterValue(valToReduce);
|
||||
valToReduce = preprocess(valToReduce);
|
||||
}
|
||||
}
|
||||
|
||||
// It is possible that `valuesToReduce` contains two entries for the same
|
||||
// computeOp. In this case, we need to apply the reduction within-computef
|
||||
|
||||
// Keep a map between a computeOp and the last Value for this reduction
|
||||
std::unordered_map<Operation*, Value> lastValueForCompute;
|
||||
for (auto& valToReduce : valuesToReduce) {
|
||||
Operation* computeOp = valToReduce.getParentBlock()->getParentOp();
|
||||
// if (valToReduce.getDefiningOp()) {
|
||||
// // If the value is defined by an operation, we take the parent
|
||||
// operation computeOp = valToReduce.getDefiningOp()->getParentOp();
|
||||
// } else {
|
||||
// // Otherwise it is a block argument,
|
||||
// computeOp->getBlock()->getParentOp();
|
||||
// }
|
||||
|
||||
assert(isa<spatial::SpatWeightedCompute>(computeOp) && "Expected a ComputeOp");
|
||||
|
||||
auto it = lastValueForCompute.find(computeOp);
|
||||
|
||||
if (it != lastValueForCompute.end()) {
|
||||
// If we have already seen this computeOp, apply the reduction
|
||||
// within-compute
|
||||
Value lastWithinComputeValue = it->second;
|
||||
|
||||
if (valToReduce.getDefiningOp()->isBeforeInBlock(lastWithinComputeValue.getDefiningOp()))
|
||||
rewriter.setInsertionPointAfterValue(lastWithinComputeValue);
|
||||
else
|
||||
rewriter.setInsertionPointAfterValue(valToReduce);
|
||||
valToReduce = reduce(lastWithinComputeValue, valToReduce);
|
||||
lastValueForCompute[computeOp] = valToReduce;
|
||||
}
|
||||
|
||||
lastValueForCompute[computeOp] = valToReduce;
|
||||
}
|
||||
|
||||
// Now, reconstruct from the map the valuesToReduce list
|
||||
valuesToReduce.clear();
|
||||
valuesToReduce.reserve(lastValueForCompute.size());
|
||||
for (auto& entry : lastValueForCompute)
|
||||
valuesToReduce.push_back(entry.second);
|
||||
|
||||
Location loc = valuesToReduce[0].getLoc();
|
||||
auto channelType = spatial::SpatChannelType::get(rewriter.getContext());
|
||||
|
||||
// Recursive algorithm to reduce the inputs to a single one:
|
||||
// - Take two inputs at a time, and reduce them into a single one, updating
|
||||
// the valuesToReduce list which becomes half the size.
|
||||
// - Repeat until there is only one input left.
|
||||
llvm::OwningArrayRef<Value> valuesToReduceRef(valuesToReduce);
|
||||
while (valuesToReduceRef.size() > 1) {
|
||||
SmallVector<Value> nextValuesToReduce;
|
||||
nextValuesToReduce.reserve(valuesToReduceRef.size() / 2);
|
||||
for (size_t i = 0; i < valuesToReduceRef.size() - 1; i += 2) {
|
||||
auto firstValue = valuesToReduceRef[i];
|
||||
auto secondValue = valuesToReduceRef[i + 1];
|
||||
|
||||
auto firstCompute = firstValue.getParentBlock()->getParentOp();
|
||||
auto secondCompute = secondValue.getParentBlock()->getParentOp();
|
||||
|
||||
assert(isa<spatial::SpatWeightedCompute>(firstCompute));
|
||||
assert(isa<spatial::SpatWeightedCompute>(secondCompute));
|
||||
|
||||
if (secondCompute->isBeforeInBlock(firstCompute)) {
|
||||
std::swap(firstValue, secondValue);
|
||||
std::swap(firstCompute, secondCompute);
|
||||
}
|
||||
|
||||
// 1. Add a channel before the first computeOp
|
||||
rewriter.setInsertionPoint(firstCompute);
|
||||
auto channel = spatial::SpatChannelNewOp::create(rewriter, loc, channelType);
|
||||
|
||||
// 2. Add a sendOp after the first value
|
||||
rewriter.setInsertionPointAfterValue(firstValue);
|
||||
spatial::SpatChannelSendOp::create(rewriter, loc, channel, firstValue);
|
||||
|
||||
// 3. Add a receiveOp after the second value
|
||||
rewriter.setInsertionPointAfterValue(secondValue);
|
||||
auto receivedValue = spatial::SpatChannelReceiveOp::create(rewriter, loc, secondValue.getType(), channel);
|
||||
|
||||
// 4. Apply reduction between second value and received value
|
||||
rewriter.setInsertionPointAfterValue(receivedValue);
|
||||
Value reduced = reduce(receivedValue, secondValue);
|
||||
|
||||
nextValuesToReduce.push_back(reduced);
|
||||
}
|
||||
|
||||
// If we have an odd number of inputs, we need to add the last one to the
|
||||
// newInputs list.
|
||||
if (valuesToReduceRef.size() % 2 == 1)
|
||||
nextValuesToReduce.push_back(valuesToReduceRef.back());
|
||||
|
||||
// Replace the inputOps list with the new one.
|
||||
valuesToReduceRef = llvm::OwningArrayRef<Value>(std::move(nextValuesToReduce));
|
||||
}
|
||||
|
||||
assert(valuesToReduceRef.size() == 1 && "Internal error: expected a single input at this point.");
|
||||
|
||||
auto finalValue = valuesToReduceRef[0];
|
||||
|
||||
if (postprocess) {
|
||||
rewriter.setInsertionPointAfterValue(finalValue);
|
||||
finalValue = postprocess(finalValue);
|
||||
}
|
||||
|
||||
return finalValue;
|
||||
}
|
||||
|
||||
template <typename PoolOp>
|
||||
bool hasPostProcessPoolingWindow() {
|
||||
return false;
|
||||
}
|
||||
|
||||
template <>
|
||||
bool hasPostProcessPoolingWindow<ONNXAveragePoolOp>() {
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename PoolOp>
|
||||
Value postProcessPoolingWindow(ConversionPatternRewriter& rewriter,
|
||||
Location loc,
|
||||
PoolOp poolOp,
|
||||
Value valueToDivide,
|
||||
size_t krn_size,
|
||||
size_t tilesSkippedByPadding) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <>
|
||||
Value postProcessPoolingWindow<ONNXAveragePoolOp>(ConversionPatternRewriter& rewriter,
|
||||
Location loc,
|
||||
ONNXAveragePoolOp poolOp,
|
||||
Value valueToDivide,
|
||||
size_t krn_size,
|
||||
size_t tilesSkippedByPadding) {
|
||||
bool countIncludePad = poolOp.getCountIncludePad() == 1;
|
||||
|
||||
size_t divisorNumber = countIncludePad ? krn_size : krn_size - tilesSkippedByPadding;
|
||||
|
||||
RankedTensorType scalarTensor = RankedTensorType::get({1}, rewriter.getF32Type());
|
||||
|
||||
// Put a spat.const before the computeOp, and use its value. We do this to be
|
||||
// compatible with the current code generation, which assumes constant to be
|
||||
// loaded in global memory, which is allocated by adding a spat.const OP
|
||||
// directly under func.func (i.e. alongside ComputeOps)
|
||||
auto computeOp = cast<spatial::SpatWeightedCompute>(valueToDivide.getDefiningOp()->getParentOp());
|
||||
rewriter.setInsertionPoint(computeOp);
|
||||
auto divisorValue = spatial::SpatConstantOp::create(rewriter,
|
||||
loc,
|
||||
scalarTensor,
|
||||
rewriter.getI64IntegerAttr(divisorNumber),
|
||||
/* should_allocate = */ rewriter.getBoolAttr(true));
|
||||
|
||||
rewriter.setInsertionPointAfterValue(valueToDivide);
|
||||
return spatial::SpatVSDivOp::create(rewriter, loc, valueToDivide.getType(), valueToDivide, divisorValue);
|
||||
}
|
||||
|
||||
template <typename PoolOp, typename PoolOpAdaptor, typename ReduceOp>
|
||||
struct PoolingBaseConverter : public OpConversionPattern<PoolOp> {
|
||||
PoolingBaseConverter(MLIRContext* ctx)
|
||||
: OpConversionPattern<PoolOp>(ctx) {}
|
||||
|
||||
LogicalResult matchAndRewrite(PoolOp poolOp, PoolOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final {
|
||||
Value X = adaptor.getX();
|
||||
ShapedType xShape = mlir::cast<ShapedType>(X.getType());
|
||||
Value Y = poolOp.getResult();
|
||||
ShapedType yShape = mlir::cast<ShapedType>(Y.getType());
|
||||
|
||||
size_t stride_x, stride_y, dilation_x, dilation_y, krn_w, krn_h;
|
||||
unpackOptionalPairVector(adaptor.getStrides(), stride_x, stride_y);
|
||||
unpackOptionalPairVector(adaptor.getDilations(), dilation_x, dilation_y);
|
||||
unpackOptionalPairVector(adaptor.getKernelShape(), krn_w, krn_h);
|
||||
|
||||
if (adaptor.getAutoPad() != "NOTSET")
|
||||
return rewriter.notifyMatchFailure(poolOp, "auto_pad != NOTSET is deprecated.");
|
||||
|
||||
size_t pad_x, pad_y;
|
||||
auto padUnpackError = unpackOptionalPadsVector(adaptor.getPads(), pad_x, pad_y);
|
||||
if (padUnpackError.has_value())
|
||||
return rewriter.notifyMatchFailure(poolOp, padUnpackError.value());
|
||||
|
||||
Location loc = poolOp.getLoc();
|
||||
|
||||
size_t input_h = getImageHeight(xShape);
|
||||
size_t input_w = getImageWidth(xShape);
|
||||
size_t output_h = getImageHeight(yShape);
|
||||
size_t output_w = getImageWidth(yShape);
|
||||
size_t channelTileCount = ceilIntegerDivide(getImageChannel(xShape), crossbarSize.getValue());
|
||||
size_t channelTileRest = getImageChannel(xShape) % crossbarSize;
|
||||
|
||||
// 1: Tile the input tensor
|
||||
// Input tiles need to be indexed by:
|
||||
// a. Channel Tile
|
||||
// b. Pixel `x` position
|
||||
// c. Pixel `y` position
|
||||
// For example: inputTiles[channelTile][x][y]
|
||||
// Example complete input tensor: tensor<1x3x12x12xf32> (NxCxWxH)
|
||||
// Suppose that the input tensor is produced by concatenating the results of
|
||||
// many ComputeOps. Get the result tiles from these ComputeOps.
|
||||
SmallVector<SmallVector<SmallVector<Value>>> inputTiles(
|
||||
channelTileCount, SmallVector<SmallVector<Value>>(input_w, SmallVector<Value>(input_h)));
|
||||
|
||||
auto resolveErrorOpt =
|
||||
resolveImgInputTiles(X, inputTiles, channelTileCount, channelTileRest, input_w, input_h, rewriter);
|
||||
if (resolveErrorOpt.has_value())
|
||||
return rewriter.notifyMatchFailure(poolOp, *resolveErrorOpt);
|
||||
|
||||
// TODO: This requires a core for each input tile, which is not ideal. We
|
||||
// can do better.
|
||||
// If some input tiles come from the func.func operands, load
|
||||
// them into a computeOp and yield them
|
||||
for (size_t t = 0; t < channelTileCount; t++) {
|
||||
for (size_t x = 0; x < input_w; x++) {
|
||||
for (size_t y = 0; y < input_h; y++) {
|
||||
if (auto extractSliceOp = inputTiles[t][x][y].getDefiningOp<tensor::ExtractSliceOp>()) {
|
||||
Location tileLoc = extractSliceOp.getLoc();
|
||||
|
||||
auto tempComputeOp = spatial::SpatWeightedCompute::create(rewriter,
|
||||
tileLoc,
|
||||
extractSliceOp.getResultType(),
|
||||
/* xbarWeights =*/ValueRange(),
|
||||
extractSliceOp.getResult());
|
||||
|
||||
Block* tempComputeOpBlock = new Block();
|
||||
tempComputeOp.getBody().push_back(tempComputeOpBlock);
|
||||
auto tempComputeOpBlockArg = tempComputeOpBlock->addArgument(extractSliceOp.getType(), tileLoc);
|
||||
|
||||
rewriter.setInsertionPointToStart(tempComputeOpBlock);
|
||||
spatial::SpatYieldOp::create(rewriter, tileLoc, tempComputeOpBlockArg);
|
||||
rewriter.setInsertionPointAfter(tempComputeOp);
|
||||
inputTiles[t][x][y] = tempComputeOp.getResult(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2: Tile the output tensor
|
||||
// Output tiles need to be indexed by:
|
||||
// a. Channel Tile
|
||||
// b. Pixel `x` position
|
||||
// c. Pixel `y` position
|
||||
// For example: outputTiles[channelTile][x][y]
|
||||
// Example complete output tensor: tensor<1x3x6x6xf32> (NxCxWxH)
|
||||
SmallVector<SmallVector<SmallVector<Value>>> outputTiles(
|
||||
channelTileCount, SmallVector<SmallVector<Value>>(output_w, SmallVector<Value>(output_h, nullptr)));
|
||||
|
||||
// List of values to pool for each output pixel
|
||||
SmallVector<Value> valuesToPool;
|
||||
|
||||
// Iterate each output tile
|
||||
for (size_t outTile = 0; outTile < channelTileCount; outTile++) {
|
||||
// Iterate each output pixel
|
||||
for (size_t outX = 0; outX < output_w; outX++) {
|
||||
for (size_t outY = 0; outY < output_h; outY++) {
|
||||
|
||||
// Each output pixel tile is computed by pooling a window of input
|
||||
// pixel tiles
|
||||
valuesToPool.clear();
|
||||
size_t tilesSkippedByPadding = 0;
|
||||
|
||||
auto [start_x, end_x] = kernel_get_start_and_end(outX, input_w, krn_w, stride_x, dilation_x, pad_x);
|
||||
auto [start_y, end_y] = kernel_get_start_and_end(outY, input_h, krn_h, stride_y, dilation_y, pad_y);
|
||||
|
||||
for (size_t inX = start_x; inX < end_x; inX += dilation_x) {
|
||||
for (size_t inY = start_y; inY < end_y; inY += dilation_y) {
|
||||
if (failed(verifyWithinBoundsAndPaddings(input_w, input_h, inX, inY, pad_x, pad_y))) {
|
||||
tilesSkippedByPadding++;
|
||||
continue;
|
||||
}
|
||||
|
||||
Value inputTile = inputTiles[outTile][inX][inY];
|
||||
|
||||
Value valueToPool;
|
||||
if (auto computeProducer = inputTile.getDefiningOp<spatial::SpatWeightedCompute>()) {
|
||||
|
||||
int resultNumber = getResultIndex(computeProducer, inputTile);
|
||||
|
||||
auto yieldInComputeOp = cast<spatial::SpatYieldOp>(computeProducer.getBody().front().getTerminator());
|
||||
valueToPool = yieldInComputeOp.getOperand(resultNumber);
|
||||
}
|
||||
else if (auto receiveProducer = inputTile.getDefiningOp<spatial::SpatChannelReceiveOp>()) {
|
||||
auto sendOpOpt = getOtherEndOfChannel(receiveProducer, true, rewriter);
|
||||
if (failed(sendOpOpt)) {
|
||||
return rewriter.notifyMatchFailure(poolOp,
|
||||
"ChannelReceiveOp does not have a matching "
|
||||
"ChannelSendOp.");
|
||||
}
|
||||
auto sendOp = cast<spatial::SpatChannelSendOp>(*sendOpOpt);
|
||||
|
||||
valueToPool = sendOp.getData();
|
||||
}
|
||||
else {
|
||||
return rewriter.notifyMatchFailure(poolOp,
|
||||
"Input tile for Pooling is not produced by a "
|
||||
"WeightedComputeOp nor a receiveOp");
|
||||
}
|
||||
|
||||
valuesToPool.push_back(valueToPool);
|
||||
}
|
||||
}
|
||||
|
||||
assert(valuesToPool.size() != 0 && "Pooling computed on zero tiles make no sense.");
|
||||
// assert(computeOpsForPooling.size() != 1 &&
|
||||
// "Pooling computed on one tiles make no sense??? Or maybe
|
||||
// this " "should have been simplified earlier???");
|
||||
|
||||
std::function<Value(const Value&)> postProcessFn = nullptr;
|
||||
if (hasPostProcessPoolingWindow<PoolOp>()) {
|
||||
postProcessFn = [&](const Value prevFinalRes) {
|
||||
return postProcessPoolingWindow(
|
||||
rewriter, loc, poolOp, prevFinalRes, krn_h * krn_w, tilesSkippedByPadding);
|
||||
};
|
||||
}
|
||||
|
||||
Value reducedWithinCompute = applyReducePatternNew(
|
||||
valuesToPool,
|
||||
rewriter,
|
||||
[&](const Value lhs, const Value rhs) { return ReduceOp::create(rewriter, loc, lhs.getType(), lhs, rhs); },
|
||||
nullptr,
|
||||
postProcessFn);
|
||||
|
||||
// Send this value through a channel, and receive it in the
|
||||
// `func.func`. During lowering, we will need to "move it" into the
|
||||
// users computeOps
|
||||
auto computeOpOfReduced =
|
||||
cast<spatial::SpatWeightedCompute>(reducedWithinCompute.getDefiningOp()->getParentOp());
|
||||
|
||||
// Create a new channel before the computeOp
|
||||
rewriter.setInsertionPoint(computeOpOfReduced);
|
||||
auto reduceChannel =
|
||||
spatial::SpatChannelNewOp::create(rewriter, loc, spatial::SpatChannelType::get(rewriter.getContext()));
|
||||
|
||||
// Send value through the channel
|
||||
rewriter.setInsertionPointAfterValue(reducedWithinCompute);
|
||||
spatial::SpatChannelSendOp::create(rewriter, loc, reduceChannel, reducedWithinCompute);
|
||||
|
||||
// Receive after the computeOp
|
||||
rewriter.setInsertionPointAfter(computeOpOfReduced);
|
||||
auto receivedValue =
|
||||
spatial::SpatChannelReceiveOp::create(rewriter, loc, reducedWithinCompute.getType(), reduceChannel);
|
||||
|
||||
outputTiles[outTile][outX][outY] = receivedValue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: outputTiles are not the results of the computeOps! We need to add
|
||||
// them!
|
||||
|
||||
std::unordered_map<Operation*, SmallVector<std::tuple<size_t, size_t, size_t, Value>>> computeOpNeedingResults;
|
||||
|
||||
// Iterate each output tile
|
||||
for (size_t outTile = 0; outTile < channelTileCount; outTile++) {
|
||||
// Iterate each output pixel
|
||||
for (size_t outX = 0; outX < output_w; outX++) {
|
||||
for (size_t outY = 0; outY < output_h; outY++) {
|
||||
auto outputTile = outputTiles[outTile][outX][outY];
|
||||
auto outputTileProducer = outputTile.getDefiningOp()->getParentOp();
|
||||
if (!outputTileProducer) {
|
||||
return rewriter.notifyMatchFailure(poolOp,
|
||||
"Output tile for Pooling is not produced by a "
|
||||
"WeightedComputeOp.");
|
||||
}
|
||||
|
||||
computeOpNeedingResults[outputTileProducer].push_back(std::make_tuple(outTile, outX, outY, outputTile));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Value outputImage = createImgConcatOp(outputTiles, rewriter, loc, poolOp.getType());
|
||||
|
||||
rewriter.replaceOp(poolOp, outputImage);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populatePoolingTilingPattern(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.insert<PoolingBaseConverter<ONNXMaxPoolSingleOutOp, ONNXMaxPoolSingleOutOpAdaptor, spatial::SpatVMaxOp>>(
|
||||
ctx);
|
||||
patterns.insert<PoolingBaseConverter<ONNXAveragePoolOp, ONNXAveragePoolOpAdaptor, spatial::SpatVAddOp>>(ctx);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -27,9 +27,21 @@ def spatToPimMVMOp : Pat<
|
||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||
>;
|
||||
|
||||
def spatToPimVAddOp : Pat<
|
||||
def spatToPimVVAddOp : Pat<
|
||||
(SpatVAddOp:$srcOpRes $a, $b),
|
||||
(PimVAddOp $a, $b,
|
||||
(PimVVAddOp $a, $b,
|
||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||
>;
|
||||
|
||||
def spatToPimVVMulOp : Pat<
|
||||
(SpatVMulOp:$srcOpRes $a, $b),
|
||||
(PimVVMulOp $a, $b,
|
||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||
>;
|
||||
|
||||
def spatToPimVVMaxOp : Pat<
|
||||
(SpatVMaxOp:$srcOpRes $a, $b),
|
||||
(PimVVMaxOp $a, $b,
|
||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||
>;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user