#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/SmallVector.h" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" using namespace mlir; namespace onnx_mlir { namespace { static SmallVector computeRowMajorStrides(ArrayRef shape) { SmallVector strides(shape.size(), 1); for (int64_t i = static_cast(shape.size()) - 2; i >= 0; --i) strides[i] = strides[i + 1] * shape[i + 1]; return strides; } static DenseElementsAttr getDenseConstantAttr(Value value) { if (auto constantOp = value.getDefiningOp()) return dyn_cast(constantOp.getValue()); if (auto constantOp = value.getDefiningOp()) return dyn_cast_or_null(constantOp.getValueAttr()); return nullptr; } static FailureOr materializeBroadcastedConstantTensor(Value value, RankedTensorType resultType, ConversionPatternRewriter& rewriter, Location loc) { auto denseAttr = getDenseConstantAttr(value); if (!denseAttr) return failure(); auto sourceType = dyn_cast(denseAttr.getType()); if (!sourceType || !sourceType.hasStaticShape() || !resultType.hasStaticShape()) return failure(); if (sourceType == resultType) return value; ArrayRef sourceShape = sourceType.getShape(); ArrayRef resultShape = resultType.getShape(); if (sourceShape.size() > resultShape.size()) return failure(); const int64_t rankOffset = static_cast(resultShape.size() - sourceShape.size()); for (int64_t i = 0; i < static_cast(resultShape.size()); ++i) { const int64_t sourceIndex = i - rankOffset; const int64_t sourceDim = sourceIndex < 0 ? 1 : sourceShape[sourceIndex]; const int64_t resultDim = resultShape[i]; if (sourceDim != 1 && sourceDim != resultDim) return failure(); } SmallVector sourceValues(denseAttr.getValues()); SmallVector sourceStrides = computeRowMajorStrides(sourceShape); SmallVector resultStrides = computeRowMajorStrides(resultShape); SmallVector resultValues; resultValues.reserve(resultType.getNumElements()); for (int64_t flatIndex = 0; flatIndex < resultType.getNumElements(); ++flatIndex) { int64_t remaining = flatIndex; int64_t sourceFlatIndex = 0; for (int64_t i = 0; i < static_cast(resultShape.size()); ++i) { const int64_t resultIndex = resultStrides.empty() ? 0 : remaining / resultStrides[i]; remaining = resultStrides.empty() ? 0 : remaining % resultStrides[i]; const int64_t sourceIndex = i - rankOffset; if (sourceIndex < 0) continue; const int64_t sourceDim = sourceShape[sourceIndex]; const int64_t mappedIndex = sourceDim == 1 ? 0 : resultIndex; sourceFlatIndex += mappedIndex * sourceStrides[sourceIndex]; } resultValues.push_back(sourceValues[sourceFlatIndex]); } auto broadcastedAttr = DenseElementsAttr::get(resultType, resultValues); return arith::ConstantOp::create(rewriter, loc, resultType, broadcastedAttr).getResult(); } static FailureOr prepareElementwiseOperand(Value value, RankedTensorType resultType, ConversionPatternRewriter& rewriter, Location loc) { auto valueType = dyn_cast(value.getType()); if (!valueType || !valueType.hasStaticShape()) return failure(); if (valueType == resultType) return value; return materializeBroadcastedConstantTensor(value, resultType, rewriter, loc); } static FailureOr materializeReciprocalTensor(Value value, RankedTensorType resultType, ConversionPatternRewriter& rewriter, Location loc) { auto broadcastedValue = materializeBroadcastedConstantTensor(value, resultType, rewriter, loc); if (failed(broadcastedValue)) return failure(); auto denseAttr = dyn_cast(getDenseConstantAttr(*broadcastedValue)); if (!denseAttr) return failure(); SmallVector reciprocalValues; reciprocalValues.reserve(denseAttr.getNumElements()); for (const APFloat& valueAttr : denseAttr.getValues()) { APFloat reciprocal(valueAttr.getSemantics(), 1); auto status = reciprocal.divide(valueAttr, APFloat::rmNearestTiesToEven); if (status & APFloat::opInvalidOp) return failure(); reciprocalValues.push_back(std::move(reciprocal)); } auto reciprocalAttr = DenseFPElementsAttr::get(resultType, reciprocalValues); return arith::ConstantOp::create(rewriter, loc, resultType, reciprocalAttr).getResult(); } template struct BinaryElementwiseToSpatialCompute : OpConversionPattern { using OpConversionPattern::OpConversionPattern; using Adaptor = typename OnnxOp::Adaptor; LogicalResult matchAndRewrite(OnnxOp op, Adaptor adaptor, ConversionPatternRewriter& rewriter) const override { auto resultType = dyn_cast(op->getResult(0).getType()); if (!resultType || !resultType.hasStaticShape()) return failure(); Location loc = op.getLoc(); auto lhs = prepareElementwiseOperand(adaptor.getOperands()[0], resultType, rewriter, loc); if (failed(lhs)) return failure(); auto rhs = prepareElementwiseOperand(adaptor.getOperands()[1], resultType, rewriter, loc); if (failed(rhs)) return failure(); constexpr size_t numInputs = 2; auto computeOp = createSpatCompute(rewriter, loc, resultType, {}, ValueRange {*lhs, *rhs}, [&](Value x, Value y) { auto loweredOp = SpatialOp::create(rewriter, loc, resultType, x, y); spatial::SpatYieldOp::create(rewriter, loc, loweredOp.getResult()); }); rewriter.replaceOp(op, computeOp); return success(); } }; struct DivToSpatialCompute : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(ONNXDivOp op, ONNXDivOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { auto resultType = dyn_cast(op.getResult().getType()); if (!resultType || !resultType.hasStaticShape()) return failure(); Location loc = op.getLoc(); auto lhs = prepareElementwiseOperand(adaptor.getA(), resultType, rewriter, loc); if (failed(lhs)) return failure(); auto reciprocalRhs = materializeReciprocalTensor(adaptor.getB(), resultType, rewriter, loc); if (failed(reciprocalRhs)) return failure(); constexpr size_t numInputs = 2; auto computeOp = createSpatCompute( rewriter, loc, resultType, {}, ValueRange {*lhs, *reciprocalRhs}, [&](Value x, Value reciprocal) { auto mulOp = spatial::SpatVMulOp::create(rewriter, loc, resultType, x, reciprocal); spatial::SpatYieldOp::create(rewriter, loc, mulOp.getResult()); }); rewriter.replaceOp(op, computeOp); return success(); } }; } // namespace void populateElementwisePatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.add>(ctx); patterns.add>(ctx); patterns.add(ctx); } } // namespace onnx_mlir