#pragma once #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/Value.h" #include "llvm/ADT/DenseMap.h" #include #include namespace onnx_mlir { /// Describes a value as a base addressable object plus a statically known /// byte offset after peeling aliases, casts, and contiguous subviews. struct ResolvedContiguousAddress { mlir::Value base; int64_t byteOffset = 0; }; /// Records compile-time facts used when interpreting address arithmetic and /// loop-carried aliases inside PIM regions. struct StaticValueKnowledge { llvm::DenseMap indexValues; llvm::DenseMap aliases; StaticValueKnowledge() {} }; struct CompiledIndexExprNode; struct CompiledIndexExpr { std::shared_ptr node; CompiledIndexExpr() = default; explicit CompiledIndexExpr(std::shared_ptr node) : node(std::move(node)) {} llvm::FailureOr evaluate(const StaticValueKnowledge& knowledge) const; }; struct CompiledIndexExprNode { enum class Kind { Constant, Symbol, Add, Sub, Mul, DivUI, DivSI, RemUI, RemSI, MinUI, CmpI, Select, ConstantGlobalLoad }; Kind kind = Kind::Constant; int64_t constant = 0; mlir::Value symbol; mlir::arith::CmpIPredicate predicate = mlir::arith::CmpIPredicate::eq; mlir::memref::GlobalOp globalOp; llvm::SmallVector globalStrides; llvm::SmallVector operands; }; struct CompiledAddressExpr { mlir::Value base; CompiledIndexExpr byteOffset; llvm::FailureOr evaluate(const StaticValueKnowledge& knowledge, std::optional lane) const; }; mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp); /// Resolves a value to contiguous backing storage when that storage can be /// proven statically from aliases, DPS ties, casts, and subviews. llvm::FailureOr resolveContiguousAddress(mlir::Value value, const StaticValueKnowledge& knowledge = {}); /// Statically evaluates index-like SSA values, including simple integer /// arithmetic and loop facts recorded in `knowledge`. llvm::FailureOr resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge = {}); llvm::FailureOr compileIndexExpr(mlir::Value value); /// Follows alias, view, and DPS chains to recover the backing value of a /// loop-carried memref/result. mlir::Value resolveLoopCarriedAlias(mlir::Value value, const StaticValueKnowledge& knowledge); llvm::FailureOr compileContiguousAddressExpr(mlir::Value value); } // namespace onnx_mlir