Merge branch 'refactorone' of chef.heaplab.deib.polimi.it:nnicolosi/Raptor into refactorone

This commit is contained in:
ilgeco
2026-06-03 13:49:42 +02:00
6 changed files with 88 additions and 38 deletions
@@ -1,3 +1,4 @@
use std::cmp::min;
use std::fmt::Debug;
use anyhow::{Context, Result, bail, ensure};
@@ -86,7 +87,7 @@ where {
size,
};
if self.memory.len() < address + size {
self.memory.resize((address + size) * 2, 0);
self.memory.resize(min((address + size) * 2, u32::MAX as usize), 0);
}
self.load_requests.push(load_request);
Ok(self)
+50 -29
View File
@@ -74,6 +74,22 @@ static MemoryValueKey getMemoryValueKey(mlir::Value value, std::optional<unsigne
return {value, getLaneForMemoryValue(value, lane)};
}
static bool isInsidePimCoreLikeOp(memref::AllocOp allocOp) {
return allocOp->getParentOfType<pim::PimCoreOp>() || allocOp->getParentOfType<pim::PimCoreBatchOp>();
}
static MemoryReportKind classifyMemoryReportKind(mlir::Value value) {
if (isa<mlir::BlockArgument>(value))
return MemoryReportKind::Input;
if (auto* op = value.getDefiningOp()) {
if (isa<memref::AllocOp>(op))
return MemoryReportKind::Alloca;
if (isa<memref::GetGlobalOp>(op))
return MemoryReportKind::Global;
}
return MemoryReportKind::None;
}
static int32_t getVectorByteSizeOrCrash(ShapedType type) {
auto byteSize = pim::getCheckedShapedTypeSizeInBytes(type, UnknownLoc::get(type.getContext()), "vector byte size");
if (failed(byteSize))
@@ -90,19 +106,23 @@ MemEntry* PimMemory::gatherMemEntry(mlir::Value value, std::optional<unsigned> l
pim::getCheckedShapedTypeSizeInBytes(type, UnknownLoc::get(type.getContext()), "memory allocation byte size");
if (failed(checkedAllocSize))
llvm_unreachable("Failed to compute checked allocation byte size");
size_t allocSize = static_cast<size_t>(*checkedAllocSize);
MemEntry memEntry = {0, allocSize};
return &memEntries.emplace_back(memEntry, getMemoryValueKey(value, lane)).first;
PendingMemEntry pending;
pending.memEntry = {0, *checkedAllocSize};
pending.key = getMemoryValueKey(value, lane);
pending.reportKind = classifyMemoryReportKind(value);
return &memEntries.emplace_back(std::move(pending)).memEntry;
}
void PimMemory::allocateGatheredMemory() {
llvm::sort(memEntries, [](auto a, auto b) -> bool { return a.first.size > b.first.size; });
for (auto& [memEntry, key] : memEntries)
allocateMemoryForValue(key, memEntry);
llvm::sort(memEntries, [](const PendingMemEntry& lhs, const PendingMemEntry& rhs) {
return lhs.memEntry.size > rhs.memEntry.size;
});
for (PendingMemEntry& pending : memEntries)
allocateMemoryForValue(pending.key, pending.memEntry, pending.reportKind);
memEntries.clear();
}
void PimMemory::allocateMemoryForValue(const MemoryValueKey& key, MemEntry& memEntry) {
void PimMemory::allocateMemoryForValue(const MemoryValueKey& key, MemEntry& memEntry, MemoryReportKind reportKind) {
memEntry.address = firstAvailableAddress;
assert(memEntry.address < (size_t) INT_MAX && "Address allocated bigger than 32bit");
firstAvailableAddress += memEntry.size;
@@ -112,6 +132,19 @@ void PimMemory::allocateMemoryForValue(const MemoryValueKey& key, MemEntry& memE
ownedMemEntriesMap[key] = memEntry;
globalMemEntriesMap[key] = memEntry;
switch (reportKind) {
case MemoryReportKind::Alloca:
++reportRow.numAlloca;
reportRow.sizeAlloca += memEntry.size;
break;
case MemoryReportKind::Global:
++reportRow.numGlobal;
reportRow.sizeGlobal += memEntry.size;
break;
case MemoryReportKind::Input:
case MemoryReportKind::None: break;
}
}
void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) {
@@ -142,7 +175,7 @@ void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) {
});
funcOp.walk([&](memref::AllocOp allocOp) {
if (!allocOp->getParentOfType<pim::PimCoreOp>() && !allocOp->getParentOfType<pim::PimCoreBatchOp>())
if (!isInsidePimCoreLikeOp(allocOp))
gatherMemEntry(allocOp.getResult());
});
@@ -195,23 +228,7 @@ static MemoryReportRow addMemoryReportRows(const MemoryReportRow& lhs, const Mem
return result;
}
MemoryReportRow PimMemory::getReportRow() const {
MemoryReportRow row;
for (auto& [key, memEntry] : ownedMemEntriesMap) {
if (auto op = key.value.getDefiningOp()) {
if (isa<memref::AllocOp>(op)) {
row.numAlloca++;
row.sizeAlloca += memEntry.size;
}
if (isa<memref::GetGlobalOp>(op)) {
row.numGlobal++;
row.sizeGlobal += memEntry.size;
}
}
}
return row;
}
MemoryReportRow PimMemory::getReportRow() const { return reportRow; }
void PimMemory::remove(mlir::Value val) {
for (auto it = ownedMemEntriesMap.begin(); it != ownedMemEntriesMap.end();)
@@ -336,13 +353,15 @@ void PimAcceleratorMemory::flushReport() {
llvm::raw_os_ostream os(fileReport);
uint64_t totalGlobalMemory = hostReportRow.has_value() ? hostReportRow->sizeGlobal : 0;
uint64_t totalWeightsMemory = totalWeightBytes;
uint64_t totalCoresMemory = 0;
for (const MemoryReportEntry& entry : reportEntries)
totalCoresMemory += entry.totalAllocaBytes;
llvm::SmallVector<ReportField, 2> totalFields = {
{"Global memory", formatReportMemory(totalGlobalMemory)},
{"Cores memory", formatReportMemory(totalCoresMemory) }
llvm::SmallVector<ReportField, 3> totalFields = {
{"Global memory", formatReportMemory(totalGlobalMemory) },
{"Weights memory", formatReportMemory(totalWeightsMemory)},
{"Cores memory", formatReportMemory(totalCoresMemory) }
};
printReportTotalsBlock(os, totalFields);
@@ -1360,7 +1379,9 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
request.weights = jobResults[jobIndex].usedWeights;
weightRequests.push_back(std::move(request));
}
auto mapCoreWeightToFileName = createAndPopulateWeightFolder(weightRequests, outputDirPath);
auto weightEmission = createAndPopulateWeightFolder(weightRequests, outputDirPath);
memory.setTotalWeightBytes(weightEmission.totalWeightBytes);
auto& mapCoreWeightToFileName = weightEmission.mapCoreWeightToFileName;
for (size_t jobIndex = 0; jobIndex < jobs.size(); ++jobIndex) {
const CoreEmissionJob& job = jobs[jobIndex];
+18 -2
View File
@@ -45,6 +45,19 @@ struct MemoryReportRow {
}
};
enum class MemoryReportKind {
None,
Alloca,
Global,
Input
};
struct PendingMemEntry {
MemEntry memEntry;
MemoryValueKey key;
MemoryReportKind reportKind = MemoryReportKind::None;
};
struct MemoryReportEntry {
enum class Kind {
Core,
@@ -60,16 +73,17 @@ struct MemoryReportEntry {
};
class PimMemory {
llvm::SmallVector<std::pair<MemEntry, MemoryValueKey>, 32> memEntries;
llvm::SmallVector<PendingMemEntry, 32> memEntries;
llvm::SmallDenseMap<MemoryValueKey, MemEntry, 32>& globalMemEntriesMap;
llvm::SmallDenseMap<MemoryValueKey, MemEntry, 32> ownedMemEntriesMap;
MemoryReportRow reportRow;
size_t minAlignment = 4;
size_t firstAvailableAddress = 0;
MemEntry* gatherMemEntry(mlir::Value value, std::optional<unsigned> lane = std::nullopt);
void allocateGatheredMemory();
void allocateMemoryForValue(const MemoryValueKey& key, MemEntry& memEntry);
void allocateMemoryForValue(const MemoryValueKey& key, MemEntry& memEntry, MemoryReportKind reportKind);
public:
PimMemory(llvm::SmallDenseMap<MemoryValueKey, MemEntry, 32>& globalMemEntriesMap)
@@ -94,6 +108,7 @@ private:
std::fstream fileReport;
std::optional<MemoryReportRow> hostReportRow;
llvm::SmallVector<MemoryReportEntry, 32> reportEntries;
uint64_t totalWeightBytes = 0;
mutable llvm::DenseMap<mlir::Value, CompiledIndexExpr> compiledIndexExprs;
mutable llvm::DenseMap<mlir::Value, CompiledAddressExpr> compiledAddressExprs;
@@ -118,6 +133,7 @@ public:
const MemoryReportRow& perCoreRow,
uint64_t totalAllocaCount,
uint64_t totalAllocaBytes);
void setTotalWeightBytes(uint64_t bytes) { totalWeightBytes = bytes; }
void flushReport();
void clean(mlir::Operation* op);
};
+10 -4
View File
@@ -7,6 +7,7 @@
#include <cassert>
#include "Common/Support/CheckedArithmetic.hpp"
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
@@ -18,7 +19,7 @@ using namespace mlir;
namespace onnx_mlir {
namespace {} // namespace
llvm::DenseMap<size_t, llvm::SmallVector<std::string, 8>>
WeightEmissionResult
createAndPopulateWeightFolder(ArrayRef<WeightFileRequest> requests, StringRef outputDirPath) {
auto coreWeightsDirPath = outputDirPath + "/weights";
auto error = sys::fs::create_directory(coreWeightsDirPath);
@@ -26,7 +27,7 @@ createAndPopulateWeightFolder(ArrayRef<WeightFileRequest> requests, StringRef ou
size_t indexFileName = 0;
int64_t xbarSize = crossbarSize.getValue();
llvm::DenseMap<size_t, llvm::SmallVector<std::string, 8>> mapCoreWeightToFileName;
WeightEmissionResult result;
llvm::SmallVector<std::pair<ResolvedWeightView, std::string>, 16> materializedWeights;
auto materializeWeight = [&](const ResolvedWeightView& weightView) -> std::string {
@@ -72,17 +73,22 @@ createAndPopulateWeightFolder(ArrayRef<WeightFileRequest> requests, StringRef ou
weightFileStream.close();
materializedWeights.push_back({weightView, newFileName});
uint64_t weightBytes = pim::checkedMulOrCrash(
pim::checkedMulOrCrash(static_cast<size_t>(xbarSize), static_cast<size_t>(xbarSize), "weight element count"),
elementByteWidth,
"weight byte size");
result.totalWeightBytes = pim::checkedAddOrCrash(result.totalWeightBytes, weightBytes, "total weight bytes");
return newFileName;
};
for (const WeightFileRequest& request : requests) {
auto& coreFiles = mapCoreWeightToFileName[request.coreId];
auto& coreFiles = result.mapCoreWeightToFileName[request.coreId];
coreFiles.reserve(request.weights.size());
for (const ResolvedWeightView& weight : request.weights)
coreFiles.push_back(materializeWeight(weight));
}
return mapCoreWeightToFileName;
return result;
}
} // namespace onnx_mlir
+7 -1
View File
@@ -6,6 +6,7 @@
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include <cstdint>
#include <string>
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
@@ -17,7 +18,12 @@ struct WeightFileRequest {
llvm::SmallVector<ResolvedWeightView, 8> weights;
};
llvm::DenseMap<size_t, llvm::SmallVector<std::string, 8>>
struct WeightEmissionResult {
llvm::DenseMap<size_t, llvm::SmallVector<std::string, 8>> mapCoreWeightToFileName;
uint64_t totalWeightBytes = 0;
};
WeightEmissionResult
createAndPopulateWeightFolder(llvm::ArrayRef<WeightFileRequest> requests, llvm::StringRef outputDirPath);
} // namespace onnx_mlir
@@ -663,7 +663,7 @@ public:
return;
}
emitMergeIrCounts("final-post-merge", func);
dumpModule(cast<ModuleOp>(func->getParentOp()), "spatial1_dcp_merged");
dumpModule(cast<ModuleOp>(func->getParentOp()), "spatial1_merged");
generateReport(func, "spatial_merge_report", analysisResult->cpuToLastComputeMap.size());
}
}