#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/SmallVector.h" #include #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" using namespace mlir; namespace onnx_mlir { namespace { static SmallVector normalizeAxes(ArrayAttr axesAttr, int64_t rank) { SmallVector normalizedAxes; if (!axesAttr) { normalizedAxes.reserve(rank); for (int64_t axis = 0; axis < rank; axis++) normalizedAxes.push_back(axis); return normalizedAxes; } normalizedAxes.reserve(axesAttr.size()); for (Attribute attr : axesAttr) { int64_t axis = cast(attr).getInt(); normalizedAxes.push_back(axis >= 0 ? axis : rank + axis); } llvm::sort(normalizedAxes); normalizedAxes.erase(std::unique(normalizedAxes.begin(), normalizedAxes.end()), normalizedAxes.end()); return normalizedAxes; } static SmallVector buildReducedAxesMask(ArrayRef axes, int64_t rank) { SmallVector reducedAxes(rank, false); for (int64_t axis : axes) { if (axis < 0 || axis >= rank) return {}; reducedAxes[axis] = true; } return reducedAxes; } static RankedTensorType getAllOnesType(RankedTensorType inputType, Type elementType) { return RankedTensorType::get(SmallVector(inputType.getRank(), 1), elementType); } static SmallVector buildCollapseReassociation(ArrayRef reducedAxes) { SmallVector reassociation; ReassociationIndices currentGroup; for (auto [axis, isReduced] : llvm::enumerate(reducedAxes)) { currentGroup.push_back(axis); if (!isReduced) { reassociation.push_back(currentGroup); currentGroup.clear(); } } if (!currentGroup.empty()) { if (reassociation.empty()) reassociation.push_back(std::move(currentGroup)); else reassociation.back().append(currentGroup.begin(), currentGroup.end()); } return reassociation; } static Value createAverageCompute(Value input, RankedTensorType resultType, ConversionPatternRewriter& rewriter, Location loc) { constexpr size_t numInputs = 1; auto computeOp = createSpatCompute(rewriter, loc, resultType, {}, ValueRange {input}, [&](Value x) { auto avgOp = spatial::SpatVAvgOp::create(rewriter, loc, resultType, x); spatial::SpatYieldOp::create(rewriter, loc, avgOp.getResult()); }); return computeOp.getResult(0); } static Value buildReduceMeanKeepdims(Value input, ArrayRef reducedAxes, int64_t axis, RankedTensorType leafType, ConversionPatternRewriter& rewriter, Location loc) { int64_t rank = cast(input.getType()).getRank(); if (axis == rank) return createAverageCompute(input, leafType, rewriter, loc); if (reducedAxes[axis]) return buildReduceMeanKeepdims(input, reducedAxes, axis + 1, leafType, rewriter, loc); SmallVector slices = sliceTensor(input, axis, /*sliceSize=*/1, rewriter, loc); SmallVector reducedSlices; reducedSlices.reserve(slices.size()); for (Value slice : slices) reducedSlices.push_back(buildReduceMeanKeepdims(slice, reducedAxes, axis + 1, leafType, rewriter, loc)); return reducedSlices.size() == 1 ? reducedSlices.front() : tensor::ConcatOp::create(rewriter, loc, axis, reducedSlices).getResult(); } static Value squeezeReducedAxes(Value keepdimsValue, RankedTensorType resultType, ArrayRef reducedAxes, ConversionPatternRewriter& rewriter, Location loc) { if (resultType.getRank() == 0) { SmallVector indices(cast(keepdimsValue.getType()).getRank(), arith::ConstantIndexOp::create(rewriter, loc, 0)); Value element = tensor::ExtractOp::create(rewriter, loc, keepdimsValue, indices); return tensor::FromElementsOp::create(rewriter, loc, resultType, ValueRange {element}); } return tensor::CollapseShapeOp::create( rewriter, loc, resultType, keepdimsValue, buildCollapseReassociation(reducedAxes)) .getResult(); } struct ReduceMeanToSpatialCompute : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(ONNXReduceMeanV13Op reduceMeanOp, ONNXReduceMeanV13OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { auto inputType = dyn_cast(adaptor.getData().getType()); auto resultType = dyn_cast(reduceMeanOp.getReduced().getType()); if (!inputType || !resultType || !inputType.hasStaticShape() || !resultType.hasStaticShape()) return failure(); SmallVector axes = normalizeAxes(reduceMeanOp.getAxesAttr(), inputType.getRank()); SmallVector reducedAxes = buildReducedAxesMask(axes, inputType.getRank()); if (reducedAxes.empty() && inputType.getRank() != 0) return failure(); Location loc = reduceMeanOp.getLoc(); RankedTensorType leafType = getAllOnesType(inputType, resultType.getElementType()); Value reducedKeepdims = buildReduceMeanKeepdims(adaptor.getData(), reducedAxes, /*axis=*/0, leafType, rewriter, loc); if (reduceMeanOp.getKeepdims() != 0) { rewriter.replaceOp(reduceMeanOp, reducedKeepdims); return success(); } Value reduced = squeezeReducedAxes(reducedKeepdims, resultType, reducedAxes, rewriter, loc); rewriter.replaceOp(reduceMeanOp, reduced); return success(); } }; } // namespace void populateReduceMeanPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.add(ctx); } } // namespace onnx_mlir