add memory coalescing pass
Validate Operations / validate-operations (push) Has been cancelled

better reports
refactor for more code-reuse and patter usage
fixes
This commit is contained in:
NiccoloN
2026-05-12 18:17:00 +02:00
parent 4f3570520c
commit 41de3cb150
26 changed files with 930 additions and 385 deletions
+1
View File
@@ -2,6 +2,7 @@ add_onnx_mlir_dialect(Pim pim)
add_onnx_mlir_dialect_doc(pim Pim.td)
add_subdirectory(Transforms/Bufferization)
add_subdirectory(Transforms/StaticMemoryCoalescing)
add_pim_library(PimOps
PimOps.hpp
@@ -0,0 +1,14 @@
add_pim_library(OMPimStaticMemoryCoalescing
StaticMemoryCoalescing.cpp
StaticMemoryCoalescing.hpp
StaticMemoryCoalescingPass.cpp
EXCLUDE_FROM_OM_LIBS
INCLUDE_DIRS PUBLIC
${PIM_PUBLIC_INCLUDE_DIRS}
LINK_LIBS PUBLIC
OMPimCommon
PimOps
)
@@ -0,0 +1,172 @@
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescing.hpp"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/STLExtras.h"
#include <limits>
using namespace mlir;
namespace onnx_mlir {
namespace pim {
namespace {
static bool isSupportedAliasOp(Operation* op) {
return isa<memref::SubViewOp, memref::CastOp, memref::CollapseShapeOp, memref::ExpandShapeOp>(op);
}
static bool isCandidateAllocType(MemRefType type) {
return type && type.hasStaticShape() && type.getLayout().isIdentity() && type.getElementTypeBitWidth() > 0;
}
static uint64_t getTypeSizeBytes(MemRefType type) {
return static_cast<uint64_t>(type.getNumElements() * type.getElementTypeBitWidth() / 8);
}
static FailureOr<uint64_t> getLastUseInstruction(memref::AllocOp allocOp,
Block& body,
const DenseMap<Operation*, uint64_t>& opOrder) {
uint64_t endInstruction = opOrder.lookup(allocOp);
SmallPtrSet<Operation*, 16> visited;
SmallVector<Value> pendingValues;
pendingValues.push_back(allocOp.getResult());
while (!pendingValues.empty()) {
Value value = pendingValues.pop_back_val();
for (Operation* user : value.getUsers()) {
if (user->getBlock() != &body)
return failure();
if (!visited.insert(user).second)
continue;
if (isSupportedAliasOp(user)) {
for (Value result : user->getResults())
pendingValues.push_back(result);
}
if (auto dpsOp = dyn_cast<DestinationStyleOpInterface>(user)) {
for (OpResult result : user->getResults()) {
OpOperand* tiedOperand = dpsOp.getTiedOpOperand(result);
if (!tiedOperand || tiedOperand->get() != value)
continue;
pendingValues.push_back(result);
}
}
auto order = opOrder.find(user);
if (order == opOrder.end())
return failure();
endInstruction = std::max(endInstruction, order->second);
}
}
return endInstruction;
}
} // namespace
StaticMemoryCoalescingAnalysis analyzeStaticMemoryCoalescingCandidates(Operation* coreLikeOp) {
StaticMemoryCoalescingAnalysis analysis;
if (!coreLikeOp || coreLikeOp->getNumRegions() != 1 || coreLikeOp->getRegion(0).empty())
return analysis;
Block& body = coreLikeOp->getRegion(0).front();
DenseMap<Operation*, uint64_t> opOrder;
uint64_t nextInstruction = 0;
for (Operation& op : body)
opOrder.try_emplace(&op, nextInstruction++);
for (Operation& op : body) {
auto allocOp = dyn_cast<memref::AllocOp>(&op);
if (!allocOp)
continue;
auto allocType = dyn_cast<MemRefType>(allocOp.getType());
if (!isCandidateAllocType(allocType)) {
++analysis.skippedAllocations;
continue;
}
auto endInstruction = getLastUseInstruction(allocOp, body, opOrder);
if (failed(endInstruction)) {
++analysis.skippedAllocations;
continue;
}
analysis.candidates.push_back(
StaticAllocationCandidate {allocOp, opOrder.lookup(allocOp), *endInstruction, getTypeSizeBytes(allocType)});
}
return analysis;
}
StaticMemoryCoalescingStats coalesceStaticMemory(Operation* coreLikeOp, RewriterBase& rewriter) {
StaticMemoryCoalescingStats stats;
auto analysis = analyzeStaticMemoryCoalescingCandidates(coreLikeOp);
stats.skippedAllocations = analysis.skippedAllocations;
llvm::sort(analysis.candidates, [](const StaticAllocationCandidate& lhs, const StaticAllocationCandidate& rhs) {
if (lhs.startInstruction != rhs.startInstruction)
return lhs.startInstruction < rhs.startInstruction;
return lhs.endInstruction < rhs.endInstruction;
});
struct ActiveStorage {
memref::AllocOp root;
uint64_t endInstruction = 0;
};
SmallVector<ActiveStorage> active;
SmallVector<memref::AllocOp> freeList;
for (StaticAllocationCandidate& candidate : analysis.candidates) {
for (auto it = active.begin(); it != active.end();) {
if (it->endInstruction < candidate.startInstruction) {
freeList.push_back(it->root);
it = active.erase(it);
continue;
}
++it;
}
auto bestFit = freeList.end();
uint64_t bestFitBytes = std::numeric_limits<uint64_t>::max();
auto candidateType = cast<MemRefType>(candidate.alloc.getType());
for (auto it = freeList.begin(); it != freeList.end(); ++it) {
auto freeType = cast<MemRefType>((*it).getType());
if (freeType != candidateType)
continue;
uint64_t freeBytes = getTypeSizeBytes(freeType);
if (freeBytes < candidate.sizeBytes || freeBytes >= bestFitBytes)
continue;
bestFit = it;
bestFitBytes = freeBytes;
}
if (bestFit == freeList.end()) {
active.push_back(ActiveStorage {candidate.alloc, candidate.endInstruction});
continue;
}
memref::AllocOp root = *bestFit;
freeList.erase(bestFit);
candidate.alloc.getResult().replaceAllUsesWith(root.getResult());
rewriter.eraseOp(candidate.alloc);
active.push_back(ActiveStorage {root, candidate.endInstruction});
++stats.removedAllocs;
stats.savedBytes += candidate.sizeBytes;
}
return stats;
}
} // namespace pim
} // namespace onnx_mlir
@@ -0,0 +1,35 @@
#pragma once
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Operation.h"
#include "llvm/ADT/SmallVector.h"
namespace onnx_mlir {
namespace pim {
struct StaticAllocationCandidate {
mlir::memref::AllocOp alloc;
uint64_t startInstruction = 0;
uint64_t endInstruction = 0;
uint64_t sizeBytes = 0;
};
struct StaticMemoryCoalescingAnalysis {
llvm::SmallVector<StaticAllocationCandidate> candidates;
uint64_t skippedAllocations = 0;
};
struct StaticMemoryCoalescingStats {
uint64_t removedAllocs = 0;
uint64_t savedBytes = 0;
uint64_t skippedAllocations = 0;
};
StaticMemoryCoalescingAnalysis analyzeStaticMemoryCoalescingCandidates(mlir::Operation* coreLikeOp);
StaticMemoryCoalescingStats coalesceStaticMemory(mlir::Operation* coreLikeOp, mlir::RewriterBase& rewriter);
} // namespace pim
} // namespace onnx_mlir
@@ -0,0 +1,203 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/raw_os_ostream.h"
#include <fstream>
#include "Common/IR/CompactAsmUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/DebugDump.hpp"
#include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescing.hpp"
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
using namespace mlir;
using namespace onnx_mlir::compact_asm;
namespace onnx_mlir {
namespace {
struct CoalescingReportRow {
uint64_t numCandidates = 0;
uint64_t numSkipped = 0;
uint64_t numRemoved = 0;
uint64_t savedBytes = 0;
bool operator==(const CoalescingReportRow& other) const {
return numCandidates == other.numCandidates && numSkipped == other.numSkipped && numRemoved == other.numRemoved
&& savedBytes == other.savedBytes;
}
};
struct CoalescingReportEntry {
enum class Kind {
Core,
Batch
};
Kind kind = Kind::Core;
uint64_t id = 0;
llvm::SmallVector<int32_t, 8> coreIds;
CoalescingReportRow row;
};
static std::string formatMemory(uint64_t bytes) {
return formatReportMemory(bytes);
}
static SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) {
auto coreIdsAttr = coreBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName);
assert(coreIdsAttr && "pim.core_batch requires coreIds array attribute");
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
}
static void printReportRow(raw_ostream& os, const CoalescingReportRow& row) {
llvm::SmallVector<ReportField, 4> fields = {
{"Number of candidates", std::to_string(row.numCandidates)},
{"Skipped allocations", std::to_string(row.numSkipped)},
{"Removed allocations", std::to_string(row.numRemoved)},
{"Saved memory", formatMemory(row.savedBytes)}};
printReportFlatFields(os, fields);
}
static CoalescingReportRow getTotalRow(const CoalescingReportEntry& entry) {
uint64_t factor = std::max<uint64_t>(1, entry.coreIds.size());
return {entry.row.numCandidates * factor,
entry.row.numSkipped * factor,
entry.row.numRemoved * factor,
entry.row.savedBytes * factor};
}
static void emitReport(ArrayRef<CoalescingReportEntry> entries) {
std::fstream file = openReportFile("static_memory_coalescing_report");
if (!file.is_open())
return;
llvm::raw_os_ostream os(file);
CoalescingReportRow totalRow;
for (const CoalescingReportEntry& entry : entries) {
CoalescingReportRow entryTotal = getTotalRow(entry);
totalRow.numCandidates += entryTotal.numCandidates;
totalRow.numSkipped += entryTotal.numSkipped;
totalRow.numRemoved += entryTotal.numRemoved;
totalRow.savedBytes += entryTotal.savedBytes;
}
llvm::SmallVector<ReportField, 4> totalFields = {{"Number of candidates", std::to_string(totalRow.numCandidates)},
{"Skipped allocations", std::to_string(totalRow.numSkipped)},
{"Removed allocations", std::to_string(totalRow.numRemoved)},
{"Saved memory", formatMemory(totalRow.savedBytes)}};
printReportTotalsBlock(os, totalFields);
if (!entries.empty())
os << "\n";
llvm::SmallVector<CoalescingReportEntry, 32> sortedEntries(entries.begin(), entries.end());
sortReportEntriesByFirstCore(sortedEntries);
for (size_t index = 0; index < sortedEntries.size();) {
size_t runEnd = index + 1;
while (runEnd < sortedEntries.size() && sortedEntries[runEnd].kind == sortedEntries[index].kind
&& sortedEntries[runEnd].row == sortedEntries[index].row) {
++runEnd;
}
if (sortedEntries[index].kind == CoalescingReportEntry::Kind::Batch) {
os << "Batch ";
for (size_t batchIndex = index; batchIndex < runEnd; ++batchIndex) {
if (batchIndex != index)
os << ",\n ";
os << sortedEntries[batchIndex].id << " (cores ";
printCompressedIntegerEntries(os, ArrayRef<int32_t>(sortedEntries[batchIndex].coreIds));
os << ")";
}
}
else {
llvm::SmallVector<int32_t, 8> coreIds;
for (size_t coreIndex = index; coreIndex < runEnd; ++coreIndex)
coreIds.push_back(sortedEntries[coreIndex].coreIds.front());
os << "Core ";
printCompressedIntegerEntries(os, ArrayRef<int32_t>(coreIds));
}
os << ":\n";
if (sortedEntries[index].kind == CoalescingReportEntry::Kind::Batch) {
llvm::SmallVector<ReportField, 4> perCoreFields = {
{"Number of candidates", std::to_string(sortedEntries[index].row.numCandidates)},
{"Skipped allocations", std::to_string(sortedEntries[index].row.numSkipped)},
{"Removed allocations", std::to_string(sortedEntries[index].row.numRemoved)},
{"Saved memory", formatMemory(sortedEntries[index].row.savedBytes)}};
CoalescingReportRow totalRow = getTotalRow(sortedEntries[index]);
llvm::SmallVector<ReportField, 4> totalFields = {
{"Number of candidates", std::to_string(totalRow.numCandidates)},
{"Skipped allocations", std::to_string(totalRow.numSkipped)},
{"Removed allocations", std::to_string(totalRow.numRemoved)},
{"Saved memory", formatMemory(totalRow.savedBytes)}};
printReportPerCoreAndTotalFields(os, perCoreFields, totalFields);
}
else {
printReportRow(os, sortedEntries[index].row);
}
printReportEntrySeparator(os, runEnd < sortedEntries.size());
index = runEnd;
}
os.flush();
file.close();
}
struct StaticMemoryCoalescingPass : PassWrapper<StaticMemoryCoalescingPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(StaticMemoryCoalescingPass)
StringRef getArgument() const override { return "pim-static-memory-coalescing"; }
StringRef getDescription() const override { return "Analyze static local PIM memory reuse opportunities"; }
StaticMemoryCoalescingPass() = default;
StaticMemoryCoalescingPass(const StaticMemoryCoalescingPass& pass) {}
void runOnOperation() override {
IRRewriter rewriter(&getContext());
SmallVector<CoalescingReportEntry, 32> reportEntries;
uint64_t nextBatchId = 0;
getOperation().walk([&](Operation* op) {
if (!isa<pim::PimCoreOp, pim::PimCoreBatchOp>(op))
return;
auto analysis = pim::analyzeStaticMemoryCoalescingCandidates(op);
auto stats = pim::coalesceStaticMemory(op, rewriter);
CoalescingReportRow row {
analysis.candidates.size(), stats.skippedAllocations, stats.removedAllocs, stats.savedBytes};
if (auto coreOp = dyn_cast<pim::PimCoreOp>(op)) {
reportEntries.push_back({CoalescingReportEntry::Kind::Core,
static_cast<uint64_t>(coreOp.getCoreId()),
{static_cast<int32_t>(coreOp.getCoreId())},
row});
return;
}
auto coreIds = getBatchCoreIds(cast<pim::PimCoreBatchOp>(op));
CoalescingReportEntry entry;
entry.kind = CoalescingReportEntry::Kind::Batch;
entry.id = nextBatchId++;
llvm::append_range(entry.coreIds, coreIds);
entry.row = row;
reportEntries.push_back(std::move(entry));
});
emitReport(reportEntries);
dumpModule(getOperation(), "pim2_coalesced");
}
};
} // namespace
std::unique_ptr<Pass> createPimStaticMemoryCoalescingPass() {
return std::make_unique<StaticMemoryCoalescingPass>();
}
} // namespace onnx_mlir
@@ -40,6 +40,7 @@
#include "RegularOpCompaction.hpp"
#include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
using namespace mlir;
@@ -764,18 +765,13 @@ void emitMotifProfile(func::FuncOp funcOp) {
}
void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpuCount = 0) {
std::string outputDir = getOutputDir();
if (outputDir.empty())
std::fstream file = openReportFile(name);
if (!file.is_open())
return;
std::string reportsDir = outputDir + "/reports";
createDirectory(reportsDir);
std::fstream file(reportsDir + "/" + name + ".txt", std::ios::out);
llvm::raw_os_ostream os(file);
struct ReportRow {
uint64_t opId = 0;
uint64_t id = 0;
uint64_t logicalComputeCount = 0;
uint64_t weightCount = 0;
uint64_t instructionCount = 0;
@@ -786,6 +782,9 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu
uint64_t totalComputeOps = 0;
uint64_t totalLogicalComputes = 0;
uint64_t totalBatchComputeOps = 0;
uint64_t totalInstructionCount = 0;
uint64_t totalWeightCount = 0;
uint64_t nextBatchId = 0;
std::vector<ReportRow> collectedData;
for (Operation& op : funcOp.getBody().front()) {
@@ -793,8 +792,13 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu
uint64_t numInst = 0;
for (auto& _ : spatCompute.getRegion().front())
++numInst;
collectedData.push_back({totalComputeOps++, 1, spatCompute.getWeights().size(), numInst, false, {}});
SmallVector<int32_t> coreIds;
if (auto coreId = getComputeCoreId(spatCompute))
coreIds.push_back(*coreId);
collectedData.push_back({totalComputeOps++, 1, spatCompute.getWeights().size(), numInst, false, coreIds});
totalLogicalComputes += 1;
totalInstructionCount += numInst;
totalWeightCount += spatCompute.getWeights().size();
continue;
}
if (auto batch = dyn_cast<SpatComputeBatch>(&op)) {
@@ -805,44 +809,27 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu
SmallVector<int32_t> coreIds;
if (auto coreIdsAttr = batch->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
llvm::append_range(coreIds, coreIdsAttr.asArrayRef());
collectedData.push_back({totalComputeOps++, logicalCount, batch.getWeights().size(), numInst, true, coreIds});
collectedData.push_back({nextBatchId++, logicalCount, batch.getWeights().size(), numInst, true, coreIds});
totalComputeOps += 1;
totalLogicalComputes += logicalCount;
totalBatchComputeOps += 1;
totalInstructionCount += numInst * logicalCount;
totalWeightCount += batch.getWeights().size();
}
}
os << "Used cores: " << usedCpuCount << "\n";
os << "Number of top-level compute ops: " << totalComputeOps << "\n";
os << "Number of logical computes: " << totalLogicalComputes << "\n";
os << "Number of top-level batch compute ops: " << totalBatchComputeOps << "\n";
os << "\n";
llvm::SmallVector<ReportField, 6> totalFields = {{"Used cores", std::to_string(usedCpuCount)},
{"Number of top-level compute ops", std::to_string(totalComputeOps)},
{"Number of logical computes", std::to_string(totalLogicalComputes)},
{"Number of top-level batch compute ops",
std::to_string(totalBatchComputeOps)},
{"Number of instructions", std::to_string(totalInstructionCount)},
{"Number of used crossbars", std::to_string(totalWeightCount)}};
printReportTotalsBlock(os, totalFields);
if (!collectedData.empty())
os << "\n";
std::stable_sort(collectedData.begin(), collectedData.end(), [](const ReportRow& lft, const ReportRow& rgt) {
if (lft.isRebatched != rgt.isRebatched)
return lft.isRebatched > rgt.isRebatched;
if (lft.instructionCount < rgt.instructionCount)
return false;
else if (rgt.instructionCount < lft.instructionCount)
return true;
if (lft.weightCount < rgt.weightCount)
return false;
else if (rgt.weightCount < lft.weightCount)
return true;
if (lft.logicalComputeCount < rgt.logicalComputeCount)
return false;
else if (rgt.logicalComputeCount < lft.logicalComputeCount)
return true;
if (lft.opId < rgt.opId)
return true;
else if (rgt.opId < lft.opId)
return false;
return true;
});
sortReportEntriesByFirstCore(collectedData);
for (uint64_t cI = 0; cI < totalComputeOps; ++cI) {
uint64_t lastIndex = cI;
@@ -863,7 +850,7 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu
for (uint64_t index = cI; index <= lastIndex; ++index) {
if (index != cI)
os << ",\n ";
os << collectedData[index].opId << " (cores ";
os << collectedData[index].id << " (cores ";
if (collectedData[index].coreIds.empty())
os << "unknown";
else
@@ -876,14 +863,32 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu
SmallVector<uint64_t> opIds;
opIds.reserve(lastIndex - cI + 1);
for (uint64_t index = cI; index <= lastIndex; ++index)
opIds.push_back(collectedData[index].opId);
opIds.push_back(collectedData[index].id);
printCompressedIntegerEntries(os, ArrayRef<uint64_t>(opIds));
}
os << ":\n";
os << "\tNumber of logical computes: " << current.logicalComputeCount << "\n";
os << "\tNumber of instructions: " << current.instructionCount << "\n";
os << "\tNumber of used crossbars: " << current.weightCount << "\n";
uint64_t perCoreLogicalComputeCount = current.isRebatched ? 1 : current.logicalComputeCount;
uint64_t perCoreInstructionCount = current.instructionCount;
uint64_t perCoreWeightCount =
current.logicalComputeCount == 0 ? 0 : current.weightCount / current.logicalComputeCount;
uint64_t totalEntryInstructionCount = current.instructionCount * current.logicalComputeCount;
llvm::SmallVector<ReportField, 3> perCoreFields = {
{"Number of logical computes", std::to_string(perCoreLogicalComputeCount)},
{"Number of instructions", std::to_string(perCoreInstructionCount)},
{"Number of used crossbars", std::to_string(perCoreWeightCount)}};
if (current.isRebatched) {
llvm::SmallVector<ReportField, 3> totalEntryFields = {
{"Number of logical computes", std::to_string(current.logicalComputeCount)},
{"Number of instructions", std::to_string(totalEntryInstructionCount)},
{"Number of used crossbars", std::to_string(current.weightCount)}};
printReportPerCoreAndTotalFields(os, perCoreFields, totalEntryFields);
}
else {
printReportFlatFields(os, perCoreFields);
}
printReportEntrySeparator(os, lastIndex + 1 < totalComputeOps);
cI = lastIndex;
}