fix memory allocation in pim codegen
fix crossbar allocation to only consider weights from vmm and mvm
This commit is contained in:
@@ -5,6 +5,7 @@
|
|||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
|
||||||
#include "llvm/ADT/DenseMap.h"
|
#include "llvm/ADT/DenseMap.h"
|
||||||
|
#include "llvm/ADT/SmallPtrSet.h"
|
||||||
#include "llvm/Support/FileSystem.h"
|
#include "llvm/Support/FileSystem.h"
|
||||||
#include "llvm/Support/JSON.h"
|
#include "llvm/Support/JSON.h"
|
||||||
#include "llvm/Support/raw_ostream.h"
|
#include "llvm/Support/raw_ostream.h"
|
||||||
@@ -33,6 +34,12 @@ MemEntry* PimMemory::gatherMemEntry(mlir::Value value) {
|
|||||||
return &memEntries.emplace_back(memEntry, value).first;
|
return &memEntries.emplace_back(memEntry, value).first;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void PimMemory::allocateGatheredMemory() {
|
||||||
|
llvm::sort(memEntries, [](auto a, auto b) -> bool { return a.first.size > b.first.size; });
|
||||||
|
for (auto& [memEntry, value] : memEntries)
|
||||||
|
allocateMemoryForValue(value, memEntry);
|
||||||
|
}
|
||||||
|
|
||||||
void PimMemory::allocateMemoryForValue(mlir::Value value, MemEntry& memEntry) {
|
void PimMemory::allocateMemoryForValue(mlir::Value value, MemEntry& memEntry) {
|
||||||
memEntry.address = firstAvailableAddress;
|
memEntry.address = firstAvailableAddress;
|
||||||
firstAvailableAddress += memEntry.size;
|
firstAvailableAddress += memEntry.size;
|
||||||
@@ -44,35 +51,37 @@ void PimMemory::allocateMemoryForValue(mlir::Value value, MemEntry& memEntry) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) {
|
void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) {
|
||||||
// More than one SSA value per single global constant:
|
SmallDenseMap<memref::GlobalOp, mlir::Value, 8> globalConstants;
|
||||||
// Cannot call gatherMemEntry for each of them, otherwise memory will be allocated multiple times
|
SmallVector<std::pair<mlir::Value, mlir::Value>, 16> globalAliases;
|
||||||
// Thus, call gatherMemEntry only for the first SSA value and assign the same memEntry to all others
|
|
||||||
SmallDenseMap<memref::GlobalOp, MemEntry*, 8> globalConstants;
|
|
||||||
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
|
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
|
||||||
if (!hasWeightAlways(getGlobalOp)) {
|
if (!hasWeightAlways(getGlobalOp)) {
|
||||||
auto globalMemrefOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
auto globalMemrefOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||||
auto iter = globalConstants.find(globalMemrefOp);
|
auto [iter, inserted] = globalConstants.try_emplace(globalMemrefOp, getGlobalOp.getResult());
|
||||||
if (iter == globalConstants.end())
|
if (inserted)
|
||||||
globalConstants[globalMemrefOp] = gatherMemEntry(getGlobalOp);
|
gatherMemEntry(getGlobalOp.getResult());
|
||||||
else {
|
else
|
||||||
MemEntry memEntry = *iter->second;
|
globalAliases.push_back({getGlobalOp.getResult(), iter->second});
|
||||||
globalMemEntriesMap[getGlobalOp] = memEntry;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
for (mlir::Value arg : funcOp.getArguments())
|
for (mlir::Value arg : funcOp.getArguments())
|
||||||
gatherMemEntry(arg);
|
gatherMemEntry(arg);
|
||||||
|
|
||||||
allocateCore(funcOp);
|
funcOp.walk([&](memref::AllocOp allocOp) {
|
||||||
|
if (!allocOp->getParentOfType<pim::PimCoreOp>())
|
||||||
|
gatherMemEntry(allocOp.getResult());
|
||||||
|
});
|
||||||
|
|
||||||
|
allocateGatheredMemory();
|
||||||
|
|
||||||
|
for (auto [alias, original] : globalAliases)
|
||||||
|
globalMemEntriesMap[alias] = getMemEntry(original);
|
||||||
}
|
}
|
||||||
|
|
||||||
void PimMemory::allocateCore(Operation* op) {
|
void PimMemory::allocateCore(Operation* op) {
|
||||||
op->walk([&](memref::AllocOp allocOp) { gatherMemEntry(allocOp); });
|
op->walk([&](memref::AllocOp allocOp) { gatherMemEntry(allocOp); });
|
||||||
|
|
||||||
llvm::sort(memEntries, [](auto a, auto b) -> bool { return a.first.size > b.first.size; });
|
allocateGatheredMemory();
|
||||||
for (auto& [memEntry, value] : memEntries)
|
|
||||||
allocateMemoryForValue(value, memEntry);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
MemEntry PimMemory::getMemEntry(mlir::Value value) const {
|
MemEntry PimMemory::getMemEntry(mlir::Value value) const {
|
||||||
@@ -465,6 +474,19 @@ std::string getMemorySizeAsString(size_t size) {
|
|||||||
return std::to_string(size) + " Bytes";
|
return std::to_string(size) + " Bytes";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static SmallVector<unsigned, 8> getUsedWeightIndices(pim::PimCoreOp coreOp) {
|
||||||
|
SmallVector<unsigned, 8> indices;
|
||||||
|
auto addIndex = [&](unsigned weightIndex) {
|
||||||
|
if (!llvm::is_contained(indices, weightIndex))
|
||||||
|
indices.push_back(weightIndex);
|
||||||
|
};
|
||||||
|
|
||||||
|
coreOp.walk([&](pim::PimMVMOp mvmOp) { addIndex(mvmOp.getWeightIndex()); });
|
||||||
|
coreOp.walk([&](pim::PimVMMOp vmmOp) { addIndex(vmmOp.getWeightIndex()); });
|
||||||
|
llvm::sort(indices);
|
||||||
|
return indices;
|
||||||
|
}
|
||||||
|
|
||||||
/// Write global constant data into a binary memory image at their allocated addresses.
|
/// Write global constant data into a binary memory image at their allocated addresses.
|
||||||
static OnnxMlirCompilerErrorCodes
|
static OnnxMlirCompilerErrorCodes
|
||||||
writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory& memory, StringRef outputDirPath) {
|
writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory& memory, StringRef outputDirPath) {
|
||||||
@@ -478,12 +500,15 @@ writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory&
|
|||||||
|
|
||||||
std::vector<char> memoryBuffer(memory.hostMem.getFirstAvailableAddress(), 0);
|
std::vector<char> memoryBuffer(memory.hostMem.getFirstAvailableAddress(), 0);
|
||||||
|
|
||||||
|
SmallPtrSet<Operation*, 16> writtenGlobals;
|
||||||
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
|
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
|
||||||
if (hasWeightAlways(getGlobalOp))
|
if (hasWeightAlways(getGlobalOp))
|
||||||
return;
|
return;
|
||||||
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||||
if (!globalOp)
|
if (!globalOp)
|
||||||
return;
|
return;
|
||||||
|
if (!writtenGlobals.insert(globalOp.getOperation()).second)
|
||||||
|
return;
|
||||||
auto initialValue = globalOp.getInitialValue();
|
auto initialValue = globalOp.getInitialValue();
|
||||||
if (!initialValue)
|
if (!initialValue)
|
||||||
return;
|
return;
|
||||||
@@ -658,7 +683,12 @@ createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) {
|
|||||||
llvm::DenseMap<memref::GlobalOp, std::string> mapGlobalOpToFileName;
|
llvm::DenseMap<memref::GlobalOp, std::string> mapGlobalOpToFileName;
|
||||||
|
|
||||||
for (pim::PimCoreOp coreOp : funcOp.getOps<pim::PimCoreOp>()) {
|
for (pim::PimCoreOp coreOp : funcOp.getOps<pim::PimCoreOp>()) {
|
||||||
for (auto [index, weight] : llvm::enumerate(coreOp.getWeights())) {
|
for (unsigned index : getUsedWeightIndices(coreOp)) {
|
||||||
|
if (index >= coreOp.getWeights().size()) {
|
||||||
|
coreOp.emitWarning("Weight index " + std::to_string(index) + " is out of range");
|
||||||
|
assert(index < coreOp.getWeights().size() && "Weight index is out of range");
|
||||||
|
}
|
||||||
|
mlir::Value weight = coreOp.getWeights()[index];
|
||||||
|
|
||||||
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
|
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
|
||||||
if (!getGlobalOp) {
|
if (!getGlobalOp) {
|
||||||
@@ -855,7 +885,12 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
|
|||||||
|
|
||||||
auto& mapWeightToFile = mapCoreWeightToFileName[coreOp];
|
auto& mapWeightToFile = mapCoreWeightToFileName[coreOp];
|
||||||
json::Array xbarsPerGroup;
|
json::Array xbarsPerGroup;
|
||||||
for (auto [index, weight] : llvm::enumerate(coreOp.getWeights())) {
|
for (unsigned index : getUsedWeightIndices(coreOp)) {
|
||||||
|
if (index >= coreOp.getWeights().size()) {
|
||||||
|
coreOp.emitWarning("Weight index " + std::to_string(index) + " is out of range");
|
||||||
|
assert(index < coreOp.getWeights().size() && "Weight index is out of range");
|
||||||
|
}
|
||||||
|
mlir::Value weight = coreOp.getWeights()[index];
|
||||||
xbarsPerGroup.push_back(index);
|
xbarsPerGroup.push_back(index);
|
||||||
assert(mapWeightToFile.contains(weight) && "Weight was not materialized into a file!!");
|
assert(mapWeightToFile.contains(weight) && "Weight was not materialized into a file!!");
|
||||||
auto& fileName = mapWeightToFile[weight];
|
auto& fileName = mapWeightToFile[weight];
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ class PimMemory {
|
|||||||
size_t firstAvailableAddress = 0;
|
size_t firstAvailableAddress = 0;
|
||||||
|
|
||||||
MemEntry* gatherMemEntry(mlir::Value value);
|
MemEntry* gatherMemEntry(mlir::Value value);
|
||||||
|
void allocateGatheredMemory();
|
||||||
void allocateMemoryForValue(mlir::Value value, MemEntry& memEntry);
|
void allocateMemoryForValue(mlir::Value value, MemEntry& memEntry);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|||||||
@@ -93,15 +93,22 @@ void PimBufferizationPass::runOnOperation() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const {
|
void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const {
|
||||||
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
|
funcOp.walk([&](PimCoreOp coreOp) {
|
||||||
bool isAlwaysWeight = !getGlobalOp->getUsers().empty()
|
auto annotateWeight = [&](unsigned weightIndex) {
|
||||||
&& all_of(getGlobalOp->getUsers(), [](auto user) -> bool { return isa<PimCoreOp>(user); });
|
if (weightIndex >= coreOp.getWeights().size())
|
||||||
if (isAlwaysWeight) {
|
return;
|
||||||
|
Value weight = coreOp.getWeights()[weightIndex];
|
||||||
|
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
|
||||||
|
if (!getGlobalOp)
|
||||||
|
return;
|
||||||
auto globalMemrefOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
auto globalMemrefOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||||
assert("Weights must be constants" && globalMemrefOp.getConstant());
|
assert("Weights must be constants" && globalMemrefOp.getConstant());
|
||||||
markWeightAlways(getGlobalOp);
|
markWeightAlways(getGlobalOp);
|
||||||
markWeightAlways(globalMemrefOp);
|
markWeightAlways(globalMemrefOp);
|
||||||
}
|
};
|
||||||
|
|
||||||
|
coreOp.walk([&](PimMVMOp mvmOp) { annotateWeight(mvmOp.getWeightIndex()); });
|
||||||
|
coreOp.walk([&](PimVMMOp vmmOp) { annotateWeight(vmmOp.getWeightIndex()); });
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user