faster pim VerificationPass.cpp and pim code emission
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
NiccoloN
2026-05-25 15:24:12 +02:00
parent 4855a2e105
commit e8a08f6dd0
18 changed files with 1610 additions and 573 deletions
+52
View File
@@ -1,10 +1,14 @@
#pragma once
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/DenseMap.h"
#include <memory>
#include <optional>
namespace onnx_mlir {
/// Describes a value as a base addressable object plus a statically known
@@ -23,6 +27,51 @@ struct StaticValueKnowledge {
StaticValueKnowledge() {}
};
struct CompiledIndexExprNode;
struct CompiledIndexExpr {
std::shared_ptr<CompiledIndexExprNode> node;
CompiledIndexExpr() = default;
explicit CompiledIndexExpr(std::shared_ptr<CompiledIndexExprNode> node) : node(std::move(node)) {}
llvm::FailureOr<int64_t> evaluate(const StaticValueKnowledge& knowledge) const;
};
struct CompiledIndexExprNode {
enum class Kind {
Constant,
Symbol,
Add,
Sub,
Mul,
DivUI,
DivSI,
RemUI,
RemSI,
MinUI,
CmpI,
Select,
ConstantGlobalLoad
};
Kind kind = Kind::Constant;
int64_t constant = 0;
mlir::Value symbol;
mlir::arith::CmpIPredicate predicate = mlir::arith::CmpIPredicate::eq;
mlir::memref::GlobalOp globalOp;
llvm::SmallVector<int64_t, 4> globalStrides;
llvm::SmallVector<CompiledIndexExpr, 4> operands;
};
struct CompiledAddressExpr {
mlir::Value base;
CompiledIndexExpr byteOffset;
llvm::FailureOr<ResolvedContiguousAddress>
evaluate(const StaticValueKnowledge& knowledge, std::optional<unsigned> lane) const;
};
mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp);
/// Resolves a value to contiguous backing storage when that storage can be
@@ -35,9 +84,12 @@ llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value
/// arithmetic and loop facts recorded in `knowledge`.
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value);
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge);
llvm::FailureOr<CompiledIndexExpr> compileIndexExpr(mlir::Value value);
/// Follows alias, view, and DPS chains to recover the backing value of a
/// loop-carried memref/result.
mlir::Value resolveLoopCarriedAlias(mlir::Value value, const StaticValueKnowledge& knowledge);
llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExpr(mlir::Value value);
} // namespace onnx_mlir