ReduceMean + resnet
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
ilgeco
2026-06-10 14:30:10 +02:00
parent 237654dadf
commit 852bef7605
12 changed files with 199 additions and 10 deletions
@@ -6,6 +6,8 @@
#include <algorithm>
#include <numeric>
#include <optional>
#include <type_traits>
#include "src/Accelerators/PIM/Common/IR/AffineUtils.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
@@ -19,6 +21,85 @@ using namespace mlir;
namespace onnx_mlir {
namespace {
struct ReduceMeanSemantics {
SmallVector<int64_t> axes;
int64_t keepdims = 1;
bool isIdentity = false;
};
static bool isNoneValueLike(Value value) { return isa_and_nonnull<ONNXNoneOp>(value.getDefiningOp()); }
static FailureOr<SmallVector<int64_t>> getConstantIntValues(Value value) {
auto denseAttr = dyn_cast_or_null<DenseIntElementsAttr>(getHostConstDenseElementsAttr(value));
if (!denseAttr)
return failure();
return SmallVector<int64_t>(denseAttr.getValues<int64_t>().begin(), denseAttr.getValues<int64_t>().end());
}
static FailureOr<SmallVector<int64_t>> normalizeAxesChecked(ArrayRef<int64_t> axes, int64_t rank) {
SmallVector<int64_t> normalizedAxes;
normalizedAxes.reserve(axes.size());
for (int64_t axis : axes) {
auto normalizedAxis = normalizeAxisChecked(axis, rank);
if (failed(normalizedAxis))
return failure();
normalizedAxes.push_back(*normalizedAxis);
}
llvm::sort(normalizedAxes);
normalizedAxes.erase(std::unique(normalizedAxes.begin(), normalizedAxes.end()), normalizedAxes.end());
return normalizedAxes;
}
template <typename ReduceMeanOp, typename ReduceMeanOpAdaptor>
static FailureOr<ReduceMeanSemantics>
getReduceMeanSemantics(ReduceMeanOp reduceMeanOp, ReduceMeanOpAdaptor adaptor, int64_t inputRank) {
ReduceMeanSemantics semantics;
semantics.keepdims = reduceMeanOp.getKeepdims();
if constexpr (std::is_same_v<ReduceMeanOp, ONNXReduceMeanV13Op>) {
auto axes = onnx_mlir::normalizeAxesChecked(std::optional<ArrayAttr>(reduceMeanOp.getAxesAttr()), inputRank);
if (failed(axes))
return failure();
semantics.axes = std::move(*axes);
return semantics;
}
else {
if (isNoneValueLike(adaptor.getAxes())) {
if (reduceMeanOp.getNoopWithEmptyAxes() != 0) {
semantics.isIdentity = true;
return semantics;
}
semantics.axes.reserve(inputRank);
for (int64_t axis = 0; axis < inputRank; ++axis)
semantics.axes.push_back(axis);
return semantics;
}
auto axes = getConstantIntValues(adaptor.getAxes());
if (failed(axes))
return failure();
if (axes->empty()) {
if (reduceMeanOp.getNoopWithEmptyAxes() != 0) {
semantics.isIdentity = true;
return semantics;
}
semantics.axes.reserve(inputRank);
for (int64_t axis = 0; axis < inputRank; ++axis)
semantics.axes.push_back(axis);
return semantics;
}
auto normalizedAxes = normalizeAxesChecked(*axes, inputRank);
if (failed(normalizedAxes))
return failure();
semantics.axes = std::move(*normalizedAxes);
return semantics;
}
}
static SmallVector<bool> buildReducedAxesMask(ArrayRef<int64_t> axes, int64_t rank) {
SmallVector<bool> reducedAxes(rank, false);
for (int64_t axis : axes) {
@@ -251,11 +332,13 @@ static Value squeezeReducedAxes(Value keepdimsValue,
return squeezeCompute.getResult(0);
}
struct ReduceMeanToSpatialCompute : OpConversionPattern<ONNXReduceMeanV13Op> {
using OpConversionPattern::OpConversionPattern;
template <typename ReduceMeanOp>
struct ReduceMeanToSpatialCompute : OpConversionPattern<ReduceMeanOp> {
using OpConversionPattern<ReduceMeanOp>::OpConversionPattern;
using Adaptor = typename ReduceMeanOp::Adaptor;
LogicalResult matchAndRewrite(ONNXReduceMeanV13Op reduceMeanOp,
ONNXReduceMeanV13OpAdaptor adaptor,
LogicalResult matchAndRewrite(ReduceMeanOp reduceMeanOp,
Adaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
auto inputType = dyn_cast<RankedTensorType>(adaptor.getData().getType());
auto resultType = dyn_cast<RankedTensorType>(reduceMeanOp.getReduced().getType());
@@ -266,10 +349,18 @@ struct ReduceMeanToSpatialCompute : OpConversionPattern<ONNXReduceMeanV13Op> {
return success();
}
auto axes = normalizeAxesChecked(std::optional<ArrayAttr>(reduceMeanOp.getAxesAttr()), inputType.getRank());
if (failed(axes))
return failure();
SmallVector<bool> reducedAxes = buildReducedAxesMask(*axes, inputType.getRank());
auto semantics = getReduceMeanSemantics(reduceMeanOp, adaptor, inputType.getRank());
if (failed(semantics))
return rewriter.notifyMatchFailure(reduceMeanOp, "requires compile-time constant, in-range ReduceMean axes");
if (semantics->isIdentity) {
if (inputType != resultType)
return rewriter.notifyMatchFailure(
reduceMeanOp, "noop_with_empty_axes identity requires the result type to match the input type");
rewriter.replaceOp(reduceMeanOp, adaptor.getData());
return success();
}
SmallVector<bool> reducedAxes = buildReducedAxesMask(semantics->axes, inputType.getRank());
if (reducedAxes.empty() && inputType.getRank() != 0)
return failure();
@@ -289,7 +380,7 @@ struct ReduceMeanToSpatialCompute : OpConversionPattern<ONNXReduceMeanV13Op> {
Value reducedKeepdims =
buildKeepdimsFromLanePackedBatch(*lanePackedKeepdims, keepdimsType, compactKeptType, reducedAxes, rewriter, loc);
if (reduceMeanOp.getKeepdims() != 0) {
if (semantics->keepdims != 0) {
rewriter.replaceOp(reduceMeanOp, reducedKeepdims);
return success();
}
@@ -303,7 +394,7 @@ struct ReduceMeanToSpatialCompute : OpConversionPattern<ONNXReduceMeanV13Op> {
} // namespace
void populateReduceMeanPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.add<ReduceMeanToSpatialCompute>(ctx);
patterns.add<ReduceMeanToSpatialCompute<ONNXReduceMeanV13Op>, ReduceMeanToSpatialCompute<ONNXReduceMeanOp>>(ctx);
}
} // namespace onnx_mlir