Refactor ONNXToSpatial Common and diagnostics

This commit is contained in:
NiccoloN
2026-05-04 13:42:43 +02:00
parent b6ba1e4fea
commit f789954ad7
26 changed files with 686 additions and 486 deletions
@@ -6,13 +6,12 @@
#include "llvm/ADT/SmallVector.h"
#include <algorithm>
#include <cassert>
#include <optional>
#include <type_traits>
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -31,8 +30,13 @@ static int64_t getOptionalI64(std::optional<ArrayAttrT> arrayAttr, size_t index,
return arrayAttr ? getI64(*arrayAttr, index) : defaultValue;
}
static Value concatAlongAxis(ConversionPatternRewriter& rewriter, Location loc, int64_t axis, ArrayRef<Value> values) {
assert(!values.empty() && "Expected at least one value to concatenate.");
template <typename PoolOp>
static FailureOr<Value>
concatAlongAxis(ConversionPatternRewriter& rewriter, Location loc, PoolOp poolOp, int64_t axis, ArrayRef<Value> values) {
if (values.empty()) {
poolOp.emitOpError("failed to build pooled output because an intermediate concatenation input list was empty");
return failure();
}
return createSpatConcat(rewriter, loc, axis, values);
}
@@ -51,8 +55,12 @@ static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Loca
}
template <typename ReduceOp>
static Value reduceWindowValues(ConversionPatternRewriter& rewriter, Location loc, ArrayRef<Value> windowValues) {
assert(!windowValues.empty() && "Expected at least one pool window value.");
static FailureOr<Value>
reduceWindowValues(ConversionPatternRewriter& rewriter, Location loc, Operation* op, ArrayRef<Value> windowValues) {
if (windowValues.empty()) {
op->emitOpError("pool window resolved to zero valid elements");
return failure();
}
Value reduced = windowValues.front();
for (Value value : windowValues.drop_front())
@@ -60,9 +68,12 @@ static Value reduceWindowValues(ConversionPatternRewriter& rewriter, Location lo
return reduced;
}
static Value
scaleAverageWindow(ConversionPatternRewriter& rewriter, Location loc, Value reducedWindow, int64_t divisor) {
assert(divisor > 0 && "AveragePool divisor must be positive.");
static FailureOr<Value>
scaleAverageWindow(ConversionPatternRewriter& rewriter, Location loc, Operation* op, Value reducedWindow, int64_t divisor) {
if (divisor <= 0) {
op->emitOpError("AveragePool divisor must be positive");
return failure();
}
if (divisor == 1)
return reducedWindow;
@@ -70,7 +81,7 @@ scaleAverageWindow(ConversionPatternRewriter& rewriter, Location loc, Value redu
double scale = 1.0 / static_cast<double>(divisor);
auto scaleAttr = DenseElementsAttr::get(tileType, rewriter.getFloatAttr(tileType.getElementType(), scale));
Value scaleTensor = arith::ConstantOp::create(rewriter, loc, tileType, scaleAttr);
return spatial::SpatVMulOp::create(rewriter, loc, tileType, reducedWindow, scaleTensor);
return spatial::SpatVMulOp::create(rewriter, loc, tileType, reducedWindow, scaleTensor).getResult();
}
template <typename PoolOp>
@@ -209,28 +220,45 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
if (windowValues.empty())
return rewriter.notifyMatchFailure(poolOp, "pool window resolved to zero valid elements.");
Value reducedWindow = reduceWindowValues<ReduceOp>(rewriter, loc, windowValues);
auto reducedWindow = reduceWindowValues<ReduceOp>(rewriter, loc, poolOp, windowValues);
if (failed(reducedWindow))
return failure();
Value reducedWindowValue = *reducedWindow;
if constexpr (std::is_same_v<PoolOp, ONNXAveragePoolOp>) {
const bool countIncludePad = poolOp.getCountIncludePad() == 1;
const int64_t divisor =
countIncludePad ? kernelHeight * kernelWidth : static_cast<int64_t>(windowValues.size());
reducedWindow = scaleAverageWindow(rewriter, loc, reducedWindow, divisor);
auto scaledWindow = scaleAverageWindow(rewriter, loc, poolOp, reducedWindowValue, divisor);
if (failed(scaledWindow))
return failure();
reducedWindowValue = *scaledWindow;
}
outputChannelTiles.push_back(reducedWindow);
outputChannelTiles.push_back(reducedWindowValue);
}
rowPixels.push_back(concatAlongAxis(rewriter, loc, /*axis=*/1, outputChannelTiles));
auto rowPixel = concatAlongAxis(rewriter, loc, poolOp, /*axis=*/1, outputChannelTiles);
if (failed(rowPixel))
return failure();
rowPixels.push_back(*rowPixel);
}
rows.push_back(concatAlongAxis(rewriter, loc, /*axis=*/3, rowPixels));
auto row = concatAlongAxis(rewriter, loc, poolOp, /*axis=*/3, rowPixels);
if (failed(row))
return failure();
rows.push_back(*row);
}
batchResults.push_back(concatAlongAxis(rewriter, loc, /*axis=*/2, rows));
auto batchResult = concatAlongAxis(rewriter, loc, poolOp, /*axis=*/2, rows);
if (failed(batchResult))
return failure();
batchResults.push_back(*batchResult);
}
Value pooledOutput = concatAlongAxis(rewriter, loc, /*axis=*/0, batchResults);
spatial::SpatYieldOp::create(rewriter, loc, pooledOutput);
auto pooledOutput = concatAlongAxis(rewriter, loc, poolOp, /*axis=*/0, batchResults);
if (failed(pooledOutput))
return failure();
spatial::SpatYieldOp::create(rewriter, loc, *pooledOutput);
return success();
});
if (failed(computeOp))