This commit is contained in:
@@ -6,6 +6,8 @@
|
||||
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include <optional>
|
||||
#include <type_traits>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/AffineUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
@@ -19,6 +21,85 @@ using namespace mlir;
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
struct ReduceMeanSemantics {
|
||||
SmallVector<int64_t> axes;
|
||||
int64_t keepdims = 1;
|
||||
bool isIdentity = false;
|
||||
};
|
||||
|
||||
static bool isNoneValueLike(Value value) { return isa_and_nonnull<ONNXNoneOp>(value.getDefiningOp()); }
|
||||
|
||||
static FailureOr<SmallVector<int64_t>> getConstantIntValues(Value value) {
|
||||
auto denseAttr = dyn_cast_or_null<DenseIntElementsAttr>(getHostConstDenseElementsAttr(value));
|
||||
if (!denseAttr)
|
||||
return failure();
|
||||
return SmallVector<int64_t>(denseAttr.getValues<int64_t>().begin(), denseAttr.getValues<int64_t>().end());
|
||||
}
|
||||
|
||||
static FailureOr<SmallVector<int64_t>> normalizeAxesChecked(ArrayRef<int64_t> axes, int64_t rank) {
|
||||
SmallVector<int64_t> normalizedAxes;
|
||||
normalizedAxes.reserve(axes.size());
|
||||
for (int64_t axis : axes) {
|
||||
auto normalizedAxis = normalizeAxisChecked(axis, rank);
|
||||
if (failed(normalizedAxis))
|
||||
return failure();
|
||||
normalizedAxes.push_back(*normalizedAxis);
|
||||
}
|
||||
llvm::sort(normalizedAxes);
|
||||
normalizedAxes.erase(std::unique(normalizedAxes.begin(), normalizedAxes.end()), normalizedAxes.end());
|
||||
return normalizedAxes;
|
||||
}
|
||||
|
||||
template <typename ReduceMeanOp, typename ReduceMeanOpAdaptor>
|
||||
static FailureOr<ReduceMeanSemantics>
|
||||
getReduceMeanSemantics(ReduceMeanOp reduceMeanOp, ReduceMeanOpAdaptor adaptor, int64_t inputRank) {
|
||||
ReduceMeanSemantics semantics;
|
||||
semantics.keepdims = reduceMeanOp.getKeepdims();
|
||||
|
||||
if constexpr (std::is_same_v<ReduceMeanOp, ONNXReduceMeanV13Op>) {
|
||||
auto axes = onnx_mlir::normalizeAxesChecked(std::optional<ArrayAttr>(reduceMeanOp.getAxesAttr()), inputRank);
|
||||
if (failed(axes))
|
||||
return failure();
|
||||
semantics.axes = std::move(*axes);
|
||||
return semantics;
|
||||
}
|
||||
else {
|
||||
if (isNoneValueLike(adaptor.getAxes())) {
|
||||
if (reduceMeanOp.getNoopWithEmptyAxes() != 0) {
|
||||
semantics.isIdentity = true;
|
||||
return semantics;
|
||||
}
|
||||
|
||||
semantics.axes.reserve(inputRank);
|
||||
for (int64_t axis = 0; axis < inputRank; ++axis)
|
||||
semantics.axes.push_back(axis);
|
||||
return semantics;
|
||||
}
|
||||
|
||||
auto axes = getConstantIntValues(adaptor.getAxes());
|
||||
if (failed(axes))
|
||||
return failure();
|
||||
|
||||
if (axes->empty()) {
|
||||
if (reduceMeanOp.getNoopWithEmptyAxes() != 0) {
|
||||
semantics.isIdentity = true;
|
||||
return semantics;
|
||||
}
|
||||
|
||||
semantics.axes.reserve(inputRank);
|
||||
for (int64_t axis = 0; axis < inputRank; ++axis)
|
||||
semantics.axes.push_back(axis);
|
||||
return semantics;
|
||||
}
|
||||
|
||||
auto normalizedAxes = normalizeAxesChecked(*axes, inputRank);
|
||||
if (failed(normalizedAxes))
|
||||
return failure();
|
||||
semantics.axes = std::move(*normalizedAxes);
|
||||
return semantics;
|
||||
}
|
||||
}
|
||||
|
||||
static SmallVector<bool> buildReducedAxesMask(ArrayRef<int64_t> axes, int64_t rank) {
|
||||
SmallVector<bool> reducedAxes(rank, false);
|
||||
for (int64_t axis : axes) {
|
||||
@@ -251,11 +332,13 @@ static Value squeezeReducedAxes(Value keepdimsValue,
|
||||
return squeezeCompute.getResult(0);
|
||||
}
|
||||
|
||||
struct ReduceMeanToSpatialCompute : OpConversionPattern<ONNXReduceMeanV13Op> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
template <typename ReduceMeanOp>
|
||||
struct ReduceMeanToSpatialCompute : OpConversionPattern<ReduceMeanOp> {
|
||||
using OpConversionPattern<ReduceMeanOp>::OpConversionPattern;
|
||||
using Adaptor = typename ReduceMeanOp::Adaptor;
|
||||
|
||||
LogicalResult matchAndRewrite(ONNXReduceMeanV13Op reduceMeanOp,
|
||||
ONNXReduceMeanV13OpAdaptor adaptor,
|
||||
LogicalResult matchAndRewrite(ReduceMeanOp reduceMeanOp,
|
||||
Adaptor adaptor,
|
||||
ConversionPatternRewriter& rewriter) const override {
|
||||
auto inputType = dyn_cast<RankedTensorType>(adaptor.getData().getType());
|
||||
auto resultType = dyn_cast<RankedTensorType>(reduceMeanOp.getReduced().getType());
|
||||
@@ -266,10 +349,18 @@ struct ReduceMeanToSpatialCompute : OpConversionPattern<ONNXReduceMeanV13Op> {
|
||||
return success();
|
||||
}
|
||||
|
||||
auto axes = normalizeAxesChecked(std::optional<ArrayAttr>(reduceMeanOp.getAxesAttr()), inputType.getRank());
|
||||
if (failed(axes))
|
||||
return failure();
|
||||
SmallVector<bool> reducedAxes = buildReducedAxesMask(*axes, inputType.getRank());
|
||||
auto semantics = getReduceMeanSemantics(reduceMeanOp, adaptor, inputType.getRank());
|
||||
if (failed(semantics))
|
||||
return rewriter.notifyMatchFailure(reduceMeanOp, "requires compile-time constant, in-range ReduceMean axes");
|
||||
if (semantics->isIdentity) {
|
||||
if (inputType != resultType)
|
||||
return rewriter.notifyMatchFailure(
|
||||
reduceMeanOp, "noop_with_empty_axes identity requires the result type to match the input type");
|
||||
rewriter.replaceOp(reduceMeanOp, adaptor.getData());
|
||||
return success();
|
||||
}
|
||||
|
||||
SmallVector<bool> reducedAxes = buildReducedAxesMask(semantics->axes, inputType.getRank());
|
||||
if (reducedAxes.empty() && inputType.getRank() != 0)
|
||||
return failure();
|
||||
|
||||
@@ -289,7 +380,7 @@ struct ReduceMeanToSpatialCompute : OpConversionPattern<ONNXReduceMeanV13Op> {
|
||||
Value reducedKeepdims =
|
||||
buildKeepdimsFromLanePackedBatch(*lanePackedKeepdims, keepdimsType, compactKeptType, reducedAxes, rewriter, loc);
|
||||
|
||||
if (reduceMeanOp.getKeepdims() != 0) {
|
||||
if (semantics->keepdims != 0) {
|
||||
rewriter.replaceOp(reduceMeanOp, reducedKeepdims);
|
||||
return success();
|
||||
}
|
||||
@@ -303,7 +394,7 @@ struct ReduceMeanToSpatialCompute : OpConversionPattern<ONNXReduceMeanV13Op> {
|
||||
} // namespace
|
||||
|
||||
void populateReduceMeanPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.add<ReduceMeanToSpatialCompute>(ctx);
|
||||
patterns.add<ReduceMeanToSpatialCompute<ONNXReduceMeanV13Op>, ReduceMeanToSpatialCompute<ONNXReduceMeanOp>>(ctx);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
Reference in New Issue
Block a user