compact syntax for spatial tensor ops
Validate Operations / validate-operations (push) Has been cancelled
Validate Operations / validate-operations (push) Has been cancelled
better IR compaction after dcp merge remove pim.mvm op better memory report
This commit is contained in:
@@ -134,6 +134,15 @@ static void printMemoryReportRow(raw_ostream& os, const MemoryReportRow& row) {
|
||||
os << "\tGlobal memory: " << formatMemory(row.sizeGlobal) << "\n";
|
||||
}
|
||||
|
||||
static MemoryReportRow addMemoryReportRows(const MemoryReportRow& lhs, const MemoryReportRow& rhs) {
|
||||
MemoryReportRow result = lhs;
|
||||
result.numAlloca += rhs.numAlloca;
|
||||
result.sizeAlloca += rhs.sizeAlloca;
|
||||
result.numGlobal += rhs.numGlobal;
|
||||
result.sizeGlobal += rhs.sizeGlobal;
|
||||
return result;
|
||||
}
|
||||
|
||||
MemoryReportRow PimMemory::getReportRow() const {
|
||||
MemoryReportRow row;
|
||||
for (auto& [val, memEntry] : globalMemEntriesMap) {
|
||||
@@ -201,8 +210,17 @@ void PimAcceleratorMemory::reportHost() {
|
||||
hostReportRow = hostMem.getReportRow();
|
||||
}
|
||||
|
||||
void PimAcceleratorMemory::reportCore(size_t coreId) {
|
||||
coreReportRows.push_back({coreId, deviceMem.at(coreId).getReportRow()});
|
||||
void PimAcceleratorMemory::recordCoreReport(size_t coreId, const MemoryReportRow& row) {
|
||||
reportEntries.push_back({MemoryReportEntry::Kind::Core, coreId, {static_cast<int32_t>(coreId)}, row});
|
||||
}
|
||||
|
||||
void PimAcceleratorMemory::recordBatchReport(uint64_t batchId, ArrayRef<int32_t> coreIds, const MemoryReportRow& row) {
|
||||
MemoryReportEntry entry;
|
||||
entry.kind = MemoryReportEntry::Kind::Batch;
|
||||
entry.id = batchId;
|
||||
llvm::append_range(entry.coreIds, coreIds);
|
||||
entry.row = row;
|
||||
reportEntries.push_back(std::move(entry));
|
||||
}
|
||||
|
||||
void PimAcceleratorMemory::flushReport() {
|
||||
@@ -215,13 +233,16 @@ void PimAcceleratorMemory::flushReport() {
|
||||
printMemoryReportRow(os, *hostReportRow);
|
||||
}
|
||||
|
||||
if (!coreReportRows.empty()) {
|
||||
if (!reportEntries.empty()) {
|
||||
if (hostReportRow.has_value())
|
||||
os << "\n";
|
||||
|
||||
llvm::stable_sort(coreReportRows, [](const auto& lhs, const auto& rhs) {
|
||||
const MemoryReportRow& lhsRow = lhs.second;
|
||||
const MemoryReportRow& rhsRow = rhs.second;
|
||||
llvm::stable_sort(reportEntries, [](const MemoryReportEntry& lhs, const MemoryReportEntry& rhs) {
|
||||
if (lhs.kind != rhs.kind)
|
||||
return lhs.kind == MemoryReportEntry::Kind::Batch;
|
||||
|
||||
const MemoryReportRow& lhsRow = lhs.row;
|
||||
const MemoryReportRow& rhsRow = rhs.row;
|
||||
if (lhsRow.sizeAlloca != rhsRow.sizeAlloca)
|
||||
return lhsRow.sizeAlloca > rhsRow.sizeAlloca;
|
||||
if (lhsRow.numAlloca != rhsRow.numAlloca)
|
||||
@@ -230,24 +251,36 @@ void PimAcceleratorMemory::flushReport() {
|
||||
return lhsRow.sizeGlobal > rhsRow.sizeGlobal;
|
||||
if (lhsRow.numGlobal != rhsRow.numGlobal)
|
||||
return lhsRow.numGlobal > rhsRow.numGlobal;
|
||||
return lhs.first < rhs.first;
|
||||
return lhs.id < rhs.id;
|
||||
});
|
||||
|
||||
for (size_t index = 0; index < coreReportRows.size();) {
|
||||
for (size_t index = 0; index < reportEntries.size();) {
|
||||
size_t runEnd = index + 1;
|
||||
while (runEnd < coreReportRows.size() && coreReportRows[runEnd].second == coreReportRows[index].second)
|
||||
while (runEnd < reportEntries.size() && reportEntries[runEnd].kind == reportEntries[index].kind
|
||||
&& reportEntries[runEnd].row == reportEntries[index].row) {
|
||||
++runEnd;
|
||||
}
|
||||
|
||||
llvm::SmallVector<size_t, 8> coreIds;
|
||||
coreIds.reserve(runEnd - index);
|
||||
for (size_t coreIndex = index; coreIndex < runEnd; ++coreIndex)
|
||||
coreIds.push_back(coreReportRows[coreIndex].first);
|
||||
|
||||
os << "Core ";
|
||||
printCompressedIntegerEntries(os, ArrayRef<size_t>(coreIds));
|
||||
if (reportEntries[index].kind == MemoryReportEntry::Kind::Batch) {
|
||||
os << "Batch ";
|
||||
for (size_t batchIndex = index; batchIndex < runEnd; ++batchIndex) {
|
||||
if (batchIndex != index)
|
||||
os << ",\n ";
|
||||
os << reportEntries[batchIndex].id << " (cores ";
|
||||
printCompressedIntegerEntries(os, ArrayRef<int32_t>(reportEntries[batchIndex].coreIds));
|
||||
os << ")";
|
||||
}
|
||||
}
|
||||
else {
|
||||
llvm::SmallVector<int32_t, 8> coreIds;
|
||||
for (size_t coreIndex = index; coreIndex < runEnd; ++coreIndex)
|
||||
coreIds.push_back(reportEntries[coreIndex].coreIds.front());
|
||||
os << "Core ";
|
||||
printCompressedIntegerEntries(os, ArrayRef<int32_t>(coreIds));
|
||||
}
|
||||
os << ":\n";
|
||||
printMemoryReportRow(os, coreReportRows[index].second);
|
||||
if (runEnd < coreReportRows.size())
|
||||
printMemoryReportRow(os, reportEntries[index].row);
|
||||
if (runEnd < reportEntries.size())
|
||||
os << "\n";
|
||||
|
||||
index = runEnd;
|
||||
@@ -678,7 +711,6 @@ static SmallVector<unsigned, 8> getUsedWeightIndices(Block& block) {
|
||||
indices.push_back(weightIndex);
|
||||
};
|
||||
|
||||
block.walk([&](pim::PimMVMOp mvmOp) { addIndex(mvmOp.getWeightIndex()); });
|
||||
block.walk([&](pim::PimVMMOp vmmOp) { addIndex(vmmOp.getWeightIndex()); });
|
||||
llvm::sort(indices);
|
||||
return indices;
|
||||
@@ -753,8 +785,6 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
|
||||
coreCodeGen.codeGenConcatOp(concatOp, knowledge);
|
||||
else if (auto vmmOp = dyn_cast<pim::PimVMMOp>(op))
|
||||
coreCodeGen.codeGenMVMLikeOp<pim::PimVMMOp>(vmmOp.getWeightIndex(), vmmOp, true, knowledge);
|
||||
else if (auto mvmOp = dyn_cast<pim::PimMVMOp>(op))
|
||||
coreCodeGen.codeGenMVMLikeOp<pim::PimMVMOp>(mvmOp.getWeightIndex(), mvmOp, false, knowledge);
|
||||
else if (auto transposeOp = dyn_cast<pim::PimTransposeOp>(op))
|
||||
coreCodeGen.codeGenTransposeOp(transposeOp, knowledge);
|
||||
else if (auto vvaddOp = dyn_cast<pim::PimVVAddOp>(op))
|
||||
@@ -816,6 +846,7 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
|
||||
// This implementation always assigns one crossbar per group.
|
||||
json::Object xbarsPerArrayGroup;
|
||||
size_t maxCoreId = 0;
|
||||
uint64_t nextBatchReportId = 0;
|
||||
|
||||
// Create Weight Folder
|
||||
auto mapCoreWeightToFileName = createAndPopulateWeightFolder(funcOp, outputDirPath);
|
||||
@@ -859,7 +890,6 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
|
||||
PimCodeGen coreCodeGen(memory, coreFileStream, emittedCoreIds);
|
||||
aliasMaterializedHostGlobals(moduleOp, funcOp, coreOp, memory);
|
||||
memory.getOrCreateDeviceMem(coreId).allocateCore(coreOp);
|
||||
memory.reportCore(coreId);
|
||||
|
||||
int64_t processedOperations = codeGenCoreOps(coreOp.getBody().front(), coreCodeGen);
|
||||
if (processedOperations < 0)
|
||||
@@ -905,18 +935,31 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
|
||||
if (auto coreOp = dyn_cast<pim::PimCoreOp>(op)) {
|
||||
if (auto err = emitCore(coreOp, false))
|
||||
return err;
|
||||
memory.recordCoreReport(emittedCoreIds.lookup(static_cast<size_t>(coreOp.getCoreId())),
|
||||
memory.getOrCreateDeviceMem(emittedCoreIds.lookup(static_cast<size_t>(coreOp.getCoreId())))
|
||||
.getReportRow());
|
||||
continue;
|
||||
}
|
||||
|
||||
auto coreBatchOp = cast<pim::PimCoreBatchOp>(op);
|
||||
auto batchCoreIds = getBatchCoreIds(coreBatchOp);
|
||||
SmallVector<int32_t> reportedCoreIds;
|
||||
reportedCoreIds.reserve(batchCoreIds.size());
|
||||
MemoryReportRow batchRow;
|
||||
for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane) {
|
||||
OnnxMlirCompilerErrorCodes laneResult = CompilerSuccess;
|
||||
if (failed(withScalarCoreFromBatchLane(coreBatchOp, lane, [&](pim::PimCoreOp coreOp) {
|
||||
size_t originalCoreId = static_cast<size_t>(batchCoreIds[lane]);
|
||||
size_t coreId = emittedCoreIds.lookup(originalCoreId);
|
||||
reportedCoreIds.push_back(static_cast<int32_t>(coreId));
|
||||
laneResult = emitCore(coreOp, true);
|
||||
if (laneResult == CompilerSuccess)
|
||||
batchRow = addMemoryReportRows(batchRow, memory.getOrCreateDeviceMem(coreId).getReportRow());
|
||||
return laneResult == CompilerSuccess ? success() : failure();
|
||||
})))
|
||||
return laneResult == CompilerSuccess ? CompilerFailure : laneResult;
|
||||
}
|
||||
memory.recordBatchReport(nextBatchReportId++, reportedCoreIds, batchRow);
|
||||
}
|
||||
|
||||
memory.flushReport();
|
||||
|
||||
@@ -33,6 +33,18 @@ struct MemoryReportRow {
|
||||
}
|
||||
};
|
||||
|
||||
struct MemoryReportEntry {
|
||||
enum class Kind {
|
||||
Core,
|
||||
Batch
|
||||
};
|
||||
|
||||
Kind kind = Kind::Core;
|
||||
uint64_t id = 0;
|
||||
llvm::SmallVector<int32_t, 8> coreIds;
|
||||
MemoryReportRow row;
|
||||
};
|
||||
|
||||
class PimMemory {
|
||||
llvm::SmallVector<std::pair<MemEntry, mlir::Value>, 32> memEntries;
|
||||
llvm::SmallDenseMap<mlir::Value, MemEntry, 32>& globalMemEntriesMap;
|
||||
@@ -66,7 +78,7 @@ private:
|
||||
llvm::SmallDenseMap<size_t, PimMemory> deviceMem;
|
||||
std::fstream fileReport;
|
||||
std::optional<MemoryReportRow> hostReportRow;
|
||||
llvm::SmallVector<std::pair<size_t, MemoryReportRow>, 32> coreReportRows;
|
||||
llvm::SmallVector<MemoryReportEntry, 32> reportEntries;
|
||||
|
||||
public:
|
||||
PimAcceleratorMemory()
|
||||
@@ -86,7 +98,8 @@ public:
|
||||
|
||||
size_t getValueAddress(mlir::Value value, const StaticValueKnowledge& knowledge = {}) const;
|
||||
void reportHost();
|
||||
void reportCore(size_t coreId);
|
||||
void recordCoreReport(size_t coreId, const MemoryReportRow& row);
|
||||
void recordBatchReport(uint64_t batchId, llvm::ArrayRef<int32_t> coreIds, const MemoryReportRow& row);
|
||||
void flushReport();
|
||||
void clean(mlir::Operation* op);
|
||||
};
|
||||
|
||||
@@ -103,7 +103,6 @@ SmallVector<unsigned, 8> getUsedWeightIndices(Block& block) {
|
||||
indices.push_back(weightIndex);
|
||||
};
|
||||
|
||||
block.walk([&](pim::PimMVMOp mvmOp) { addIndex(mvmOp.getWeightIndex()); });
|
||||
block.walk([&](pim::PimVMMOp vmmOp) { addIndex(vmmOp.getWeightIndex()); });
|
||||
llvm::sort(indices);
|
||||
return indices;
|
||||
|
||||
Reference in New Issue
Block a user