add shared loop creation helpers
Validate Operations / validate-operations (push) Has been cancelled

add shared checked arithmetic helpers
refactor pim passes into Pim/Transforms
more robust memory coalescing pass
This commit is contained in:
NiccoloN
2026-06-01 16:49:06 +02:00
parent 356be6ccc2
commit 636310d0cb
55 changed files with 2007 additions and 1103 deletions
+3
View File
@@ -5,9 +5,11 @@ add_pim_library(OMPimCommon
IR/ConstantUtils.cpp
IR/CoreBlockUtils.cpp
IR/EntryPointUtils.cpp
IR/LoopUtils.cpp
IR/ShapeUtils.cpp
IR/SubviewUtils.cpp
IR/WeightUtils.cpp
Support/CheckedArithmetic.cpp
Support/DebugDump.cpp
Support/Diagnostics.cpp
Support/FileSystemUtils.cpp
@@ -20,6 +22,7 @@ add_pim_library(OMPimCommon
LINK_LIBS PUBLIC
MLIRLinalgDialect
MLIRSCFDialect
onnx
SpatialOps
PimOps
+96
View File
@@ -0,0 +1,96 @@
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "llvm/Support/MathExtras.h"
#include <optional>
#include "ConstantUtils.hpp"
#include "LoopUtils.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static std::optional<int64_t> getStaticTripCount(Value lowerBound, Value upperBound, Value step) {
auto lower = matchConstantIndexValue(lowerBound);
auto upper = matchConstantIndexValue(upperBound);
auto stepValue = matchConstantIndexValue(step);
if (!lower || !upper || !stepValue)
return std::nullopt;
if (*stepValue <= 0)
return std::nullopt;
if (*upper <= *lower)
return int64_t {0};
return llvm::divideCeil(*upper - *lower, *stepValue);
}
} // namespace
static LogicalResult validateNormalizedLoopYields(Location loc, ValueRange initArgs, ArrayRef<Value> yieldedValues) {
if (yieldedValues.size() == initArgs.size())
return success();
emitError(loc) << "normalized loop body yielded " << yieldedValues.size() << " values for " << initArgs.size()
<< " iter args";
return failure();
}
FailureOr<NormalizedLoopResult> buildNormalizedScfFor(OpBuilder& builder,
Location loc,
Value lowerBound,
Value upperBound,
Value step,
ValueRange initArgs,
NormalizedLoopBodyBuilder bodyBuilder) {
NormalizedLoopResult result;
if (auto stepValue = matchConstantIndexValue(step); stepValue && *stepValue <= 0) {
emitError(loc) << "normalized scf.for requires a positive step, got " << *stepValue;
return failure();
}
if (auto tripCount = getStaticTripCount(lowerBound, upperBound, step)) {
if (*tripCount == 0) {
llvm::append_range(result.results, initArgs);
return result;
}
if (*tripCount == 1) {
result.inductionVar = lowerBound;
if (failed(bodyBuilder(builder, loc, lowerBound, initArgs, result.results)))
return failure();
if (failed(validateNormalizedLoopYields(loc, initArgs, result.results)))
return failure();
return result;
}
}
result.loop = scf::ForOp::create(builder, loc, lowerBound, upperBound, step, initArgs);
result.inductionVar = result.loop.getInductionVar();
{
OpBuilder::InsertionGuard guard(builder);
Block* body = result.loop.getBody();
if (!body->empty())
if (auto yieldOp = dyn_cast<scf::YieldOp>(body->back()))
yieldOp->erase();
builder.setInsertionPointToEnd(body);
ValueRange iterArgs = result.loop.getRegionIterArgs();
if (failed(bodyBuilder(builder, loc, result.inductionVar, iterArgs, result.results))) {
result.loop.erase();
return failure();
}
if (failed(validateNormalizedLoopYields(loc, initArgs, result.results))) {
result.loop.erase();
return failure();
}
scf::YieldOp::create(builder, loc, result.results);
}
builder.setInsertionPointAfter(result.loop);
result.results.assign(result.loop.getResults().begin(), result.loop.getResults().end());
return result;
}
} // namespace onnx_mlir
+30
View File
@@ -0,0 +1,30 @@
#pragma once
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Builders.h"
#include "llvm/ADT/STLFunctionalExtras.h"
#include "llvm/ADT/SmallVector.h"
namespace onnx_mlir {
struct NormalizedLoopResult {
mlir::Value inductionVar;
llvm::SmallVector<mlir::Value, 4> results;
mlir::scf::ForOp loop;
bool wasInlined() const { return !loop; }
};
using NormalizedLoopBodyBuilder = llvm::function_ref<mlir::LogicalResult(
mlir::OpBuilder&, mlir::Location, mlir::Value, mlir::ValueRange, llvm::SmallVectorImpl<mlir::Value>&)>;
mlir::FailureOr<NormalizedLoopResult> buildNormalizedScfFor(mlir::OpBuilder& builder,
mlir::Location loc,
mlir::Value lowerBound,
mlir::Value upperBound,
mlir::Value step,
mlir::ValueRange initArgs,
NormalizedLoopBodyBuilder bodyBuilder);
} // namespace onnx_mlir
@@ -0,0 +1,222 @@
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"
#include "CheckedArithmetic.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
using namespace mlir;
namespace onnx_mlir::pim {
namespace {
static void emitCrashMessage(llvm::StringRef fieldName, llvm::StringRef message) {
llvm::errs() << "PIM " << fieldName << " " << message << "\n";
}
template <typename To, typename From>
static FailureOr<To> checkedCastAtLocation(From value, Location loc, llvm::StringRef fieldName) {
static_assert(std::is_integral_v<To> && std::is_integral_v<From>, "checkedCastAtLocation requires integral types");
using ToLimits = std::numeric_limits<To>;
if constexpr (std::is_signed_v<From> == std::is_signed_v<To>) {
if (value < static_cast<From>(ToLimits::min()) || value > static_cast<From>(ToLimits::max())) {
emitCheckedArithmeticError(loc, fieldName, "is outside representable range");
return failure();
}
}
else if constexpr (std::is_signed_v<From>) {
using UnsignedFrom = std::make_unsigned_t<From>;
using UnsignedTo = std::make_unsigned_t<To>;
if (value < 0 || static_cast<UnsignedFrom>(value) > static_cast<UnsignedTo>(ToLimits::max())) {
emitCheckedArithmeticError(loc, fieldName, "is outside representable range");
return failure();
}
}
else {
using UnsignedFrom = std::make_unsigned_t<From>;
using UnsignedTo = std::conditional_t<std::is_signed_v<To>, std::make_unsigned_t<To>, To>;
if (static_cast<UnsignedFrom>(value) > static_cast<UnsignedTo>(ToLimits::max())) {
emitCheckedArithmeticError(loc, fieldName, "is outside representable range");
return failure();
}
}
return static_cast<To>(value);
}
template <typename UInt>
FailureOr<UInt> checkedMulAtLocation(UInt lhs, UInt rhs, Location loc, llvm::StringRef fieldName) {
static_assert(std::is_integral_v<UInt> && std::is_unsigned_v<UInt>,
"checkedMulAtLocation requires unsigned integral types");
if (lhs != 0 && rhs > std::numeric_limits<UInt>::max() / lhs) {
emitCheckedArithmeticError(loc, fieldName, "multiplication overflow");
return failure();
}
return lhs * rhs;
}
} // namespace
InFlightDiagnostic emitCheckedArithmeticError(Operation* anchor, llvm::StringRef fieldName, llvm::StringRef message) {
assert(anchor && "expected arithmetic diagnostics to have an anchor op");
return anchor->emitOpError() << fieldName << " " << message;
}
InFlightDiagnostic emitCheckedArithmeticError(Location loc, llvm::StringRef fieldName, llvm::StringRef message) {
return emitError(loc) << "PIM " << fieldName << " " << message;
}
FailureOr<int32_t> checkedI32(int64_t value, Operation* anchor, llvm::StringRef fieldName) {
return checkedCast<int32_t>(value, anchor, fieldName);
}
FailureOr<int32_t> checkedI32(uint64_t value, Operation* anchor, llvm::StringRef fieldName) {
return checkedCast<int32_t>(value, anchor, fieldName);
}
FailureOr<uint8_t> checkedU8(uint64_t value, Operation* anchor, llvm::StringRef fieldName) {
return checkedCast<uint8_t>(value, anchor, fieldName);
}
FailureOr<size_t> checkedSize(int64_t value, Operation* anchor, llvm::StringRef fieldName) {
return checkedCast<size_t>(value, anchor, fieldName);
}
FailureOr<IntegerAttr>
getCheckedI32Attr(Builder& builder, Operation* anchor, int64_t value, llvm::StringRef fieldName) {
assert(anchor && "checked op-based attrs require a non-null diagnostic anchor");
auto checkedValue = checkedI32(value, anchor, fieldName);
if (failed(checkedValue))
return failure();
return builder.getI32IntegerAttr(*checkedValue);
}
FailureOr<IntegerAttr>
getCheckedI32Attr(Builder& builder, Operation* anchor, uint64_t value, llvm::StringRef fieldName) {
assert(anchor && "checked op-based attrs require a non-null diagnostic anchor");
auto checkedValue = checkedI32(value, anchor, fieldName);
if (failed(checkedValue))
return failure();
return builder.getI32IntegerAttr(*checkedValue);
}
FailureOr<IntegerAttr> getCheckedI32Attr(Builder& builder, Location loc, int64_t value, llvm::StringRef fieldName) {
auto checkedValue = checkedCastAtLocation<int32_t>(value, loc, fieldName);
if (failed(checkedValue))
return failure();
return builder.getI32IntegerAttr(*checkedValue);
}
FailureOr<IntegerAttr> getCheckedI32Attr(Builder& builder, Location loc, uint64_t value, llvm::StringRef fieldName) {
auto checkedValue = checkedCastAtLocation<int32_t>(value, loc, fieldName);
if (failed(checkedValue))
return failure();
return builder.getI32IntegerAttr(*checkedValue);
}
FailureOr<uint64_t> getCheckedShapedTypeSizeInBytes(ShapedType type, Operation* anchor, llvm::StringRef fieldName) {
assert(anchor && "checked op-based size helpers require a non-null diagnostic anchor");
if (!type.hasStaticShape()) {
emitCheckedArithmeticError(anchor, fieldName, "requires static shaped type");
return failure();
}
if (!hasByteSizedElementType(type.getElementType())) {
emitCheckedArithmeticError(anchor, fieldName, "requires byte-sized element type");
return failure();
}
uint64_t elements = 1;
for (int64_t dim : type.getShape()) {
if (dim < 0) {
emitCheckedArithmeticError(anchor, fieldName, "requires nonnegative dimensions");
return failure();
}
auto nextElements = checkedMul(elements, static_cast<uint64_t>(dim), anchor, fieldName);
if (failed(nextElements))
return failure();
elements = *nextElements;
}
return checkedMul(
elements, static_cast<uint64_t>(getElementTypeSizeInBytes(type.getElementType())), anchor, fieldName);
}
FailureOr<uint64_t> getCheckedShapedTypeSizeInBytes(ShapedType type, Location loc, llvm::StringRef fieldName) {
if (!type.hasStaticShape()) {
emitCheckedArithmeticError(loc, fieldName, "requires static shaped type");
return failure();
}
if (!hasByteSizedElementType(type.getElementType())) {
emitCheckedArithmeticError(loc, fieldName, "requires byte-sized element type");
return failure();
}
uint64_t elements = 1;
for (int64_t dim : type.getShape()) {
if (dim < 0) {
emitCheckedArithmeticError(loc, fieldName, "requires nonnegative dimensions");
return failure();
}
auto nextElements = checkedMulAtLocation(elements, static_cast<uint64_t>(dim), loc, fieldName);
if (failed(nextElements))
return failure();
elements = *nextElements;
}
return checkedMulAtLocation(
elements, static_cast<uint64_t>(getElementTypeSizeInBytes(type.getElementType())), loc, fieldName);
}
int32_t checkedI32OrCrash(int64_t value, llvm::StringRef fieldName) {
if (value < std::numeric_limits<int32_t>::min() || value > std::numeric_limits<int32_t>::max()) {
emitCrashMessage(fieldName, "is outside representable range");
llvm_unreachable("PIM checked arithmetic failure");
}
return static_cast<int32_t>(value);
}
int32_t checkedI32OrCrash(uint64_t value, llvm::StringRef fieldName) {
if (value > static_cast<uint64_t>(std::numeric_limits<int32_t>::max())) {
emitCrashMessage(fieldName, "is outside representable range");
llvm_unreachable("PIM checked arithmetic failure");
}
return static_cast<int32_t>(value);
}
uint8_t checkedU8OrCrash(uint64_t value, llvm::StringRef fieldName) {
if (value > static_cast<uint64_t>(std::numeric_limits<uint8_t>::max())) {
emitCrashMessage(fieldName, "is outside representable range");
llvm_unreachable("PIM checked arithmetic failure");
}
return static_cast<uint8_t>(value);
}
size_t checkedSizeOrCrash(int64_t value, llvm::StringRef fieldName) {
if (value < 0) {
emitCrashMessage(fieldName, "is outside representable range");
llvm_unreachable("PIM checked arithmetic failure");
}
return static_cast<size_t>(value);
}
size_t checkedAddOrCrash(size_t lhs, size_t rhs, llvm::StringRef fieldName) {
if (rhs > std::numeric_limits<size_t>::max() - lhs) {
emitCrashMessage(fieldName, "addition overflow");
llvm_unreachable("PIM checked arithmetic failure");
}
return lhs + rhs;
}
size_t checkedMulOrCrash(size_t lhs, size_t rhs, llvm::StringRef fieldName) {
if (lhs != 0 && rhs > std::numeric_limits<size_t>::max() / lhs) {
emitCrashMessage(fieldName, "multiplication overflow");
llvm_unreachable("PIM checked arithmetic failure");
}
return lhs * rhs;
}
} // namespace onnx_mlir::pim
@@ -0,0 +1,107 @@
#pragma once
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Operation.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/StringRef.h"
#include <cstddef>
#include <cstdint>
#include <limits>
#include <type_traits>
namespace onnx_mlir::pim {
mlir::InFlightDiagnostic
emitCheckedArithmeticError(mlir::Operation* anchor, llvm::StringRef fieldName, llvm::StringRef message);
mlir::InFlightDiagnostic
emitCheckedArithmeticError(mlir::Location loc, llvm::StringRef fieldName, llvm::StringRef message);
template <typename To, typename From>
mlir::FailureOr<To> checkedCast(From value, mlir::Operation* anchor, llvm::StringRef fieldName) {
static_assert(std::is_integral_v<To> && std::is_integral_v<From>, "checkedCast requires integral types");
using ToLimits = std::numeric_limits<To>;
if constexpr (std::is_signed_v<From> == std::is_signed_v<To>) {
if (value < static_cast<From>(ToLimits::min()) || value > static_cast<From>(ToLimits::max())) {
emitCheckedArithmeticError(anchor, fieldName, "is outside representable range");
return mlir::failure();
}
}
else if constexpr (std::is_signed_v<From>) {
using UnsignedFrom = std::make_unsigned_t<From>;
using UnsignedTo = std::make_unsigned_t<To>;
if (value < 0 || static_cast<UnsignedFrom>(value) > static_cast<UnsignedTo>(ToLimits::max())) {
emitCheckedArithmeticError(anchor, fieldName, "is outside representable range");
return mlir::failure();
}
}
else {
using UnsignedFrom = std::make_unsigned_t<From>;
using UnsignedTo = std::conditional_t<std::is_signed_v<To>, std::make_unsigned_t<To>, To>;
if (static_cast<UnsignedFrom>(value) > static_cast<UnsignedTo>(ToLimits::max())) {
emitCheckedArithmeticError(anchor, fieldName, "is outside representable range");
return mlir::failure();
}
}
return static_cast<To>(value);
}
template <typename UInt>
mlir::FailureOr<UInt> checkedAdd(UInt lhs, UInt rhs, mlir::Operation* anchor, llvm::StringRef fieldName) {
static_assert(std::is_integral_v<UInt> && std::is_unsigned_v<UInt>, "checkedAdd requires unsigned integral types");
if (rhs > std::numeric_limits<UInt>::max() - lhs) {
emitCheckedArithmeticError(anchor, fieldName, "addition overflow");
return mlir::failure();
}
return lhs + rhs;
}
template <typename UInt>
mlir::FailureOr<UInt> checkedMul(UInt lhs, UInt rhs, mlir::Operation* anchor, llvm::StringRef fieldName) {
static_assert(std::is_integral_v<UInt> && std::is_unsigned_v<UInt>, "checkedMul requires unsigned integral types");
if (lhs != 0 && rhs > std::numeric_limits<UInt>::max() / lhs) {
emitCheckedArithmeticError(anchor, fieldName, "multiplication overflow");
return mlir::failure();
}
return lhs * rhs;
}
mlir::FailureOr<int32_t> checkedI32(int64_t value, mlir::Operation* anchor, llvm::StringRef fieldName);
mlir::FailureOr<int32_t> checkedI32(uint64_t value, mlir::Operation* anchor, llvm::StringRef fieldName);
mlir::FailureOr<uint8_t> checkedU8(uint64_t value, mlir::Operation* anchor, llvm::StringRef fieldName);
mlir::FailureOr<size_t> checkedSize(int64_t value, mlir::Operation* anchor, llvm::StringRef fieldName);
mlir::FailureOr<mlir::IntegerAttr>
getCheckedI32Attr(mlir::Builder& builder, mlir::Operation* anchor, int64_t value, llvm::StringRef fieldName);
mlir::FailureOr<mlir::IntegerAttr>
getCheckedI32Attr(mlir::Builder& builder, mlir::Operation* anchor, uint64_t value, llvm::StringRef fieldName);
mlir::FailureOr<mlir::IntegerAttr>
getCheckedI32Attr(mlir::Builder& builder, mlir::Location loc, int64_t value, llvm::StringRef fieldName);
mlir::FailureOr<mlir::IntegerAttr>
getCheckedI32Attr(mlir::Builder& builder, mlir::Location loc, uint64_t value, llvm::StringRef fieldName);
mlir::FailureOr<uint64_t>
getCheckedShapedTypeSizeInBytes(mlir::ShapedType type, mlir::Operation* anchor, llvm::StringRef fieldName);
mlir::FailureOr<uint64_t>
getCheckedShapedTypeSizeInBytes(mlir::ShapedType type, mlir::Location loc, llvm::StringRef fieldName);
int32_t checkedI32OrCrash(int64_t value, llvm::StringRef fieldName);
int32_t checkedI32OrCrash(uint64_t value, llvm::StringRef fieldName);
uint8_t checkedU8OrCrash(uint64_t value, llvm::StringRef fieldName);
size_t checkedSizeOrCrash(int64_t value, llvm::StringRef fieldName);
size_t checkedAddOrCrash(size_t lhs, size_t rhs, llvm::StringRef fieldName);
size_t checkedMulOrCrash(size_t lhs, size_t rhs, llvm::StringRef fieldName);
} // namespace onnx_mlir::pim