diff --git a/backend-simulators/pim/pim-simulator/src/lib/memory_manager/mod.rs b/backend-simulators/pim/pim-simulator/src/lib/memory_manager/mod.rs index fc2ecb4..d217584 100644 --- a/backend-simulators/pim/pim-simulator/src/lib/memory_manager/mod.rs +++ b/backend-simulators/pim/pim-simulator/src/lib/memory_manager/mod.rs @@ -1,3 +1,4 @@ +use std::cmp::min; use std::fmt::Debug; use anyhow::{Context, Result, bail, ensure}; @@ -86,7 +87,7 @@ where { size, }; if self.memory.len() < address + size { - self.memory.resize((address + size) * 2, 0); + self.memory.resize(min((address + size) * 2, u32::MAX as usize), 0); } self.load_requests.push(load_request); Ok(self) diff --git a/src/PIM/Compiler/PimCodeGen.cpp b/src/PIM/Compiler/PimCodeGen.cpp index 57043ad..cee4516 100644 --- a/src/PIM/Compiler/PimCodeGen.cpp +++ b/src/PIM/Compiler/PimCodeGen.cpp @@ -74,6 +74,22 @@ static MemoryValueKey getMemoryValueKey(mlir::Value value, std::optionalgetParentOfType() || allocOp->getParentOfType(); +} + +static MemoryReportKind classifyMemoryReportKind(mlir::Value value) { + if (isa(value)) + return MemoryReportKind::Input; + if (auto* op = value.getDefiningOp()) { + if (isa(op)) + return MemoryReportKind::Alloca; + if (isa(op)) + return MemoryReportKind::Global; + } + return MemoryReportKind::None; +} + static int32_t getVectorByteSizeOrCrash(ShapedType type) { auto byteSize = pim::getCheckedShapedTypeSizeInBytes(type, UnknownLoc::get(type.getContext()), "vector byte size"); if (failed(byteSize)) @@ -90,19 +106,23 @@ MemEntry* PimMemory::gatherMemEntry(mlir::Value value, std::optional l pim::getCheckedShapedTypeSizeInBytes(type, UnknownLoc::get(type.getContext()), "memory allocation byte size"); if (failed(checkedAllocSize)) llvm_unreachable("Failed to compute checked allocation byte size"); - size_t allocSize = static_cast(*checkedAllocSize); - MemEntry memEntry = {0, allocSize}; - return &memEntries.emplace_back(memEntry, getMemoryValueKey(value, lane)).first; + PendingMemEntry pending; + pending.memEntry = {0, *checkedAllocSize}; + pending.key = getMemoryValueKey(value, lane); + pending.reportKind = classifyMemoryReportKind(value); + return &memEntries.emplace_back(std::move(pending)).memEntry; } void PimMemory::allocateGatheredMemory() { - llvm::sort(memEntries, [](auto a, auto b) -> bool { return a.first.size > b.first.size; }); - for (auto& [memEntry, key] : memEntries) - allocateMemoryForValue(key, memEntry); + llvm::sort(memEntries, [](const PendingMemEntry& lhs, const PendingMemEntry& rhs) { + return lhs.memEntry.size > rhs.memEntry.size; + }); + for (PendingMemEntry& pending : memEntries) + allocateMemoryForValue(pending.key, pending.memEntry, pending.reportKind); memEntries.clear(); } -void PimMemory::allocateMemoryForValue(const MemoryValueKey& key, MemEntry& memEntry) { +void PimMemory::allocateMemoryForValue(const MemoryValueKey& key, MemEntry& memEntry, MemoryReportKind reportKind) { memEntry.address = firstAvailableAddress; assert(memEntry.address < (size_t) INT_MAX && "Address allocated bigger than 32bit"); firstAvailableAddress += memEntry.size; @@ -112,6 +132,19 @@ void PimMemory::allocateMemoryForValue(const MemoryValueKey& key, MemEntry& memE ownedMemEntriesMap[key] = memEntry; globalMemEntriesMap[key] = memEntry; + + switch (reportKind) { + case MemoryReportKind::Alloca: + ++reportRow.numAlloca; + reportRow.sizeAlloca += memEntry.size; + break; + case MemoryReportKind::Global: + ++reportRow.numGlobal; + reportRow.sizeGlobal += memEntry.size; + break; + case MemoryReportKind::Input: + case MemoryReportKind::None: break; + } } void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) { @@ -142,7 +175,7 @@ void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) { }); funcOp.walk([&](memref::AllocOp allocOp) { - if (!allocOp->getParentOfType() && !allocOp->getParentOfType()) + if (!isInsidePimCoreLikeOp(allocOp)) gatherMemEntry(allocOp.getResult()); }); @@ -195,23 +228,7 @@ static MemoryReportRow addMemoryReportRows(const MemoryReportRow& lhs, const Mem return result; } -MemoryReportRow PimMemory::getReportRow() const { - MemoryReportRow row; - for (auto& [key, memEntry] : ownedMemEntriesMap) { - if (auto op = key.value.getDefiningOp()) { - if (isa(op)) { - row.numAlloca++; - row.sizeAlloca += memEntry.size; - } - - if (isa(op)) { - row.numGlobal++; - row.sizeGlobal += memEntry.size; - } - } - } - return row; -} +MemoryReportRow PimMemory::getReportRow() const { return reportRow; } void PimMemory::remove(mlir::Value val) { for (auto it = ownedMemEntriesMap.begin(); it != ownedMemEntriesMap.end();) @@ -336,13 +353,15 @@ void PimAcceleratorMemory::flushReport() { llvm::raw_os_ostream os(fileReport); uint64_t totalGlobalMemory = hostReportRow.has_value() ? hostReportRow->sizeGlobal : 0; + uint64_t totalWeightsMemory = totalWeightBytes; uint64_t totalCoresMemory = 0; for (const MemoryReportEntry& entry : reportEntries) totalCoresMemory += entry.totalAllocaBytes; - llvm::SmallVector totalFields = { - {"Global memory", formatReportMemory(totalGlobalMemory)}, - {"Cores memory", formatReportMemory(totalCoresMemory) } + llvm::SmallVector totalFields = { + {"Global memory", formatReportMemory(totalGlobalMemory) }, + {"Weights memory", formatReportMemory(totalWeightsMemory)}, + {"Cores memory", formatReportMemory(totalCoresMemory) } }; printReportTotalsBlock(os, totalFields); @@ -1360,7 +1379,9 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std:: request.weights = jobResults[jobIndex].usedWeights; weightRequests.push_back(std::move(request)); } - auto mapCoreWeightToFileName = createAndPopulateWeightFolder(weightRequests, outputDirPath); + auto weightEmission = createAndPopulateWeightFolder(weightRequests, outputDirPath); + memory.setTotalWeightBytes(weightEmission.totalWeightBytes); + auto& mapCoreWeightToFileName = weightEmission.mapCoreWeightToFileName; for (size_t jobIndex = 0; jobIndex < jobs.size(); ++jobIndex) { const CoreEmissionJob& job = jobs[jobIndex]; diff --git a/src/PIM/Compiler/PimCodeGen.hpp b/src/PIM/Compiler/PimCodeGen.hpp index 67ab3c2..8d45349 100644 --- a/src/PIM/Compiler/PimCodeGen.hpp +++ b/src/PIM/Compiler/PimCodeGen.hpp @@ -45,6 +45,19 @@ struct MemoryReportRow { } }; +enum class MemoryReportKind { + None, + Alloca, + Global, + Input +}; + +struct PendingMemEntry { + MemEntry memEntry; + MemoryValueKey key; + MemoryReportKind reportKind = MemoryReportKind::None; +}; + struct MemoryReportEntry { enum class Kind { Core, @@ -60,16 +73,17 @@ struct MemoryReportEntry { }; class PimMemory { - llvm::SmallVector, 32> memEntries; + llvm::SmallVector memEntries; llvm::SmallDenseMap& globalMemEntriesMap; llvm::SmallDenseMap ownedMemEntriesMap; + MemoryReportRow reportRow; size_t minAlignment = 4; size_t firstAvailableAddress = 0; MemEntry* gatherMemEntry(mlir::Value value, std::optional lane = std::nullopt); void allocateGatheredMemory(); - void allocateMemoryForValue(const MemoryValueKey& key, MemEntry& memEntry); + void allocateMemoryForValue(const MemoryValueKey& key, MemEntry& memEntry, MemoryReportKind reportKind); public: PimMemory(llvm::SmallDenseMap& globalMemEntriesMap) @@ -94,6 +108,7 @@ private: std::fstream fileReport; std::optional hostReportRow; llvm::SmallVector reportEntries; + uint64_t totalWeightBytes = 0; mutable llvm::DenseMap compiledIndexExprs; mutable llvm::DenseMap compiledAddressExprs; @@ -118,6 +133,7 @@ public: const MemoryReportRow& perCoreRow, uint64_t totalAllocaCount, uint64_t totalAllocaBytes); + void setTotalWeightBytes(uint64_t bytes) { totalWeightBytes = bytes; } void flushReport(); void clean(mlir::Operation* op); }; diff --git a/src/PIM/Compiler/PimWeightEmitter.cpp b/src/PIM/Compiler/PimWeightEmitter.cpp index f16c320..8ad6932 100644 --- a/src/PIM/Compiler/PimWeightEmitter.cpp +++ b/src/PIM/Compiler/PimWeightEmitter.cpp @@ -7,6 +7,7 @@ #include +#include "Common/Support/CheckedArithmetic.hpp" #include "Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" @@ -18,7 +19,7 @@ using namespace mlir; namespace onnx_mlir { namespace {} // namespace -llvm::DenseMap> +WeightEmissionResult createAndPopulateWeightFolder(ArrayRef requests, StringRef outputDirPath) { auto coreWeightsDirPath = outputDirPath + "/weights"; auto error = sys::fs::create_directory(coreWeightsDirPath); @@ -26,7 +27,7 @@ createAndPopulateWeightFolder(ArrayRef requests, StringRef ou size_t indexFileName = 0; int64_t xbarSize = crossbarSize.getValue(); - llvm::DenseMap> mapCoreWeightToFileName; + WeightEmissionResult result; llvm::SmallVector, 16> materializedWeights; auto materializeWeight = [&](const ResolvedWeightView& weightView) -> std::string { @@ -72,17 +73,22 @@ createAndPopulateWeightFolder(ArrayRef requests, StringRef ou weightFileStream.close(); materializedWeights.push_back({weightView, newFileName}); + uint64_t weightBytes = pim::checkedMulOrCrash( + pim::checkedMulOrCrash(static_cast(xbarSize), static_cast(xbarSize), "weight element count"), + elementByteWidth, + "weight byte size"); + result.totalWeightBytes = pim::checkedAddOrCrash(result.totalWeightBytes, weightBytes, "total weight bytes"); return newFileName; }; for (const WeightFileRequest& request : requests) { - auto& coreFiles = mapCoreWeightToFileName[request.coreId]; + auto& coreFiles = result.mapCoreWeightToFileName[request.coreId]; coreFiles.reserve(request.weights.size()); for (const ResolvedWeightView& weight : request.weights) coreFiles.push_back(materializeWeight(weight)); } - return mapCoreWeightToFileName; + return result; } } // namespace onnx_mlir diff --git a/src/PIM/Compiler/PimWeightEmitter.hpp b/src/PIM/Compiler/PimWeightEmitter.hpp index 2daa5ae..53058e1 100644 --- a/src/PIM/Compiler/PimWeightEmitter.hpp +++ b/src/PIM/Compiler/PimWeightEmitter.hpp @@ -6,6 +6,7 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" +#include #include #include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp" @@ -17,7 +18,12 @@ struct WeightFileRequest { llvm::SmallVector weights; }; -llvm::DenseMap> +struct WeightEmissionResult { + llvm::DenseMap> mapCoreWeightToFileName; + uint64_t totalWeightBytes = 0; +}; + +WeightEmissionResult createAndPopulateWeightFolder(llvm::ArrayRef requests, llvm::StringRef outputDirPath); } // namespace onnx_mlir diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp index 997f6a6..53c85e0 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp @@ -663,7 +663,7 @@ public: return; } emitMergeIrCounts("final-post-merge", func); - dumpModule(cast(func->getParentOp()), "spatial1_dcp_merged"); + dumpModule(cast(func->getParentOp()), "spatial1_merged"); generateReport(func, "spatial_merge_report", analysisResult->cpuToLastComputeMap.size()); } }