#pragma once #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/Operation.h" #include "mlir/Support/LogicalResult.h" #include "llvm/ADT/StringRef.h" #include #include #include #include namespace onnx_mlir::pim { mlir::InFlightDiagnostic emitCheckedArithmeticError(mlir::Operation* anchor, llvm::StringRef fieldName, llvm::StringRef message); mlir::InFlightDiagnostic emitCheckedArithmeticError(mlir::Location loc, llvm::StringRef fieldName, llvm::StringRef message); template mlir::FailureOr checkedCast(From value, mlir::Operation* anchor, llvm::StringRef fieldName) { static_assert(std::is_integral_v && std::is_integral_v, "checkedCast requires integral types"); using ToLimits = std::numeric_limits; if constexpr (std::is_signed_v == std::is_signed_v) { if (value < static_cast(ToLimits::min()) || value > static_cast(ToLimits::max())) { emitCheckedArithmeticError(anchor, fieldName, "is outside representable range"); return mlir::failure(); } } else if constexpr (std::is_signed_v) { using UnsignedFrom = std::make_unsigned_t; using UnsignedTo = std::make_unsigned_t; if (value < 0 || static_cast(value) > static_cast(ToLimits::max())) { emitCheckedArithmeticError(anchor, fieldName, "is outside representable range"); return mlir::failure(); } } else { using UnsignedFrom = std::make_unsigned_t; using UnsignedTo = std::conditional_t, std::make_unsigned_t, To>; if (static_cast(value) > static_cast(ToLimits::max())) { emitCheckedArithmeticError(anchor, fieldName, "is outside representable range"); return mlir::failure(); } } return static_cast(value); } template mlir::FailureOr checkedAdd(UInt lhs, UInt rhs, mlir::Operation* anchor, llvm::StringRef fieldName) { static_assert(std::is_integral_v && std::is_unsigned_v, "checkedAdd requires unsigned integral types"); if (rhs > std::numeric_limits::max() - lhs) { emitCheckedArithmeticError(anchor, fieldName, "addition overflow"); return mlir::failure(); } return lhs + rhs; } template mlir::FailureOr checkedMul(UInt lhs, UInt rhs, mlir::Operation* anchor, llvm::StringRef fieldName) { static_assert(std::is_integral_v && std::is_unsigned_v, "checkedMul requires unsigned integral types"); if (lhs != 0 && rhs > std::numeric_limits::max() / lhs) { emitCheckedArithmeticError(anchor, fieldName, "multiplication overflow"); return mlir::failure(); } return lhs * rhs; } mlir::FailureOr checkedI32(int64_t value, mlir::Operation* anchor, llvm::StringRef fieldName); mlir::FailureOr checkedI32(uint64_t value, mlir::Operation* anchor, llvm::StringRef fieldName); mlir::FailureOr checkedU8(uint64_t value, mlir::Operation* anchor, llvm::StringRef fieldName); mlir::FailureOr checkedSize(int64_t value, mlir::Operation* anchor, llvm::StringRef fieldName); mlir::FailureOr getCheckedI32Attr(mlir::Builder& builder, mlir::Operation* anchor, int64_t value, llvm::StringRef fieldName); mlir::FailureOr getCheckedI32Attr(mlir::Builder& builder, mlir::Operation* anchor, uint64_t value, llvm::StringRef fieldName); mlir::FailureOr getCheckedI32Attr(mlir::Builder& builder, mlir::Location loc, int64_t value, llvm::StringRef fieldName); mlir::FailureOr getCheckedI32Attr(mlir::Builder& builder, mlir::Location loc, uint64_t value, llvm::StringRef fieldName); mlir::FailureOr getCheckedShapedTypeSizeInBytes(mlir::ShapedType type, mlir::Operation* anchor, llvm::StringRef fieldName); mlir::FailureOr getCheckedShapedTypeSizeInBytes(mlir::ShapedType type, mlir::Location loc, llvm::StringRef fieldName); int32_t checkedI32OrCrash(int64_t value, llvm::StringRef fieldName); int32_t checkedI32OrCrash(uint64_t value, llvm::StringRef fieldName); uint8_t checkedU8OrCrash(uint64_t value, llvm::StringRef fieldName); size_t checkedSizeOrCrash(int64_t value, llvm::StringRef fieldName); size_t checkedAddOrCrash(size_t lhs, size_t rhs, llvm::StringRef fieldName); size_t checkedMulOrCrash(size_t lhs, size_t rhs, llvm::StringRef fieldName); } // namespace onnx_mlir::pim