Files
Raptor/src/PIM/Common/IR/WeightUtils.hpp
T
NiccoloN 8bb0babf1b
Validate Operations / validate-operations (push) Has been cancelled
finish helper refactoring
use uniqued constant helpers everywhere
materialize transposed constants directly
2026-05-29 17:05:45 +02:00

65 lines
2.2 KiB
C++

#pragma once
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/STLFunctionalExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include <optional>
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
inline constexpr llvm::StringRef PimWeightAlwaysAttrName = "weightAlways";
namespace onnx_mlir {
struct ResolvedWeightView {
mlir::memref::GlobalOp globalOp;
llvm::SmallVector<int64_t> shape;
llvm::SmallVector<int64_t> strides;
int64_t offset = 0;
bool operator==(const ResolvedWeightView& other) const {
return globalOp == other.globalOp && shape == other.shape && strides == other.strides && offset == other.offset;
}
};
bool hasWeightAlways(mlir::Operation* op);
/// Tags an op as producing a value that should stay materialized as a reusable
/// weight across later PIM lowering/codegen stages.
void markWeightAlways(mlir::Operation* op);
bool isSpatialMvmVmmWeightUse(mlir::OpOperand& use);
/// Returns true when a value flows only into Spatial weighted MVM/VMM operands,
/// allowing later passes to preserve it as a dedicated weight-like object.
bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value);
/// Visits weight operands consumed by Pim core ops/core batches so downstream
/// passes can identify globals that must remain weight-backed.
void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir::OpOperand&)> callback);
std::optional<unsigned> resolveWeightIndex(mlir::Operation* weightOwner, mlir::Value weight);
llvm::FailureOr<ResolvedWeightView>
resolveWeightView(mlir::Operation* weightOwner, mlir::Value weight, const StaticValueKnowledge& knowledge = {});
template <typename CoreLikeOpTy>
llvm::SmallVector<unsigned, 8> getUsedWeightIndices(CoreLikeOpTy coreLikeOp) {
llvm::SmallVector<unsigned, 8> indices;
coreLikeOp.walk([&](pim::PimVMMOp vmmOp) {
auto weightIndex = resolveWeightIndex(coreLikeOp.getOperation(), vmmOp.getWeight());
if (weightIndex && !llvm::is_contained(indices, *weightIndex))
indices.push_back(*weightIndex);
});
llvm::sort(indices);
return indices;
}
} // namespace onnx_mlir