add shared loop creation helpers
Validate Operations / validate-operations (push) Has been cancelled
Validate Operations / validate-operations (push) Has been cancelled
add shared checked arithmetic helpers refactor pim passes into Pim/Transforms more robust memory coalescing pass
This commit is contained in:
@@ -0,0 +1,222 @@
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include "CheckedArithmetic.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir::pim {
|
||||
|
||||
namespace {
|
||||
|
||||
static void emitCrashMessage(llvm::StringRef fieldName, llvm::StringRef message) {
|
||||
llvm::errs() << "PIM " << fieldName << " " << message << "\n";
|
||||
}
|
||||
|
||||
template <typename To, typename From>
|
||||
static FailureOr<To> checkedCastAtLocation(From value, Location loc, llvm::StringRef fieldName) {
|
||||
static_assert(std::is_integral_v<To> && std::is_integral_v<From>, "checkedCastAtLocation requires integral types");
|
||||
|
||||
using ToLimits = std::numeric_limits<To>;
|
||||
|
||||
if constexpr (std::is_signed_v<From> == std::is_signed_v<To>) {
|
||||
if (value < static_cast<From>(ToLimits::min()) || value > static_cast<From>(ToLimits::max())) {
|
||||
emitCheckedArithmeticError(loc, fieldName, "is outside representable range");
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
else if constexpr (std::is_signed_v<From>) {
|
||||
using UnsignedFrom = std::make_unsigned_t<From>;
|
||||
using UnsignedTo = std::make_unsigned_t<To>;
|
||||
if (value < 0 || static_cast<UnsignedFrom>(value) > static_cast<UnsignedTo>(ToLimits::max())) {
|
||||
emitCheckedArithmeticError(loc, fieldName, "is outside representable range");
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
else {
|
||||
using UnsignedFrom = std::make_unsigned_t<From>;
|
||||
using UnsignedTo = std::conditional_t<std::is_signed_v<To>, std::make_unsigned_t<To>, To>;
|
||||
if (static_cast<UnsignedFrom>(value) > static_cast<UnsignedTo>(ToLimits::max())) {
|
||||
emitCheckedArithmeticError(loc, fieldName, "is outside representable range");
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
return static_cast<To>(value);
|
||||
}
|
||||
|
||||
template <typename UInt>
|
||||
FailureOr<UInt> checkedMulAtLocation(UInt lhs, UInt rhs, Location loc, llvm::StringRef fieldName) {
|
||||
static_assert(std::is_integral_v<UInt> && std::is_unsigned_v<UInt>,
|
||||
"checkedMulAtLocation requires unsigned integral types");
|
||||
if (lhs != 0 && rhs > std::numeric_limits<UInt>::max() / lhs) {
|
||||
emitCheckedArithmeticError(loc, fieldName, "multiplication overflow");
|
||||
return failure();
|
||||
}
|
||||
return lhs * rhs;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
InFlightDiagnostic emitCheckedArithmeticError(Operation* anchor, llvm::StringRef fieldName, llvm::StringRef message) {
|
||||
assert(anchor && "expected arithmetic diagnostics to have an anchor op");
|
||||
return anchor->emitOpError() << fieldName << " " << message;
|
||||
}
|
||||
|
||||
InFlightDiagnostic emitCheckedArithmeticError(Location loc, llvm::StringRef fieldName, llvm::StringRef message) {
|
||||
return emitError(loc) << "PIM " << fieldName << " " << message;
|
||||
}
|
||||
|
||||
FailureOr<int32_t> checkedI32(int64_t value, Operation* anchor, llvm::StringRef fieldName) {
|
||||
return checkedCast<int32_t>(value, anchor, fieldName);
|
||||
}
|
||||
|
||||
FailureOr<int32_t> checkedI32(uint64_t value, Operation* anchor, llvm::StringRef fieldName) {
|
||||
return checkedCast<int32_t>(value, anchor, fieldName);
|
||||
}
|
||||
|
||||
FailureOr<uint8_t> checkedU8(uint64_t value, Operation* anchor, llvm::StringRef fieldName) {
|
||||
return checkedCast<uint8_t>(value, anchor, fieldName);
|
||||
}
|
||||
|
||||
FailureOr<size_t> checkedSize(int64_t value, Operation* anchor, llvm::StringRef fieldName) {
|
||||
return checkedCast<size_t>(value, anchor, fieldName);
|
||||
}
|
||||
|
||||
FailureOr<IntegerAttr>
|
||||
getCheckedI32Attr(Builder& builder, Operation* anchor, int64_t value, llvm::StringRef fieldName) {
|
||||
assert(anchor && "checked op-based attrs require a non-null diagnostic anchor");
|
||||
auto checkedValue = checkedI32(value, anchor, fieldName);
|
||||
if (failed(checkedValue))
|
||||
return failure();
|
||||
return builder.getI32IntegerAttr(*checkedValue);
|
||||
}
|
||||
|
||||
FailureOr<IntegerAttr>
|
||||
getCheckedI32Attr(Builder& builder, Operation* anchor, uint64_t value, llvm::StringRef fieldName) {
|
||||
assert(anchor && "checked op-based attrs require a non-null diagnostic anchor");
|
||||
auto checkedValue = checkedI32(value, anchor, fieldName);
|
||||
if (failed(checkedValue))
|
||||
return failure();
|
||||
return builder.getI32IntegerAttr(*checkedValue);
|
||||
}
|
||||
|
||||
FailureOr<IntegerAttr> getCheckedI32Attr(Builder& builder, Location loc, int64_t value, llvm::StringRef fieldName) {
|
||||
auto checkedValue = checkedCastAtLocation<int32_t>(value, loc, fieldName);
|
||||
if (failed(checkedValue))
|
||||
return failure();
|
||||
return builder.getI32IntegerAttr(*checkedValue);
|
||||
}
|
||||
|
||||
FailureOr<IntegerAttr> getCheckedI32Attr(Builder& builder, Location loc, uint64_t value, llvm::StringRef fieldName) {
|
||||
auto checkedValue = checkedCastAtLocation<int32_t>(value, loc, fieldName);
|
||||
if (failed(checkedValue))
|
||||
return failure();
|
||||
return builder.getI32IntegerAttr(*checkedValue);
|
||||
}
|
||||
|
||||
FailureOr<uint64_t> getCheckedShapedTypeSizeInBytes(ShapedType type, Operation* anchor, llvm::StringRef fieldName) {
|
||||
assert(anchor && "checked op-based size helpers require a non-null diagnostic anchor");
|
||||
if (!type.hasStaticShape()) {
|
||||
emitCheckedArithmeticError(anchor, fieldName, "requires static shaped type");
|
||||
return failure();
|
||||
}
|
||||
if (!hasByteSizedElementType(type.getElementType())) {
|
||||
emitCheckedArithmeticError(anchor, fieldName, "requires byte-sized element type");
|
||||
return failure();
|
||||
}
|
||||
|
||||
uint64_t elements = 1;
|
||||
for (int64_t dim : type.getShape()) {
|
||||
if (dim < 0) {
|
||||
emitCheckedArithmeticError(anchor, fieldName, "requires nonnegative dimensions");
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto nextElements = checkedMul(elements, static_cast<uint64_t>(dim), anchor, fieldName);
|
||||
if (failed(nextElements))
|
||||
return failure();
|
||||
elements = *nextElements;
|
||||
}
|
||||
|
||||
return checkedMul(
|
||||
elements, static_cast<uint64_t>(getElementTypeSizeInBytes(type.getElementType())), anchor, fieldName);
|
||||
}
|
||||
|
||||
FailureOr<uint64_t> getCheckedShapedTypeSizeInBytes(ShapedType type, Location loc, llvm::StringRef fieldName) {
|
||||
if (!type.hasStaticShape()) {
|
||||
emitCheckedArithmeticError(loc, fieldName, "requires static shaped type");
|
||||
return failure();
|
||||
}
|
||||
if (!hasByteSizedElementType(type.getElementType())) {
|
||||
emitCheckedArithmeticError(loc, fieldName, "requires byte-sized element type");
|
||||
return failure();
|
||||
}
|
||||
|
||||
uint64_t elements = 1;
|
||||
for (int64_t dim : type.getShape()) {
|
||||
if (dim < 0) {
|
||||
emitCheckedArithmeticError(loc, fieldName, "requires nonnegative dimensions");
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto nextElements = checkedMulAtLocation(elements, static_cast<uint64_t>(dim), loc, fieldName);
|
||||
if (failed(nextElements))
|
||||
return failure();
|
||||
elements = *nextElements;
|
||||
}
|
||||
|
||||
return checkedMulAtLocation(
|
||||
elements, static_cast<uint64_t>(getElementTypeSizeInBytes(type.getElementType())), loc, fieldName);
|
||||
}
|
||||
|
||||
int32_t checkedI32OrCrash(int64_t value, llvm::StringRef fieldName) {
|
||||
if (value < std::numeric_limits<int32_t>::min() || value > std::numeric_limits<int32_t>::max()) {
|
||||
emitCrashMessage(fieldName, "is outside representable range");
|
||||
llvm_unreachable("PIM checked arithmetic failure");
|
||||
}
|
||||
return static_cast<int32_t>(value);
|
||||
}
|
||||
|
||||
int32_t checkedI32OrCrash(uint64_t value, llvm::StringRef fieldName) {
|
||||
if (value > static_cast<uint64_t>(std::numeric_limits<int32_t>::max())) {
|
||||
emitCrashMessage(fieldName, "is outside representable range");
|
||||
llvm_unreachable("PIM checked arithmetic failure");
|
||||
}
|
||||
return static_cast<int32_t>(value);
|
||||
}
|
||||
|
||||
uint8_t checkedU8OrCrash(uint64_t value, llvm::StringRef fieldName) {
|
||||
if (value > static_cast<uint64_t>(std::numeric_limits<uint8_t>::max())) {
|
||||
emitCrashMessage(fieldName, "is outside representable range");
|
||||
llvm_unreachable("PIM checked arithmetic failure");
|
||||
}
|
||||
return static_cast<uint8_t>(value);
|
||||
}
|
||||
|
||||
size_t checkedSizeOrCrash(int64_t value, llvm::StringRef fieldName) {
|
||||
if (value < 0) {
|
||||
emitCrashMessage(fieldName, "is outside representable range");
|
||||
llvm_unreachable("PIM checked arithmetic failure");
|
||||
}
|
||||
return static_cast<size_t>(value);
|
||||
}
|
||||
|
||||
size_t checkedAddOrCrash(size_t lhs, size_t rhs, llvm::StringRef fieldName) {
|
||||
if (rhs > std::numeric_limits<size_t>::max() - lhs) {
|
||||
emitCrashMessage(fieldName, "addition overflow");
|
||||
llvm_unreachable("PIM checked arithmetic failure");
|
||||
}
|
||||
return lhs + rhs;
|
||||
}
|
||||
|
||||
size_t checkedMulOrCrash(size_t lhs, size_t rhs, llvm::StringRef fieldName) {
|
||||
if (lhs != 0 && rhs > std::numeric_limits<size_t>::max() / lhs) {
|
||||
emitCrashMessage(fieldName, "multiplication overflow");
|
||||
llvm_unreachable("PIM checked arithmetic failure");
|
||||
}
|
||||
return lhs * rhs;
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir::pim
|
||||
@@ -0,0 +1,107 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <limits>
|
||||
#include <type_traits>
|
||||
|
||||
namespace onnx_mlir::pim {
|
||||
|
||||
mlir::InFlightDiagnostic
|
||||
emitCheckedArithmeticError(mlir::Operation* anchor, llvm::StringRef fieldName, llvm::StringRef message);
|
||||
|
||||
mlir::InFlightDiagnostic
|
||||
emitCheckedArithmeticError(mlir::Location loc, llvm::StringRef fieldName, llvm::StringRef message);
|
||||
|
||||
template <typename To, typename From>
|
||||
mlir::FailureOr<To> checkedCast(From value, mlir::Operation* anchor, llvm::StringRef fieldName) {
|
||||
static_assert(std::is_integral_v<To> && std::is_integral_v<From>, "checkedCast requires integral types");
|
||||
|
||||
using ToLimits = std::numeric_limits<To>;
|
||||
|
||||
if constexpr (std::is_signed_v<From> == std::is_signed_v<To>) {
|
||||
if (value < static_cast<From>(ToLimits::min()) || value > static_cast<From>(ToLimits::max())) {
|
||||
emitCheckedArithmeticError(anchor, fieldName, "is outside representable range");
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
else if constexpr (std::is_signed_v<From>) {
|
||||
using UnsignedFrom = std::make_unsigned_t<From>;
|
||||
using UnsignedTo = std::make_unsigned_t<To>;
|
||||
if (value < 0 || static_cast<UnsignedFrom>(value) > static_cast<UnsignedTo>(ToLimits::max())) {
|
||||
emitCheckedArithmeticError(anchor, fieldName, "is outside representable range");
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
else {
|
||||
using UnsignedFrom = std::make_unsigned_t<From>;
|
||||
using UnsignedTo = std::conditional_t<std::is_signed_v<To>, std::make_unsigned_t<To>, To>;
|
||||
if (static_cast<UnsignedFrom>(value) > static_cast<UnsignedTo>(ToLimits::max())) {
|
||||
emitCheckedArithmeticError(anchor, fieldName, "is outside representable range");
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
|
||||
return static_cast<To>(value);
|
||||
}
|
||||
|
||||
template <typename UInt>
|
||||
mlir::FailureOr<UInt> checkedAdd(UInt lhs, UInt rhs, mlir::Operation* anchor, llvm::StringRef fieldName) {
|
||||
static_assert(std::is_integral_v<UInt> && std::is_unsigned_v<UInt>, "checkedAdd requires unsigned integral types");
|
||||
if (rhs > std::numeric_limits<UInt>::max() - lhs) {
|
||||
emitCheckedArithmeticError(anchor, fieldName, "addition overflow");
|
||||
return mlir::failure();
|
||||
}
|
||||
return lhs + rhs;
|
||||
}
|
||||
|
||||
template <typename UInt>
|
||||
mlir::FailureOr<UInt> checkedMul(UInt lhs, UInt rhs, mlir::Operation* anchor, llvm::StringRef fieldName) {
|
||||
static_assert(std::is_integral_v<UInt> && std::is_unsigned_v<UInt>, "checkedMul requires unsigned integral types");
|
||||
if (lhs != 0 && rhs > std::numeric_limits<UInt>::max() / lhs) {
|
||||
emitCheckedArithmeticError(anchor, fieldName, "multiplication overflow");
|
||||
return mlir::failure();
|
||||
}
|
||||
return lhs * rhs;
|
||||
}
|
||||
|
||||
mlir::FailureOr<int32_t> checkedI32(int64_t value, mlir::Operation* anchor, llvm::StringRef fieldName);
|
||||
mlir::FailureOr<int32_t> checkedI32(uint64_t value, mlir::Operation* anchor, llvm::StringRef fieldName);
|
||||
|
||||
mlir::FailureOr<uint8_t> checkedU8(uint64_t value, mlir::Operation* anchor, llvm::StringRef fieldName);
|
||||
|
||||
mlir::FailureOr<size_t> checkedSize(int64_t value, mlir::Operation* anchor, llvm::StringRef fieldName);
|
||||
|
||||
mlir::FailureOr<mlir::IntegerAttr>
|
||||
getCheckedI32Attr(mlir::Builder& builder, mlir::Operation* anchor, int64_t value, llvm::StringRef fieldName);
|
||||
|
||||
mlir::FailureOr<mlir::IntegerAttr>
|
||||
getCheckedI32Attr(mlir::Builder& builder, mlir::Operation* anchor, uint64_t value, llvm::StringRef fieldName);
|
||||
|
||||
mlir::FailureOr<mlir::IntegerAttr>
|
||||
getCheckedI32Attr(mlir::Builder& builder, mlir::Location loc, int64_t value, llvm::StringRef fieldName);
|
||||
|
||||
mlir::FailureOr<mlir::IntegerAttr>
|
||||
getCheckedI32Attr(mlir::Builder& builder, mlir::Location loc, uint64_t value, llvm::StringRef fieldName);
|
||||
|
||||
mlir::FailureOr<uint64_t>
|
||||
getCheckedShapedTypeSizeInBytes(mlir::ShapedType type, mlir::Operation* anchor, llvm::StringRef fieldName);
|
||||
|
||||
mlir::FailureOr<uint64_t>
|
||||
getCheckedShapedTypeSizeInBytes(mlir::ShapedType type, mlir::Location loc, llvm::StringRef fieldName);
|
||||
|
||||
int32_t checkedI32OrCrash(int64_t value, llvm::StringRef fieldName);
|
||||
int32_t checkedI32OrCrash(uint64_t value, llvm::StringRef fieldName);
|
||||
uint8_t checkedU8OrCrash(uint64_t value, llvm::StringRef fieldName);
|
||||
size_t checkedSizeOrCrash(int64_t value, llvm::StringRef fieldName);
|
||||
size_t checkedAddOrCrash(size_t lhs, size_t rhs, llvm::StringRef fieldName);
|
||||
size_t checkedMulOrCrash(size_t lhs, size_t rhs, llvm::StringRef fieldName);
|
||||
|
||||
} // namespace onnx_mlir::pim
|
||||
Reference in New Issue
Block a user