fix memory allocation in pim codegen

fix crossbar allocation to only consider weights from vmm and mvm
This commit is contained in:
NiccoloN
2026-04-21 13:31:10 +02:00
parent 85e2750d6c
commit 25ade1bd63
3 changed files with 65 additions and 22 deletions

View File

@@ -5,6 +5,7 @@
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/FileSystem.h" #include "llvm/Support/FileSystem.h"
#include "llvm/Support/JSON.h" #include "llvm/Support/JSON.h"
#include "llvm/Support/raw_ostream.h" #include "llvm/Support/raw_ostream.h"
@@ -33,6 +34,12 @@ MemEntry* PimMemory::gatherMemEntry(mlir::Value value) {
return &memEntries.emplace_back(memEntry, value).first; return &memEntries.emplace_back(memEntry, value).first;
} }
void PimMemory::allocateGatheredMemory() {
llvm::sort(memEntries, [](auto a, auto b) -> bool { return a.first.size > b.first.size; });
for (auto& [memEntry, value] : memEntries)
allocateMemoryForValue(value, memEntry);
}
void PimMemory::allocateMemoryForValue(mlir::Value value, MemEntry& memEntry) { void PimMemory::allocateMemoryForValue(mlir::Value value, MemEntry& memEntry) {
memEntry.address = firstAvailableAddress; memEntry.address = firstAvailableAddress;
firstAvailableAddress += memEntry.size; firstAvailableAddress += memEntry.size;
@@ -44,35 +51,37 @@ void PimMemory::allocateMemoryForValue(mlir::Value value, MemEntry& memEntry) {
} }
void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) { void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) {
// More than one SSA value per single global constant: SmallDenseMap<memref::GlobalOp, mlir::Value, 8> globalConstants;
// Cannot call gatherMemEntry for each of them, otherwise memory will be allocated multiple times SmallVector<std::pair<mlir::Value, mlir::Value>, 16> globalAliases;
// Thus, call gatherMemEntry only for the first SSA value and assign the same memEntry to all others
SmallDenseMap<memref::GlobalOp, MemEntry*, 8> globalConstants;
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) { funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
if (!hasWeightAlways(getGlobalOp)) { if (!hasWeightAlways(getGlobalOp)) {
auto globalMemrefOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp); auto globalMemrefOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
auto iter = globalConstants.find(globalMemrefOp); auto [iter, inserted] = globalConstants.try_emplace(globalMemrefOp, getGlobalOp.getResult());
if (iter == globalConstants.end()) if (inserted)
globalConstants[globalMemrefOp] = gatherMemEntry(getGlobalOp); gatherMemEntry(getGlobalOp.getResult());
else { else
MemEntry memEntry = *iter->second; globalAliases.push_back({getGlobalOp.getResult(), iter->second});
globalMemEntriesMap[getGlobalOp] = memEntry;
}
} }
}); });
for (mlir::Value arg : funcOp.getArguments()) for (mlir::Value arg : funcOp.getArguments())
gatherMemEntry(arg); gatherMemEntry(arg);
allocateCore(funcOp); funcOp.walk([&](memref::AllocOp allocOp) {
if (!allocOp->getParentOfType<pim::PimCoreOp>())
gatherMemEntry(allocOp.getResult());
});
allocateGatheredMemory();
for (auto [alias, original] : globalAliases)
globalMemEntriesMap[alias] = getMemEntry(original);
} }
void PimMemory::allocateCore(Operation* op) { void PimMemory::allocateCore(Operation* op) {
op->walk([&](memref::AllocOp allocOp) { gatherMemEntry(allocOp); }); op->walk([&](memref::AllocOp allocOp) { gatherMemEntry(allocOp); });
llvm::sort(memEntries, [](auto a, auto b) -> bool { return a.first.size > b.first.size; }); allocateGatheredMemory();
for (auto& [memEntry, value] : memEntries)
allocateMemoryForValue(value, memEntry);
} }
MemEntry PimMemory::getMemEntry(mlir::Value value) const { MemEntry PimMemory::getMemEntry(mlir::Value value) const {
@@ -465,6 +474,19 @@ std::string getMemorySizeAsString(size_t size) {
return std::to_string(size) + " Bytes"; return std::to_string(size) + " Bytes";
} }
static SmallVector<unsigned, 8> getUsedWeightIndices(pim::PimCoreOp coreOp) {
SmallVector<unsigned, 8> indices;
auto addIndex = [&](unsigned weightIndex) {
if (!llvm::is_contained(indices, weightIndex))
indices.push_back(weightIndex);
};
coreOp.walk([&](pim::PimMVMOp mvmOp) { addIndex(mvmOp.getWeightIndex()); });
coreOp.walk([&](pim::PimVMMOp vmmOp) { addIndex(vmmOp.getWeightIndex()); });
llvm::sort(indices);
return indices;
}
/// Write global constant data into a binary memory image at their allocated addresses. /// Write global constant data into a binary memory image at their allocated addresses.
static OnnxMlirCompilerErrorCodes static OnnxMlirCompilerErrorCodes
writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory& memory, StringRef outputDirPath) { writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory& memory, StringRef outputDirPath) {
@@ -478,12 +500,15 @@ writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory&
std::vector<char> memoryBuffer(memory.hostMem.getFirstAvailableAddress(), 0); std::vector<char> memoryBuffer(memory.hostMem.getFirstAvailableAddress(), 0);
SmallPtrSet<Operation*, 16> writtenGlobals;
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) { funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
if (hasWeightAlways(getGlobalOp)) if (hasWeightAlways(getGlobalOp))
return; return;
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp); auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp) if (!globalOp)
return; return;
if (!writtenGlobals.insert(globalOp.getOperation()).second)
return;
auto initialValue = globalOp.getInitialValue(); auto initialValue = globalOp.getInitialValue();
if (!initialValue) if (!initialValue)
return; return;
@@ -658,7 +683,12 @@ createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) {
llvm::DenseMap<memref::GlobalOp, std::string> mapGlobalOpToFileName; llvm::DenseMap<memref::GlobalOp, std::string> mapGlobalOpToFileName;
for (pim::PimCoreOp coreOp : funcOp.getOps<pim::PimCoreOp>()) { for (pim::PimCoreOp coreOp : funcOp.getOps<pim::PimCoreOp>()) {
for (auto [index, weight] : llvm::enumerate(coreOp.getWeights())) { for (unsigned index : getUsedWeightIndices(coreOp)) {
if (index >= coreOp.getWeights().size()) {
coreOp.emitWarning("Weight index " + std::to_string(index) + " is out of range");
assert(index < coreOp.getWeights().size() && "Weight index is out of range");
}
mlir::Value weight = coreOp.getWeights()[index];
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>(); auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
if (!getGlobalOp) { if (!getGlobalOp) {
@@ -855,7 +885,12 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
auto& mapWeightToFile = mapCoreWeightToFileName[coreOp]; auto& mapWeightToFile = mapCoreWeightToFileName[coreOp];
json::Array xbarsPerGroup; json::Array xbarsPerGroup;
for (auto [index, weight] : llvm::enumerate(coreOp.getWeights())) { for (unsigned index : getUsedWeightIndices(coreOp)) {
if (index >= coreOp.getWeights().size()) {
coreOp.emitWarning("Weight index " + std::to_string(index) + " is out of range");
assert(index < coreOp.getWeights().size() && "Weight index is out of range");
}
mlir::Value weight = coreOp.getWeights()[index];
xbarsPerGroup.push_back(index); xbarsPerGroup.push_back(index);
assert(mapWeightToFile.contains(weight) && "Weight was not materialized into a file!!"); assert(mapWeightToFile.contains(weight) && "Weight was not materialized into a file!!");
auto& fileName = mapWeightToFile[weight]; auto& fileName = mapWeightToFile[weight];

View File

@@ -24,6 +24,7 @@ class PimMemory {
size_t firstAvailableAddress = 0; size_t firstAvailableAddress = 0;
MemEntry* gatherMemEntry(mlir::Value value); MemEntry* gatherMemEntry(mlir::Value value);
void allocateGatheredMemory();
void allocateMemoryForValue(mlir::Value value, MemEntry& memEntry); void allocateMemoryForValue(mlir::Value value, MemEntry& memEntry);
public: public:

View File

@@ -93,15 +93,22 @@ void PimBufferizationPass::runOnOperation() {
} }
void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const { void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const {
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) { funcOp.walk([&](PimCoreOp coreOp) {
bool isAlwaysWeight = !getGlobalOp->getUsers().empty() auto annotateWeight = [&](unsigned weightIndex) {
&& all_of(getGlobalOp->getUsers(), [](auto user) -> bool { return isa<PimCoreOp>(user); }); if (weightIndex >= coreOp.getWeights().size())
if (isAlwaysWeight) { return;
Value weight = coreOp.getWeights()[weightIndex];
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
if (!getGlobalOp)
return;
auto globalMemrefOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp); auto globalMemrefOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
assert("Weights must be constants" && globalMemrefOp.getConstant()); assert("Weights must be constants" && globalMemrefOp.getConstant());
markWeightAlways(getGlobalOp); markWeightAlways(getGlobalOp);
markWeightAlways(globalMemrefOp); markWeightAlways(globalMemrefOp);
} };
coreOp.walk([&](PimMVMOp mvmOp) { annotateWeight(mvmOp.getWeightIndex()); });
coreOp.walk([&](PimVMMOp vmmOp) { annotateWeight(vmmOp.getWeightIndex()); });
}); });
} }