Refactor ONNXToSpatial Common and diagnostics
This commit is contained in:
@@ -6,13 +6,12 @@
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <optional>
|
||||
#include <type_traits>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -31,8 +30,13 @@ static int64_t getOptionalI64(std::optional<ArrayAttrT> arrayAttr, size_t index,
|
||||
return arrayAttr ? getI64(*arrayAttr, index) : defaultValue;
|
||||
}
|
||||
|
||||
static Value concatAlongAxis(ConversionPatternRewriter& rewriter, Location loc, int64_t axis, ArrayRef<Value> values) {
|
||||
assert(!values.empty() && "Expected at least one value to concatenate.");
|
||||
template <typename PoolOp>
|
||||
static FailureOr<Value>
|
||||
concatAlongAxis(ConversionPatternRewriter& rewriter, Location loc, PoolOp poolOp, int64_t axis, ArrayRef<Value> values) {
|
||||
if (values.empty()) {
|
||||
poolOp.emitOpError("failed to build pooled output because an intermediate concatenation input list was empty");
|
||||
return failure();
|
||||
}
|
||||
return createSpatConcat(rewriter, loc, axis, values);
|
||||
}
|
||||
|
||||
@@ -51,8 +55,12 @@ static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Loca
|
||||
}
|
||||
|
||||
template <typename ReduceOp>
|
||||
static Value reduceWindowValues(ConversionPatternRewriter& rewriter, Location loc, ArrayRef<Value> windowValues) {
|
||||
assert(!windowValues.empty() && "Expected at least one pool window value.");
|
||||
static FailureOr<Value>
|
||||
reduceWindowValues(ConversionPatternRewriter& rewriter, Location loc, Operation* op, ArrayRef<Value> windowValues) {
|
||||
if (windowValues.empty()) {
|
||||
op->emitOpError("pool window resolved to zero valid elements");
|
||||
return failure();
|
||||
}
|
||||
|
||||
Value reduced = windowValues.front();
|
||||
for (Value value : windowValues.drop_front())
|
||||
@@ -60,9 +68,12 @@ static Value reduceWindowValues(ConversionPatternRewriter& rewriter, Location lo
|
||||
return reduced;
|
||||
}
|
||||
|
||||
static Value
|
||||
scaleAverageWindow(ConversionPatternRewriter& rewriter, Location loc, Value reducedWindow, int64_t divisor) {
|
||||
assert(divisor > 0 && "AveragePool divisor must be positive.");
|
||||
static FailureOr<Value>
|
||||
scaleAverageWindow(ConversionPatternRewriter& rewriter, Location loc, Operation* op, Value reducedWindow, int64_t divisor) {
|
||||
if (divisor <= 0) {
|
||||
op->emitOpError("AveragePool divisor must be positive");
|
||||
return failure();
|
||||
}
|
||||
if (divisor == 1)
|
||||
return reducedWindow;
|
||||
|
||||
@@ -70,7 +81,7 @@ scaleAverageWindow(ConversionPatternRewriter& rewriter, Location loc, Value redu
|
||||
double scale = 1.0 / static_cast<double>(divisor);
|
||||
auto scaleAttr = DenseElementsAttr::get(tileType, rewriter.getFloatAttr(tileType.getElementType(), scale));
|
||||
Value scaleTensor = arith::ConstantOp::create(rewriter, loc, tileType, scaleAttr);
|
||||
return spatial::SpatVMulOp::create(rewriter, loc, tileType, reducedWindow, scaleTensor);
|
||||
return spatial::SpatVMulOp::create(rewriter, loc, tileType, reducedWindow, scaleTensor).getResult();
|
||||
}
|
||||
|
||||
template <typename PoolOp>
|
||||
@@ -209,28 +220,45 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
|
||||
if (windowValues.empty())
|
||||
return rewriter.notifyMatchFailure(poolOp, "pool window resolved to zero valid elements.");
|
||||
|
||||
Value reducedWindow = reduceWindowValues<ReduceOp>(rewriter, loc, windowValues);
|
||||
auto reducedWindow = reduceWindowValues<ReduceOp>(rewriter, loc, poolOp, windowValues);
|
||||
if (failed(reducedWindow))
|
||||
return failure();
|
||||
Value reducedWindowValue = *reducedWindow;
|
||||
if constexpr (std::is_same_v<PoolOp, ONNXAveragePoolOp>) {
|
||||
const bool countIncludePad = poolOp.getCountIncludePad() == 1;
|
||||
const int64_t divisor =
|
||||
countIncludePad ? kernelHeight * kernelWidth : static_cast<int64_t>(windowValues.size());
|
||||
reducedWindow = scaleAverageWindow(rewriter, loc, reducedWindow, divisor);
|
||||
auto scaledWindow = scaleAverageWindow(rewriter, loc, poolOp, reducedWindowValue, divisor);
|
||||
if (failed(scaledWindow))
|
||||
return failure();
|
||||
reducedWindowValue = *scaledWindow;
|
||||
}
|
||||
|
||||
outputChannelTiles.push_back(reducedWindow);
|
||||
outputChannelTiles.push_back(reducedWindowValue);
|
||||
}
|
||||
|
||||
rowPixels.push_back(concatAlongAxis(rewriter, loc, /*axis=*/1, outputChannelTiles));
|
||||
auto rowPixel = concatAlongAxis(rewriter, loc, poolOp, /*axis=*/1, outputChannelTiles);
|
||||
if (failed(rowPixel))
|
||||
return failure();
|
||||
rowPixels.push_back(*rowPixel);
|
||||
}
|
||||
|
||||
rows.push_back(concatAlongAxis(rewriter, loc, /*axis=*/3, rowPixels));
|
||||
auto row = concatAlongAxis(rewriter, loc, poolOp, /*axis=*/3, rowPixels);
|
||||
if (failed(row))
|
||||
return failure();
|
||||
rows.push_back(*row);
|
||||
}
|
||||
|
||||
batchResults.push_back(concatAlongAxis(rewriter, loc, /*axis=*/2, rows));
|
||||
auto batchResult = concatAlongAxis(rewriter, loc, poolOp, /*axis=*/2, rows);
|
||||
if (failed(batchResult))
|
||||
return failure();
|
||||
batchResults.push_back(*batchResult);
|
||||
}
|
||||
|
||||
Value pooledOutput = concatAlongAxis(rewriter, loc, /*axis=*/0, batchResults);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, pooledOutput);
|
||||
auto pooledOutput = concatAlongAxis(rewriter, loc, poolOp, /*axis=*/0, batchResults);
|
||||
if (failed(pooledOutput))
|
||||
return failure();
|
||||
spatial::SpatYieldOp::create(rewriter, loc, *pooledOutput);
|
||||
return success();
|
||||
});
|
||||
if (failed(computeOp))
|
||||
|
||||
Reference in New Issue
Block a user