add support for operations: reduceMean, add, mul, div, sigmoid
Some checks failed
Validate Operations / validate-operations (push) Failing after 51m52s
Some checks failed
Validate Operations / validate-operations (push) Failing after 51m52s
This commit is contained in:
204
src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Elementwise.cpp
Normal file
204
src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Elementwise.cpp
Normal file
@@ -0,0 +1,204 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
static SmallVector<int64_t> computeRowMajorStrides(ArrayRef<int64_t> shape) {
|
||||
SmallVector<int64_t> strides(shape.size(), 1);
|
||||
for (int64_t i = static_cast<int64_t>(shape.size()) - 2; i >= 0; --i)
|
||||
strides[i] = strides[i + 1] * shape[i + 1];
|
||||
return strides;
|
||||
}
|
||||
|
||||
static DenseElementsAttr getDenseConstantAttr(Value value) {
|
||||
if (auto constantOp = value.getDefiningOp<arith::ConstantOp>())
|
||||
return dyn_cast<DenseElementsAttr>(constantOp.getValue());
|
||||
|
||||
if (auto constantOp = value.getDefiningOp<ONNXConstantOp>())
|
||||
return dyn_cast_or_null<DenseElementsAttr>(constantOp.getValueAttr());
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static FailureOr<Value> materializeBroadcastedConstantTensor(Value value,
|
||||
RankedTensorType resultType,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
auto denseAttr = getDenseConstantAttr(value);
|
||||
if (!denseAttr)
|
||||
return failure();
|
||||
|
||||
auto sourceType = dyn_cast<RankedTensorType>(denseAttr.getType());
|
||||
if (!sourceType || !sourceType.hasStaticShape() || !resultType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
if (sourceType == resultType)
|
||||
return value;
|
||||
|
||||
ArrayRef<int64_t> sourceShape = sourceType.getShape();
|
||||
ArrayRef<int64_t> resultShape = resultType.getShape();
|
||||
if (sourceShape.size() > resultShape.size())
|
||||
return failure();
|
||||
|
||||
const int64_t rankOffset = static_cast<int64_t>(resultShape.size() - sourceShape.size());
|
||||
for (int64_t i = 0; i < static_cast<int64_t>(resultShape.size()); ++i) {
|
||||
const int64_t sourceIndex = i - rankOffset;
|
||||
const int64_t sourceDim = sourceIndex < 0 ? 1 : sourceShape[sourceIndex];
|
||||
const int64_t resultDim = resultShape[i];
|
||||
if (sourceDim != 1 && sourceDim != resultDim)
|
||||
return failure();
|
||||
}
|
||||
|
||||
SmallVector<Attribute> sourceValues(denseAttr.getValues<Attribute>());
|
||||
SmallVector<int64_t> sourceStrides = computeRowMajorStrides(sourceShape);
|
||||
SmallVector<int64_t> resultStrides = computeRowMajorStrides(resultShape);
|
||||
|
||||
SmallVector<Attribute> resultValues;
|
||||
resultValues.reserve(resultType.getNumElements());
|
||||
|
||||
for (int64_t flatIndex = 0; flatIndex < resultType.getNumElements(); ++flatIndex) {
|
||||
int64_t remaining = flatIndex;
|
||||
int64_t sourceFlatIndex = 0;
|
||||
|
||||
for (int64_t i = 0; i < static_cast<int64_t>(resultShape.size()); ++i) {
|
||||
const int64_t resultIndex = resultStrides.empty() ? 0 : remaining / resultStrides[i];
|
||||
remaining = resultStrides.empty() ? 0 : remaining % resultStrides[i];
|
||||
|
||||
const int64_t sourceIndex = i - rankOffset;
|
||||
if (sourceIndex < 0)
|
||||
continue;
|
||||
|
||||
const int64_t sourceDim = sourceShape[sourceIndex];
|
||||
const int64_t mappedIndex = sourceDim == 1 ? 0 : resultIndex;
|
||||
sourceFlatIndex += mappedIndex * sourceStrides[sourceIndex];
|
||||
}
|
||||
|
||||
resultValues.push_back(sourceValues[sourceFlatIndex]);
|
||||
}
|
||||
|
||||
auto broadcastedAttr = DenseElementsAttr::get(resultType, resultValues);
|
||||
return arith::ConstantOp::create(rewriter, loc, resultType, broadcastedAttr).getResult();
|
||||
}
|
||||
|
||||
static FailureOr<Value> prepareElementwiseOperand(Value value,
|
||||
RankedTensorType resultType,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
auto valueType = dyn_cast<RankedTensorType>(value.getType());
|
||||
if (!valueType || !valueType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
if (valueType == resultType)
|
||||
return value;
|
||||
|
||||
return materializeBroadcastedConstantTensor(value, resultType, rewriter, loc);
|
||||
}
|
||||
|
||||
static FailureOr<Value> materializeReciprocalTensor(Value value,
|
||||
RankedTensorType resultType,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
auto broadcastedValue = materializeBroadcastedConstantTensor(value, resultType, rewriter, loc);
|
||||
if (failed(broadcastedValue))
|
||||
return failure();
|
||||
|
||||
auto denseAttr = dyn_cast<DenseFPElementsAttr>(getDenseConstantAttr(*broadcastedValue));
|
||||
if (!denseAttr)
|
||||
return failure();
|
||||
|
||||
SmallVector<APFloat> reciprocalValues;
|
||||
reciprocalValues.reserve(denseAttr.getNumElements());
|
||||
for (const APFloat& valueAttr : denseAttr.getValues<APFloat>()) {
|
||||
APFloat reciprocal(valueAttr.getSemantics(), 1);
|
||||
auto status = reciprocal.divide(valueAttr, APFloat::rmNearestTiesToEven);
|
||||
if (status & APFloat::opInvalidOp)
|
||||
return failure();
|
||||
reciprocalValues.push_back(std::move(reciprocal));
|
||||
}
|
||||
|
||||
auto reciprocalAttr = DenseFPElementsAttr::get(resultType, reciprocalValues);
|
||||
return arith::ConstantOp::create(rewriter, loc, resultType, reciprocalAttr).getResult();
|
||||
}
|
||||
|
||||
template <typename OnnxOp, typename SpatialOp>
|
||||
struct BinaryElementwiseToSpatialCompute : OpConversionPattern<OnnxOp> {
|
||||
using OpConversionPattern<OnnxOp>::OpConversionPattern;
|
||||
using Adaptor = typename OnnxOp::Adaptor;
|
||||
|
||||
LogicalResult matchAndRewrite(OnnxOp op, Adaptor adaptor, ConversionPatternRewriter& rewriter) const override {
|
||||
auto resultType = dyn_cast<RankedTensorType>(op->getResult(0).getType());
|
||||
if (!resultType || !resultType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
Location loc = op.getLoc();
|
||||
auto lhs = prepareElementwiseOperand(adaptor.getOperands()[0], resultType, rewriter, loc);
|
||||
if (failed(lhs))
|
||||
return failure();
|
||||
|
||||
auto rhs = prepareElementwiseOperand(adaptor.getOperands()[1], resultType, rewriter, loc);
|
||||
if (failed(rhs))
|
||||
return failure();
|
||||
|
||||
constexpr size_t numInputs = 2;
|
||||
auto computeOp =
|
||||
createSpatCompute<numInputs>(rewriter, loc, resultType, {}, ValueRange {*lhs, *rhs}, [&](Value x, Value y) {
|
||||
auto loweredOp = SpatialOp::create(rewriter, loc, resultType, x, y);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, loweredOp.getResult());
|
||||
});
|
||||
|
||||
rewriter.replaceOp(op, computeOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct DivToSpatialCompute : OpConversionPattern<ONNXDivOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(ONNXDivOp op, ONNXDivOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override {
|
||||
auto resultType = dyn_cast<RankedTensorType>(op.getResult().getType());
|
||||
if (!resultType || !resultType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
Location loc = op.getLoc();
|
||||
auto lhs = prepareElementwiseOperand(adaptor.getA(), resultType, rewriter, loc);
|
||||
if (failed(lhs))
|
||||
return failure();
|
||||
|
||||
auto reciprocalRhs = materializeReciprocalTensor(adaptor.getB(), resultType, rewriter, loc);
|
||||
if (failed(reciprocalRhs))
|
||||
return failure();
|
||||
|
||||
constexpr size_t numInputs = 2;
|
||||
auto computeOp = createSpatCompute<numInputs>(
|
||||
rewriter, loc, resultType, {}, ValueRange {*lhs, *reciprocalRhs}, [&](Value x, Value reciprocal) {
|
||||
auto mulOp = spatial::SpatVMulOp::create(rewriter, loc, resultType, x, reciprocal);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, mulOp.getResult());
|
||||
});
|
||||
|
||||
rewriter.replaceOp(op, computeOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void populateElementwisePatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.add<BinaryElementwiseToSpatialCompute<ONNXAddOp, spatial::SpatVAddOp>>(ctx);
|
||||
patterns.add<BinaryElementwiseToSpatialCompute<ONNXMulOp, spatial::SpatVMulOp>>(ctx);
|
||||
patterns.add<DivToSpatialCompute>(ctx);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
Reference in New Issue
Block a user