From 20cf40c9ba78860bea532b306a8e8e97864ee480 Mon Sep 17 00:00:00 2001 From: ilgeco Date: Wed, 3 Jun 2026 18:15:30 +0200 Subject: [PATCH] Memory Liveness --- src/PIM/Common/Support/ReportUtils.cpp | 6 +- src/PIM/Common/Support/ReportUtils.hpp | 1 + src/PIM/Compiler/CMakeLists.txt | 1 + src/PIM/Compiler/PimCodeGen.cpp | 203 ++++- src/PIM/Compiler/PimCodeGen.hpp | 17 + src/PIM/Compiler/PimCompilerOptions.cpp | 9 + src/PIM/Compiler/PimCompilerOptions.hpp | 7 + src/PIM/Compiler/PimMemoryLiveness.cpp | 742 ++++++++++++++++++ src/PIM/Compiler/PimMemoryLiveness.hpp | 63 ++ .../Bufferization/BufferizationUtils.cpp | 2 +- .../MemoryCoalescing/MemoryCoalescing.cpp | 213 ++--- .../MemoryCoalescing/MemoryCoalescing.hpp | 11 +- .../MemoryCoalescing/MemoryCoalescingPass.cpp | 8 +- test/PIM/CMakeLists.txt | 6 + test/PIM/PimMemoryLivenessPlannerTest.cpp | 86 ++ 15 files changed, 1263 insertions(+), 112 deletions(-) create mode 100644 src/PIM/Compiler/PimMemoryLiveness.cpp create mode 100644 src/PIM/Compiler/PimMemoryLiveness.hpp create mode 100644 test/PIM/PimMemoryLivenessPlannerTest.cpp diff --git a/src/PIM/Common/Support/ReportUtils.cpp b/src/PIM/Common/Support/ReportUtils.cpp index 334b43b..4f1e918 100644 --- a/src/PIM/Common/Support/ReportUtils.cpp +++ b/src/PIM/Common/Support/ReportUtils.cpp @@ -5,16 +5,18 @@ namespace onnx_mlir { -std::fstream openReportFile(const std::string& name) { +std::fstream openReportFileWithExtension(const std::string& name, llvm::StringRef extension) { std::string outputDir = getOutputDir(); if (outputDir.empty()) return {}; std::string reportsDir = outputDir + "/reports"; createDirectory(reportsDir); - return std::fstream(reportsDir + "/" + name + ".txt", std::ios::out); + return std::fstream(reportsDir + "/" + name + "." + extension.str(), std::ios::out); } +std::fstream openReportFile(const std::string& name) { return openReportFileWithExtension(name, "txt"); } + std::string formatReportMemory(uint64_t bytes) { const char* units[] = {"B", "KB", "MB", "GB", "TB", "PB", "EB"}; int i = 0; diff --git a/src/PIM/Common/Support/ReportUtils.hpp b/src/PIM/Common/Support/ReportUtils.hpp index 0c3a470..d722fe7 100644 --- a/src/PIM/Common/Support/ReportUtils.hpp +++ b/src/PIM/Common/Support/ReportUtils.hpp @@ -11,6 +11,7 @@ namespace onnx_mlir { std::fstream openReportFile(const std::string& name); +std::fstream openReportFileWithExtension(const std::string& name, llvm::StringRef extension); std::string formatReportMemory(uint64_t bytes); struct ReportField { diff --git a/src/PIM/Compiler/CMakeLists.txt b/src/PIM/Compiler/CMakeLists.txt index 85a3453..5e66728 100644 --- a/src/PIM/Compiler/CMakeLists.txt +++ b/src/PIM/Compiler/CMakeLists.txt @@ -17,6 +17,7 @@ add_pim_library(OMPimCompilerUtils PimCompilerUtils.cpp PimArtifactWriter.cpp PimCodeGen.cpp + PimMemoryLiveness.cpp PimWeightEmitter.cpp EXCLUDE_FROM_OM_LIBS diff --git a/src/PIM/Compiler/PimCodeGen.cpp b/src/PIM/Compiler/PimCodeGen.cpp index cee4516..947cf18 100644 --- a/src/PIM/Compiler/PimCodeGen.cpp +++ b/src/PIM/Compiler/PimCodeGen.cpp @@ -2,7 +2,6 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/IR/AsmState.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" @@ -23,11 +22,11 @@ #include #include #include -#include #include #include #include #include +#include #include #include @@ -38,10 +37,12 @@ #include "Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp" #include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp" +#include "src/Accelerators/PIM/Common/Support/FileSystemUtils.hpp" #include "src/Accelerators/PIM/Compiler/PimArtifactWriter.hpp" #include "src/Accelerators/PIM/Compiler/PimBinaryFormat.hpp" #include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" +#include "src/Accelerators/PIM/Compiler/PimMemoryLiveness.hpp" #include "src/Accelerators/PIM/Compiler/PimWeightEmitter.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" @@ -97,6 +98,51 @@ static int32_t getVectorByteSizeOrCrash(ShapedType type) { return pim::checkedI32OrCrash(*byteSize, "vector byte size"); } +static Operation *getDiagnosticAnchor(mlir::Value value) { + if (Operation *definingOp = value.getDefiningOp()) + return definingOp; + if (auto blockArg = dyn_cast(value)) + return blockArg.getOwner()->getParentOp(); + return nullptr; +} + +// PIM instruction immediates are serialized as signed int32_t fields today +// (`sldi` goes through checkedI32OrCrash), so local addresses must stay within +// the non-negative int32_t range. +static constexpr size_t kPimAddressLimit = static_cast(std::numeric_limits::max()); + +static FailureOr checkedAlignTo(size_t value, size_t alignment, Operation *anchor, StringRef fieldName) { + if (alignment == 0) + return value; + size_t remainder = value % alignment; + if (remainder == 0) + return value; + return pim::checkedAdd(value, alignment - remainder, anchor, fieldName); +} + +static void printMemoryOverflowDiagnostic(mlir::Value value, + const MemoryValueKey &key, + size_t requestedSize, + size_t currentFirstAvailableAddress, + size_t alignedEndAddress) { + llvm::errs() << "PIM local memory allocation overflow\n"; + llvm::errs() << "Requested allocation size: " << requestedSize << " bytes\n"; + llvm::errs() << "Current firstAvailableAddress: " << currentFirstAvailableAddress << "\n"; + llvm::errs() << "Aligned end address: " << alignedEndAddress << "\n"; + llvm::errs() << "Address limit: " << kPimAddressLimit << " (signed int32_t immediate range)\n"; + if (key.lane) + llvm::errs() << "Lane: " << *key.lane << "\n"; + llvm::errs() << "Value: "; + value.print(llvm::errs()); + llvm::errs() << "\n"; + llvm::errs() << "Value type: " << value.getType() << "\n"; + if (Operation *definingOp = value.getDefiningOp()) { + llvm::errs() << "Defining op:\n"; + definingOp->print(llvm::errs()); + llvm::errs() << "\n"; + } +} + } // namespace MemEntry* PimMemory::gatherMemEntry(mlir::Value value, std::optional lane) { @@ -124,20 +170,30 @@ void PimMemory::allocateGatheredMemory() { void PimMemory::allocateMemoryForValue(const MemoryValueKey& key, MemEntry& memEntry, MemoryReportKind reportKind) { memEntry.address = firstAvailableAddress; - assert(memEntry.address < (size_t) INT_MAX && "Address allocated bigger than 32bit"); - firstAvailableAddress += memEntry.size; - // Alignment - if (size_t remainder = firstAvailableAddress % minAlignment) - firstAvailableAddress += minAlignment - remainder; + Operation *anchor = getDiagnosticAnchor(key.value); + auto checkedEnd = pim::checkedAdd(memEntry.address, memEntry.size, anchor, "local memory end"); + FailureOr checkedAlignedEnd = failure(); + if (succeeded(checkedEnd)) + checkedAlignedEnd = checkedAlignTo(*checkedEnd, minAlignment, anchor, "local memory alignment"); + bool startFits = memEntry.address <= kPimAddressLimit; + bool endFits = succeeded(checkedEnd) && *checkedEnd <= kPimAddressLimit; + bool alignedEndFits = succeeded(checkedAlignedEnd) && *checkedAlignedEnd <= kPimAddressLimit; + if (!startFits || !endFits || !alignedEndFits) { + printMemoryOverflowDiagnostic( + key.value, + key, + memEntry.size, + firstAvailableAddress, + succeeded(checkedAlignedEnd) ? *checkedAlignedEnd : kPimAddressLimit); + llvm_unreachable("PIM local memory allocation overflow"); + } + firstAvailableAddress = *checkedAlignedEnd; ownedMemEntriesMap[key] = memEntry; globalMemEntriesMap[key] = memEntry; switch (reportKind) { - case MemoryReportKind::Alloca: - ++reportRow.numAlloca; - reportRow.sizeAlloca += memEntry.size; - break; + case MemoryReportKind::Alloca: break; case MemoryReportKind::Global: ++reportRow.numGlobal; reportRow.sizeGlobal += memEntry.size; @@ -147,6 +203,31 @@ void PimMemory::allocateMemoryForValue(const MemoryValueKey& key, MemEntry& memE } } +PhysicalSlotInfo PimMemory::allocatePhysicalSlot(size_t slotSize, const MemoryValueKey& key) { + PhysicalSlotInfo slot; + slot.id = nextPhysicalSlotId++; + slot.address = firstAvailableAddress; + slot.size = slotSize; + + Operation *anchor = getDiagnosticAnchor(key.value); + auto checkedEnd = pim::checkedAdd(slot.address, slot.size, anchor, "local memory end"); + FailureOr checkedAlignedEnd = failure(); + if (succeeded(checkedEnd)) + checkedAlignedEnd = checkedAlignTo(*checkedEnd, minAlignment, anchor, "local memory alignment"); + bool startFits = slot.address <= kPimAddressLimit; + bool endFits = succeeded(checkedEnd) && *checkedEnd <= kPimAddressLimit; + bool alignedEndFits = succeeded(checkedAlignedEnd) && *checkedAlignedEnd <= kPimAddressLimit; + if (!startFits || !endFits || !alignedEndFits) { + printMemoryOverflowDiagnostic( + key.value, key, slot.size, firstAvailableAddress, succeeded(checkedAlignedEnd) ? *checkedAlignedEnd : kPimAddressLimit); + llvm_unreachable("PIM local memory allocation overflow"); + } + + firstAvailableAddress = *checkedAlignedEnd; + localPhysicalSlots.push_back(slot); + return slot; +} + void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) { SmallDenseMap globalConstants; SmallVector, 16> globalAliases; @@ -186,9 +267,71 @@ void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) { } void PimMemory::allocateCore(Operation* op, std::optional lane) { - op->walk([&](memref::AllocOp allocOp) { gatherMemEntry(allocOp, lane); }); + auto intervals = buildLocalAllocIntervals(op, lane); + SmallVector plannedSlots = planPhysicalSlots(intervals); - allocateGatheredMemory(); + SmallVector slotOrder(plannedSlots.size()); + std::iota(slotOrder.begin(), slotOrder.end(), 0); + llvm::stable_sort(slotOrder, [&](size_t lhsIndex, size_t rhsIndex) { + const PlannedPhysicalSlot &lhs = plannedSlots[lhsIndex]; + const PlannedPhysicalSlot &rhs = plannedSlots[rhsIndex]; + if (lhs.requiredSize != rhs.requiredSize) + return lhs.requiredSize > rhs.requiredSize; + return lhs.id < rhs.id; + }); + + SmallVector usedExistingSlots(localPhysicalSlots.size(), false); + for (size_t slotIndex : slotOrder) { + PlannedPhysicalSlot &slot = plannedSlots[slotIndex]; + size_t bestExistingIndex = std::numeric_limits::max(); + auto bestKey = std::tuple( + std::numeric_limits::max(), std::numeric_limits::max(), std::numeric_limits::max()); + + for (size_t existingIndex = 0; existingIndex < localPhysicalSlots.size(); ++existingIndex) { + if (usedExistingSlots[existingIndex]) + continue; + const PhysicalSlotInfo &existingSlot = localPhysicalSlots[existingIndex]; + if (existingSlot.size < slot.requiredSize) + continue; + auto candidateKey = std::tuple( + existingSlot.size - slot.requiredSize, existingSlot.size, existingSlot.id); + if (candidateKey < bestKey) { + bestKey = candidateKey; + bestExistingIndex = existingIndex; + } + } + + if (bestExistingIndex != std::numeric_limits::max()) { + const PhysicalSlotInfo &existingSlot = localPhysicalSlots[bestExistingIndex]; + slot.id = existingSlot.id; + slot.address = existingSlot.address; + slot.size = existingSlot.size; + usedExistingSlots[bestExistingIndex] = true; + } + else { + PhysicalSlotInfo newSlot = allocatePhysicalSlot(slot.requiredSize, intervals[slot.intervalIndices.front()].key); + slot.id = newSlot.id; + slot.address = newSlot.address; + slot.size = newSlot.size; + usedExistingSlots.push_back(true); + } + + for (size_t intervalIndex : slot.intervalIndices) { + LocalAllocInterval &interval = intervals[intervalIndex]; + interval.physicalSlotId = slot.id; + interval.assignedAddress = slot.address; + interval.physicalSlotSize = slot.size; + MemEntry memEntry {slot.address, interval.size}; + ownedMemEntriesMap[interval.key] = memEntry; + globalMemEntriesMap[interval.key] = memEntry; + } + } + + if (pimMemoryReport != PimMemoryReportNone) { + MemoryPlanArtifacts artifacts = + buildMemoryPlanArtifacts(op, lane, intervals, plannedSlots, kPimAddressLimit, pimMemoryReport); + livenessArtifacts.textReport += artifacts.textReport; + } } static void printHostMemoryReportRow(raw_ostream& os, const MemoryReportRow& row) { @@ -228,7 +371,14 @@ static MemoryReportRow addMemoryReportRows(const MemoryReportRow& lhs, const Mem return result; } -MemoryReportRow PimMemory::getReportRow() const { return reportRow; } +MemoryReportRow PimMemory::getReportRow() const { + MemoryReportRow row = reportRow; + row.numAlloca = localPhysicalSlots.size(); + row.sizeAlloca = 0; + for (const PhysicalSlotInfo &slot : localPhysicalSlots) + row.sizeAlloca += slot.size; + return row; +} void PimMemory::remove(mlir::Value val) { for (auto it = ownedMemEntriesMap.begin(); it != ownedMemEntriesMap.end();) @@ -847,6 +997,7 @@ struct CoreEmissionResult { OnnxMlirCompilerErrorCodes status = CompilerSuccess; MemoryReportRow reportRow; llvm::SmallVector usedWeights; + MemoryPlanArtifacts livenessArtifacts; }; template @@ -1319,6 +1470,7 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std:: assert(processedOperations > 0); result.reportRow = deviceMemory.getReportRow(); result.usedWeights = std::move(usedWeights); + result.livenessArtifacts = deviceMemory.getLivenessArtifacts(); } else { auto coreBatchOp = cast(job.coreLikeOp); @@ -1349,6 +1501,7 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std:: result.reportRow = deviceMemory.getReportRow(); result.usedWeights = std::move(usedWeights); + result.livenessArtifacts = deviceMemory.getLivenessArtifacts(); } pim_binary::patchInstructionCount(coreBinaryStream, coreCodeGen.getEmittedInstructionCount()); @@ -1382,6 +1535,18 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std:: auto weightEmission = createAndPopulateWeightFolder(weightRequests, outputDirPath); memory.setTotalWeightBytes(weightEmission.totalWeightBytes); auto& mapCoreWeightToFileName = weightEmission.mapCoreWeightToFileName; + if (std::string reportsRoot = getOutputDir(); !reportsRoot.empty()) { + std::string reportsDir = reportsRoot + "/reports"; + sys::fs::remove(reportsDir + "/pim_memory_liveness_report.txt"); + sys::fs::remove(reportsDir + "/pim_memory_liveness_report.json"); + sys::fs::remove(reportsDir + "/pim_memory_liveness_timeline.dot"); + } + std::fstream livenessReportFile; + std::unique_ptr livenessReportOs; + if (pimMemoryReport != PimMemoryReportNone) { + livenessReportFile = openReportFileWithExtension("pim_memory_liveness_report", "txt"); + livenessReportOs = std::make_unique(livenessReportFile); + } for (size_t jobIndex = 0; jobIndex < jobs.size(); ++jobIndex) { const CoreEmissionJob& job = jobs[jobIndex]; @@ -1393,6 +1558,8 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std:: return err; xbarsPerArrayGroup["core" + std::to_string(job.emittedCoreId)] = std::move(xbarsPerGroup); memory.recordCoreReport(job.emittedCoreId, result.reportRow); + if (livenessReportFile.is_open()) + *livenessReportOs << "Core " << job.emittedCoreId << ":\n" << result.livenessArtifacts.textReport; continue; } } @@ -1421,10 +1588,18 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std:: batchPerCoreRow.value_or(MemoryReportRow {}), batchRow.numAlloca, batchRow.sizeAlloca); + if (livenessReportFile.is_open()) + for (size_t jobIndex : group) + *livenessReportOs << "Batch " << batchReportId << " core " << jobs[jobIndex].emittedCoreId << ":\n" + << jobResults[jobIndex].livenessArtifacts.textReport; } maxCoreId = nextEmittedCoreId == 0 ? 0 : nextEmittedCoreId - 1; + if (livenessReportFile.is_open()) { + livenessReportOs->flush(); + livenessReportFile.close(); + } memory.flushReport(); return writeConfigJson(funcOp, memory, maxCoreId, std::move(xbarsPerArrayGroup), outputDirPath); } diff --git a/src/PIM/Compiler/PimCodeGen.hpp b/src/PIM/Compiler/PimCodeGen.hpp index 8d45349..59a043e 100644 --- a/src/PIM/Compiler/PimCodeGen.hpp +++ b/src/PIM/Compiler/PimCodeGen.hpp @@ -5,12 +5,14 @@ #include "llvm-project/clang/include/clang/Basic/LLVM.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/Hashing.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/Support/JSON.h" #include "llvm/Support/raw_os_ostream.h" #include #include #include +#include #include "onnx-mlir/Compiler/OMCompilerTypes.h" #include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp" @@ -26,6 +28,16 @@ struct MemEntry { size_t size; }; +struct PhysicalSlotInfo { + size_t id = 0; + size_t address = 0; + size_t size = 0; +}; + +struct MemoryPlanArtifacts { + std::string textReport; +}; + struct MemoryValueKey { mlir::Value value; std::optional lane; @@ -74,16 +86,20 @@ struct MemoryReportEntry { class PimMemory { llvm::SmallVector memEntries; + llvm::SmallVector localPhysicalSlots; llvm::SmallDenseMap& globalMemEntriesMap; llvm::SmallDenseMap ownedMemEntriesMap; MemoryReportRow reportRow; + MemoryPlanArtifacts livenessArtifacts; size_t minAlignment = 4; size_t firstAvailableAddress = 0; + size_t nextPhysicalSlotId = 0; MemEntry* gatherMemEntry(mlir::Value value, std::optional lane = std::nullopt); void allocateGatheredMemory(); void allocateMemoryForValue(const MemoryValueKey& key, MemEntry& memEntry, MemoryReportKind reportKind); + PhysicalSlotInfo allocatePhysicalSlot(size_t slotSize, const MemoryValueKey& key); public: PimMemory(llvm::SmallDenseMap& globalMemEntriesMap) @@ -92,6 +108,7 @@ public: void allocateHost(mlir::ModuleOp moduleOp, mlir::func::FuncOp funcOp); void allocateCore(mlir::Operation* op, std::optional lane = std::nullopt); MemoryReportRow getReportRow() const; + const MemoryPlanArtifacts& getLivenessArtifacts() const { return livenessArtifacts; } void remove(mlir::Value val); size_t getFirstAvailableAddress() const { return firstAvailableAddress; } diff --git a/src/PIM/Compiler/PimCompilerOptions.cpp b/src/PIM/Compiler/PimCompilerOptions.cpp index 4eee65c..242aea0 100644 --- a/src/PIM/Compiler/PimCompilerOptions.cpp +++ b/src/PIM/Compiler/PimCompilerOptions.cpp @@ -22,6 +22,15 @@ llvm::cl::opt llvm::cl::init(MergeSchedulerPeft), llvm::cl::cat(OnnxMlirOptions)); +llvm::cl::opt pimMemoryReport( + "pim-memory-report", + llvm::cl::desc("Emit a human-readable PIM memory planning report"), + llvm::cl::values(clEnumValN(PimMemoryReportNone, "none", "Do not emit any PIM memory planning report")), + llvm::cl::values(clEnumValN(PimMemoryReportSummary, "summary", "Emit a concise slot reuse report with key offenders")), + llvm::cl::values(clEnumValN(PimMemoryReportFull, "full", "Emit the full detailed PIM memory planning report")), + llvm::cl::init(PimMemoryReportNone), + llvm::cl::cat(OnnxMlirOptions)); + llvm::cl::opt pimOnlyCodegen("pim-only-codegen", llvm::cl::desc("Only generate code for PIM (assume input is already in bufferized PIM IR)"), diff --git a/src/PIM/Compiler/PimCompilerOptions.hpp b/src/PIM/Compiler/PimCompilerOptions.hpp index 05e51ab..b486070 100644 --- a/src/PIM/Compiler/PimCompilerOptions.hpp +++ b/src/PIM/Compiler/PimCompilerOptions.hpp @@ -24,9 +24,16 @@ typedef enum { MergeSchedulerPeft = 0, } PimMergeSchedulerType; +typedef enum { + PimMemoryReportNone = 0, + PimMemoryReportSummary = 1, + PimMemoryReportFull = 2, +} PimMemoryReportLevel; + extern llvm::cl::OptionCategory OnnxMlirOptions; extern llvm::cl::opt pimEmissionTarget; extern llvm::cl::opt pimMergeScheduler; +extern llvm::cl::opt pimMemoryReport; extern llvm::cl::opt pimOnlyCodegen; extern llvm::cl::opt useExperimentalConvImpl; diff --git a/src/PIM/Compiler/PimMemoryLiveness.cpp b/src/PIM/Compiler/PimMemoryLiveness.cpp new file mode 100644 index 0000000..d8c8471 --- /dev/null +++ b/src/PIM/Compiler/PimMemoryLiveness.cpp @@ -0,0 +1,742 @@ +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/Value.h" +#include "mlir/Interfaces/DestinationStyleOpInterface.h" + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/raw_ostream.h" + +#include +#include +#include +#include + +#include "Common/Support/CheckedArithmetic.hpp" +#include "Common/Support/ReportUtils.hpp" +#include "src/Accelerators/PIM/Common/PimCommon.hpp" +#include "src/Accelerators/PIM/Compiler/PimMemoryLiveness.hpp" +#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" + +using namespace llvm; +using namespace mlir; +using namespace onnx_mlir; + +namespace { + +static std::optional getLaneForMemoryValue(mlir::Value value, std::optional lane) { + if (!lane) + return std::nullopt; + auto allocOp = value.getDefiningOp(); + if (!allocOp || !allocOp->getParentOfType()) + return std::nullopt; + return lane; +} + +static MemoryValueKey getMemoryValueKey(mlir::Value value, std::optional lane = std::nullopt) { + return {value, getLaneForMemoryValue(value, lane)}; +} + +struct MemoryTouchInterval { + uint64_t start = 0; + uint64_t end = 0; + Operation *startOp = nullptr; + Operation *endOp = nullptr; + Operation *firstTouchOp = nullptr; + Operation *lastTouchOp = nullptr; + uint64_t firstTouchPosition = 0; + uint64_t lastTouchPosition = 0; + bool hasRuntimeUse = false; + bool startUsedAllocFallback = false; + bool endUsedFallback = false; + bool escapesLoop = false; + std::string fallbackReason; + llvm::SmallVector aliasesFollowed; +}; + +struct OperationOrdering { + llvm::DenseMap position; + llvm::DenseMap subtreeEnd; + uint64_t nextPosition = 0; +}; + +static std::string printValueToString(mlir::Value value) { + std::string text; + llvm::raw_string_ostream os(text); + value.print(os); + os.flush(); + return text; +} + +static std::string printOperationToString(Operation *op) { + if (!op) + return ""; + std::string text; + llvm::raw_string_ostream os(text); + op->print(os); + os.flush(); + return text; +} + +static std::string printLocationToString(Location loc) { + std::string text; + llvm::raw_string_ostream os(text); + loc.print(os); + os.flush(); + return text; +} + +static std::string collapseWhitespace(StringRef text) { + std::string out; + out.reserve(text.size()); + bool lastWasSpace = false; + for (char c : text) { + bool isSpace = c == ' ' || c == '\n' || c == '\t' || c == '\r'; + if (isSpace) { + if (!lastWasSpace && !out.empty()) + out.push_back(' '); + lastWasSpace = true; + continue; + } + out.push_back(c); + lastWasSpace = false; + } + return out; +} + +static std::string abbreviate(StringRef text, size_t maxLen) { + if (text.size() <= maxLen) + return text.str(); + return (text.take_front(maxLen - 3) + "...").str(); +} + +static std::string summarizeValue(mlir::Value value, size_t maxLen = 72) { + return abbreviate(collapseWhitespace(printValueToString(value)), maxLen); +} + +static std::string summarizeOperation(Operation *op, size_t maxLen = 96) { + if (!op) + return ""; + std::string prefix = op->getName().getStringRef().str(); + std::string full = collapseWhitespace(printOperationToString(op)); + if (full == prefix) + return prefix; + return abbreviate(prefix + " :: " + full, maxLen); +} + +static std::string summarizeLocation(Location loc, size_t maxLen = 88) { + return abbreviate(collapseWhitespace(printLocationToString(loc)), maxLen); +} + +static void assignOperationOrdering(Operation *op, OperationOrdering &ordering) { + uint64_t position = ordering.nextPosition++; + ordering.position[op] = position; + uint64_t end = position; + for (Region ®ion : op->getRegions()) + for (Block &block : region) + for (Operation &nestedOp : block) { + assignOperationOrdering(&nestedOp, ordering); + end = std::max(end, ordering.subtreeEnd.lookup(&nestedOp)); + } + ordering.subtreeEnd[op] = end; +} + +static OperationOrdering buildOperationOrdering(Operation *coreLikeOp) { + OperationOrdering ordering; + if (!coreLikeOp || coreLikeOp->getNumRegions() != 1 || coreLikeOp->getRegion(0).empty()) + return ordering; + + for (Operation &op : coreLikeOp->getRegion(0).front()) + assignOperationOrdering(&op, ordering); + return ordering; +} + +static bool isSupportedAliasOp(Operation *op) { + return isa(op); +} + +static bool isRuntimeMemoryTouchOp(Operation *op) { + return isa(op); +} + +static bool isIgnoredLivenessUser(Operation *op) { + return isSupportedAliasOp(op) || isa(op) || isCoreStaticAddressOp(op); +} + +static bool isWithin(mlir::Value value, Region *region) { + if (!region) + return false; + if (auto blockArg = dyn_cast(value)) + return blockArg.getOwner()->getParent() == region; + if (Operation *definingOp = value.getDefiningOp()) + return definingOp->getParentRegion() == region || region->isAncestor(definingOp->getParentRegion()); + return false; +} + +static bool isNestedAllocation(Operation *coreLikeOp, memref::AllocOp allocOp) { + if (!coreLikeOp || coreLikeOp->getNumRegions() != 1 || coreLikeOp->getRegion(0).empty()) + return false; + return allocOp->getBlock() != &coreLikeOp->getRegion(0).front(); +} + +static void addFallbackReason(std::string &reason, StringRef newReason) { + if (newReason.empty()) + return; + if (!reason.empty()) + reason += "; "; + reason += newReason.str(); +} + +static void appendAliasDescription(llvm::SmallVectorImpl &aliases, mlir::Value value) { + std::string text = printValueToString(value); + if (!llvm::is_contained(aliases, text)) + aliases.push_back(std::move(text)); +} + +struct OrderedTouchRange { + uint64_t start = 0; + uint64_t end = 0; + Operation *startOp = nullptr; + Operation *endOp = nullptr; + bool escapedLoop = false; +}; + +static OrderedTouchRange +getEffectiveTouchRange(mlir::Value definingValue, Operation *user, const OperationOrdering &ordering) { + OrderedTouchRange range { + ordering.position.lookup(user), ordering.position.lookup(user), user, user, false}; + for (Operation *current = user; current; current = current->getParentOp()) { + auto forOp = dyn_cast(current); + if (!forOp || isWithin(definingValue, &forOp.getRegion())) + continue; + range.start = std::min(range.start, ordering.position.lookup(forOp)); + range.end = std::max(range.end, ordering.subtreeEnd.lookup(forOp)); + range.startOp = forOp; + range.endOp = forOp; + range.escapedLoop = true; + } + return range; +} + +static MemoryTouchInterval +computeMemoryTouchInterval(memref::AllocOp allocOp, const OperationOrdering &ordering, uint64_t fallbackEnd) { + MemoryTouchInterval interval; + interval.start = ordering.position.lookup(allocOp); + interval.end = interval.start; + interval.startOp = allocOp; + interval.endOp = allocOp; + + SmallPtrSet visitedValues; + SmallPtrSet visitedUsers; + SmallVector pendingValues; + pendingValues.push_back(allocOp.getResult()); + auto parentLoop = allocOp->getParentOfType(); + + while (!pendingValues.empty()) { + mlir::Value value = pendingValues.pop_back_val(); + if (!visitedValues.insert(value).second) + continue; + + for (Operation *user : value.getUsers()) { + if (!visitedUsers.insert(user).second) + continue; + + if (isSupportedAliasOp(user)) { + for (mlir::Value result : user->getResults()) { + pendingValues.push_back(result); + appendAliasDescription(interval.aliasesFollowed, result); + } + } + + if (auto dpsOp = dyn_cast(user)) { + for (OpResult result : user->getResults()) { + OpOperand *tiedOperand = dpsOp.getTiedOpOperand(result); + if (!tiedOperand || tiedOperand->get() != value) + continue; + pendingValues.push_back(result); + appendAliasDescription(interval.aliasesFollowed, result); + } + } + + if (auto forOp = dyn_cast(user)) { + for (auto [index, initArg] : llvm::enumerate(forOp.getInitArgs())) { + if (initArg != value) + continue; + pendingValues.push_back(forOp.getRegionIterArgs()[index]); + pendingValues.push_back(forOp.getResult(index)); + appendAliasDescription(interval.aliasesFollowed, forOp.getRegionIterArgs()[index]); + appendAliasDescription(interval.aliasesFollowed, forOp.getResult(index)); + if (parentLoop && forOp != parentLoop) + interval.escapesLoop = true; + } + } + + if (auto yieldOp = dyn_cast(user)) { + auto forOp = dyn_cast(yieldOp->getParentOp()); + if (!forOp) { + addFallbackReason(interval.fallbackReason, "yield without scf.for parent"); + } + else { + for (auto [index, operand] : llvm::enumerate(yieldOp.getOperands())) { + if (operand != value) + continue; + pendingValues.push_back(forOp.getResult(index)); + appendAliasDescription(interval.aliasesFollowed, forOp.getResult(index)); + if (parentLoop && forOp == parentLoop) + interval.escapesLoop = true; + } + } + } + + if (isRuntimeMemoryTouchOp(user)) { + uint64_t touchPosition = ordering.position.lookup(user); + if (!interval.hasRuntimeUse || touchPosition < interval.firstTouchPosition) { + interval.firstTouchPosition = touchPosition; + interval.firstTouchOp = user; + } + if (!interval.hasRuntimeUse || touchPosition > interval.lastTouchPosition) { + interval.lastTouchPosition = touchPosition; + interval.lastTouchOp = user; + } + + OrderedTouchRange range = getEffectiveTouchRange(allocOp.getResult(), user, ordering); + interval.escapesLoop |= range.escapedLoop; + if (!interval.hasRuntimeUse) { + interval.start = range.start; + interval.end = range.end; + interval.startOp = range.startOp; + interval.endOp = range.endOp; + interval.hasRuntimeUse = true; + } + else { + if (range.start < interval.start) { + interval.start = range.start; + interval.startOp = range.startOp; + } + if (range.end > interval.end) { + interval.end = range.end; + interval.endOp = range.endOp; + } + } + continue; + } + + if (isIgnoredLivenessUser(user)) + continue; + + addFallbackReason(interval.fallbackReason, "unhandled user op"); + interval.endUsedFallback = true; + } + } + + if (!interval.hasRuntimeUse) { + interval.startUsedAllocFallback = true; + interval.endUsedFallback = true; + interval.start = ordering.position.lookup(allocOp); + interval.end = fallbackEnd; + interval.startOp = allocOp; + interval.endOp = allocOp->getParentOp(); + interval.firstTouchPosition = interval.start; + interval.lastTouchPosition = interval.end; + addFallbackReason(interval.fallbackReason, "no runtime memory touch"); + return interval; + } + + if (interval.endUsedFallback) { + interval.end = std::max(interval.end, fallbackEnd); + interval.endOp = allocOp->getParentOp(); + } + + return interval; +} + +static FailureOr getAllocSizeBytes(memref::AllocOp allocOp) { + auto type = dyn_cast(allocOp.getType()); + if (!type) + return failure(); + auto checkedBytes = pim::getCheckedShapedTypeSizeInBytes(type, allocOp, "memory allocation byte size"); + if (failed(checkedBytes)) + return failure(); + return pim::checkedSize(*checkedBytes, allocOp, "memory allocation byte size"); +} + +static bool intervalsOverlap(const LocalAllocInterval &lhs, const LocalAllocInterval &rhs) { + return !(lhs.end < rhs.start || rhs.end < lhs.start); +} + +static uint64_t getSlotLogicalBytes(const PlannedPhysicalSlot &slot, ArrayRef intervals) { + uint64_t slotLogicalBytes = 0; + for (size_t intervalIndex : slot.intervalIndices) + slotLogicalBytes += intervals[intervalIndex].size; + return slotLogicalBytes; +} + +} // namespace + +SmallVector onnx_mlir::buildLocalAllocIntervals(Operation *coreLikeOp, + std::optional lane) { + SmallVector intervals; + OperationOrdering ordering = buildOperationOrdering(coreLikeOp); + if (ordering.position.empty()) + return intervals; + + uint64_t fallbackEnd = ordering.nextPosition == 0 ? 0 : ordering.nextPosition - 1; + size_t nextIntervalId = 0; + coreLikeOp->walk([&](memref::AllocOp allocOp) { + auto checkedSize = getAllocSizeBytes(allocOp); + if (failed(checkedSize)) { + llvm::errs() << "Failed to compute local allocation size for value: "; + allocOp.getResult().print(llvm::errs()); + llvm::errs() << "\n"; + llvm_unreachable("Failed to compute local allocation size"); + } + + MemoryTouchInterval touchInterval = computeMemoryTouchInterval(allocOp, ordering, fallbackEnd); + LocalAllocInterval interval; + interval.id = nextIntervalId++; + interval.alloc = allocOp; + interval.key = getMemoryValueKey(allocOp.getResult(), lane); + interval.start = touchInterval.start; + interval.end = touchInterval.end; + interval.size = *checkedSize; + interval.startOp = touchInterval.startOp; + interval.endOp = touchInterval.endOp; + interval.firstTouchOp = touchInterval.firstTouchOp; + interval.lastTouchOp = touchInterval.lastTouchOp; + interval.firstTouchPosition = touchInterval.firstTouchPosition; + interval.lastTouchPosition = touchInterval.lastTouchPosition; + interval.startUsedAllocFallback = touchInterval.startUsedAllocFallback; + interval.endUsedFallback = touchInterval.endUsedFallback; + interval.hasRuntimeUse = touchInterval.hasRuntimeUse; + interval.insideNestedRegion = isNestedAllocation(coreLikeOp, allocOp); + interval.escapesLoop = touchInterval.escapesLoop; + interval.fallbackReason = std::move(touchInterval.fallbackReason); + interval.aliasesFollowed = std::move(touchInterval.aliasesFollowed); + intervals.push_back(std::move(interval)); + }); + + return intervals; +} + +SmallVector onnx_mlir::planPhysicalSlots(MutableArrayRef intervals) { + SmallVector slots; + SmallVector intervalOrder(intervals.size()); + std::iota(intervalOrder.begin(), intervalOrder.end(), 0); + llvm::stable_sort(intervalOrder, [&](size_t lhsIndex, size_t rhsIndex) { + const LocalAllocInterval &lhs = intervals[lhsIndex]; + const LocalAllocInterval &rhs = intervals[rhsIndex]; + if (lhs.size != rhs.size) + return lhs.size > rhs.size; + if (lhs.start != rhs.start) + return lhs.start < rhs.start; + if (lhs.end != rhs.end) + return lhs.end < rhs.end; + return lhs.id < rhs.id; + }); + + for (size_t intervalIndex : intervalOrder) { + LocalAllocInterval &interval = intervals[intervalIndex]; + PlannedPhysicalSlot *bestSlot = nullptr; + auto bestKey = std::tuple( + std::numeric_limits::max(), + std::numeric_limits::max(), + std::numeric_limits::max(), + std::numeric_limits::max()); + + for (size_t slotIndex = 0; slotIndex < slots.size(); ++slotIndex) { + PlannedPhysicalSlot &slot = slots[slotIndex]; + bool compatible = true; + for (size_t otherIndex : slot.intervalIndices) { + if (intervalsOverlap(interval, intervals[otherIndex])) { + compatible = false; + break; + } + } + if (!compatible) + continue; + + size_t resultingSize = std::max(slot.requiredSize, interval.size); + size_t growth = resultingSize - slot.requiredSize; + auto candidateKey = std::tuple( + growth, resultingSize, slot.intervalIndices.size(), slot.id); + if (candidateKey < bestKey) { + bestKey = candidateKey; + bestSlot = &slot; + } + } + + if (!bestSlot) { + slots.push_back({slots.size(), interval.size, interval.size, 0, {intervalIndex}}); + interval.slotPlanIndex = slots.size() - 1; + interval.physicalSlotId = slots.back().id; + interval.physicalSlotSize = slots.back().requiredSize; + continue; + } + + bestSlot->requiredSize = std::max(bestSlot->requiredSize, interval.size); + bestSlot->size = bestSlot->requiredSize; + bestSlot->intervalIndices.push_back(intervalIndex); + interval.slotPlanIndex = static_cast(bestSlot - slots.data()); + interval.physicalSlotId = bestSlot->id; + interval.physicalSlotSize = bestSlot->requiredSize; + } + + return slots; +} + +MemoryPlanArtifacts onnx_mlir::buildMemoryPlanArtifacts(Operation *coreLikeOp, + std::optional lane, + ArrayRef intervals, + ArrayRef slots, + size_t addressLimit, + PimMemoryReportLevel reportLevel) { + MemoryPlanArtifacts artifacts; + + uint64_t totalLogicalBytes = 0; + uint64_t totalPhysicalBytes = 0; + uint64_t fallbackIntervals = 0; + uint64_t noRuntimeTouchIntervals = 0; + uint64_t reusedAllocations = 0; + uint64_t nestedIntervals = 0; + uint64_t loopEscapingIntervals = 0; + size_t largestLogicalAllocation = 0; + size_t largestPhysicalSlot = 0; + size_t maximumAssignedAddress = 0; + + for (const LocalAllocInterval &interval : intervals) { + totalLogicalBytes += interval.size; + largestLogicalAllocation = std::max(largestLogicalAllocation, interval.size); + maximumAssignedAddress = std::max(maximumAssignedAddress, interval.assignedAddress + interval.physicalSlotSize); + if (interval.startUsedAllocFallback || interval.endUsedFallback) + ++fallbackIntervals; + if (!interval.hasRuntimeUse) + ++noRuntimeTouchIntervals; + if (interval.insideNestedRegion) + ++nestedIntervals; + if (interval.escapesLoop) + ++loopEscapingIntervals; + } + for (const PlannedPhysicalSlot &slot : slots) { + totalPhysicalBytes += slot.size; + largestPhysicalSlot = std::max(largestPhysicalSlot, slot.size); + if (slot.intervalIndices.size() > 1) + reusedAllocations += slot.intervalIndices.size() - 1; + } + + uint64_t savedBytes = totalLogicalBytes >= totalPhysicalBytes ? totalLogicalBytes - totalPhysicalBytes : 0; + double savedPercent = + totalLogicalBytes == 0 ? 0.0 : 100.0 * static_cast(savedBytes) / static_cast(totalLogicalBytes); + + raw_string_ostream os(artifacts.textReport); + os << "=== PIM Memory Liveness Report ===\n"; + os << "Op: " << coreLikeOp->getName() << "\n"; + if (lane) + os << "Lane: " << *lane << "\n"; + os << "Summary:\n"; + os << " logical allocation bytes: " << formatReportMemory(totalLogicalBytes) << " (" << totalLogicalBytes << ")\n"; + os << " physical allocation bytes: " << formatReportMemory(totalPhysicalBytes) << " (" << totalPhysicalBytes << ")\n"; + os << " saved bytes: " << formatReportMemory(savedBytes) << " (" << savedBytes << ")\n"; + os << " saved percent: " << format("%.2f%%", savedPercent) << "\n"; + os << " intervals: " << intervals.size() << "\n"; + os << " physical slots: " << slots.size() << "\n"; + os << " reused allocations: " << reusedAllocations << "\n"; + os << " fallback intervals: " << fallbackIntervals << "\n"; + os << " intervals with no runtime memory touch: " << noRuntimeTouchIntervals << "\n"; + os << " nested allocations: " << nestedIntervals << "\n"; + os << " loop-escaping allocations: " << loopEscapingIntervals << "\n"; + os << " largest logical allocation: " << largestLogicalAllocation << "\n"; + os << " largest physical slot: " << largestPhysicalSlot << "\n"; + os << " address limit: " << addressLimit << "\n"; + os << " peak physical memory: " << formatReportMemory(maximumAssignedAddress) << " (" << maximumAssignedAddress << ")\n"; + os << " maximum assigned address: " << maximumAssignedAddress << "\n"; + + os << "\nHow To Read:\n"; + os << " `summary` only shows the strongest reuse cases and the worst offenders.\n"; + os << " Use `--pim-memory-report=full` when you need the complete slot-by-slot and interval-by-interval dump.\n"; + os << " Large single-use slots, fallback intervals, and nested single-use allocations are the best places\n"; + os << " to inspect if allocations should be moved, sunk, or made easier to coalesce earlier in the pipeline.\n"; + + SmallVector reusedSlots; + SmallVector singleUseSlots; + for (const PlannedPhysicalSlot &slot : slots) { + if (slot.intervalIndices.size() > 1) + reusedSlots.push_back(&slot); + else + singleUseSlots.push_back(&slot); + } + + llvm::stable_sort(reusedSlots, [&](const PlannedPhysicalSlot *lhs, const PlannedPhysicalSlot *rhs) { + uint64_t lhsLogicalBytes = getSlotLogicalBytes(*lhs, intervals); + uint64_t rhsLogicalBytes = getSlotLogicalBytes(*rhs, intervals); + if (lhs->intervalIndices.size() != rhs->intervalIndices.size()) + return lhs->intervalIndices.size() > rhs->intervalIndices.size(); + if (lhsLogicalBytes != rhsLogicalBytes) + return lhsLogicalBytes > rhsLogicalBytes; + if (lhs->size != rhs->size) + return lhs->size > rhs->size; + return lhs->id < rhs->id; + }); + llvm::stable_sort(singleUseSlots, [&](const PlannedPhysicalSlot *lhs, const PlannedPhysicalSlot *rhs) { + if (lhs->size != rhs->size) + return lhs->size > rhs->size; + return lhs->id < rhs->id; + }); + + constexpr size_t kSummaryReuseLimit = 6; + constexpr size_t kSummaryOffenderLimit = 10; + + os << "\nBest Reuse:\n"; + if (reusedSlots.empty()) { + os << " no slots were shared by multiple intervals\n"; + } else { + for (const PlannedPhysicalSlot *slot : ArrayRef(reusedSlots).take_front(kSummaryReuseLimit)) { + uint64_t slotLogicalBytes = getSlotLogicalBytes(*slot, intervals); + os << " slot #" << slot->id + << " addr=" << slot->address + << " size=" << formatReportMemory(slot->size) + << " intervals=" << slot->intervalIndices.size() + << " logical_sum=" << formatReportMemory(slotLogicalBytes) << "\n"; + for (size_t intervalIndex : slot->intervalIndices) { + const LocalAllocInterval &interval = intervals[intervalIndex]; + os << " #" << interval.id + << " [" << interval.start << "," << interval.end << "]" + << " logical=" << formatReportMemory(interval.size) + << " first=" << summarizeOperation(interval.firstTouchOp, 40) + << " last=" << summarizeOperation(interval.lastTouchOp, 40) << "\n"; + } + } + } + + os << "\nTop Offenders:\n"; + bool printedAttention = false; + for (const PlannedPhysicalSlot *slot : ArrayRef(singleUseSlots).take_front(kSummaryOffenderLimit)) { + const LocalAllocInterval &interval = intervals[slot->intervalIndices.front()]; + printedAttention = true; + os << " slot #" << slot->id << " is single-use" + << " size=" << formatReportMemory(slot->size) + << " interval=#" << interval.id + << " value=" << summarizeValue(interval.key.value, 56) << "\n"; + os << " first=" << summarizeOperation(interval.firstTouchOp, 40) + << " last=" << summarizeOperation(interval.lastTouchOp, 40) + << " nested=" << (interval.insideNestedRegion ? "yes" : "no") + << " escapes_loop=" << (interval.escapesLoop ? "yes" : "no") << "\n"; + } + size_t fallbackPrinted = 0; + for (const LocalAllocInterval &interval : intervals) { + if (!(interval.startUsedAllocFallback || interval.endUsedFallback) || fallbackPrinted >= kSummaryOffenderLimit) + continue; + printedAttention = true; + ++fallbackPrinted; + os << " fallback interval #" << interval.id + << " size=" << formatReportMemory(interval.size) + << " value=" << summarizeValue(interval.key.value, 56) << "\n"; + os << " reason: " << (interval.fallbackReason.empty() ? "" : interval.fallbackReason) << "\n"; + } + size_t nestedPrinted = 0; + for (const LocalAllocInterval &interval : intervals) { + if (nestedPrinted >= kSummaryOffenderLimit) + break; + if (!(interval.insideNestedRegion && slots[interval.slotPlanIndex].intervalIndices.size() == 1)) + continue; + printedAttention = true; + ++nestedPrinted; + os << " nested single-use interval #" << interval.id + << " slot #" << interval.physicalSlotId + << " size=" << formatReportMemory(interval.size) + << " value=" << summarizeValue(interval.key.value, 56) << "\n"; + os << " hint: move or sink this alloc inside the nested region if the IR allows it.\n"; + } + if (!printedAttention) + os << " no obvious blockers detected in this core\n"; + + if (reportLevel == PimMemoryReportFull) { + os << "\nSlot Reuse:\n"; + for (const PlannedPhysicalSlot &slot : slots) { + uint64_t slotLogicalBytes = getSlotLogicalBytes(slot, intervals); + os << " slot #" << slot.id << " addr=" << slot.address << " size=" << formatReportMemory(slot.size) << " (" + << slot.size << ")" + << " intervals=" << slot.intervalIndices.size() + << " logical_sum=" << formatReportMemory(slotLogicalBytes) << "\n"; + for (size_t intervalIndex : slot.intervalIndices) { + const LocalAllocInterval &interval = intervals[intervalIndex]; + mlir::Value allocValue = interval.key.value; + os << " [" << interval.start << "," << interval.end << "]" + << " #" << interval.id + << " logical=" << formatReportMemory(interval.size) + << " nested=" << (interval.insideNestedRegion ? "yes" : "no") + << " escapes_loop=" << (interval.escapesLoop ? "yes" : "no") + << " first=" << summarizeOperation(interval.firstTouchOp, 48) + << " last=" << summarizeOperation(interval.lastTouchOp, 48) << "\n"; + os << " value=" << summarizeValue(allocValue) << "\n"; + } + } + } + + if (reportLevel == PimMemoryReportFull) { + os << "\nInterval Details:\n"; + for (const LocalAllocInterval &interval : intervals) { + const PlannedPhysicalSlot &slot = slots[interval.slotPlanIndex]; + mlir::Value allocValue = interval.key.value; + Operation *definingOp = allocValue.getDefiningOp(); + os << " #" << interval.id + << " slot=" << slot.id + << " live=[" << interval.start << "," << interval.end << "]" + << " logical=" << formatReportMemory(interval.size) + << " slot_size=" << formatReportMemory(interval.physicalSlotSize) + << " addr=" << interval.assignedAddress << "\n"; + os << " value=" << summarizeValue(allocValue, 88) << "\n"; + os << " type=" << allocValue.getType() << "\n"; + os << " loc=" + << summarizeLocation(definingOp ? definingOp->getLoc() : UnknownLoc::get(coreLikeOp->getContext())) << "\n"; + os << " nested=" << (interval.insideNestedRegion ? "yes" : "no") + << " escapes_loop=" << (interval.escapesLoop ? "yes" : "no") + << " start_fallback=" << (interval.startUsedAllocFallback ? "yes" : "no") + << " end_fallback=" << (interval.endUsedFallback ? "yes" : "no") << "\n"; + os << " first_use=" << summarizeOperation(interval.firstTouchOp) << " @" << interval.firstTouchPosition + << "\n"; + os << " last_use=" << summarizeOperation(interval.lastTouchOp) << " @" << interval.lastTouchPosition << "\n"; + os << " slot_peers="; + bool first = true; + for (size_t otherIndex : slot.intervalIndices) { + if (intervals[otherIndex].id == interval.id) + continue; + if (!first) + os << ", "; + os << "#" << intervals[otherIndex].id; + first = false; + } + if (first) + os << ""; + os << "\n"; + if (!interval.fallbackReason.empty()) + os << " fallback_reason=" << interval.fallbackReason << "\n"; + if (!interval.aliasesFollowed.empty()) { + os << " aliases_followed=" << interval.aliasesFollowed.size() << "\n"; + for (const std::string &alias : interval.aliasesFollowed) + os << " - " << abbreviate(collapseWhitespace(alias), 108) << "\n"; + } + } + } + os.flush(); + + return artifacts; +} diff --git a/src/PIM/Compiler/PimMemoryLiveness.hpp b/src/PIM/Compiler/PimMemoryLiveness.hpp new file mode 100644 index 0000000..5925fcd --- /dev/null +++ b/src/PIM/Compiler/PimMemoryLiveness.hpp @@ -0,0 +1,63 @@ +#pragma once + +#include "mlir/Dialect/MemRef/IR/MemRef.h" + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" + +#include +#include +#include + +#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp" +#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" + +namespace onnx_mlir { + +struct LocalAllocInterval { + size_t id = 0; + mlir::memref::AllocOp alloc; + MemoryValueKey key; + uint64_t start = 0; + uint64_t end = 0; + size_t size = 0; + mlir::Operation *startOp = nullptr; + mlir::Operation *endOp = nullptr; + mlir::Operation *firstTouchOp = nullptr; + mlir::Operation *lastTouchOp = nullptr; + uint64_t firstTouchPosition = 0; + uint64_t lastTouchPosition = 0; + bool startUsedAllocFallback = false; + bool endUsedFallback = false; + bool hasRuntimeUse = false; + bool insideNestedRegion = false; + bool escapesLoop = false; + std::string fallbackReason; + llvm::SmallVector aliasesFollowed; + size_t slotPlanIndex = std::numeric_limits::max(); + size_t physicalSlotId = std::numeric_limits::max(); + size_t assignedAddress = 0; + size_t physicalSlotSize = 0; +}; + +struct PlannedPhysicalSlot { + size_t id = std::numeric_limits::max(); + size_t requiredSize = 0; + size_t size = 0; + size_t address = 0; + llvm::SmallVector intervalIndices; +}; + +llvm::SmallVector buildLocalAllocIntervals(mlir::Operation *coreLikeOp, + std::optional lane); + +llvm::SmallVector planPhysicalSlots(llvm::MutableArrayRef intervals); + +MemoryPlanArtifacts buildMemoryPlanArtifacts(mlir::Operation *coreLikeOp, + std::optional lane, + llvm::ArrayRef intervals, + llvm::ArrayRef slots, + size_t addressLimit, + PimMemoryReportLevel reportLevel); + +} // namespace onnx_mlir diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.cpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.cpp index 260389b..de22489 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.cpp +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.cpp @@ -35,7 +35,7 @@ FailureOr materializeContiguousInputMemRef(Value memrefValue, Location lo } Value allocateContiguousResultMemRefLike(Value memrefValue, Location loc, RewriterBase& rewriter) { - if (succeeded(resolveContiguousAddress(memrefValue))) + if (succeeded(resolveContiguousAddress(memrefValue)) || succeeded(compileContiguousAddressExpr(memrefValue))) return memrefValue; auto shapedType = cast(memrefValue.getType()); diff --git a/src/PIM/Dialect/Pim/Transforms/MemoryCoalescing/MemoryCoalescing.cpp b/src/PIM/Dialect/Pim/Transforms/MemoryCoalescing/MemoryCoalescing.cpp index 4153489..ec1fe81 100644 --- a/src/PIM/Dialect/Pim/Transforms/MemoryCoalescing/MemoryCoalescing.cpp +++ b/src/PIM/Dialect/Pim/Transforms/MemoryCoalescing/MemoryCoalescing.cpp @@ -19,7 +19,7 @@ namespace pim { namespace { -static bool isSupportedAliasOp(Operation* op) { +static bool isSupportedAliasOp(Operation *op) { return isa(op); } @@ -32,32 +32,51 @@ static uint64_t getTypeSizeBytes(MemRefType type) { return static_cast(type.getNumElements() * getElementTypeSizeInBytes(type.getElementType())); } -static Operation* getTopLevelAncestorInBody(Operation* op, Block& body) { - Operation* current = op; - while (current && current->getBlock() != &body) +static Operation *getTopLevelAncestorInBlock(Operation *op, Block &block) { + Operation *current = op; + while (current && current->getBlock() != &block) current = current->getParentOp(); return current; } +static void analyzeBlock(Block &block, MemoryCoalescingAnalysis &analysis); + static FailureOr -getLastUseInstruction(memref::AllocOp allocOp, Block& body, const DenseMap& opOrder) { +getLastUseInstruction(memref::AllocOp allocOp, Block &scopeBlock, const DenseMap &opOrder) { uint64_t endInstruction = opOrder.lookup(allocOp); - SmallPtrSet visited; + SmallPtrSet visitedValues; + SmallPtrSet visitedUsers; SmallVector pendingValues; pendingValues.push_back(allocOp.getResult()); while (!pendingValues.empty()) { Value value = pendingValues.pop_back_val(); - for (Operation* user : value.getUsers()) { - Operation* orderedUser = getTopLevelAncestorInBody(user, body); - if (!orderedUser) - return failure(); - if (!visited.insert(user).second) + if (!visitedValues.insert(value).second) + continue; + + for (Operation *user : value.getUsers()) { + if (!visitedUsers.insert(user).second) continue; if (isSupportedAliasOp(user)) - for (Value result : user->getResults()) - pendingValues.push_back(result); + llvm::append_range(pendingValues, user->getResults()); + + if (auto dpsOp = dyn_cast(user)) { + for (OpResult result : user->getResults()) { + OpOperand *tiedOperand = dpsOp.getTiedOpOperand(result); + if (tiedOperand && tiedOperand->get() == value) + pendingValues.push_back(result); + } + } + + if (auto forOp = dyn_cast(user)) { + for (auto [index, initArg] : llvm::enumerate(forOp.getInitArgs())) { + if (initArg != value) + continue; + pendingValues.push_back(forOp.getRegionIterArgs()[index]); + pendingValues.push_back(forOp.getResult(index)); + } + } if (auto yieldOp = dyn_cast(user)) { auto forOp = dyn_cast(yieldOp->getParentOp()); @@ -68,20 +87,9 @@ getLastUseInstruction(memref::AllocOp allocOp, Block& body, const DenseMap(user)) { - for (auto [index, initArg] : llvm::enumerate(forOp.getInitArgs())) - if (initArg == value) - pendingValues.push_back(forOp.getResult(index)); - } - - if (auto dpsOp = dyn_cast(user)) { - for (OpResult result : user->getResults()) { - OpOperand* tiedOperand = dpsOp.getTiedOpOperand(result); - if (!tiedOperand || tiedOperand->get() != value) - continue; - pendingValues.push_back(result); - } - } + Operation *orderedUser = getTopLevelAncestorInBlock(user, scopeBlock); + if (!orderedUser) + return failure(); auto order = opOrder.find(orderedUser); if (order == opOrder.end()) @@ -93,101 +101,126 @@ getLastUseInstruction(memref::AllocOp allocOp, Block& body, const DenseMapgetNumRegions() != 1 || coreLikeOp->getRegion(0).empty()) - return analysis; - - Block& body = coreLikeOp->getRegion(0).front(); - DenseMap opOrder; + DenseMap opOrder; uint64_t nextInstruction = 0; - for (Operation& op : body) + for (Operation &op : block) opOrder.try_emplace(&op, nextInstruction++); - for (Operation& op : body) { + MemoryCoalescingBlockAnalysis blockAnalysis; + blockAnalysis.block = █ + + for (Operation &op : block) { auto allocOp = dyn_cast(&op); if (!allocOp) continue; auto allocType = dyn_cast(allocOp.getType()); if (!isCandidateAllocType(allocType)) { - ++analysis.skippedAllocations; + ++blockAnalysis.skippedAllocations; continue; } - auto endInstruction = getLastUseInstruction(allocOp, body, opOrder); + auto endInstruction = getLastUseInstruction(allocOp, block, opOrder); if (failed(endInstruction)) { - ++analysis.skippedAllocations; + ++blockAnalysis.skippedAllocations; continue; } - analysis.candidates.push_back( - AllocationCandidate {allocOp, opOrder.lookup(allocOp), *endInstruction, getTypeSizeBytes(allocType)}); + blockAnalysis.candidates.push_back( + AllocationCandidate {allocOp, &block, opOrder.lookup(allocOp), *endInstruction, getTypeSizeBytes(allocType)}); } + analysis.skippedAllocations += blockAnalysis.skippedAllocations; + if (!blockAnalysis.candidates.empty() || blockAnalysis.skippedAllocations != 0) + analysis.blocks.push_back(std::move(blockAnalysis)); +} + +} // namespace + +uint64_t MemoryCoalescingAnalysis::getCandidateCount() const { + uint64_t total = 0; + for (const MemoryCoalescingBlockAnalysis &block : blocks) + total += block.candidates.size(); + return total; +} + +MemoryCoalescingAnalysis analyzeMemoryCoalescingCandidates(Operation *coreLikeOp) { + MemoryCoalescingAnalysis analysis; + if (!coreLikeOp || coreLikeOp->getNumRegions() != 1 || coreLikeOp->getRegion(0).empty()) + return analysis; + + analyzeBlock(coreLikeOp->getRegion(0).front(), analysis); return analysis; } MemoryCoalescingStats -coalesceMemory(Operation* coreLikeOp, const MemoryCoalescingAnalysis& analysis, RewriterBase& rewriter) { +coalesceMemory(Operation *coreLikeOp, const MemoryCoalescingAnalysis &analysis, RewriterBase &rewriter) { + (void) coreLikeOp; + MemoryCoalescingStats stats; stats.skippedAllocations = analysis.skippedAllocations; - auto candidates = analysis.candidates; - llvm::sort(candidates, [](const AllocationCandidate& lhs, const AllocationCandidate& rhs) { - if (lhs.startInstruction != rhs.startInstruction) - return lhs.startInstruction < rhs.startInstruction; - return lhs.endInstruction < rhs.endInstruction; - }); + for (const MemoryCoalescingBlockAnalysis &blockAnalysis : analysis.blocks) { + auto candidates = blockAnalysis.candidates; + llvm::sort(candidates, [](const AllocationCandidate &lhs, const AllocationCandidate &rhs) { + if (lhs.startInstruction != rhs.startInstruction) + return lhs.startInstruction < rhs.startInstruction; + return lhs.endInstruction < rhs.endInstruction; + }); - struct ActiveStorage { - memref::AllocOp root; - uint64_t endInstruction = 0; - }; + struct ActiveStorage { + memref::AllocOp root; + uint64_t endInstruction = 0; + }; - SmallVector active; - SmallVector freeList; + SmallVector active; + SmallVector freeList; - for (AllocationCandidate& candidate : candidates) { - for (auto it = active.begin(); it != active.end();) { - if (it->endInstruction < candidate.startInstruction) { - freeList.push_back(it->root); - it = active.erase(it); + for (AllocationCandidate &candidate : candidates) { + for (auto it = active.begin(); it != active.end();) { + if (it->endInstruction < candidate.startInstruction) { + freeList.push_back(it->root); + it = active.erase(it); + continue; + } + ++it; + } + + auto bestFit = freeList.end(); + uint64_t bestFitBytes = std::numeric_limits::max(); + auto candidateType = cast(candidate.alloc.getType()); + for (auto it = freeList.begin(); it != freeList.end(); ++it) { + auto freeType = cast((*it).getType()); + if (freeType != candidateType) + continue; + + uint64_t freeBytes = getTypeSizeBytes(freeType); + if (freeBytes < candidate.sizeBytes || freeBytes >= bestFitBytes) + continue; + + bestFit = it; + bestFitBytes = freeBytes; + } + + if (bestFit == freeList.end()) { + active.push_back(ActiveStorage {candidate.alloc, candidate.endInstruction}); continue; } - ++it; + + memref::AllocOp root = *bestFit; + freeList.erase(bestFit); + candidate.alloc.getResult().replaceAllUsesWith(root.getResult()); + rewriter.eraseOp(candidate.alloc); + active.push_back(ActiveStorage {root, candidate.endInstruction}); + ++stats.removedAllocs; + stats.savedBytes += candidate.sizeBytes; } - - auto bestFit = freeList.end(); - uint64_t bestFitBytes = std::numeric_limits::max(); - auto candidateType = cast(candidate.alloc.getType()); - for (auto it = freeList.begin(); it != freeList.end(); ++it) { - auto freeType = cast((*it).getType()); - if (freeType != candidateType) - continue; - - uint64_t freeBytes = getTypeSizeBytes(freeType); - if (freeBytes < candidate.sizeBytes || freeBytes >= bestFitBytes) - continue; - - bestFit = it; - bestFitBytes = freeBytes; - } - - if (bestFit == freeList.end()) { - active.push_back(ActiveStorage {candidate.alloc, candidate.endInstruction}); - continue; - } - - memref::AllocOp root = *bestFit; - freeList.erase(bestFit); - candidate.alloc.getResult().replaceAllUsesWith(root.getResult()); - rewriter.eraseOp(candidate.alloc); - active.push_back(ActiveStorage {root, candidate.endInstruction}); - ++stats.removedAllocs; - stats.savedBytes += candidate.sizeBytes; } return stats; diff --git a/src/PIM/Dialect/Pim/Transforms/MemoryCoalescing/MemoryCoalescing.hpp b/src/PIM/Dialect/Pim/Transforms/MemoryCoalescing/MemoryCoalescing.hpp index e0b4025..6b839e9 100644 --- a/src/PIM/Dialect/Pim/Transforms/MemoryCoalescing/MemoryCoalescing.hpp +++ b/src/PIM/Dialect/Pim/Transforms/MemoryCoalescing/MemoryCoalescing.hpp @@ -10,16 +10,25 @@ namespace pim { struct AllocationCandidate { mlir::memref::AllocOp alloc; + mlir::Block *scopeBlock = nullptr; uint64_t startInstruction = 0; uint64_t endInstruction = 0; uint64_t sizeBytes = 0; }; -struct MemoryCoalescingAnalysis { +struct MemoryCoalescingBlockAnalysis { + mlir::Block *block = nullptr; llvm::SmallVector candidates; uint64_t skippedAllocations = 0; }; +struct MemoryCoalescingAnalysis { + llvm::SmallVector blocks; + uint64_t skippedAllocations = 0; + + uint64_t getCandidateCount() const; +}; + struct MemoryCoalescingStats { uint64_t removedAllocs = 0; uint64_t savedBytes = 0; diff --git a/src/PIM/Dialect/Pim/Transforms/MemoryCoalescing/MemoryCoalescingPass.cpp b/src/PIM/Dialect/Pim/Transforms/MemoryCoalescing/MemoryCoalescingPass.cpp index 5422a14..a6bb54c 100644 --- a/src/PIM/Dialect/Pim/Transforms/MemoryCoalescing/MemoryCoalescingPass.cpp +++ b/src/PIM/Dialect/Pim/Transforms/MemoryCoalescing/MemoryCoalescingPass.cpp @@ -23,9 +23,9 @@ using namespace onnx_mlir::compact_asm; namespace onnx_mlir { namespace { -// This pass assumes bufferization has already normalized executable PIM -// operands. It only reuses compatible local allocations with non-overlapping -// lifetimes; it does not repair memory contiguity. +// This pass is an IR cleanup step after bufferization. It only rewrites +// obviously compatible local allocations with non-overlapping lifetimes inside +// the same block and leaves the final physical memory planning to codegen. struct CoalescingReportRow { uint64_t numCandidates = 0; @@ -174,7 +174,7 @@ struct PimMemoryCoalescingPass : PassWrapper(op)) { auto checkedCoreId = diff --git a/test/PIM/CMakeLists.txt b/test/PIM/CMakeLists.txt index 7e9be67..4af9680 100644 --- a/test/PIM/CMakeLists.txt +++ b/test/PIM/CMakeLists.txt @@ -30,3 +30,9 @@ add_pim_unittest(LabeledListTest OMPimCommon ) +add_pim_unittest(PimMemoryLivenessPlannerTest + PimMemoryLivenessPlannerTest.cpp + + LINK_LIBS PRIVATE + OMPimCompilerUtils +) diff --git a/test/PIM/PimMemoryLivenessPlannerTest.cpp b/test/PIM/PimMemoryLivenessPlannerTest.cpp new file mode 100644 index 0000000..246d722 --- /dev/null +++ b/test/PIM/PimMemoryLivenessPlannerTest.cpp @@ -0,0 +1,86 @@ +#include +#include +#include + +#include "src/Accelerators/PIM/Compiler/PimMemoryLiveness.hpp" + +using onnx_mlir::LocalAllocInterval; +using onnx_mlir::planPhysicalSlots; + +namespace { + +LocalAllocInterval makeInterval(size_t id, size_t size, uint64_t start, uint64_t end) { + LocalAllocInterval interval; + interval.id = id; + interval.size = size; + interval.start = start; + interval.end = end; + return interval; +} + +void assertSingleSlotCase(LocalAllocInterval a, LocalAllocInterval b, size_t expectedSlotSize) { + llvm::SmallVector intervals = {a, b}; + auto slots = planPhysicalSlots(intervals); + assert(slots.size() == 1); + assert(slots.front().requiredSize == expectedSlotSize); + assert(intervals[0].physicalSlotId == intervals[1].physicalSlotId); +} + +int testSameSizeNonOverlap() { + std::cout << "testSameSizeNonOverlap:" << std::endl; + assertSingleSlotCase(makeInterval(0, 64, 0, 10), makeInterval(1, 64, 11, 20), 64); + return 0; +} + +int testLargerFirst() { + std::cout << "testLargerFirst:" << std::endl; + assertSingleSlotCase(makeInterval(0, 100, 0, 10), makeInterval(1, 40, 11, 20), 100); + return 0; +} + +int testSmallerFirst() { + std::cout << "testSmallerFirst:" << std::endl; + assertSingleSlotCase(makeInterval(0, 40, 0, 10), makeInterval(1, 100, 11, 20), 100); + return 0; +} + +int testOverlapNeedsTwoSlots() { + std::cout << "testOverlapNeedsTwoSlots:" << std::endl; + llvm::SmallVector intervals = { + makeInterval(0, 100, 0, 20), makeInterval(1, 40, 10, 30)}; + auto slots = planPhysicalSlots(intervals); + assert(slots.size() == 2); + assert(intervals[0].physicalSlotId != intervals[1].physicalSlotId); + return 0; +} + +int testReuseChain() { + std::cout << "testReuseChain:" << std::endl; + llvm::SmallVector intervals = { + makeInterval(0, 40, 0, 10), makeInterval(1, 100, 11, 20), makeInterval(2, 20, 21, 30)}; + auto slots = planPhysicalSlots(intervals); + assert(slots.size() == 1); + assert(slots.front().requiredSize == 100); + assert(intervals[0].physicalSlotId == intervals[1].physicalSlotId); + assert(intervals[1].physicalSlotId == intervals[2].physicalSlotId); + return 0; +} + +} // namespace + +int main(int argc, char *argv[]) { + (void) argc; + (void) argv; + + int failures = 0; + failures += testSameSizeNonOverlap(); + failures += testLargerFirst(); + failures += testSmallerFirst(); + failures += testOverlapNeedsTwoSlots(); + failures += testReuseChain(); + if (failures != 0) { + std::cerr << failures << " test failures\n"; + return EXIT_FAILURE; + } + return EXIT_SUCCESS; +}