Add register reuse + peft scheduler cost model + Useless merger

This commit is contained in:
ilgeco
2026-06-18 10:56:57 +02:00
parent 852bef7605
commit e083c27d80
13 changed files with 350 additions and 20 deletions
@@ -2018,6 +2018,7 @@ LogicalResult collectProjectedTransfers(MaterializerState& state) {
struct PendingProjectedTransferDescriptor {
ProjectedBatchInputKey inputKey;
Operation* extractOp = nullptr;
RankedTensorType sourceType;
RankedTensorType fragmentType;
SmallVector<int64_t, 4> fragmentShape;
SmallVector<SmallVector<SmallVector<int64_t, 4>, 16>, 8> fragmentOffsetsByLane;
@@ -2029,6 +2030,20 @@ LogicalResult collectProjectedTransfers(MaterializerState& state) {
DenseMap<ProducerKey, DenseMap<ClassId, PendingProjectedTransferDescriptor>, ProducerKeyInfo> pending;
const auto isIdentityProjectedTransfer = [&](const PendingProjectedTransferDescriptor& descriptor) {
if (!descriptor.sourceType || descriptor.sourceType != descriptor.fragmentType)
return false;
if (descriptor.fragmentOffsetsByLane.size() != 1)
return false;
ArrayRef<SmallVector<int64_t, 4>> fragments = descriptor.fragmentOffsetsByLane.front();
if (fragments.size() != 1)
return false;
return llvm::all_of(fragments.front(), [](int64_t offset) { return offset == 0; });
};
const auto appendEvaluatedFragments = [&](PendingProjectedTransferDescriptor& descriptor,
unsigned targetLane,
const AffineProjectedInputSliceMatch& match,
@@ -2117,6 +2132,7 @@ LogicalResult collectProjectedTransfers(MaterializerState& state) {
if (descriptor.fragmentOffsetsByLane.empty()) {
descriptor.inputKey = {batch.getOperation(), static_cast<unsigned>(inputIndex)};
descriptor.extractOp = match->extract.getOperation();
descriptor.sourceType = match->sourceType;
descriptor.fragmentType = match->fragmentType;
descriptor.fragmentShape = match->fragmentShape;
descriptor.fragmentOffsetsByLane.resize(targetClass.isBatch ? targetClass.cpus.size() : 1);
@@ -2132,7 +2148,8 @@ LogicalResult collectProjectedTransfers(MaterializerState& state) {
ProjectedBatchInputKey currentInputKey {batch.getOperation(), static_cast<unsigned>(inputIndex)};
if (!(descriptor.inputKey == currentInputKey) || descriptor.extractOp != match->extract.getOperation()
|| descriptor.fragmentType != match->fragmentType || descriptor.fragmentShape != match->fragmentShape
|| descriptor.sourceType != match->sourceType || descriptor.fragmentType != match->fragmentType
|| descriptor.fragmentShape != match->fragmentShape
|| descriptor.loopLowerBounds.size() != match->loops.size()) {
descriptor.invalid = true;
continue;
@@ -2175,6 +2192,8 @@ LogicalResult collectProjectedTransfers(MaterializerState& state) {
continue;
if (pendingDescriptor.fragmentOffsetsByLane.empty())
continue;
if (isIdentityProjectedTransfer(pendingDescriptor))
continue;
MaterializedClass& targetClass = state.classes[targetClassId];
ProjectedTransferDescriptor descriptor;
@@ -2755,8 +2774,14 @@ FailureOr<ScalarSourceFanoutPlan> buildScalarSourceFanoutPlan(MaterializerState&
if (*descriptor) {
const ProjectedTransferDescriptor& projectedDescriptor = **descriptor;
if (!targetClass.isBatch && projectedDescriptor.payloadType == payload.getType())
return targetClass.op->emitError("scalar projected receive unexpectedly uses the full producer tensor type");
if (!targetClass.isBatch && projectedDescriptor.payloadType == payload.getType()) {
if (!fanoutPlan.ordinaryMessages)
fanoutPlan.ordinaryMessages = MessageVector {};
fanoutPlan.ordinaryMessages->append(
receivePlan.messages.channelIds, receivePlan.messages.sourceCoreIds, receivePlan.messages.targetCoreIds);
fanoutPlan.receivePlans.push_back(std::move(receivePlan));
continue;
}
receivePlan.receiveType = projectedDescriptor.payloadType;
receivePlan.projectedExtractOp = projectedDescriptor.extractOp;
@@ -42,7 +42,6 @@ namespace {
using namespace onnx_mlir::compact_asm;
using SpatCompute = spatial::SpatCompute;
using SpatComputeBatch = spatial::SpatComputeBatch;
using spatial::getProducerValueRef;
static std::optional<int32_t> getComputeCoreId(SpatCompute compute) {
if (auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName)) {
@@ -187,13 +186,23 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu
SmallVector<int32_t> coreIds;
};
//TODO Used for report refactor
struct CollectorConcatRow {
uint64_t computeId = 0;
int32_t coreId = -1;
uint64_t operandCount = 0;
};
uint64_t totalComputeOps = 0;
uint64_t totalLogicalComputes = 0;
uint64_t totalBatchComputeOps = 0;
uint64_t totalInstructionCount = 0;
uint64_t totalCrossbarCount = 0;
uint64_t nextBatchId = 0;
//TODO Used for report refactor
std::vector<ReportRow> collectedData;
//TODO Used for report refactor
std::vector<CollectorConcatRow> collectorConcatRows;
auto getPerInstanceCrossbarCount = [&](Operation* op) -> uint64_t {
return static_cast<uint64_t>(spatial::collectDistinctCrossbarWeights(op).size());
@@ -206,7 +215,15 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu
SmallVector<int32_t> coreIds;
if (auto coreId = getComputeCoreId(spatCompute))
coreIds.push_back(*coreId);
collectedData.push_back({totalComputeOps++, 1, perInstanceCrossbarCount, numInst, false, coreIds});
uint64_t computeId = totalComputeOps++;
collectedData.push_back({computeId, 1, perInstanceCrossbarCount, numInst, false, coreIds});
uint64_t maxConcatOperands = 0;
spatCompute.getBody().walk([&](spatial::SpatConcatOp concatOp) {
maxConcatOperands = std::max<uint64_t>(maxConcatOperands, concatOp.getInputs().size());
});
//TODO 128 is a magic number
if (maxConcatOperands >= 128 && !coreIds.empty())
collectorConcatRows.push_back({computeId, coreIds.front(), maxConcatOperands});
totalLogicalComputes += 1;
totalInstructionCount += numInst;
totalCrossbarCount += perInstanceCrossbarCount;
@@ -238,9 +255,17 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu
{"Number of used crossbars", std::to_string(totalCrossbarCount) }
};
printReportTotalsBlock(os, totalFields);
if (!collectedData.empty())
if (!collectedData.empty() || !collectorConcatRows.empty())
os << "\n";
if (!collectorConcatRows.empty()) {
os << "Collector concat materialization:\n";
for (const CollectorConcatRow& row : collectorConcatRows)
os << "\tmaterialization_kind = single_collector_concat, compute = " << row.computeId
<< ", concat_operand_count = " << row.operandCount << ", collector_core = " << row.coreId << "\n";
os << "\n";
}
sortReportEntriesByFirstCore(collectedData);
for (uint64_t cI = 0; cI < totalComputeOps; ++cI) {
@@ -23,7 +23,10 @@ MergeSchedulerKind getSchedulerKind() {
llvm_unreachable("unknown merge scheduler kind");
}
void verifySchedule(const ComputeGraph& graph, const MergeScheduleResult& result, unsigned long crossbarCapacity) {
void verifySchedule(const ComputeGraph& graph,
const MergeScheduleResult& result,
unsigned long crossbarCapacity,
size_t processorCount) {
llvm::DenseMap<size_t, std::vector<std::pair<size_t, size_t>>> tasksByCpu;
tasksByCpu.reserve(result.cpuToLastComputeMap.size());
@@ -79,7 +82,8 @@ void verifySchedule(const ComputeGraph& graph, const MergeScheduleResult& result
Time earliestTargetStart = addOrMax(sourceStart, graph.nodes[edge.source].cost);
if (sourceCpu != targetCpu)
earliestTargetStart = addOrMax(earliestTargetStart, edge.transferCost);
earliestTargetStart = addOrMax(
earliestTargetStart, getPeftTransferTime(edge.transferCost, sourceCpu, targetCpu, processorCount));
if (targetStart < earliestTargetStart) {
std::string message = llvm::formatv("merge scheduling: dependency legality failed between tasks {0} and {1}",
graph.nodes[edge.source].originalOrder,
@@ -115,7 +119,10 @@ MergeScheduleResult MergeSchedulingAnalysis::run() {
static_cast<unsigned long>(crossbarCountInCore.getValue()),
entryOp->getContext()});
}
verifySchedule(graph, schedule, static_cast<unsigned long>(crossbarCountInCore.getValue()));
verifySchedule(graph,
schedule,
static_cast<unsigned long>(crossbarCountInCore.getValue()),
options.processorCount);
return schedule;
}
@@ -4,6 +4,7 @@
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/FormatVariadic.h"
#include <cmath>
#include <limits>
#include <queue>
#include <vector>
@@ -21,6 +22,63 @@ struct ScheduledTask {
Time endTime = 0;
};
struct MeshModel {
size_t rows = 1;
size_t cols = 1;
long double averageDistance = 0.0L;
static MeshModel infer(size_t processorCount) {
MeshModel model;
if (processorCount == 0)
return model;
model.rows = static_cast<size_t>(std::sqrt(static_cast<long double>(processorCount)));
if (model.rows == 0)
model.rows = 1;
while (model.rows > 1 && processorCount % model.rows != 0)
--model.rows;
model.cols = (processorCount + model.rows - 1) / model.rows;
auto averageAxisDistance = [](size_t size) -> long double {
if (size <= 1)
return 0.0L;
return static_cast<long double>(size * size - 1) / (3.0L * static_cast<long double>(size));
};
model.averageDistance = averageAxisDistance(model.rows) + averageAxisDistance(model.cols);
return model;
}
std::pair<size_t, size_t> getCoord(size_t processor) const {
return {processor / cols, processor % cols};
}
size_t getDistance(size_t lhs, size_t rhs) const {
auto [lhsRow, lhsCol] = getCoord(lhs);
auto [rhsRow, rhsCol] = getCoord(rhs);
size_t rowDistance = lhsRow > rhsRow ? lhsRow - rhsRow : rhsRow - lhsRow;
size_t colDistance = lhsCol > rhsCol ? lhsCol - rhsCol : rhsCol - lhsCol;
return rowDistance + colDistance;
}
Time scaleTransferCost(Time transferCost, size_t sourceProcessor, size_t targetProcessor) const {
if (sourceProcessor == targetProcessor || transferCost == 0)
return 0;
long double distance = static_cast<long double>(getDistance(sourceProcessor, targetProcessor));
long double scale = averageDistance > 0.0L ? distance / averageDistance : 1.0L;
scale = std::max(0.25L, scale);
return static_cast<Time>(std::ceil(static_cast<long double>(transferCost) * scale));
}
size_t getCenterDistance(size_t processor) const {
auto [row, col] = getCoord(processor);
size_t centerRow = rows / 2;
size_t centerCol = cols / 2;
size_t rowDistance = row > centerRow ? row - centerRow : centerRow - row;
size_t colDistance = col > centerCol ? col - centerCol : centerCol - col;
return rowDistance + colDistance;
}
};
std::vector<std::vector<size_t>> buildReverseLevels(const ComputeGraph& graph) {
std::vector<size_t> remainingSuccessors(graph.nodes.size(), 0);
std::queue<size_t> readySinks;
@@ -77,11 +135,16 @@ void verifyOctTableSize(size_t nodeCount, size_t processorCount) {
} // namespace
Time getPeftTransferTime(Time transferCost, size_t sourceProcessor, size_t targetProcessor, size_t processorCount) {
return MeshModel::infer(processorCount).scaleTransferCost(transferCost, sourceProcessor, targetProcessor);
}
MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftScheduleOptions& options) {
const size_t nodeCount = graph.nodes.size();
const size_t processorCount = options.processorCount;
if (processorCount == 0)
llvm::report_fatal_error("PEFT scheduler: processor count must be positive");
MeshModel mesh = MeshModel::infer(processorCount);
verifyOctTableSize(nodeCount, processorCount);
std::vector<std::vector<size_t>> reverseLevels = buildReverseLevels(graph);
@@ -89,7 +152,6 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
// MOCK: Replace this with your actual heterogeneous cost lookup.
// If graph.nodes[task] is modified to hold a vector of costs per processor, access it here.
auto getComputeCost = [&](size_t task, size_t processor) -> Time { return graph.nodes[task].cost; };
std::vector<Time> oct(nodeCount * processorCount, 0);
std::vector<Time> minOctPlusComp(nodeCount, 0);
@@ -177,6 +239,7 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
Time bestEft = 0;
Time bestOeft = std::numeric_limits<Time>::max();
unsigned int bestOverlapCount = 0;
size_t bestCenterDistance = std::numeric_limits<size_t>::max();
bool crossbarRejected = false;
for (size_t processor = 0; processor < processorCount; ++processor) {
@@ -191,7 +254,7 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
Time dataReady = 0;
for (const auto& [pred, comm] : graph.predecessors[task]) {
const ScheduledTask& predSchedule = schedules[pred];
Time commPenalty = predSchedule.processor == processor ? 0 : comm;
Time commPenalty = getPeftTransferTime(comm, predSchedule.processor, processor, processorCount);
dataReady = std::max(dataReady, addOrMax(predSchedule.endTime, commPenalty));
}
@@ -218,6 +281,7 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
Time eft = addOrMax(est, computeCost);
Time oeft = addOrMax(eft, oct[task * processorCount + processor]);
size_t centerDistance = mesh.getCenterDistance(processor);
if (oeft < bestOeft || (oeft == bestOeft && eft < bestEft)
|| (oeft == bestOeft && eft == bestEft && est < bestEst)) {
@@ -226,13 +290,25 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
bestEft = eft;
bestOeft = oeft;
bestOverlapCount = overlapCount;
bestCenterDistance = centerDistance;
}
else if (oeft == bestOeft && eft == bestEft && est < bestEst && overlapCount < bestOverlapCount) {
else if (oeft == bestOeft && eft == bestEft && est == bestEst
&& centerDistance < bestCenterDistance) {
bestProcessor = processor;
bestEst = est;
bestEft = eft;
bestOeft = oeft;
bestOverlapCount = overlapCount;
bestCenterDistance = centerDistance;
}
else if (oeft == bestOeft && eft == bestEft && est == bestEst
&& centerDistance == bestCenterDistance && overlapCount < bestOverlapCount) {
bestProcessor = processor;
bestEst = est;
bestEft = eft;
bestOeft = oeft;
bestOverlapCount = overlapCount;
bestCenterDistance = centerDistance;
}
}
@@ -14,6 +14,7 @@ struct PeftScheduleOptions {
mlir::MLIRContext* context = nullptr;
};
Time getPeftTransferTime(Time transferCost, size_t sourceProcessor, size_t targetProcessor, size_t processorCount);
MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftScheduleOptions& options);
} // namespace spatial