#pragma once #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Value.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLFunctionalExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include #include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" inline constexpr llvm::StringRef PimWeightAlwaysAttrName = "weightAlways"; namespace onnx_mlir { struct ResolvedWeightView { mlir::memref::GlobalOp globalOp; llvm::SmallVector shape; llvm::SmallVector strides; int64_t offset = 0; bool operator==(const ResolvedWeightView& other) const { return globalOp == other.globalOp && shape == other.shape && strides == other.strides && offset == other.offset; } }; bool hasWeightAlways(mlir::Operation* op); /// Tags an op as producing a value that should stay materialized as a reusable /// weight across later PIM lowering/codegen stages. void markWeightAlways(mlir::Operation* op); bool isSpatialMvmVmmWeightUse(mlir::OpOperand& use); /// Returns true when a value flows only into Spatial weighted MVM/VMM operands, /// allowing later passes to preserve it as a dedicated weight-like object. bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value); /// Visits weight operands consumed by Pim core ops/core batches so downstream /// passes can identify globals that must remain weight-backed. void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref callback); std::optional resolveWeightIndex(mlir::Operation* weightOwner, mlir::Value weight); llvm::FailureOr resolveWeightView(mlir::Operation* weightOwner, mlir::Value weight, const StaticValueKnowledge& knowledge = {}); template llvm::SmallVector getUsedWeightIndices(CoreLikeOpTy coreLikeOp) { llvm::SmallVector indices; coreLikeOp.walk([&](pim::PimVMMOp vmmOp) { auto weightIndex = resolveWeightIndex(coreLikeOp.getOperation(), vmmOp.getWeight()); if (weightIndex && !llvm::is_contained(indices, *weightIndex)) indices.push_back(*weightIndex); }); llvm::sort(indices); return indices; } } // namespace onnx_mlir