add memory coalescing pass
Validate Operations / validate-operations (push) Has been cancelled

better reports
refactor for more code-reuse and patter usage
fixes
This commit is contained in:
NiccoloN
2026-05-12 18:17:00 +02:00
parent 4f3570520c
commit 41de3cb150
26 changed files with 930 additions and 385 deletions
+1
View File
@@ -29,6 +29,7 @@ add_pim_library(OMPimCompilerUtils
OMPimCompilerOptions
OMPimCommon
OMPimBufferization
OMPimStaticMemoryCoalescing
OMPimPasses
OMONNXToSpatial
OMSpatialToPim
+73 -63
View File
@@ -24,6 +24,78 @@ static SmallVector<int32_t> getLaneChunkCoreIds(ArrayRef<int32_t> coreIds, size_
return laneCoreIds;
}
static void scalarizeBatchOpsInCore(pim::PimCoreOp scalarCore, size_t laneCount, unsigned lane) {
IRRewriter rewriter(scalarCore.getContext());
SmallVector<Operation*> batchOps;
scalarCore.walk([&](Operation* op) {
if (isa<pim::PimSendBatchOp,
pim::PimSendTensorBatchOp,
pim::PimReceiveBatchOp,
pim::PimReceiveTensorBatchOp,
pim::PimMemCopyHostToDevBatchOp>(op)) {
batchOps.push_back(op);
}
});
for (Operation* op : batchOps) {
rewriter.setInsertionPoint(op);
if (auto sendBatchOp = dyn_cast<pim::PimSendBatchOp>(op)) {
pim::PimSendOp::create(rewriter,
sendBatchOp.getLoc(),
sendBatchOp.getInput(),
sendBatchOp.getSizeAttr(),
rewriter.getI32IntegerAttr(sendBatchOp.getTargetCoreIds()[lane]));
rewriter.eraseOp(op);
continue;
}
if (auto sendTensorBatchOp = dyn_cast<pim::PimSendTensorBatchOp>(op)) {
pim::PimSendTensorOp::create(
rewriter,
sendTensorBatchOp.getLoc(),
sendTensorBatchOp.getInput(),
rewriter.getDenseI32ArrayAttr(getLaneChunkCoreIds(sendTensorBatchOp.getTargetCoreIds(), laneCount, lane)));
rewriter.eraseOp(op);
continue;
}
if (auto receiveBatchOp = dyn_cast<pim::PimReceiveBatchOp>(op)) {
auto scalarReceive =
pim::PimReceiveOp::create(rewriter,
receiveBatchOp.getLoc(),
receiveBatchOp.getOutput().getType(),
receiveBatchOp.getOutputBuffer(),
receiveBatchOp.getSizeAttr(),
rewriter.getI32IntegerAttr(receiveBatchOp.getSourceCoreIds()[lane]));
rewriter.replaceOp(op, scalarReceive->getResults());
continue;
}
if (auto receiveTensorBatchOp = dyn_cast<pim::PimReceiveTensorBatchOp>(op)) {
auto scalarReceive = pim::PimReceiveTensorOp::create(
rewriter,
receiveTensorBatchOp.getLoc(),
receiveTensorBatchOp.getOutput().getType(),
receiveTensorBatchOp.getOutputBuffer(),
rewriter.getDenseI32ArrayAttr(getLaneChunkCoreIds(receiveTensorBatchOp.getSourceCoreIds(), laneCount, lane)));
rewriter.replaceOp(op, scalarReceive->getResults());
continue;
}
auto memcpBatchOp = cast<pim::PimMemCopyHostToDevBatchOp>(op);
auto scalarCopy = pim::PimMemCopyHostToDevOp::create(rewriter,
memcpBatchOp.getLoc(),
memcpBatchOp.getOutput().getType(),
memcpBatchOp.getDeviceTarget(),
memcpBatchOp.getHostSource(),
memcpBatchOp.getDeviceTargetOffsetAttr(),
memcpBatchOp.getHostSourceOffsetAttr(),
memcpBatchOp.getSizeAttr());
rewriter.replaceOp(op, scalarCopy->getResults());
}
}
} // namespace
LogicalResult withScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp,
@@ -50,69 +122,6 @@ LogicalResult withScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp,
builder.setInsertionPointToEnd(block);
for (Operation& op : coreBatchOp.getBody().front()) {
if (isa<pim::PimHaltOp>(op)) {
pim::PimHaltOp::create(builder, op.getLoc());
continue;
}
if (auto sendBatchOp = dyn_cast<pim::PimSendBatchOp>(op)) {
pim::PimSendOp::create(builder,
sendBatchOp.getLoc(),
mapper.lookup(sendBatchOp.getInput()),
sendBatchOp.getSizeAttr(),
builder.getI32IntegerAttr(sendBatchOp.getTargetCoreIds()[lane]));
continue;
}
if (auto sendTensorBatchOp = dyn_cast<pim::PimSendTensorBatchOp>(op)) {
pim::PimSendTensorOp::create(
builder,
sendTensorBatchOp.getLoc(),
mapper.lookup(sendTensorBatchOp.getInput()),
builder.getDenseI32ArrayAttr(getLaneChunkCoreIds(sendTensorBatchOp.getTargetCoreIds(), laneCount, lane)));
continue;
}
if (auto receiveBatchOp = dyn_cast<pim::PimReceiveBatchOp>(op)) {
auto scalarReceive =
pim::PimReceiveOp::create(builder,
receiveBatchOp.getLoc(),
receiveBatchOp.getOutput().getType(),
mapper.lookup(receiveBatchOp.getOutputBuffer()),
receiveBatchOp.getSizeAttr(),
builder.getI32IntegerAttr(receiveBatchOp.getSourceCoreIds()[lane]));
mapper.map(receiveBatchOp.getOutput(), scalarReceive.getOutput());
continue;
}
if (auto receiveTensorBatchOp = dyn_cast<pim::PimReceiveTensorBatchOp>(op)) {
auto scalarReceive = pim::PimReceiveTensorOp::create(
builder,
receiveTensorBatchOp.getLoc(),
receiveTensorBatchOp.getOutput().getType(),
mapper.lookup(receiveTensorBatchOp.getOutputBuffer()),
builder.getDenseI32ArrayAttr(getLaneChunkCoreIds(receiveTensorBatchOp.getSourceCoreIds(), laneCount, lane)));
mapper.map(receiveTensorBatchOp.getOutput(), scalarReceive.getOutput());
continue;
}
if (auto memcpBatchOp = dyn_cast<pim::PimMemCopyHostToDevBatchOp>(op)) {
Value hostSource = mapper.lookupOrNull(memcpBatchOp.getHostSource());
if (!hostSource)
hostSource = memcpBatchOp.getHostSource();
auto scalarCopy = pim::PimMemCopyHostToDevOp::create(builder,
memcpBatchOp.getLoc(),
memcpBatchOp.getOutput().getType(),
mapper.lookup(memcpBatchOp.getDeviceTarget()),
hostSource,
memcpBatchOp.getDeviceTargetOffsetAttr(),
memcpBatchOp.getHostSourceOffsetAttr(),
memcpBatchOp.getSizeAttr());
mapper.map(memcpBatchOp.getOutput(), scalarCopy.getOutput());
continue;
}
Operation* cloned = builder.clone(op, mapper);
for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults()))
mapper.map(originalResult, clonedResult);
@@ -120,6 +129,7 @@ LogicalResult withScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp,
if (block->empty() || !isa<pim::PimHaltOp>(block->back()))
pim::PimHaltOp::create(builder, coreBatchOp.getLoc());
scalarizeBatchOpsInCore(scalarCore, laneCount, lane);
return callback(scalarCore);
}
+84 -55
View File
@@ -26,6 +26,7 @@
#include "Common/IR/CompactAsmUtils.hpp"
#include "Common/PimCommon.hpp"
#include "Common/Support/ReportUtils.hpp"
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Compiler/PimArtifactWriter.hpp"
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
@@ -65,6 +66,7 @@ void PimMemory::allocateMemoryForValue(mlir::Value value, MemEntry& memEntry) {
if (size_t remainder = firstAvailableAddress % minAlignment)
firstAvailableAddress += minAlignment - remainder;
ownedMemEntriesMap[value] = memEntry;
globalMemEntriesMap[value] = memEntry;
}
@@ -112,26 +114,28 @@ void PimMemory::allocateCore(Operation* op) {
allocateGatheredMemory();
}
std::string formatMemory(uint64_t bytes) {
const char* units[] = {"B", "KB", "MB", "GB", "TB", "PB", "EB"};
int i = 0;
double size = static_cast<double>(bytes);
while (size >= 1024 && i < 6) {
size /= 1024;
i++;
}
// Formats to 2 decimal places
std::string out;
llvm::raw_string_ostream rss(out);
rss << llvm::format("%.2f ", size) << units[i];
return rss.str();
static void printHostMemoryReportRow(raw_ostream& os, const MemoryReportRow& row) {
llvm::SmallVector<ReportField, 2> fields = {
{"Number of globals", std::to_string(row.numGlobal)},
{"Global memory", formatReportMemory(row.sizeGlobal)}};
printReportFlatFields(os, fields);
}
static void printMemoryReportRow(raw_ostream& os, const MemoryReportRow& row) {
os << "\tNumber of allocas: " << row.numAlloca << "\n";
os << "\tAllocated memory: " << formatMemory(row.sizeAlloca) << "\n";
os << "\tNumber of globals: " << row.numGlobal << "\n";
os << "\tGlobal memory: " << formatMemory(row.sizeGlobal) << "\n";
static void printCoreMemoryReportRow(raw_ostream& os, const MemoryReportEntry& entry) {
llvm::SmallVector<ReportField, 2> fields = {
{"Number of allocas", std::to_string(entry.row.numAlloca)},
{"Allocated memory", formatReportMemory(entry.row.sizeAlloca)}};
printReportFlatFields(os, fields);
}
static void printBatchMemoryReportRow(raw_ostream& os, const MemoryReportEntry& entry) {
llvm::SmallVector<ReportField, 2> perCoreFields = {
{"Number of allocas", std::to_string(entry.row.numAlloca)},
{"Allocated memory", formatReportMemory(entry.row.sizeAlloca)}};
llvm::SmallVector<ReportField, 2> totalFields = {
{"Number of allocas", std::to_string(entry.totalAllocaCount)},
{"Batch memory", formatReportMemory(entry.totalAllocaBytes)}};
printReportPerCoreAndTotalFields(os, perCoreFields, totalFields);
}
static MemoryReportRow addMemoryReportRows(const MemoryReportRow& lhs, const MemoryReportRow& rhs) {
@@ -145,7 +149,7 @@ static MemoryReportRow addMemoryReportRows(const MemoryReportRow& lhs, const Mem
MemoryReportRow PimMemory::getReportRow() const {
MemoryReportRow row;
for (auto& [val, memEntry] : globalMemEntriesMap) {
for (auto& [val, memEntry] : ownedMemEntriesMap) {
if (auto op = val.getDefiningOp()) {
if (isa<memref::AllocOp>(op)) {
row.numAlloca++;
@@ -162,6 +166,8 @@ MemoryReportRow PimMemory::getReportRow() const {
}
void PimMemory::remove(mlir::Value val) {
if (auto removeIter = ownedMemEntriesMap.find(val); removeIter != ownedMemEntriesMap.end())
ownedMemEntriesMap.erase(removeIter);
if (auto removeIter = globalMemEntriesMap.find(val); removeIter != globalMemEntriesMap.end())
globalMemEntriesMap.erase(removeIter);
}
@@ -209,15 +215,26 @@ size_t PimAcceleratorMemory::getValueAddress(mlir::Value value, const StaticValu
void PimAcceleratorMemory::reportHost() { hostReportRow = hostMem.getReportRow(); }
void PimAcceleratorMemory::recordCoreReport(size_t coreId, const MemoryReportRow& row) {
reportEntries.push_back({MemoryReportEntry::Kind::Core, coreId, {static_cast<int32_t>(coreId)}, row});
reportEntries.push_back({MemoryReportEntry::Kind::Core,
coreId,
{static_cast<int32_t>(coreId)},
row,
row.numAlloca,
row.sizeAlloca});
}
void PimAcceleratorMemory::recordBatchReport(uint64_t batchId, ArrayRef<int32_t> coreIds, const MemoryReportRow& row) {
void PimAcceleratorMemory::recordBatchReport(uint64_t batchId,
ArrayRef<int32_t> coreIds,
const MemoryReportRow& perCoreRow,
uint64_t totalAllocaCount,
uint64_t totalAllocaBytes) {
MemoryReportEntry entry;
entry.kind = MemoryReportEntry::Kind::Batch;
entry.id = batchId;
llvm::append_range(entry.coreIds, coreIds);
entry.row = row;
entry.row = perCoreRow;
entry.totalAllocaCount = totalAllocaCount;
entry.totalAllocaBytes = totalAllocaBytes;
reportEntries.push_back(std::move(entry));
}
@@ -226,36 +243,32 @@ void PimAcceleratorMemory::flushReport() {
return;
llvm::raw_os_ostream os(fileReport);
uint64_t totalGlobalMemory = hostReportRow.has_value() ? hostReportRow->sizeGlobal : 0;
uint64_t totalCoresMemory = 0;
for (const MemoryReportEntry& entry : reportEntries)
totalCoresMemory += entry.totalAllocaBytes;
llvm::SmallVector<ReportField, 2> totalFields = {
{"Global memory", formatReportMemory(totalGlobalMemory)},
{"Cores memory", formatReportMemory(totalCoresMemory)}};
printReportTotalsBlock(os, totalFields);
if (hostReportRow.has_value()) {
os << "Host:\n";
printMemoryReportRow(os, *hostReportRow);
os << "\nHost:\n";
printHostMemoryReportRow(os, *hostReportRow);
}
if (!reportEntries.empty()) {
if (hostReportRow.has_value())
os << "\n";
llvm::stable_sort(reportEntries, [](const MemoryReportEntry& lhs, const MemoryReportEntry& rhs) {
if (lhs.kind != rhs.kind)
return lhs.kind == MemoryReportEntry::Kind::Batch;
const MemoryReportRow& lhsRow = lhs.row;
const MemoryReportRow& rhsRow = rhs.row;
if (lhsRow.sizeAlloca != rhsRow.sizeAlloca)
return lhsRow.sizeAlloca > rhsRow.sizeAlloca;
if (lhsRow.numAlloca != rhsRow.numAlloca)
return lhsRow.numAlloca > rhsRow.numAlloca;
if (lhsRow.sizeGlobal != rhsRow.sizeGlobal)
return lhsRow.sizeGlobal > rhsRow.sizeGlobal;
if (lhsRow.numGlobal != rhsRow.numGlobal)
return lhsRow.numGlobal > rhsRow.numGlobal;
return lhs.id < rhs.id;
});
sortReportEntriesByFirstCore(reportEntries);
for (size_t index = 0; index < reportEntries.size();) {
size_t runEnd = index + 1;
while (runEnd < reportEntries.size() && reportEntries[runEnd].kind == reportEntries[index].kind
&& reportEntries[runEnd].row == reportEntries[index].row) {
&& reportEntries[runEnd].row == reportEntries[index].row
&& reportEntries[runEnd].totalAllocaCount == reportEntries[index].totalAllocaCount
&& reportEntries[runEnd].totalAllocaBytes == reportEntries[index].totalAllocaBytes) {
++runEnd;
}
@@ -277,9 +290,11 @@ void PimAcceleratorMemory::flushReport() {
printCompressedIntegerEntries(os, ArrayRef<int32_t>(coreIds));
}
os << ":\n";
printMemoryReportRow(os, reportEntries[index].row);
if (runEnd < reportEntries.size())
os << "\n";
if (reportEntries[index].kind == MemoryReportEntry::Kind::Batch)
printBatchMemoryReportRow(os, reportEntries[index]);
else
printCoreMemoryReportRow(os, reportEntries[index]);
printReportEntrySeparator(os, runEnd < reportEntries.size());
index = runEnd;
}
@@ -876,7 +891,9 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
}
for (Operation* op : coreLikeOps) {
auto emitCore = [&](pim::PimCoreOp coreOp, bool temporaryCore) -> OnnxMlirCompilerErrorCodes {
auto emitCore = [&](pim::PimCoreOp coreOp,
bool temporaryCore,
MemoryReportRow* reportRow = nullptr) -> OnnxMlirCompilerErrorCodes {
size_t originalCoreId = static_cast<size_t>(coreOp.getCoreId());
size_t coreId = emittedCoreIds.lookup(originalCoreId);
maxCoreId = std::max(maxCoreId, coreId);
@@ -892,13 +909,17 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
PimCodeGen coreCodeGen(memory, coreFileStream, emittedCoreIds);
aliasMaterializedHostGlobals(moduleOp, funcOp, coreOp, memory);
memory.getOrCreateDeviceMem(coreId).allocateCore(coreOp);
auto& deviceMemory = memory.getOrCreateDeviceMem(coreId);
deviceMemory.allocateCore(coreOp);
int64_t processedOperations = codeGenCoreOps(coreOp.getBody().front(), coreCodeGen);
if (processedOperations < 0)
return CompilerFailure;
assert(processedOperations > 0);
if (reportRow)
*reportRow = deviceMemory.getReportRow();
coreFileStream.seek(coreFileStream.tell() - 1);
coreFileStream << ']';
coreFileStream.close();
@@ -936,11 +957,10 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
};
if (auto coreOp = dyn_cast<pim::PimCoreOp>(op)) {
if (auto err = emitCore(coreOp, false))
MemoryReportRow coreRow;
if (auto err = emitCore(coreOp, false, &coreRow))
return err;
memory.recordCoreReport(
emittedCoreIds.lookup(static_cast<size_t>(coreOp.getCoreId())),
memory.getOrCreateDeviceMem(emittedCoreIds.lookup(static_cast<size_t>(coreOp.getCoreId()))).getReportRow());
memory.recordCoreReport(emittedCoreIds.lookup(static_cast<size_t>(coreOp.getCoreId())), coreRow);
continue;
}
@@ -949,20 +969,29 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
SmallVector<int32_t> reportedCoreIds;
reportedCoreIds.reserve(batchCoreIds.size());
MemoryReportRow batchRow;
std::optional<MemoryReportRow> batchPerCoreRow;
for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane) {
OnnxMlirCompilerErrorCodes laneResult = CompilerSuccess;
if (failed(withScalarCoreFromBatchLane(coreBatchOp, lane, [&](pim::PimCoreOp coreOp) {
size_t originalCoreId = static_cast<size_t>(batchCoreIds[lane]);
size_t coreId = emittedCoreIds.lookup(originalCoreId);
reportedCoreIds.push_back(static_cast<int32_t>(coreId));
laneResult = emitCore(coreOp, true);
if (laneResult == CompilerSuccess)
batchRow = addMemoryReportRows(batchRow, memory.getOrCreateDeviceMem(coreId).getReportRow());
MemoryReportRow laneRow;
laneResult = emitCore(coreOp, true, &laneRow);
if (laneResult == CompilerSuccess) {
if (!batchPerCoreRow.has_value())
batchPerCoreRow = laneRow;
batchRow = addMemoryReportRows(batchRow, laneRow);
}
return laneResult == CompilerSuccess ? success() : failure();
})))
return laneResult == CompilerSuccess ? CompilerFailure : laneResult;
}
memory.recordBatchReport(nextBatchReportId++, reportedCoreIds, batchRow);
memory.recordBatchReport(nextBatchReportId++,
reportedCoreIds,
batchPerCoreRow.value_or(MemoryReportRow {}),
batchRow.numAlloca,
batchRow.sizeAlloca);
}
memory.flushReport();
+10 -12
View File
@@ -12,6 +12,7 @@
#include "onnx-mlir/Compiler/OMCompilerTypes.h"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
namespace onnx_mlir {
@@ -43,11 +44,14 @@ struct MemoryReportEntry {
uint64_t id = 0;
llvm::SmallVector<int32_t, 8> coreIds;
MemoryReportRow row;
uint64_t totalAllocaCount = 0;
uint64_t totalAllocaBytes = 0;
};
class PimMemory {
llvm::SmallVector<std::pair<MemEntry, mlir::Value>, 32> memEntries;
llvm::SmallDenseMap<mlir::Value, MemEntry, 32>& globalMemEntriesMap;
llvm::SmallDenseMap<mlir::Value, MemEntry, 32> ownedMemEntriesMap;
size_t minAlignment = 4;
size_t firstAvailableAddress = 0;
@@ -82,24 +86,18 @@ private:
public:
PimAcceleratorMemory()
: hostMem(memEntriesMap) {
std::string outputDir = getOutputDir();
if (outputDir.empty())
return;
std::string dialectsDir = outputDir + "/reports/";
createDirectory(dialectsDir);
std::fstream file(dialectsDir + "/memory_report.txt", std::ios::out);
fileReport = std::move(file);
}
: hostMem(memEntriesMap), fileReport(openReportFile("memory_report")) {}
PimMemory& getOrCreateDeviceMem(size_t id);
size_t getValueAddress(mlir::Value value, const StaticValueKnowledge& knowledge = {}) const;
void reportHost();
void recordCoreReport(size_t coreId, const MemoryReportRow& row);
void recordBatchReport(uint64_t batchId, llvm::ArrayRef<int32_t> coreIds, const MemoryReportRow& row);
void recordBatchReport(uint64_t batchId,
llvm::ArrayRef<int32_t> coreIds,
const MemoryReportRow& perCoreRow,
uint64_t totalAllocaCount,
uint64_t totalAllocaBytes);
void flushReport();
void clean(mlir::Operation* op);
};
+1
View File
@@ -41,6 +41,7 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
if (pimEmissionTarget >= EmitPimBufferized) {
pm.addPass(createPimBufferizationPass());
pm.addPass(createPimStaticMemoryCoalescingPass());
// pm.addPass(createCountInstructionPass());
pm.addPass(createMessagePass("Pim bufferized"));
}