fix memory allocation in pim codegen
fix crossbar allocation to only consider weights from vmm and mvm
This commit is contained in:
@@ -5,6 +5,7 @@
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
#include "llvm/Support/FileSystem.h"
|
||||
#include "llvm/Support/JSON.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
@@ -33,6 +34,12 @@ MemEntry* PimMemory::gatherMemEntry(mlir::Value value) {
|
||||
return &memEntries.emplace_back(memEntry, value).first;
|
||||
}
|
||||
|
||||
void PimMemory::allocateGatheredMemory() {
|
||||
llvm::sort(memEntries, [](auto a, auto b) -> bool { return a.first.size > b.first.size; });
|
||||
for (auto& [memEntry, value] : memEntries)
|
||||
allocateMemoryForValue(value, memEntry);
|
||||
}
|
||||
|
||||
void PimMemory::allocateMemoryForValue(mlir::Value value, MemEntry& memEntry) {
|
||||
memEntry.address = firstAvailableAddress;
|
||||
firstAvailableAddress += memEntry.size;
|
||||
@@ -44,35 +51,37 @@ void PimMemory::allocateMemoryForValue(mlir::Value value, MemEntry& memEntry) {
|
||||
}
|
||||
|
||||
void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) {
|
||||
// More than one SSA value per single global constant:
|
||||
// Cannot call gatherMemEntry for each of them, otherwise memory will be allocated multiple times
|
||||
// Thus, call gatherMemEntry only for the first SSA value and assign the same memEntry to all others
|
||||
SmallDenseMap<memref::GlobalOp, MemEntry*, 8> globalConstants;
|
||||
SmallDenseMap<memref::GlobalOp, mlir::Value, 8> globalConstants;
|
||||
SmallVector<std::pair<mlir::Value, mlir::Value>, 16> globalAliases;
|
||||
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
|
||||
if (!hasWeightAlways(getGlobalOp)) {
|
||||
auto globalMemrefOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||
auto iter = globalConstants.find(globalMemrefOp);
|
||||
if (iter == globalConstants.end())
|
||||
globalConstants[globalMemrefOp] = gatherMemEntry(getGlobalOp);
|
||||
else {
|
||||
MemEntry memEntry = *iter->second;
|
||||
globalMemEntriesMap[getGlobalOp] = memEntry;
|
||||
}
|
||||
auto [iter, inserted] = globalConstants.try_emplace(globalMemrefOp, getGlobalOp.getResult());
|
||||
if (inserted)
|
||||
gatherMemEntry(getGlobalOp.getResult());
|
||||
else
|
||||
globalAliases.push_back({getGlobalOp.getResult(), iter->second});
|
||||
}
|
||||
});
|
||||
|
||||
for (mlir::Value arg : funcOp.getArguments())
|
||||
gatherMemEntry(arg);
|
||||
|
||||
allocateCore(funcOp);
|
||||
funcOp.walk([&](memref::AllocOp allocOp) {
|
||||
if (!allocOp->getParentOfType<pim::PimCoreOp>())
|
||||
gatherMemEntry(allocOp.getResult());
|
||||
});
|
||||
|
||||
allocateGatheredMemory();
|
||||
|
||||
for (auto [alias, original] : globalAliases)
|
||||
globalMemEntriesMap[alias] = getMemEntry(original);
|
||||
}
|
||||
|
||||
void PimMemory::allocateCore(Operation* op) {
|
||||
op->walk([&](memref::AllocOp allocOp) { gatherMemEntry(allocOp); });
|
||||
|
||||
llvm::sort(memEntries, [](auto a, auto b) -> bool { return a.first.size > b.first.size; });
|
||||
for (auto& [memEntry, value] : memEntries)
|
||||
allocateMemoryForValue(value, memEntry);
|
||||
allocateGatheredMemory();
|
||||
}
|
||||
|
||||
MemEntry PimMemory::getMemEntry(mlir::Value value) const {
|
||||
@@ -465,6 +474,19 @@ std::string getMemorySizeAsString(size_t size) {
|
||||
return std::to_string(size) + " Bytes";
|
||||
}
|
||||
|
||||
static SmallVector<unsigned, 8> getUsedWeightIndices(pim::PimCoreOp coreOp) {
|
||||
SmallVector<unsigned, 8> indices;
|
||||
auto addIndex = [&](unsigned weightIndex) {
|
||||
if (!llvm::is_contained(indices, weightIndex))
|
||||
indices.push_back(weightIndex);
|
||||
};
|
||||
|
||||
coreOp.walk([&](pim::PimMVMOp mvmOp) { addIndex(mvmOp.getWeightIndex()); });
|
||||
coreOp.walk([&](pim::PimVMMOp vmmOp) { addIndex(vmmOp.getWeightIndex()); });
|
||||
llvm::sort(indices);
|
||||
return indices;
|
||||
}
|
||||
|
||||
/// Write global constant data into a binary memory image at their allocated addresses.
|
||||
static OnnxMlirCompilerErrorCodes
|
||||
writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory& memory, StringRef outputDirPath) {
|
||||
@@ -478,12 +500,15 @@ writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory&
|
||||
|
||||
std::vector<char> memoryBuffer(memory.hostMem.getFirstAvailableAddress(), 0);
|
||||
|
||||
SmallPtrSet<Operation*, 16> writtenGlobals;
|
||||
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
|
||||
if (hasWeightAlways(getGlobalOp))
|
||||
return;
|
||||
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||
if (!globalOp)
|
||||
return;
|
||||
if (!writtenGlobals.insert(globalOp.getOperation()).second)
|
||||
return;
|
||||
auto initialValue = globalOp.getInitialValue();
|
||||
if (!initialValue)
|
||||
return;
|
||||
@@ -658,7 +683,12 @@ createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) {
|
||||
llvm::DenseMap<memref::GlobalOp, std::string> mapGlobalOpToFileName;
|
||||
|
||||
for (pim::PimCoreOp coreOp : funcOp.getOps<pim::PimCoreOp>()) {
|
||||
for (auto [index, weight] : llvm::enumerate(coreOp.getWeights())) {
|
||||
for (unsigned index : getUsedWeightIndices(coreOp)) {
|
||||
if (index >= coreOp.getWeights().size()) {
|
||||
coreOp.emitWarning("Weight index " + std::to_string(index) + " is out of range");
|
||||
assert(index < coreOp.getWeights().size() && "Weight index is out of range");
|
||||
}
|
||||
mlir::Value weight = coreOp.getWeights()[index];
|
||||
|
||||
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
|
||||
if (!getGlobalOp) {
|
||||
@@ -855,7 +885,12 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
|
||||
|
||||
auto& mapWeightToFile = mapCoreWeightToFileName[coreOp];
|
||||
json::Array xbarsPerGroup;
|
||||
for (auto [index, weight] : llvm::enumerate(coreOp.getWeights())) {
|
||||
for (unsigned index : getUsedWeightIndices(coreOp)) {
|
||||
if (index >= coreOp.getWeights().size()) {
|
||||
coreOp.emitWarning("Weight index " + std::to_string(index) + " is out of range");
|
||||
assert(index < coreOp.getWeights().size() && "Weight index is out of range");
|
||||
}
|
||||
mlir::Value weight = coreOp.getWeights()[index];
|
||||
xbarsPerGroup.push_back(index);
|
||||
assert(mapWeightToFile.contains(weight) && "Weight was not materialized into a file!!");
|
||||
auto& fileName = mapWeightToFile[weight];
|
||||
|
||||
Reference in New Issue
Block a user