Files
Raptor/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/ReduceMean.cpp
T
NiccoloN 909c4acfdd
Validate Operations / validate-operations (push) Has been cancelled
huge refactor for high RewritePatterns usage and less ad-hoc cpp code
remove Spatial many ops in favor of tensor ops like in pim
2026-05-12 10:35:44 +02:00

188 lines
7.4 KiB
C++

#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVector.h"
#include <algorithm>
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static SmallVector<int64_t> normalizeAxes(ArrayAttr axesAttr, int64_t rank) {
SmallVector<int64_t> normalizedAxes;
if (!axesAttr) {
normalizedAxes.reserve(rank);
for (int64_t axis = 0; axis < rank; axis++)
normalizedAxes.push_back(axis);
return normalizedAxes;
}
normalizedAxes.reserve(axesAttr.size());
for (Attribute attr : axesAttr) {
int64_t axis = cast<IntegerAttr>(attr).getInt();
normalizedAxes.push_back(axis >= 0 ? axis : rank + axis);
}
llvm::sort(normalizedAxes);
normalizedAxes.erase(std::unique(normalizedAxes.begin(), normalizedAxes.end()), normalizedAxes.end());
return normalizedAxes;
}
static SmallVector<bool> buildReducedAxesMask(ArrayRef<int64_t> axes, int64_t rank) {
SmallVector<bool> reducedAxes(rank, false);
for (int64_t axis : axes) {
if (axis < 0 || axis >= rank)
return {};
reducedAxes[axis] = true;
}
return reducedAxes;
}
static RankedTensorType getAllOnesType(RankedTensorType inputType, Type elementType) {
return RankedTensorType::get(SmallVector<int64_t>(inputType.getRank(), 1), elementType);
}
static SmallVector<ReassociationIndices> buildCollapseReassociation(ArrayRef<bool> reducedAxes) {
SmallVector<ReassociationIndices> reassociation;
ReassociationIndices currentGroup;
for (auto [axis, isReduced] : llvm::enumerate(reducedAxes)) {
currentGroup.push_back(axis);
if (!isReduced) {
reassociation.push_back(currentGroup);
currentGroup.clear();
}
}
if (!currentGroup.empty()) {
if (reassociation.empty())
reassociation.push_back(std::move(currentGroup));
else
reassociation.back().append(currentGroup.begin(), currentGroup.end());
}
return reassociation;
}
static Value
createAverageCompute(Value input, RankedTensorType resultType, ConversionPatternRewriter& rewriter, Location loc) {
constexpr size_t numInputs = 1;
auto computeOp = createSpatCompute<numInputs>(rewriter, loc, resultType, {}, ValueRange {input}, [&](Value x) {
auto avgOp = spatial::SpatVAvgOp::create(rewriter, loc, resultType, x);
spatial::SpatYieldOp::create(rewriter, loc, avgOp.getResult());
});
return computeOp.getResult(0);
}
static Value concatValues(ValueRange inputs, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
auto firstType = cast<RankedTensorType>(inputs.front().getType());
SmallVector<int64_t> outputShape(firstType.getShape().begin(), firstType.getShape().end());
int64_t concatDimSize = 0;
for (Value input : inputs)
concatDimSize += cast<RankedTensorType>(input.getType()).getDimSize(axis);
outputShape[axis] = concatDimSize;
auto resultType = RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
if (llvm::all_of(inputs, isHostFoldableValue))
return createSpatConcat(rewriter, loc, axis, inputs);
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {resultType}, {}, inputs, [&](ValueRange args) {
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, axis, args));
});
return concatCompute.getResult(0);
}
static Value buildReduceMeanKeepdims(Value input,
ArrayRef<bool> reducedAxes,
int64_t axis,
RankedTensorType leafType,
ConversionPatternRewriter& rewriter,
Location loc) {
int64_t rank = cast<RankedTensorType>(input.getType()).getRank();
if (axis == rank)
return createAverageCompute(input, leafType, rewriter, loc);
if (reducedAxes[axis])
return buildReduceMeanKeepdims(input, reducedAxes, axis + 1, leafType, rewriter, loc);
SmallVector<Value> slices = sliceTensor(input, axis, /*sliceSize=*/1, rewriter, loc);
SmallVector<Value> reducedSlices;
reducedSlices.reserve(slices.size());
for (Value slice : slices)
reducedSlices.push_back(buildReduceMeanKeepdims(slice, reducedAxes, axis + 1, leafType, rewriter, loc));
return concatValues(reducedSlices, axis, rewriter, loc);
}
static Value squeezeReducedAxes(Value keepdimsValue,
RankedTensorType resultType,
ArrayRef<bool> reducedAxes,
ConversionPatternRewriter& rewriter,
Location loc) {
if (resultType.getRank() == 0) {
SmallVector<Value> indices(cast<RankedTensorType>(keepdimsValue.getType()).getRank(),
arith::ConstantIndexOp::create(rewriter, loc, 0));
Value element = tensor::ExtractOp::create(rewriter, loc, keepdimsValue, indices);
return tensor::FromElementsOp::create(rewriter, loc, resultType, ValueRange {element});
}
auto reassociation = buildCollapseReassociation(reducedAxes);
if (isHostFoldableValue(keepdimsValue))
return tensor::CollapseShapeOp::create(rewriter, loc, resultType, keepdimsValue, reassociation).getResult();
auto squeezeCompute =
createSpatCompute<1>(rewriter, loc, TypeRange {resultType}, {}, ValueRange {keepdimsValue}, [&](Value input) {
Value collapsed = tensor::CollapseShapeOp::create(rewriter, loc, resultType, input, reassociation);
spatial::SpatYieldOp::create(rewriter, loc, collapsed);
});
return squeezeCompute.getResult(0);
}
struct ReduceMeanToSpatialCompute : OpConversionPattern<ONNXReduceMeanV13Op> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(ONNXReduceMeanV13Op reduceMeanOp,
ONNXReduceMeanV13OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
auto inputType = dyn_cast<RankedTensorType>(adaptor.getData().getType());
auto resultType = dyn_cast<RankedTensorType>(reduceMeanOp.getReduced().getType());
if (!inputType || !resultType || !inputType.hasStaticShape() || !resultType.hasStaticShape())
return failure();
SmallVector<int64_t> axes = normalizeAxes(reduceMeanOp.getAxesAttr(), inputType.getRank());
SmallVector<bool> reducedAxes = buildReducedAxesMask(axes, inputType.getRank());
if (reducedAxes.empty() && inputType.getRank() != 0)
return failure();
Location loc = reduceMeanOp.getLoc();
RankedTensorType leafType = getAllOnesType(inputType, resultType.getElementType());
Value reducedKeepdims =
buildReduceMeanKeepdims(adaptor.getData(), reducedAxes, /*axis=*/0, leafType, rewriter, loc);
if (reduceMeanOp.getKeepdims() != 0) {
rewriter.replaceOp(reduceMeanOp, reducedKeepdims);
return success();
}
Value reduced = squeezeReducedAxes(reducedKeepdims, resultType, reducedAxes, rewriter, loc);
rewriter.replaceOp(reduceMeanOp, reduced);
return success();
}
};
} // namespace
void populateReduceMeanPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.add<ReduceMeanToSpatialCompute>(ctx);
}
} // namespace onnx_mlir