Files
Raptor/src/PIM/Common/Support/CheckedArithmetic.hpp
T
NiccoloN 636310d0cb
Validate Operations / validate-operations (push) Has been cancelled
add shared loop creation helpers
add shared checked arithmetic helpers
refactor pim passes into Pim/Transforms
more robust memory coalescing pass
2026-06-01 16:49:06 +02:00

108 lines
4.4 KiB
C++

#pragma once
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Operation.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/StringRef.h"
#include <cstddef>
#include <cstdint>
#include <limits>
#include <type_traits>
namespace onnx_mlir::pim {
mlir::InFlightDiagnostic
emitCheckedArithmeticError(mlir::Operation* anchor, llvm::StringRef fieldName, llvm::StringRef message);
mlir::InFlightDiagnostic
emitCheckedArithmeticError(mlir::Location loc, llvm::StringRef fieldName, llvm::StringRef message);
template <typename To, typename From>
mlir::FailureOr<To> checkedCast(From value, mlir::Operation* anchor, llvm::StringRef fieldName) {
static_assert(std::is_integral_v<To> && std::is_integral_v<From>, "checkedCast requires integral types");
using ToLimits = std::numeric_limits<To>;
if constexpr (std::is_signed_v<From> == std::is_signed_v<To>) {
if (value < static_cast<From>(ToLimits::min()) || value > static_cast<From>(ToLimits::max())) {
emitCheckedArithmeticError(anchor, fieldName, "is outside representable range");
return mlir::failure();
}
}
else if constexpr (std::is_signed_v<From>) {
using UnsignedFrom = std::make_unsigned_t<From>;
using UnsignedTo = std::make_unsigned_t<To>;
if (value < 0 || static_cast<UnsignedFrom>(value) > static_cast<UnsignedTo>(ToLimits::max())) {
emitCheckedArithmeticError(anchor, fieldName, "is outside representable range");
return mlir::failure();
}
}
else {
using UnsignedFrom = std::make_unsigned_t<From>;
using UnsignedTo = std::conditional_t<std::is_signed_v<To>, std::make_unsigned_t<To>, To>;
if (static_cast<UnsignedFrom>(value) > static_cast<UnsignedTo>(ToLimits::max())) {
emitCheckedArithmeticError(anchor, fieldName, "is outside representable range");
return mlir::failure();
}
}
return static_cast<To>(value);
}
template <typename UInt>
mlir::FailureOr<UInt> checkedAdd(UInt lhs, UInt rhs, mlir::Operation* anchor, llvm::StringRef fieldName) {
static_assert(std::is_integral_v<UInt> && std::is_unsigned_v<UInt>, "checkedAdd requires unsigned integral types");
if (rhs > std::numeric_limits<UInt>::max() - lhs) {
emitCheckedArithmeticError(anchor, fieldName, "addition overflow");
return mlir::failure();
}
return lhs + rhs;
}
template <typename UInt>
mlir::FailureOr<UInt> checkedMul(UInt lhs, UInt rhs, mlir::Operation* anchor, llvm::StringRef fieldName) {
static_assert(std::is_integral_v<UInt> && std::is_unsigned_v<UInt>, "checkedMul requires unsigned integral types");
if (lhs != 0 && rhs > std::numeric_limits<UInt>::max() / lhs) {
emitCheckedArithmeticError(anchor, fieldName, "multiplication overflow");
return mlir::failure();
}
return lhs * rhs;
}
mlir::FailureOr<int32_t> checkedI32(int64_t value, mlir::Operation* anchor, llvm::StringRef fieldName);
mlir::FailureOr<int32_t> checkedI32(uint64_t value, mlir::Operation* anchor, llvm::StringRef fieldName);
mlir::FailureOr<uint8_t> checkedU8(uint64_t value, mlir::Operation* anchor, llvm::StringRef fieldName);
mlir::FailureOr<size_t> checkedSize(int64_t value, mlir::Operation* anchor, llvm::StringRef fieldName);
mlir::FailureOr<mlir::IntegerAttr>
getCheckedI32Attr(mlir::Builder& builder, mlir::Operation* anchor, int64_t value, llvm::StringRef fieldName);
mlir::FailureOr<mlir::IntegerAttr>
getCheckedI32Attr(mlir::Builder& builder, mlir::Operation* anchor, uint64_t value, llvm::StringRef fieldName);
mlir::FailureOr<mlir::IntegerAttr>
getCheckedI32Attr(mlir::Builder& builder, mlir::Location loc, int64_t value, llvm::StringRef fieldName);
mlir::FailureOr<mlir::IntegerAttr>
getCheckedI32Attr(mlir::Builder& builder, mlir::Location loc, uint64_t value, llvm::StringRef fieldName);
mlir::FailureOr<uint64_t>
getCheckedShapedTypeSizeInBytes(mlir::ShapedType type, mlir::Operation* anchor, llvm::StringRef fieldName);
mlir::FailureOr<uint64_t>
getCheckedShapedTypeSizeInBytes(mlir::ShapedType type, mlir::Location loc, llvm::StringRef fieldName);
int32_t checkedI32OrCrash(int64_t value, llvm::StringRef fieldName);
int32_t checkedI32OrCrash(uint64_t value, llvm::StringRef fieldName);
uint8_t checkedU8OrCrash(uint64_t value, llvm::StringRef fieldName);
size_t checkedSizeOrCrash(int64_t value, llvm::StringRef fieldName);
size_t checkedAddOrCrash(size_t lhs, size_t rhs, llvm::StringRef fieldName);
size_t checkedMulOrCrash(size_t lhs, size_t rhs, llvm::StringRef fieldName);
} // namespace onnx_mlir::pim