|
|
|
@@ -6,13 +6,12 @@
|
|
|
|
|
#include "llvm/ADT/SmallVector.h"
|
|
|
|
|
|
|
|
|
|
#include <algorithm>
|
|
|
|
|
#include <cassert>
|
|
|
|
|
#include <optional>
|
|
|
|
|
#include <type_traits>
|
|
|
|
|
|
|
|
|
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
|
|
|
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
|
|
|
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
|
|
|
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
|
|
|
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
|
|
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
|
|
|
|
|
|
|
|
@@ -31,8 +30,13 @@ static int64_t getOptionalI64(std::optional<ArrayAttrT> arrayAttr, size_t index,
|
|
|
|
|
return arrayAttr ? getI64(*arrayAttr, index) : defaultValue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static Value concatAlongAxis(ConversionPatternRewriter& rewriter, Location loc, int64_t axis, ArrayRef<Value> values) {
|
|
|
|
|
assert(!values.empty() && "Expected at least one value to concatenate.");
|
|
|
|
|
template <typename PoolOp>
|
|
|
|
|
static FailureOr<Value>
|
|
|
|
|
concatAlongAxis(ConversionPatternRewriter& rewriter, Location loc, PoolOp poolOp, int64_t axis, ArrayRef<Value> values) {
|
|
|
|
|
if (values.empty()) {
|
|
|
|
|
poolOp.emitOpError("failed to build pooled output because an intermediate concatenation input list was empty");
|
|
|
|
|
return failure();
|
|
|
|
|
}
|
|
|
|
|
return createSpatConcat(rewriter, loc, axis, values);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@@ -51,8 +55,12 @@ static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Loca
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename ReduceOp>
|
|
|
|
|
static Value reduceWindowValues(ConversionPatternRewriter& rewriter, Location loc, ArrayRef<Value> windowValues) {
|
|
|
|
|
assert(!windowValues.empty() && "Expected at least one pool window value.");
|
|
|
|
|
static FailureOr<Value>
|
|
|
|
|
reduceWindowValues(ConversionPatternRewriter& rewriter, Location loc, Operation* op, ArrayRef<Value> windowValues) {
|
|
|
|
|
if (windowValues.empty()) {
|
|
|
|
|
op->emitOpError("pool window resolved to zero valid elements");
|
|
|
|
|
return failure();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Value reduced = windowValues.front();
|
|
|
|
|
for (Value value : windowValues.drop_front())
|
|
|
|
@@ -60,9 +68,12 @@ static Value reduceWindowValues(ConversionPatternRewriter& rewriter, Location lo
|
|
|
|
|
return reduced;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static Value
|
|
|
|
|
scaleAverageWindow(ConversionPatternRewriter& rewriter, Location loc, Value reducedWindow, int64_t divisor) {
|
|
|
|
|
assert(divisor > 0 && "AveragePool divisor must be positive.");
|
|
|
|
|
static FailureOr<Value>
|
|
|
|
|
scaleAverageWindow(ConversionPatternRewriter& rewriter, Location loc, Operation* op, Value reducedWindow, int64_t divisor) {
|
|
|
|
|
if (divisor <= 0) {
|
|
|
|
|
op->emitOpError("AveragePool divisor must be positive");
|
|
|
|
|
return failure();
|
|
|
|
|
}
|
|
|
|
|
if (divisor == 1)
|
|
|
|
|
return reducedWindow;
|
|
|
|
|
|
|
|
|
@@ -70,7 +81,7 @@ scaleAverageWindow(ConversionPatternRewriter& rewriter, Location loc, Value redu
|
|
|
|
|
double scale = 1.0 / static_cast<double>(divisor);
|
|
|
|
|
auto scaleAttr = DenseElementsAttr::get(tileType, rewriter.getFloatAttr(tileType.getElementType(), scale));
|
|
|
|
|
Value scaleTensor = arith::ConstantOp::create(rewriter, loc, tileType, scaleAttr);
|
|
|
|
|
return spatial::SpatVMulOp::create(rewriter, loc, tileType, reducedWindow, scaleTensor);
|
|
|
|
|
return spatial::SpatVMulOp::create(rewriter, loc, tileType, reducedWindow, scaleTensor).getResult();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename PoolOp>
|
|
|
|
@@ -209,28 +220,45 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
|
|
|
|
|
if (windowValues.empty())
|
|
|
|
|
return rewriter.notifyMatchFailure(poolOp, "pool window resolved to zero valid elements.");
|
|
|
|
|
|
|
|
|
|
Value reducedWindow = reduceWindowValues<ReduceOp>(rewriter, loc, windowValues);
|
|
|
|
|
auto reducedWindow = reduceWindowValues<ReduceOp>(rewriter, loc, poolOp, windowValues);
|
|
|
|
|
if (failed(reducedWindow))
|
|
|
|
|
return failure();
|
|
|
|
|
Value reducedWindowValue = *reducedWindow;
|
|
|
|
|
if constexpr (std::is_same_v<PoolOp, ONNXAveragePoolOp>) {
|
|
|
|
|
const bool countIncludePad = poolOp.getCountIncludePad() == 1;
|
|
|
|
|
const int64_t divisor =
|
|
|
|
|
countIncludePad ? kernelHeight * kernelWidth : static_cast<int64_t>(windowValues.size());
|
|
|
|
|
reducedWindow = scaleAverageWindow(rewriter, loc, reducedWindow, divisor);
|
|
|
|
|
auto scaledWindow = scaleAverageWindow(rewriter, loc, poolOp, reducedWindowValue, divisor);
|
|
|
|
|
if (failed(scaledWindow))
|
|
|
|
|
return failure();
|
|
|
|
|
reducedWindowValue = *scaledWindow;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
outputChannelTiles.push_back(reducedWindow);
|
|
|
|
|
outputChannelTiles.push_back(reducedWindowValue);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
rowPixels.push_back(concatAlongAxis(rewriter, loc, /*axis=*/1, outputChannelTiles));
|
|
|
|
|
auto rowPixel = concatAlongAxis(rewriter, loc, poolOp, /*axis=*/1, outputChannelTiles);
|
|
|
|
|
if (failed(rowPixel))
|
|
|
|
|
return failure();
|
|
|
|
|
rowPixels.push_back(*rowPixel);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
rows.push_back(concatAlongAxis(rewriter, loc, /*axis=*/3, rowPixels));
|
|
|
|
|
auto row = concatAlongAxis(rewriter, loc, poolOp, /*axis=*/3, rowPixels);
|
|
|
|
|
if (failed(row))
|
|
|
|
|
return failure();
|
|
|
|
|
rows.push_back(*row);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
batchResults.push_back(concatAlongAxis(rewriter, loc, /*axis=*/2, rows));
|
|
|
|
|
auto batchResult = concatAlongAxis(rewriter, loc, poolOp, /*axis=*/2, rows);
|
|
|
|
|
if (failed(batchResult))
|
|
|
|
|
return failure();
|
|
|
|
|
batchResults.push_back(*batchResult);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Value pooledOutput = concatAlongAxis(rewriter, loc, /*axis=*/0, batchResults);
|
|
|
|
|
spatial::SpatYieldOp::create(rewriter, loc, pooledOutput);
|
|
|
|
|
auto pooledOutput = concatAlongAxis(rewriter, loc, poolOp, /*axis=*/0, batchResults);
|
|
|
|
|
if (failed(pooledOutput))
|
|
|
|
|
return failure();
|
|
|
|
|
spatial::SpatYieldOp::create(rewriter, loc, *pooledOutput);
|
|
|
|
|
return success();
|
|
|
|
|
});
|
|
|
|
|
if (failed(computeOp))
|
|
|
|
|