add support for operations: reduceMean, add, mul, div, sigmoid
Validate Operations / validate-operations (push) Failing after 51m52s

This commit is contained in:
NiccoloN
2026-03-30 15:41:12 +02:00
parent 5e7114f517
commit 39830be888
32 changed files with 1057 additions and 224 deletions
@@ -0,0 +1,204 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static SmallVector<int64_t> computeRowMajorStrides(ArrayRef<int64_t> shape) {
SmallVector<int64_t> strides(shape.size(), 1);
for (int64_t i = static_cast<int64_t>(shape.size()) - 2; i >= 0; --i)
strides[i] = strides[i + 1] * shape[i + 1];
return strides;
}
static DenseElementsAttr getDenseConstantAttr(Value value) {
if (auto constantOp = value.getDefiningOp<arith::ConstantOp>())
return dyn_cast<DenseElementsAttr>(constantOp.getValue());
if (auto constantOp = value.getDefiningOp<ONNXConstantOp>())
return dyn_cast_or_null<DenseElementsAttr>(constantOp.getValueAttr());
return nullptr;
}
static FailureOr<Value> materializeBroadcastedConstantTensor(Value value,
RankedTensorType resultType,
ConversionPatternRewriter& rewriter,
Location loc) {
auto denseAttr = getDenseConstantAttr(value);
if (!denseAttr)
return failure();
auto sourceType = dyn_cast<RankedTensorType>(denseAttr.getType());
if (!sourceType || !sourceType.hasStaticShape() || !resultType.hasStaticShape())
return failure();
if (sourceType == resultType)
return value;
ArrayRef<int64_t> sourceShape = sourceType.getShape();
ArrayRef<int64_t> resultShape = resultType.getShape();
if (sourceShape.size() > resultShape.size())
return failure();
const int64_t rankOffset = static_cast<int64_t>(resultShape.size() - sourceShape.size());
for (int64_t i = 0; i < static_cast<int64_t>(resultShape.size()); ++i) {
const int64_t sourceIndex = i - rankOffset;
const int64_t sourceDim = sourceIndex < 0 ? 1 : sourceShape[sourceIndex];
const int64_t resultDim = resultShape[i];
if (sourceDim != 1 && sourceDim != resultDim)
return failure();
}
SmallVector<Attribute> sourceValues(denseAttr.getValues<Attribute>());
SmallVector<int64_t> sourceStrides = computeRowMajorStrides(sourceShape);
SmallVector<int64_t> resultStrides = computeRowMajorStrides(resultShape);
SmallVector<Attribute> resultValues;
resultValues.reserve(resultType.getNumElements());
for (int64_t flatIndex = 0; flatIndex < resultType.getNumElements(); ++flatIndex) {
int64_t remaining = flatIndex;
int64_t sourceFlatIndex = 0;
for (int64_t i = 0; i < static_cast<int64_t>(resultShape.size()); ++i) {
const int64_t resultIndex = resultStrides.empty() ? 0 : remaining / resultStrides[i];
remaining = resultStrides.empty() ? 0 : remaining % resultStrides[i];
const int64_t sourceIndex = i - rankOffset;
if (sourceIndex < 0)
continue;
const int64_t sourceDim = sourceShape[sourceIndex];
const int64_t mappedIndex = sourceDim == 1 ? 0 : resultIndex;
sourceFlatIndex += mappedIndex * sourceStrides[sourceIndex];
}
resultValues.push_back(sourceValues[sourceFlatIndex]);
}
auto broadcastedAttr = DenseElementsAttr::get(resultType, resultValues);
return arith::ConstantOp::create(rewriter, loc, resultType, broadcastedAttr).getResult();
}
static FailureOr<Value> prepareElementwiseOperand(Value value,
RankedTensorType resultType,
ConversionPatternRewriter& rewriter,
Location loc) {
auto valueType = dyn_cast<RankedTensorType>(value.getType());
if (!valueType || !valueType.hasStaticShape())
return failure();
if (valueType == resultType)
return value;
return materializeBroadcastedConstantTensor(value, resultType, rewriter, loc);
}
static FailureOr<Value> materializeReciprocalTensor(Value value,
RankedTensorType resultType,
ConversionPatternRewriter& rewriter,
Location loc) {
auto broadcastedValue = materializeBroadcastedConstantTensor(value, resultType, rewriter, loc);
if (failed(broadcastedValue))
return failure();
auto denseAttr = dyn_cast<DenseFPElementsAttr>(getDenseConstantAttr(*broadcastedValue));
if (!denseAttr)
return failure();
SmallVector<APFloat> reciprocalValues;
reciprocalValues.reserve(denseAttr.getNumElements());
for (const APFloat& valueAttr : denseAttr.getValues<APFloat>()) {
APFloat reciprocal(valueAttr.getSemantics(), 1);
auto status = reciprocal.divide(valueAttr, APFloat::rmNearestTiesToEven);
if (status & APFloat::opInvalidOp)
return failure();
reciprocalValues.push_back(std::move(reciprocal));
}
auto reciprocalAttr = DenseFPElementsAttr::get(resultType, reciprocalValues);
return arith::ConstantOp::create(rewriter, loc, resultType, reciprocalAttr).getResult();
}
template <typename OnnxOp, typename SpatialOp>
struct BinaryElementwiseToSpatialCompute : OpConversionPattern<OnnxOp> {
using OpConversionPattern<OnnxOp>::OpConversionPattern;
using Adaptor = typename OnnxOp::Adaptor;
LogicalResult matchAndRewrite(OnnxOp op, Adaptor adaptor, ConversionPatternRewriter& rewriter) const override {
auto resultType = dyn_cast<RankedTensorType>(op->getResult(0).getType());
if (!resultType || !resultType.hasStaticShape())
return failure();
Location loc = op.getLoc();
auto lhs = prepareElementwiseOperand(adaptor.getOperands()[0], resultType, rewriter, loc);
if (failed(lhs))
return failure();
auto rhs = prepareElementwiseOperand(adaptor.getOperands()[1], resultType, rewriter, loc);
if (failed(rhs))
return failure();
constexpr size_t numInputs = 2;
auto computeOp =
createSpatCompute<numInputs>(rewriter, loc, resultType, {}, ValueRange {*lhs, *rhs}, [&](Value x, Value y) {
auto loweredOp = SpatialOp::create(rewriter, loc, resultType, x, y);
spatial::SpatYieldOp::create(rewriter, loc, loweredOp.getResult());
});
rewriter.replaceOp(op, computeOp);
return success();
}
};
struct DivToSpatialCompute : OpConversionPattern<ONNXDivOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(ONNXDivOp op, ONNXDivOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override {
auto resultType = dyn_cast<RankedTensorType>(op.getResult().getType());
if (!resultType || !resultType.hasStaticShape())
return failure();
Location loc = op.getLoc();
auto lhs = prepareElementwiseOperand(adaptor.getA(), resultType, rewriter, loc);
if (failed(lhs))
return failure();
auto reciprocalRhs = materializeReciprocalTensor(adaptor.getB(), resultType, rewriter, loc);
if (failed(reciprocalRhs))
return failure();
constexpr size_t numInputs = 2;
auto computeOp = createSpatCompute<numInputs>(
rewriter, loc, resultType, {}, ValueRange {*lhs, *reciprocalRhs}, [&](Value x, Value reciprocal) {
auto mulOp = spatial::SpatVMulOp::create(rewriter, loc, resultType, x, reciprocal);
spatial::SpatYieldOp::create(rewriter, loc, mulOp.getResult());
});
rewriter.replaceOp(op, computeOp);
return success();
}
};
} // namespace
void populateElementwisePatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.add<BinaryElementwiseToSpatialCompute<ONNXAddOp, spatial::SpatVAddOp>>(ctx);
patterns.add<BinaryElementwiseToSpatialCompute<ONNXMulOp, spatial::SpatVMulOp>>(ctx);
patterns.add<DivToSpatialCompute>(ctx);
}
} // namespace onnx_mlir
@@ -0,0 +1,163 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVector.h"
#include <algorithm>
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static SmallVector<int64_t> normalizeAxes(ArrayAttr axesAttr, int64_t rank) {
SmallVector<int64_t> normalizedAxes;
if (!axesAttr) {
normalizedAxes.reserve(rank);
for (int64_t axis = 0; axis < rank; axis++)
normalizedAxes.push_back(axis);
return normalizedAxes;
}
normalizedAxes.reserve(axesAttr.size());
for (Attribute attr : axesAttr) {
int64_t axis = cast<IntegerAttr>(attr).getInt();
normalizedAxes.push_back(axis >= 0 ? axis : rank + axis);
}
llvm::sort(normalizedAxes);
normalizedAxes.erase(std::unique(normalizedAxes.begin(), normalizedAxes.end()), normalizedAxes.end());
return normalizedAxes;
}
static SmallVector<bool> buildReducedAxesMask(ArrayRef<int64_t> axes, int64_t rank) {
SmallVector<bool> reducedAxes(rank, false);
for (int64_t axis : axes) {
if (axis < 0 || axis >= rank)
return {};
reducedAxes[axis] = true;
}
return reducedAxes;
}
static RankedTensorType getAllOnesType(RankedTensorType inputType, Type elementType) {
return RankedTensorType::get(SmallVector<int64_t>(inputType.getRank(), 1), elementType);
}
static SmallVector<ReassociationIndices> buildCollapseReassociation(ArrayRef<bool> reducedAxes) {
SmallVector<ReassociationIndices> reassociation;
ReassociationIndices currentGroup;
for (auto [axis, isReduced] : llvm::enumerate(reducedAxes)) {
currentGroup.push_back(axis);
if (!isReduced) {
reassociation.push_back(currentGroup);
currentGroup.clear();
}
}
if (!currentGroup.empty()) {
if (reassociation.empty())
reassociation.push_back(std::move(currentGroup));
else
reassociation.back().append(currentGroup.begin(), currentGroup.end());
}
return reassociation;
}
static Value createAverageCompute(Value input,
RankedTensorType resultType,
ConversionPatternRewriter& rewriter,
Location loc) {
constexpr size_t numInputs = 1;
auto computeOp = createSpatCompute<numInputs>(rewriter, loc, resultType, {}, ValueRange {input}, [&](Value x) {
auto avgOp = spatial::SpatVAvgOp::create(rewriter, loc, resultType, x);
spatial::SpatYieldOp::create(rewriter, loc, avgOp.getResult());
});
return computeOp.getResult(0);
}
static Value buildReduceMeanKeepdims(Value input,
ArrayRef<bool> reducedAxes,
int64_t axis,
RankedTensorType leafType,
ConversionPatternRewriter& rewriter,
Location loc) {
int64_t rank = cast<RankedTensorType>(input.getType()).getRank();
if (axis == rank)
return createAverageCompute(input, leafType, rewriter, loc);
if (reducedAxes[axis])
return buildReduceMeanKeepdims(input, reducedAxes, axis + 1, leafType, rewriter, loc);
SmallVector<Value> slices = sliceTensor(input, axis, /*sliceSize=*/1, rewriter, loc);
SmallVector<Value> reducedSlices;
reducedSlices.reserve(slices.size());
for (Value slice : slices)
reducedSlices.push_back(buildReduceMeanKeepdims(slice, reducedAxes, axis + 1, leafType, rewriter, loc));
return reducedSlices.size() == 1 ? reducedSlices.front()
: tensor::ConcatOp::create(rewriter, loc, axis, reducedSlices).getResult();
}
static Value squeezeReducedAxes(Value keepdimsValue,
RankedTensorType resultType,
ArrayRef<bool> reducedAxes,
ConversionPatternRewriter& rewriter,
Location loc) {
if (resultType.getRank() == 0) {
SmallVector<Value> indices(cast<RankedTensorType>(keepdimsValue.getType()).getRank(),
arith::ConstantIndexOp::create(rewriter, loc, 0));
Value element = tensor::ExtractOp::create(rewriter, loc, keepdimsValue, indices);
return tensor::FromElementsOp::create(rewriter, loc, resultType, ValueRange {element});
}
return tensor::CollapseShapeOp::create(
rewriter, loc, resultType, keepdimsValue, buildCollapseReassociation(reducedAxes))
.getResult();
}
struct ReduceMeanToSpatialCompute : OpConversionPattern<ONNXReduceMeanV13Op> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(ONNXReduceMeanV13Op reduceMeanOp,
ONNXReduceMeanV13OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
auto inputType = dyn_cast<RankedTensorType>(adaptor.getData().getType());
auto resultType = dyn_cast<RankedTensorType>(reduceMeanOp.getReduced().getType());
if (!inputType || !resultType || !inputType.hasStaticShape() || !resultType.hasStaticShape())
return failure();
SmallVector<int64_t> axes = normalizeAxes(reduceMeanOp.getAxesAttr(), inputType.getRank());
SmallVector<bool> reducedAxes = buildReducedAxesMask(axes, inputType.getRank());
if (reducedAxes.empty() && inputType.getRank() != 0)
return failure();
Location loc = reduceMeanOp.getLoc();
RankedTensorType leafType = getAllOnesType(inputType, resultType.getElementType());
Value reducedKeepdims = buildReduceMeanKeepdims(adaptor.getData(), reducedAxes, /*axis=*/0, leafType, rewriter, loc);
if (reduceMeanOp.getKeepdims() != 0) {
rewriter.replaceOp(reduceMeanOp, reducedKeepdims);
return success();
}
Value reduced = squeezeReducedAxes(reducedKeepdims, resultType, reducedAxes, rewriter, loc);
rewriter.replaceOp(reduceMeanOp, reduced);
return success();
}
};
} // namespace
void populateReduceMeanPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.add<ReduceMeanToSpatialCompute>(ctx);
}
} // namespace onnx_mlir
@@ -0,0 +1,36 @@
#include "mlir/Transforms/DialectConversion.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
struct SigmoidToSpatialCompute : OpConversionPattern<ONNXSigmoidOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(ONNXSigmoidOp sigmoidOp,
ONNXSigmoidOpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
Location loc = sigmoidOp.getLoc();
Type resultType = sigmoidOp.getResult().getType();
constexpr size_t numInputs = 1;
auto computeOp = createSpatCompute<numInputs>(rewriter, loc, resultType, {}, adaptor.getX(), [&](Value x) {
auto spatSigmoidOp = spatial::SpatSigmoidOp::create(rewriter, loc, resultType, x);
spatial::SpatYieldOp::create(rewriter, loc, spatSigmoidOp.getResult());
});
rewriter.replaceOp(sigmoidOp, computeOp);
return success();
}
};
} // namespace
void populateSigmoidPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.add<SigmoidToSpatialCompute>(ctx);
}
} // namespace onnx_mlir