From 25ade1bd632f5ec59d681e9dddbbdb9b06f31f50 Mon Sep 17 00:00:00 2001 From: NiccoloN Date: Tue, 21 Apr 2026 13:31:10 +0200 Subject: [PATCH] fix memory allocation in pim codegen fix crossbar allocation to only consider weights from vmm and mvm --- src/PIM/Compiler/PimCodeGen.cpp | 69 ++++++++++++++----- src/PIM/Compiler/PimCodeGen.hpp | 1 + .../Bufferization/PimBufferizationPass.cpp | 17 +++-- 3 files changed, 65 insertions(+), 22 deletions(-) diff --git a/src/PIM/Compiler/PimCodeGen.cpp b/src/PIM/Compiler/PimCodeGen.cpp index 0d744da..a2deda2 100644 --- a/src/PIM/Compiler/PimCodeGen.cpp +++ b/src/PIM/Compiler/PimCodeGen.cpp @@ -5,6 +5,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/FileSystem.h" #include "llvm/Support/JSON.h" #include "llvm/Support/raw_ostream.h" @@ -33,6 +34,12 @@ MemEntry* PimMemory::gatherMemEntry(mlir::Value value) { return &memEntries.emplace_back(memEntry, value).first; } +void PimMemory::allocateGatheredMemory() { + llvm::sort(memEntries, [](auto a, auto b) -> bool { return a.first.size > b.first.size; }); + for (auto& [memEntry, value] : memEntries) + allocateMemoryForValue(value, memEntry); +} + void PimMemory::allocateMemoryForValue(mlir::Value value, MemEntry& memEntry) { memEntry.address = firstAvailableAddress; firstAvailableAddress += memEntry.size; @@ -44,35 +51,37 @@ void PimMemory::allocateMemoryForValue(mlir::Value value, MemEntry& memEntry) { } void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) { - // More than one SSA value per single global constant: - // Cannot call gatherMemEntry for each of them, otherwise memory will be allocated multiple times - // Thus, call gatherMemEntry only for the first SSA value and assign the same memEntry to all others - SmallDenseMap globalConstants; + SmallDenseMap globalConstants; + SmallVector, 16> globalAliases; funcOp.walk([&](memref::GetGlobalOp getGlobalOp) { if (!hasWeightAlways(getGlobalOp)) { auto globalMemrefOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp); - auto iter = globalConstants.find(globalMemrefOp); - if (iter == globalConstants.end()) - globalConstants[globalMemrefOp] = gatherMemEntry(getGlobalOp); - else { - MemEntry memEntry = *iter->second; - globalMemEntriesMap[getGlobalOp] = memEntry; - } + auto [iter, inserted] = globalConstants.try_emplace(globalMemrefOp, getGlobalOp.getResult()); + if (inserted) + gatherMemEntry(getGlobalOp.getResult()); + else + globalAliases.push_back({getGlobalOp.getResult(), iter->second}); } }); for (mlir::Value arg : funcOp.getArguments()) gatherMemEntry(arg); - allocateCore(funcOp); + funcOp.walk([&](memref::AllocOp allocOp) { + if (!allocOp->getParentOfType()) + gatherMemEntry(allocOp.getResult()); + }); + + allocateGatheredMemory(); + + for (auto [alias, original] : globalAliases) + globalMemEntriesMap[alias] = getMemEntry(original); } void PimMemory::allocateCore(Operation* op) { op->walk([&](memref::AllocOp allocOp) { gatherMemEntry(allocOp); }); - llvm::sort(memEntries, [](auto a, auto b) -> bool { return a.first.size > b.first.size; }); - for (auto& [memEntry, value] : memEntries) - allocateMemoryForValue(value, memEntry); + allocateGatheredMemory(); } MemEntry PimMemory::getMemEntry(mlir::Value value) const { @@ -465,6 +474,19 @@ std::string getMemorySizeAsString(size_t size) { return std::to_string(size) + " Bytes"; } +static SmallVector getUsedWeightIndices(pim::PimCoreOp coreOp) { + SmallVector indices; + auto addIndex = [&](unsigned weightIndex) { + if (!llvm::is_contained(indices, weightIndex)) + indices.push_back(weightIndex); + }; + + coreOp.walk([&](pim::PimMVMOp mvmOp) { addIndex(mvmOp.getWeightIndex()); }); + coreOp.walk([&](pim::PimVMMOp vmmOp) { addIndex(vmmOp.getWeightIndex()); }); + llvm::sort(indices); + return indices; +} + /// Write global constant data into a binary memory image at their allocated addresses. static OnnxMlirCompilerErrorCodes writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory& memory, StringRef outputDirPath) { @@ -478,12 +500,15 @@ writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory& std::vector memoryBuffer(memory.hostMem.getFirstAvailableAddress(), 0); + SmallPtrSet writtenGlobals; funcOp.walk([&](memref::GetGlobalOp getGlobalOp) { if (hasWeightAlways(getGlobalOp)) return; auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp); if (!globalOp) return; + if (!writtenGlobals.insert(globalOp.getOperation()).second) + return; auto initialValue = globalOp.getInitialValue(); if (!initialValue) return; @@ -658,7 +683,12 @@ createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) { llvm::DenseMap mapGlobalOpToFileName; for (pim::PimCoreOp coreOp : funcOp.getOps()) { - for (auto [index, weight] : llvm::enumerate(coreOp.getWeights())) { + for (unsigned index : getUsedWeightIndices(coreOp)) { + if (index >= coreOp.getWeights().size()) { + coreOp.emitWarning("Weight index " + std::to_string(index) + " is out of range"); + assert(index < coreOp.getWeights().size() && "Weight index is out of range"); + } + mlir::Value weight = coreOp.getWeights()[index]; auto getGlobalOp = weight.getDefiningOp(); if (!getGlobalOp) { @@ -855,7 +885,12 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std:: auto& mapWeightToFile = mapCoreWeightToFileName[coreOp]; json::Array xbarsPerGroup; - for (auto [index, weight] : llvm::enumerate(coreOp.getWeights())) { + for (unsigned index : getUsedWeightIndices(coreOp)) { + if (index >= coreOp.getWeights().size()) { + coreOp.emitWarning("Weight index " + std::to_string(index) + " is out of range"); + assert(index < coreOp.getWeights().size() && "Weight index is out of range"); + } + mlir::Value weight = coreOp.getWeights()[index]; xbarsPerGroup.push_back(index); assert(mapWeightToFile.contains(weight) && "Weight was not materialized into a file!!"); auto& fileName = mapWeightToFile[weight]; diff --git a/src/PIM/Compiler/PimCodeGen.hpp b/src/PIM/Compiler/PimCodeGen.hpp index cd0fd4a..bad6b55 100644 --- a/src/PIM/Compiler/PimCodeGen.hpp +++ b/src/PIM/Compiler/PimCodeGen.hpp @@ -24,6 +24,7 @@ class PimMemory { size_t firstAvailableAddress = 0; MemEntry* gatherMemEntry(mlir::Value value); + void allocateGatheredMemory(); void allocateMemoryForValue(mlir::Value value, MemEntry& memEntry); public: diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp index a2b99a2..69da36a 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp @@ -93,15 +93,22 @@ void PimBufferizationPass::runOnOperation() { } void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const { - funcOp.walk([&](memref::GetGlobalOp getGlobalOp) { - bool isAlwaysWeight = !getGlobalOp->getUsers().empty() - && all_of(getGlobalOp->getUsers(), [](auto user) -> bool { return isa(user); }); - if (isAlwaysWeight) { + funcOp.walk([&](PimCoreOp coreOp) { + auto annotateWeight = [&](unsigned weightIndex) { + if (weightIndex >= coreOp.getWeights().size()) + return; + Value weight = coreOp.getWeights()[weightIndex]; + auto getGlobalOp = weight.getDefiningOp(); + if (!getGlobalOp) + return; auto globalMemrefOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp); assert("Weights must be constants" && globalMemrefOp.getConstant()); markWeightAlways(getGlobalOp); markWeightAlways(globalMemrefOp); - } + }; + + coreOp.walk([&](PimMVMOp mvmOp) { annotateWeight(mvmOp.getWeightIndex()); }); + coreOp.walk([&](PimVMMOp vmmOp) { annotateWeight(vmmOp.getWeightIndex()); }); }); }