#pragma once #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Value.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLFunctionalExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "src/Compiler/CompilerOptions.hpp" inline constexpr llvm::StringRef PimWeightAlwaysAttrName = "weightAlways"; namespace onnx_mlir { struct ResolvedContiguousAddress { mlir::Value base; int64_t byteOffset = 0; }; struct StaticValueKnowledge { llvm::DenseMap indexValues; llvm::DenseMap aliases; StaticValueKnowledge() {} }; std::string getOutputDir(); void createDirectory(const std::string& directory); void dumpModule(mlir::ModuleOp moduleOp, const std::string& name); llvm::FailureOr getPimEntryFunc(mlir::ModuleOp moduleOp); bool hasWeightAlways(mlir::Operation* op); void markWeightAlways(mlir::Operation* op); bool isSpatialMvmVmmWeightUse(mlir::OpOperand& use); bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value); void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref callback); mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp); llvm::FailureOr getOtherEndOfChannel(mlir::Operation* op, bool opIsReceive, mlir::RewriterBase& rewriter); llvm::SmallVector computeRowMajorStrides(llvm::ArrayRef shape); llvm::SmallVector delinearizeIndex(int64_t linearIndex, llvm::ArrayRef shape, llvm::ArrayRef strides); int64_t linearizeIndex(llvm::ArrayRef indices, llvm::ArrayRef strides); int64_t getNumElements(llvm::ArrayRef shape); bool isMemoryContiguous(llvm::ArrayRef srcShape, llvm::ArrayRef offsets, llvm::ArrayRef sizes, llvm::ArrayRef strides); llvm::FailureOr resolveContiguousAddress(mlir::Value value); llvm::FailureOr resolveContiguousAddress(mlir::Value value, const StaticValueKnowledge& knowledge); llvm::FailureOr resolveIndexValue(mlir::Value value); llvm::FailureOr resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge); /// Follows alias and view/DPS chains using `knowledge` to find the value an scf.for /// iter-arg is ultimately backed by. Used when interpreting scf.for loop carries. mlir::Value resolveLoopCarriedAlias(mlir::Value value, const StaticValueKnowledge& knowledge); /// Returns true for ops inside a pim.core body that do not emit any PIM instruction and /// only contribute to static addressing or index computations (arith integer math, /// memref view ops, memref.alloc, arith.constant). bool isCoreStaticAddressOp(mlir::Operation* op); /// Walks `block` (the body of a pim.core region or an scf.for nested in it), statically /// unrolling any scf.for with resolvable bounds using `knowledge`. For each remaining op /// that is not skipped (pim.halt, scf.yield, or isCoreStaticAddressOp), `callback` is /// invoked with the op and the in-scope knowledge. The walker keeps going after a callback /// failure so callers can collect multiple diagnostics, but propagates the overall result. mlir::LogicalResult walkPimCoreBlock(mlir::Block& block, const StaticValueKnowledge& knowledge, llvm::function_ref callback); } // namespace onnx_mlir