8bb0babf1b
Validate Operations / validate-operations (push) Has been cancelled
use uniqued constant helpers everywhere materialize transposed constants directly
95 lines
2.8 KiB
C++
95 lines
2.8 KiB
C++
#pragma once
|
|
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
#include "mlir/IR/Value.h"
|
|
|
|
#include "llvm/ADT/DenseMap.h"
|
|
|
|
#include <memory>
|
|
#include <optional>
|
|
|
|
namespace onnx_mlir {
|
|
|
|
/// Describes a value as a base addressable object plus a statically known
|
|
/// byte offset after peeling aliases, casts, and contiguous subviews.
|
|
struct ResolvedContiguousAddress {
|
|
mlir::Value base;
|
|
int64_t byteOffset = 0;
|
|
};
|
|
|
|
/// Records compile-time facts used when interpreting address arithmetic and
|
|
/// loop-carried aliases inside PIM regions.
|
|
struct StaticValueKnowledge {
|
|
llvm::DenseMap<mlir::Value, int64_t> indexValues;
|
|
llvm::DenseMap<mlir::Value, mlir::Value> aliases;
|
|
|
|
StaticValueKnowledge() {}
|
|
};
|
|
|
|
struct CompiledIndexExprNode;
|
|
|
|
struct CompiledIndexExpr {
|
|
std::shared_ptr<CompiledIndexExprNode> node;
|
|
|
|
CompiledIndexExpr() = default;
|
|
explicit CompiledIndexExpr(std::shared_ptr<CompiledIndexExprNode> node)
|
|
: node(std::move(node)) {}
|
|
|
|
llvm::FailureOr<int64_t> evaluate(const StaticValueKnowledge& knowledge) const;
|
|
};
|
|
|
|
struct CompiledIndexExprNode {
|
|
enum class Kind {
|
|
Constant,
|
|
Symbol,
|
|
Add,
|
|
Sub,
|
|
Mul,
|
|
DivUI,
|
|
DivSI,
|
|
RemUI,
|
|
RemSI,
|
|
MinUI,
|
|
CmpI,
|
|
Select,
|
|
ConstantGlobalLoad
|
|
};
|
|
|
|
Kind kind = Kind::Constant;
|
|
int64_t constant = 0;
|
|
mlir::Value symbol;
|
|
mlir::arith::CmpIPredicate predicate = mlir::arith::CmpIPredicate::eq;
|
|
mlir::memref::GlobalOp globalOp;
|
|
llvm::SmallVector<int64_t, 4> globalStrides;
|
|
llvm::SmallVector<CompiledIndexExpr, 4> operands;
|
|
};
|
|
|
|
struct CompiledAddressExpr {
|
|
mlir::Value base;
|
|
CompiledIndexExpr byteOffset;
|
|
|
|
llvm::FailureOr<ResolvedContiguousAddress> evaluate(const StaticValueKnowledge& knowledge,
|
|
std::optional<unsigned> lane) const;
|
|
};
|
|
|
|
mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp);
|
|
|
|
/// Resolves a value to contiguous backing storage when that storage can be
|
|
/// proven statically from aliases, DPS ties, casts, and subviews.
|
|
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value,
|
|
const StaticValueKnowledge& knowledge = {});
|
|
|
|
/// Statically evaluates index-like SSA values, including simple integer
|
|
/// arithmetic and loop facts recorded in `knowledge`.
|
|
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge = {});
|
|
llvm::FailureOr<CompiledIndexExpr> compileIndexExpr(mlir::Value value);
|
|
|
|
/// Follows alias, view, and DPS chains to recover the backing value of a
|
|
/// loop-carried memref/result.
|
|
mlir::Value resolveLoopCarriedAlias(mlir::Value value, const StaticValueKnowledge& knowledge);
|
|
|
|
llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExpr(mlir::Value value);
|
|
|
|
} // namespace onnx_mlir
|