Merge branch 'refactorone' of chef.heaplab.deib.polimi.it:nnicolosi/Raptor into refactorone

This commit is contained in:
ilgeco
2026-06-03 13:49:42 +02:00
6 changed files with 88 additions and 38 deletions
@@ -1,3 +1,4 @@
use std::cmp::min;
use std::fmt::Debug; use std::fmt::Debug;
use anyhow::{Context, Result, bail, ensure}; use anyhow::{Context, Result, bail, ensure};
@@ -86,7 +87,7 @@ where {
size, size,
}; };
if self.memory.len() < address + size { if self.memory.len() < address + size {
self.memory.resize((address + size) * 2, 0); self.memory.resize(min((address + size) * 2, u32::MAX as usize), 0);
} }
self.load_requests.push(load_request); self.load_requests.push(load_request);
Ok(self) Ok(self)
+50 -29
View File
@@ -74,6 +74,22 @@ static MemoryValueKey getMemoryValueKey(mlir::Value value, std::optional<unsigne
return {value, getLaneForMemoryValue(value, lane)}; return {value, getLaneForMemoryValue(value, lane)};
} }
static bool isInsidePimCoreLikeOp(memref::AllocOp allocOp) {
return allocOp->getParentOfType<pim::PimCoreOp>() || allocOp->getParentOfType<pim::PimCoreBatchOp>();
}
static MemoryReportKind classifyMemoryReportKind(mlir::Value value) {
if (isa<mlir::BlockArgument>(value))
return MemoryReportKind::Input;
if (auto* op = value.getDefiningOp()) {
if (isa<memref::AllocOp>(op))
return MemoryReportKind::Alloca;
if (isa<memref::GetGlobalOp>(op))
return MemoryReportKind::Global;
}
return MemoryReportKind::None;
}
static int32_t getVectorByteSizeOrCrash(ShapedType type) { static int32_t getVectorByteSizeOrCrash(ShapedType type) {
auto byteSize = pim::getCheckedShapedTypeSizeInBytes(type, UnknownLoc::get(type.getContext()), "vector byte size"); auto byteSize = pim::getCheckedShapedTypeSizeInBytes(type, UnknownLoc::get(type.getContext()), "vector byte size");
if (failed(byteSize)) if (failed(byteSize))
@@ -90,19 +106,23 @@ MemEntry* PimMemory::gatherMemEntry(mlir::Value value, std::optional<unsigned> l
pim::getCheckedShapedTypeSizeInBytes(type, UnknownLoc::get(type.getContext()), "memory allocation byte size"); pim::getCheckedShapedTypeSizeInBytes(type, UnknownLoc::get(type.getContext()), "memory allocation byte size");
if (failed(checkedAllocSize)) if (failed(checkedAllocSize))
llvm_unreachable("Failed to compute checked allocation byte size"); llvm_unreachable("Failed to compute checked allocation byte size");
size_t allocSize = static_cast<size_t>(*checkedAllocSize); PendingMemEntry pending;
MemEntry memEntry = {0, allocSize}; pending.memEntry = {0, *checkedAllocSize};
return &memEntries.emplace_back(memEntry, getMemoryValueKey(value, lane)).first; pending.key = getMemoryValueKey(value, lane);
pending.reportKind = classifyMemoryReportKind(value);
return &memEntries.emplace_back(std::move(pending)).memEntry;
} }
void PimMemory::allocateGatheredMemory() { void PimMemory::allocateGatheredMemory() {
llvm::sort(memEntries, [](auto a, auto b) -> bool { return a.first.size > b.first.size; }); llvm::sort(memEntries, [](const PendingMemEntry& lhs, const PendingMemEntry& rhs) {
for (auto& [memEntry, key] : memEntries) return lhs.memEntry.size > rhs.memEntry.size;
allocateMemoryForValue(key, memEntry); });
for (PendingMemEntry& pending : memEntries)
allocateMemoryForValue(pending.key, pending.memEntry, pending.reportKind);
memEntries.clear(); memEntries.clear();
} }
void PimMemory::allocateMemoryForValue(const MemoryValueKey& key, MemEntry& memEntry) { void PimMemory::allocateMemoryForValue(const MemoryValueKey& key, MemEntry& memEntry, MemoryReportKind reportKind) {
memEntry.address = firstAvailableAddress; memEntry.address = firstAvailableAddress;
assert(memEntry.address < (size_t) INT_MAX && "Address allocated bigger than 32bit"); assert(memEntry.address < (size_t) INT_MAX && "Address allocated bigger than 32bit");
firstAvailableAddress += memEntry.size; firstAvailableAddress += memEntry.size;
@@ -112,6 +132,19 @@ void PimMemory::allocateMemoryForValue(const MemoryValueKey& key, MemEntry& memE
ownedMemEntriesMap[key] = memEntry; ownedMemEntriesMap[key] = memEntry;
globalMemEntriesMap[key] = memEntry; globalMemEntriesMap[key] = memEntry;
switch (reportKind) {
case MemoryReportKind::Alloca:
++reportRow.numAlloca;
reportRow.sizeAlloca += memEntry.size;
break;
case MemoryReportKind::Global:
++reportRow.numGlobal;
reportRow.sizeGlobal += memEntry.size;
break;
case MemoryReportKind::Input:
case MemoryReportKind::None: break;
}
} }
void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) { void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) {
@@ -142,7 +175,7 @@ void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) {
}); });
funcOp.walk([&](memref::AllocOp allocOp) { funcOp.walk([&](memref::AllocOp allocOp) {
if (!allocOp->getParentOfType<pim::PimCoreOp>() && !allocOp->getParentOfType<pim::PimCoreBatchOp>()) if (!isInsidePimCoreLikeOp(allocOp))
gatherMemEntry(allocOp.getResult()); gatherMemEntry(allocOp.getResult());
}); });
@@ -195,23 +228,7 @@ static MemoryReportRow addMemoryReportRows(const MemoryReportRow& lhs, const Mem
return result; return result;
} }
MemoryReportRow PimMemory::getReportRow() const { MemoryReportRow PimMemory::getReportRow() const { return reportRow; }
MemoryReportRow row;
for (auto& [key, memEntry] : ownedMemEntriesMap) {
if (auto op = key.value.getDefiningOp()) {
if (isa<memref::AllocOp>(op)) {
row.numAlloca++;
row.sizeAlloca += memEntry.size;
}
if (isa<memref::GetGlobalOp>(op)) {
row.numGlobal++;
row.sizeGlobal += memEntry.size;
}
}
}
return row;
}
void PimMemory::remove(mlir::Value val) { void PimMemory::remove(mlir::Value val) {
for (auto it = ownedMemEntriesMap.begin(); it != ownedMemEntriesMap.end();) for (auto it = ownedMemEntriesMap.begin(); it != ownedMemEntriesMap.end();)
@@ -336,13 +353,15 @@ void PimAcceleratorMemory::flushReport() {
llvm::raw_os_ostream os(fileReport); llvm::raw_os_ostream os(fileReport);
uint64_t totalGlobalMemory = hostReportRow.has_value() ? hostReportRow->sizeGlobal : 0; uint64_t totalGlobalMemory = hostReportRow.has_value() ? hostReportRow->sizeGlobal : 0;
uint64_t totalWeightsMemory = totalWeightBytes;
uint64_t totalCoresMemory = 0; uint64_t totalCoresMemory = 0;
for (const MemoryReportEntry& entry : reportEntries) for (const MemoryReportEntry& entry : reportEntries)
totalCoresMemory += entry.totalAllocaBytes; totalCoresMemory += entry.totalAllocaBytes;
llvm::SmallVector<ReportField, 2> totalFields = { llvm::SmallVector<ReportField, 3> totalFields = {
{"Global memory", formatReportMemory(totalGlobalMemory)}, {"Global memory", formatReportMemory(totalGlobalMemory) },
{"Cores memory", formatReportMemory(totalCoresMemory) } {"Weights memory", formatReportMemory(totalWeightsMemory)},
{"Cores memory", formatReportMemory(totalCoresMemory) }
}; };
printReportTotalsBlock(os, totalFields); printReportTotalsBlock(os, totalFields);
@@ -1360,7 +1379,9 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
request.weights = jobResults[jobIndex].usedWeights; request.weights = jobResults[jobIndex].usedWeights;
weightRequests.push_back(std::move(request)); weightRequests.push_back(std::move(request));
} }
auto mapCoreWeightToFileName = createAndPopulateWeightFolder(weightRequests, outputDirPath); auto weightEmission = createAndPopulateWeightFolder(weightRequests, outputDirPath);
memory.setTotalWeightBytes(weightEmission.totalWeightBytes);
auto& mapCoreWeightToFileName = weightEmission.mapCoreWeightToFileName;
for (size_t jobIndex = 0; jobIndex < jobs.size(); ++jobIndex) { for (size_t jobIndex = 0; jobIndex < jobs.size(); ++jobIndex) {
const CoreEmissionJob& job = jobs[jobIndex]; const CoreEmissionJob& job = jobs[jobIndex];
+18 -2
View File
@@ -45,6 +45,19 @@ struct MemoryReportRow {
} }
}; };
enum class MemoryReportKind {
None,
Alloca,
Global,
Input
};
struct PendingMemEntry {
MemEntry memEntry;
MemoryValueKey key;
MemoryReportKind reportKind = MemoryReportKind::None;
};
struct MemoryReportEntry { struct MemoryReportEntry {
enum class Kind { enum class Kind {
Core, Core,
@@ -60,16 +73,17 @@ struct MemoryReportEntry {
}; };
class PimMemory { class PimMemory {
llvm::SmallVector<std::pair<MemEntry, MemoryValueKey>, 32> memEntries; llvm::SmallVector<PendingMemEntry, 32> memEntries;
llvm::SmallDenseMap<MemoryValueKey, MemEntry, 32>& globalMemEntriesMap; llvm::SmallDenseMap<MemoryValueKey, MemEntry, 32>& globalMemEntriesMap;
llvm::SmallDenseMap<MemoryValueKey, MemEntry, 32> ownedMemEntriesMap; llvm::SmallDenseMap<MemoryValueKey, MemEntry, 32> ownedMemEntriesMap;
MemoryReportRow reportRow;
size_t minAlignment = 4; size_t minAlignment = 4;
size_t firstAvailableAddress = 0; size_t firstAvailableAddress = 0;
MemEntry* gatherMemEntry(mlir::Value value, std::optional<unsigned> lane = std::nullopt); MemEntry* gatherMemEntry(mlir::Value value, std::optional<unsigned> lane = std::nullopt);
void allocateGatheredMemory(); void allocateGatheredMemory();
void allocateMemoryForValue(const MemoryValueKey& key, MemEntry& memEntry); void allocateMemoryForValue(const MemoryValueKey& key, MemEntry& memEntry, MemoryReportKind reportKind);
public: public:
PimMemory(llvm::SmallDenseMap<MemoryValueKey, MemEntry, 32>& globalMemEntriesMap) PimMemory(llvm::SmallDenseMap<MemoryValueKey, MemEntry, 32>& globalMemEntriesMap)
@@ -94,6 +108,7 @@ private:
std::fstream fileReport; std::fstream fileReport;
std::optional<MemoryReportRow> hostReportRow; std::optional<MemoryReportRow> hostReportRow;
llvm::SmallVector<MemoryReportEntry, 32> reportEntries; llvm::SmallVector<MemoryReportEntry, 32> reportEntries;
uint64_t totalWeightBytes = 0;
mutable llvm::DenseMap<mlir::Value, CompiledIndexExpr> compiledIndexExprs; mutable llvm::DenseMap<mlir::Value, CompiledIndexExpr> compiledIndexExprs;
mutable llvm::DenseMap<mlir::Value, CompiledAddressExpr> compiledAddressExprs; mutable llvm::DenseMap<mlir::Value, CompiledAddressExpr> compiledAddressExprs;
@@ -118,6 +133,7 @@ public:
const MemoryReportRow& perCoreRow, const MemoryReportRow& perCoreRow,
uint64_t totalAllocaCount, uint64_t totalAllocaCount,
uint64_t totalAllocaBytes); uint64_t totalAllocaBytes);
void setTotalWeightBytes(uint64_t bytes) { totalWeightBytes = bytes; }
void flushReport(); void flushReport();
void clean(mlir::Operation* op); void clean(mlir::Operation* op);
}; };
+10 -4
View File
@@ -7,6 +7,7 @@
#include <cassert> #include <cassert>
#include "Common/Support/CheckedArithmetic.hpp"
#include "Conversion/ONNXToSpatial/Common/Common.hpp" #include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp" #include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
@@ -18,7 +19,7 @@ using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
namespace {} // namespace namespace {} // namespace
llvm::DenseMap<size_t, llvm::SmallVector<std::string, 8>> WeightEmissionResult
createAndPopulateWeightFolder(ArrayRef<WeightFileRequest> requests, StringRef outputDirPath) { createAndPopulateWeightFolder(ArrayRef<WeightFileRequest> requests, StringRef outputDirPath) {
auto coreWeightsDirPath = outputDirPath + "/weights"; auto coreWeightsDirPath = outputDirPath + "/weights";
auto error = sys::fs::create_directory(coreWeightsDirPath); auto error = sys::fs::create_directory(coreWeightsDirPath);
@@ -26,7 +27,7 @@ createAndPopulateWeightFolder(ArrayRef<WeightFileRequest> requests, StringRef ou
size_t indexFileName = 0; size_t indexFileName = 0;
int64_t xbarSize = crossbarSize.getValue(); int64_t xbarSize = crossbarSize.getValue();
llvm::DenseMap<size_t, llvm::SmallVector<std::string, 8>> mapCoreWeightToFileName; WeightEmissionResult result;
llvm::SmallVector<std::pair<ResolvedWeightView, std::string>, 16> materializedWeights; llvm::SmallVector<std::pair<ResolvedWeightView, std::string>, 16> materializedWeights;
auto materializeWeight = [&](const ResolvedWeightView& weightView) -> std::string { auto materializeWeight = [&](const ResolvedWeightView& weightView) -> std::string {
@@ -72,17 +73,22 @@ createAndPopulateWeightFolder(ArrayRef<WeightFileRequest> requests, StringRef ou
weightFileStream.close(); weightFileStream.close();
materializedWeights.push_back({weightView, newFileName}); materializedWeights.push_back({weightView, newFileName});
uint64_t weightBytes = pim::checkedMulOrCrash(
pim::checkedMulOrCrash(static_cast<size_t>(xbarSize), static_cast<size_t>(xbarSize), "weight element count"),
elementByteWidth,
"weight byte size");
result.totalWeightBytes = pim::checkedAddOrCrash(result.totalWeightBytes, weightBytes, "total weight bytes");
return newFileName; return newFileName;
}; };
for (const WeightFileRequest& request : requests) { for (const WeightFileRequest& request : requests) {
auto& coreFiles = mapCoreWeightToFileName[request.coreId]; auto& coreFiles = result.mapCoreWeightToFileName[request.coreId];
coreFiles.reserve(request.weights.size()); coreFiles.reserve(request.weights.size());
for (const ResolvedWeightView& weight : request.weights) for (const ResolvedWeightView& weight : request.weights)
coreFiles.push_back(materializeWeight(weight)); coreFiles.push_back(materializeWeight(weight));
} }
return mapCoreWeightToFileName; return result;
} }
} // namespace onnx_mlir } // namespace onnx_mlir
+7 -1
View File
@@ -6,6 +6,7 @@
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringRef.h"
#include <cstdint>
#include <string> #include <string>
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp" #include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
@@ -17,7 +18,12 @@ struct WeightFileRequest {
llvm::SmallVector<ResolvedWeightView, 8> weights; llvm::SmallVector<ResolvedWeightView, 8> weights;
}; };
llvm::DenseMap<size_t, llvm::SmallVector<std::string, 8>> struct WeightEmissionResult {
llvm::DenseMap<size_t, llvm::SmallVector<std::string, 8>> mapCoreWeightToFileName;
uint64_t totalWeightBytes = 0;
};
WeightEmissionResult
createAndPopulateWeightFolder(llvm::ArrayRef<WeightFileRequest> requests, llvm::StringRef outputDirPath); createAndPopulateWeightFolder(llvm::ArrayRef<WeightFileRequest> requests, llvm::StringRef outputDirPath);
} // namespace onnx_mlir } // namespace onnx_mlir
@@ -663,7 +663,7 @@ public:
return; return;
} }
emitMergeIrCounts("final-post-merge", func); emitMergeIrCounts("final-post-merge", func);
dumpModule(cast<ModuleOp>(func->getParentOp()), "spatial1_dcp_merged"); dumpModule(cast<ModuleOp>(func->getParentOp()), "spatial1_merged");
generateReport(func, "spatial_merge_report", analysisResult->cpuToLastComputeMap.size()); generateReport(func, "spatial_merge_report", analysisResult->cpuToLastComputeMap.size());
} }
} }