636310d0cb
Validate Operations / validate-operations (push) Has been cancelled
add shared checked arithmetic helpers refactor pim passes into Pim/Transforms more robust memory coalescing pass
108 lines
4.4 KiB
C++
108 lines
4.4 KiB
C++
#pragma once
|
|
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
|
#include "mlir/IR/Operation.h"
|
|
#include "mlir/Support/LogicalResult.h"
|
|
|
|
#include "llvm/ADT/StringRef.h"
|
|
|
|
#include <cstddef>
|
|
#include <cstdint>
|
|
#include <limits>
|
|
#include <type_traits>
|
|
|
|
namespace onnx_mlir::pim {
|
|
|
|
mlir::InFlightDiagnostic
|
|
emitCheckedArithmeticError(mlir::Operation* anchor, llvm::StringRef fieldName, llvm::StringRef message);
|
|
|
|
mlir::InFlightDiagnostic
|
|
emitCheckedArithmeticError(mlir::Location loc, llvm::StringRef fieldName, llvm::StringRef message);
|
|
|
|
template <typename To, typename From>
|
|
mlir::FailureOr<To> checkedCast(From value, mlir::Operation* anchor, llvm::StringRef fieldName) {
|
|
static_assert(std::is_integral_v<To> && std::is_integral_v<From>, "checkedCast requires integral types");
|
|
|
|
using ToLimits = std::numeric_limits<To>;
|
|
|
|
if constexpr (std::is_signed_v<From> == std::is_signed_v<To>) {
|
|
if (value < static_cast<From>(ToLimits::min()) || value > static_cast<From>(ToLimits::max())) {
|
|
emitCheckedArithmeticError(anchor, fieldName, "is outside representable range");
|
|
return mlir::failure();
|
|
}
|
|
}
|
|
else if constexpr (std::is_signed_v<From>) {
|
|
using UnsignedFrom = std::make_unsigned_t<From>;
|
|
using UnsignedTo = std::make_unsigned_t<To>;
|
|
if (value < 0 || static_cast<UnsignedFrom>(value) > static_cast<UnsignedTo>(ToLimits::max())) {
|
|
emitCheckedArithmeticError(anchor, fieldName, "is outside representable range");
|
|
return mlir::failure();
|
|
}
|
|
}
|
|
else {
|
|
using UnsignedFrom = std::make_unsigned_t<From>;
|
|
using UnsignedTo = std::conditional_t<std::is_signed_v<To>, std::make_unsigned_t<To>, To>;
|
|
if (static_cast<UnsignedFrom>(value) > static_cast<UnsignedTo>(ToLimits::max())) {
|
|
emitCheckedArithmeticError(anchor, fieldName, "is outside representable range");
|
|
return mlir::failure();
|
|
}
|
|
}
|
|
|
|
return static_cast<To>(value);
|
|
}
|
|
|
|
template <typename UInt>
|
|
mlir::FailureOr<UInt> checkedAdd(UInt lhs, UInt rhs, mlir::Operation* anchor, llvm::StringRef fieldName) {
|
|
static_assert(std::is_integral_v<UInt> && std::is_unsigned_v<UInt>, "checkedAdd requires unsigned integral types");
|
|
if (rhs > std::numeric_limits<UInt>::max() - lhs) {
|
|
emitCheckedArithmeticError(anchor, fieldName, "addition overflow");
|
|
return mlir::failure();
|
|
}
|
|
return lhs + rhs;
|
|
}
|
|
|
|
template <typename UInt>
|
|
mlir::FailureOr<UInt> checkedMul(UInt lhs, UInt rhs, mlir::Operation* anchor, llvm::StringRef fieldName) {
|
|
static_assert(std::is_integral_v<UInt> && std::is_unsigned_v<UInt>, "checkedMul requires unsigned integral types");
|
|
if (lhs != 0 && rhs > std::numeric_limits<UInt>::max() / lhs) {
|
|
emitCheckedArithmeticError(anchor, fieldName, "multiplication overflow");
|
|
return mlir::failure();
|
|
}
|
|
return lhs * rhs;
|
|
}
|
|
|
|
mlir::FailureOr<int32_t> checkedI32(int64_t value, mlir::Operation* anchor, llvm::StringRef fieldName);
|
|
mlir::FailureOr<int32_t> checkedI32(uint64_t value, mlir::Operation* anchor, llvm::StringRef fieldName);
|
|
|
|
mlir::FailureOr<uint8_t> checkedU8(uint64_t value, mlir::Operation* anchor, llvm::StringRef fieldName);
|
|
|
|
mlir::FailureOr<size_t> checkedSize(int64_t value, mlir::Operation* anchor, llvm::StringRef fieldName);
|
|
|
|
mlir::FailureOr<mlir::IntegerAttr>
|
|
getCheckedI32Attr(mlir::Builder& builder, mlir::Operation* anchor, int64_t value, llvm::StringRef fieldName);
|
|
|
|
mlir::FailureOr<mlir::IntegerAttr>
|
|
getCheckedI32Attr(mlir::Builder& builder, mlir::Operation* anchor, uint64_t value, llvm::StringRef fieldName);
|
|
|
|
mlir::FailureOr<mlir::IntegerAttr>
|
|
getCheckedI32Attr(mlir::Builder& builder, mlir::Location loc, int64_t value, llvm::StringRef fieldName);
|
|
|
|
mlir::FailureOr<mlir::IntegerAttr>
|
|
getCheckedI32Attr(mlir::Builder& builder, mlir::Location loc, uint64_t value, llvm::StringRef fieldName);
|
|
|
|
mlir::FailureOr<uint64_t>
|
|
getCheckedShapedTypeSizeInBytes(mlir::ShapedType type, mlir::Operation* anchor, llvm::StringRef fieldName);
|
|
|
|
mlir::FailureOr<uint64_t>
|
|
getCheckedShapedTypeSizeInBytes(mlir::ShapedType type, mlir::Location loc, llvm::StringRef fieldName);
|
|
|
|
int32_t checkedI32OrCrash(int64_t value, llvm::StringRef fieldName);
|
|
int32_t checkedI32OrCrash(uint64_t value, llvm::StringRef fieldName);
|
|
uint8_t checkedU8OrCrash(uint64_t value, llvm::StringRef fieldName);
|
|
size_t checkedSizeOrCrash(int64_t value, llvm::StringRef fieldName);
|
|
size_t checkedAddOrCrash(size_t lhs, size_t rhs, llvm::StringRef fieldName);
|
|
size_t checkedMulOrCrash(size_t lhs, size_t rhs, llvm::StringRef fieldName);
|
|
|
|
} // namespace onnx_mlir::pim
|