add support for operations: reduceMean, add, mul, div, sigmoid
Some checks failed
Validate Operations / validate-operations (push) Failing after 51m52s
Some checks failed
Validate Operations / validate-operations (push) Failing after 51m52s
This commit is contained in:
163
src/PIM/Conversion/ONNXToSpatial/Patterns/Math/ReduceMean.cpp
Normal file
163
src/PIM/Conversion/ONNXToSpatial/Patterns/Math/ReduceMean.cpp
Normal file
@@ -0,0 +1,163 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
static SmallVector<int64_t> normalizeAxes(ArrayAttr axesAttr, int64_t rank) {
|
||||
SmallVector<int64_t> normalizedAxes;
|
||||
if (!axesAttr) {
|
||||
normalizedAxes.reserve(rank);
|
||||
for (int64_t axis = 0; axis < rank; axis++)
|
||||
normalizedAxes.push_back(axis);
|
||||
return normalizedAxes;
|
||||
}
|
||||
|
||||
normalizedAxes.reserve(axesAttr.size());
|
||||
for (Attribute attr : axesAttr) {
|
||||
int64_t axis = cast<IntegerAttr>(attr).getInt();
|
||||
normalizedAxes.push_back(axis >= 0 ? axis : rank + axis);
|
||||
}
|
||||
|
||||
llvm::sort(normalizedAxes);
|
||||
normalizedAxes.erase(std::unique(normalizedAxes.begin(), normalizedAxes.end()), normalizedAxes.end());
|
||||
return normalizedAxes;
|
||||
}
|
||||
|
||||
static SmallVector<bool> buildReducedAxesMask(ArrayRef<int64_t> axes, int64_t rank) {
|
||||
SmallVector<bool> reducedAxes(rank, false);
|
||||
for (int64_t axis : axes) {
|
||||
if (axis < 0 || axis >= rank)
|
||||
return {};
|
||||
reducedAxes[axis] = true;
|
||||
}
|
||||
return reducedAxes;
|
||||
}
|
||||
|
||||
static RankedTensorType getAllOnesType(RankedTensorType inputType, Type elementType) {
|
||||
return RankedTensorType::get(SmallVector<int64_t>(inputType.getRank(), 1), elementType);
|
||||
}
|
||||
|
||||
static SmallVector<ReassociationIndices> buildCollapseReassociation(ArrayRef<bool> reducedAxes) {
|
||||
SmallVector<ReassociationIndices> reassociation;
|
||||
ReassociationIndices currentGroup;
|
||||
|
||||
for (auto [axis, isReduced] : llvm::enumerate(reducedAxes)) {
|
||||
currentGroup.push_back(axis);
|
||||
if (!isReduced) {
|
||||
reassociation.push_back(currentGroup);
|
||||
currentGroup.clear();
|
||||
}
|
||||
}
|
||||
|
||||
if (!currentGroup.empty()) {
|
||||
if (reassociation.empty())
|
||||
reassociation.push_back(std::move(currentGroup));
|
||||
else
|
||||
reassociation.back().append(currentGroup.begin(), currentGroup.end());
|
||||
}
|
||||
|
||||
return reassociation;
|
||||
}
|
||||
|
||||
static Value createAverageCompute(Value input,
|
||||
RankedTensorType resultType,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
constexpr size_t numInputs = 1;
|
||||
auto computeOp = createSpatCompute<numInputs>(rewriter, loc, resultType, {}, ValueRange {input}, [&](Value x) {
|
||||
auto avgOp = spatial::SpatVAvgOp::create(rewriter, loc, resultType, x);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, avgOp.getResult());
|
||||
});
|
||||
return computeOp.getResult(0);
|
||||
}
|
||||
|
||||
static Value buildReduceMeanKeepdims(Value input,
|
||||
ArrayRef<bool> reducedAxes,
|
||||
int64_t axis,
|
||||
RankedTensorType leafType,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
int64_t rank = cast<RankedTensorType>(input.getType()).getRank();
|
||||
if (axis == rank)
|
||||
return createAverageCompute(input, leafType, rewriter, loc);
|
||||
|
||||
if (reducedAxes[axis])
|
||||
return buildReduceMeanKeepdims(input, reducedAxes, axis + 1, leafType, rewriter, loc);
|
||||
|
||||
SmallVector<Value> slices = sliceTensor(input, axis, /*sliceSize=*/1, rewriter, loc);
|
||||
SmallVector<Value> reducedSlices;
|
||||
reducedSlices.reserve(slices.size());
|
||||
for (Value slice : slices)
|
||||
reducedSlices.push_back(buildReduceMeanKeepdims(slice, reducedAxes, axis + 1, leafType, rewriter, loc));
|
||||
|
||||
return reducedSlices.size() == 1 ? reducedSlices.front()
|
||||
: tensor::ConcatOp::create(rewriter, loc, axis, reducedSlices).getResult();
|
||||
}
|
||||
|
||||
static Value squeezeReducedAxes(Value keepdimsValue,
|
||||
RankedTensorType resultType,
|
||||
ArrayRef<bool> reducedAxes,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
if (resultType.getRank() == 0) {
|
||||
SmallVector<Value> indices(cast<RankedTensorType>(keepdimsValue.getType()).getRank(),
|
||||
arith::ConstantIndexOp::create(rewriter, loc, 0));
|
||||
Value element = tensor::ExtractOp::create(rewriter, loc, keepdimsValue, indices);
|
||||
return tensor::FromElementsOp::create(rewriter, loc, resultType, ValueRange {element});
|
||||
}
|
||||
|
||||
return tensor::CollapseShapeOp::create(
|
||||
rewriter, loc, resultType, keepdimsValue, buildCollapseReassociation(reducedAxes))
|
||||
.getResult();
|
||||
}
|
||||
|
||||
struct ReduceMeanToSpatialCompute : OpConversionPattern<ONNXReduceMeanV13Op> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ONNXReduceMeanV13Op reduceMeanOp,
|
||||
ONNXReduceMeanV13OpAdaptor adaptor,
|
||||
ConversionPatternRewriter& rewriter) const override {
|
||||
auto inputType = dyn_cast<RankedTensorType>(adaptor.getData().getType());
|
||||
auto resultType = dyn_cast<RankedTensorType>(reduceMeanOp.getReduced().getType());
|
||||
if (!inputType || !resultType || !inputType.hasStaticShape() || !resultType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
SmallVector<int64_t> axes = normalizeAxes(reduceMeanOp.getAxesAttr(), inputType.getRank());
|
||||
SmallVector<bool> reducedAxes = buildReducedAxesMask(axes, inputType.getRank());
|
||||
if (reducedAxes.empty() && inputType.getRank() != 0)
|
||||
return failure();
|
||||
|
||||
Location loc = reduceMeanOp.getLoc();
|
||||
RankedTensorType leafType = getAllOnesType(inputType, resultType.getElementType());
|
||||
Value reducedKeepdims = buildReduceMeanKeepdims(adaptor.getData(), reducedAxes, /*axis=*/0, leafType, rewriter, loc);
|
||||
|
||||
if (reduceMeanOp.getKeepdims() != 0) {
|
||||
rewriter.replaceOp(reduceMeanOp, reducedKeepdims);
|
||||
return success();
|
||||
}
|
||||
|
||||
Value reduced = squeezeReducedAxes(reducedKeepdims, resultType, reducedAxes, rewriter, loc);
|
||||
rewriter.replaceOp(reduceMeanOp, reduced);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void populateReduceMeanPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.add<ReduceMeanToSpatialCompute>(ctx);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
Reference in New Issue
Block a user