Add register reuse + peft scheduler cost model + Useless merger
This commit is contained in:
@@ -2018,6 +2018,7 @@ LogicalResult collectProjectedTransfers(MaterializerState& state) {
|
||||
struct PendingProjectedTransferDescriptor {
|
||||
ProjectedBatchInputKey inputKey;
|
||||
Operation* extractOp = nullptr;
|
||||
RankedTensorType sourceType;
|
||||
RankedTensorType fragmentType;
|
||||
SmallVector<int64_t, 4> fragmentShape;
|
||||
SmallVector<SmallVector<SmallVector<int64_t, 4>, 16>, 8> fragmentOffsetsByLane;
|
||||
@@ -2029,6 +2030,20 @@ LogicalResult collectProjectedTransfers(MaterializerState& state) {
|
||||
|
||||
DenseMap<ProducerKey, DenseMap<ClassId, PendingProjectedTransferDescriptor>, ProducerKeyInfo> pending;
|
||||
|
||||
const auto isIdentityProjectedTransfer = [&](const PendingProjectedTransferDescriptor& descriptor) {
|
||||
if (!descriptor.sourceType || descriptor.sourceType != descriptor.fragmentType)
|
||||
return false;
|
||||
|
||||
if (descriptor.fragmentOffsetsByLane.size() != 1)
|
||||
return false;
|
||||
|
||||
ArrayRef<SmallVector<int64_t, 4>> fragments = descriptor.fragmentOffsetsByLane.front();
|
||||
if (fragments.size() != 1)
|
||||
return false;
|
||||
|
||||
return llvm::all_of(fragments.front(), [](int64_t offset) { return offset == 0; });
|
||||
};
|
||||
|
||||
const auto appendEvaluatedFragments = [&](PendingProjectedTransferDescriptor& descriptor,
|
||||
unsigned targetLane,
|
||||
const AffineProjectedInputSliceMatch& match,
|
||||
@@ -2117,6 +2132,7 @@ LogicalResult collectProjectedTransfers(MaterializerState& state) {
|
||||
if (descriptor.fragmentOffsetsByLane.empty()) {
|
||||
descriptor.inputKey = {batch.getOperation(), static_cast<unsigned>(inputIndex)};
|
||||
descriptor.extractOp = match->extract.getOperation();
|
||||
descriptor.sourceType = match->sourceType;
|
||||
descriptor.fragmentType = match->fragmentType;
|
||||
descriptor.fragmentShape = match->fragmentShape;
|
||||
descriptor.fragmentOffsetsByLane.resize(targetClass.isBatch ? targetClass.cpus.size() : 1);
|
||||
@@ -2132,7 +2148,8 @@ LogicalResult collectProjectedTransfers(MaterializerState& state) {
|
||||
|
||||
ProjectedBatchInputKey currentInputKey {batch.getOperation(), static_cast<unsigned>(inputIndex)};
|
||||
if (!(descriptor.inputKey == currentInputKey) || descriptor.extractOp != match->extract.getOperation()
|
||||
|| descriptor.fragmentType != match->fragmentType || descriptor.fragmentShape != match->fragmentShape
|
||||
|| descriptor.sourceType != match->sourceType || descriptor.fragmentType != match->fragmentType
|
||||
|| descriptor.fragmentShape != match->fragmentShape
|
||||
|| descriptor.loopLowerBounds.size() != match->loops.size()) {
|
||||
descriptor.invalid = true;
|
||||
continue;
|
||||
@@ -2175,6 +2192,8 @@ LogicalResult collectProjectedTransfers(MaterializerState& state) {
|
||||
continue;
|
||||
if (pendingDescriptor.fragmentOffsetsByLane.empty())
|
||||
continue;
|
||||
if (isIdentityProjectedTransfer(pendingDescriptor))
|
||||
continue;
|
||||
|
||||
MaterializedClass& targetClass = state.classes[targetClassId];
|
||||
ProjectedTransferDescriptor descriptor;
|
||||
@@ -2755,8 +2774,14 @@ FailureOr<ScalarSourceFanoutPlan> buildScalarSourceFanoutPlan(MaterializerState&
|
||||
if (*descriptor) {
|
||||
const ProjectedTransferDescriptor& projectedDescriptor = **descriptor;
|
||||
|
||||
if (!targetClass.isBatch && projectedDescriptor.payloadType == payload.getType())
|
||||
return targetClass.op->emitError("scalar projected receive unexpectedly uses the full producer tensor type");
|
||||
if (!targetClass.isBatch && projectedDescriptor.payloadType == payload.getType()) {
|
||||
if (!fanoutPlan.ordinaryMessages)
|
||||
fanoutPlan.ordinaryMessages = MessageVector {};
|
||||
fanoutPlan.ordinaryMessages->append(
|
||||
receivePlan.messages.channelIds, receivePlan.messages.sourceCoreIds, receivePlan.messages.targetCoreIds);
|
||||
fanoutPlan.receivePlans.push_back(std::move(receivePlan));
|
||||
continue;
|
||||
}
|
||||
|
||||
receivePlan.receiveType = projectedDescriptor.payloadType;
|
||||
receivePlan.projectedExtractOp = projectedDescriptor.extractOp;
|
||||
|
||||
@@ -42,7 +42,6 @@ namespace {
|
||||
using namespace onnx_mlir::compact_asm;
|
||||
using SpatCompute = spatial::SpatCompute;
|
||||
using SpatComputeBatch = spatial::SpatComputeBatch;
|
||||
using spatial::getProducerValueRef;
|
||||
|
||||
static std::optional<int32_t> getComputeCoreId(SpatCompute compute) {
|
||||
if (auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName)) {
|
||||
@@ -187,13 +186,23 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu
|
||||
SmallVector<int32_t> coreIds;
|
||||
};
|
||||
|
||||
//TODO Used for report refactor
|
||||
struct CollectorConcatRow {
|
||||
uint64_t computeId = 0;
|
||||
int32_t coreId = -1;
|
||||
uint64_t operandCount = 0;
|
||||
};
|
||||
|
||||
uint64_t totalComputeOps = 0;
|
||||
uint64_t totalLogicalComputes = 0;
|
||||
uint64_t totalBatchComputeOps = 0;
|
||||
uint64_t totalInstructionCount = 0;
|
||||
uint64_t totalCrossbarCount = 0;
|
||||
uint64_t nextBatchId = 0;
|
||||
//TODO Used for report refactor
|
||||
std::vector<ReportRow> collectedData;
|
||||
//TODO Used for report refactor
|
||||
std::vector<CollectorConcatRow> collectorConcatRows;
|
||||
|
||||
auto getPerInstanceCrossbarCount = [&](Operation* op) -> uint64_t {
|
||||
return static_cast<uint64_t>(spatial::collectDistinctCrossbarWeights(op).size());
|
||||
@@ -206,7 +215,15 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu
|
||||
SmallVector<int32_t> coreIds;
|
||||
if (auto coreId = getComputeCoreId(spatCompute))
|
||||
coreIds.push_back(*coreId);
|
||||
collectedData.push_back({totalComputeOps++, 1, perInstanceCrossbarCount, numInst, false, coreIds});
|
||||
uint64_t computeId = totalComputeOps++;
|
||||
collectedData.push_back({computeId, 1, perInstanceCrossbarCount, numInst, false, coreIds});
|
||||
uint64_t maxConcatOperands = 0;
|
||||
spatCompute.getBody().walk([&](spatial::SpatConcatOp concatOp) {
|
||||
maxConcatOperands = std::max<uint64_t>(maxConcatOperands, concatOp.getInputs().size());
|
||||
});
|
||||
//TODO 128 is a magic number
|
||||
if (maxConcatOperands >= 128 && !coreIds.empty())
|
||||
collectorConcatRows.push_back({computeId, coreIds.front(), maxConcatOperands});
|
||||
totalLogicalComputes += 1;
|
||||
totalInstructionCount += numInst;
|
||||
totalCrossbarCount += perInstanceCrossbarCount;
|
||||
@@ -238,9 +255,17 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu
|
||||
{"Number of used crossbars", std::to_string(totalCrossbarCount) }
|
||||
};
|
||||
printReportTotalsBlock(os, totalFields);
|
||||
if (!collectedData.empty())
|
||||
if (!collectedData.empty() || !collectorConcatRows.empty())
|
||||
os << "\n";
|
||||
|
||||
if (!collectorConcatRows.empty()) {
|
||||
os << "Collector concat materialization:\n";
|
||||
for (const CollectorConcatRow& row : collectorConcatRows)
|
||||
os << "\tmaterialization_kind = single_collector_concat, compute = " << row.computeId
|
||||
<< ", concat_operand_count = " << row.operandCount << ", collector_core = " << row.coreId << "\n";
|
||||
os << "\n";
|
||||
}
|
||||
|
||||
sortReportEntriesByFirstCore(collectedData);
|
||||
|
||||
for (uint64_t cI = 0; cI < totalComputeOps; ++cI) {
|
||||
|
||||
+10
-3
@@ -23,7 +23,10 @@ MergeSchedulerKind getSchedulerKind() {
|
||||
llvm_unreachable("unknown merge scheduler kind");
|
||||
}
|
||||
|
||||
void verifySchedule(const ComputeGraph& graph, const MergeScheduleResult& result, unsigned long crossbarCapacity) {
|
||||
void verifySchedule(const ComputeGraph& graph,
|
||||
const MergeScheduleResult& result,
|
||||
unsigned long crossbarCapacity,
|
||||
size_t processorCount) {
|
||||
llvm::DenseMap<size_t, std::vector<std::pair<size_t, size_t>>> tasksByCpu;
|
||||
tasksByCpu.reserve(result.cpuToLastComputeMap.size());
|
||||
|
||||
@@ -79,7 +82,8 @@ void verifySchedule(const ComputeGraph& graph, const MergeScheduleResult& result
|
||||
|
||||
Time earliestTargetStart = addOrMax(sourceStart, graph.nodes[edge.source].cost);
|
||||
if (sourceCpu != targetCpu)
|
||||
earliestTargetStart = addOrMax(earliestTargetStart, edge.transferCost);
|
||||
earliestTargetStart = addOrMax(
|
||||
earliestTargetStart, getPeftTransferTime(edge.transferCost, sourceCpu, targetCpu, processorCount));
|
||||
if (targetStart < earliestTargetStart) {
|
||||
std::string message = llvm::formatv("merge scheduling: dependency legality failed between tasks {0} and {1}",
|
||||
graph.nodes[edge.source].originalOrder,
|
||||
@@ -115,7 +119,10 @@ MergeScheduleResult MergeSchedulingAnalysis::run() {
|
||||
static_cast<unsigned long>(crossbarCountInCore.getValue()),
|
||||
entryOp->getContext()});
|
||||
}
|
||||
verifySchedule(graph, schedule, static_cast<unsigned long>(crossbarCountInCore.getValue()));
|
||||
verifySchedule(graph,
|
||||
schedule,
|
||||
static_cast<unsigned long>(crossbarCountInCore.getValue()),
|
||||
options.processorCount);
|
||||
return schedule;
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <limits>
|
||||
#include <queue>
|
||||
#include <vector>
|
||||
@@ -21,6 +22,63 @@ struct ScheduledTask {
|
||||
Time endTime = 0;
|
||||
};
|
||||
|
||||
struct MeshModel {
|
||||
size_t rows = 1;
|
||||
size_t cols = 1;
|
||||
long double averageDistance = 0.0L;
|
||||
|
||||
static MeshModel infer(size_t processorCount) {
|
||||
MeshModel model;
|
||||
if (processorCount == 0)
|
||||
return model;
|
||||
|
||||
model.rows = static_cast<size_t>(std::sqrt(static_cast<long double>(processorCount)));
|
||||
if (model.rows == 0)
|
||||
model.rows = 1;
|
||||
while (model.rows > 1 && processorCount % model.rows != 0)
|
||||
--model.rows;
|
||||
model.cols = (processorCount + model.rows - 1) / model.rows;
|
||||
|
||||
auto averageAxisDistance = [](size_t size) -> long double {
|
||||
if (size <= 1)
|
||||
return 0.0L;
|
||||
return static_cast<long double>(size * size - 1) / (3.0L * static_cast<long double>(size));
|
||||
};
|
||||
model.averageDistance = averageAxisDistance(model.rows) + averageAxisDistance(model.cols);
|
||||
return model;
|
||||
}
|
||||
|
||||
std::pair<size_t, size_t> getCoord(size_t processor) const {
|
||||
return {processor / cols, processor % cols};
|
||||
}
|
||||
|
||||
size_t getDistance(size_t lhs, size_t rhs) const {
|
||||
auto [lhsRow, lhsCol] = getCoord(lhs);
|
||||
auto [rhsRow, rhsCol] = getCoord(rhs);
|
||||
size_t rowDistance = lhsRow > rhsRow ? lhsRow - rhsRow : rhsRow - lhsRow;
|
||||
size_t colDistance = lhsCol > rhsCol ? lhsCol - rhsCol : rhsCol - lhsCol;
|
||||
return rowDistance + colDistance;
|
||||
}
|
||||
|
||||
Time scaleTransferCost(Time transferCost, size_t sourceProcessor, size_t targetProcessor) const {
|
||||
if (sourceProcessor == targetProcessor || transferCost == 0)
|
||||
return 0;
|
||||
long double distance = static_cast<long double>(getDistance(sourceProcessor, targetProcessor));
|
||||
long double scale = averageDistance > 0.0L ? distance / averageDistance : 1.0L;
|
||||
scale = std::max(0.25L, scale);
|
||||
return static_cast<Time>(std::ceil(static_cast<long double>(transferCost) * scale));
|
||||
}
|
||||
|
||||
size_t getCenterDistance(size_t processor) const {
|
||||
auto [row, col] = getCoord(processor);
|
||||
size_t centerRow = rows / 2;
|
||||
size_t centerCol = cols / 2;
|
||||
size_t rowDistance = row > centerRow ? row - centerRow : centerRow - row;
|
||||
size_t colDistance = col > centerCol ? col - centerCol : centerCol - col;
|
||||
return rowDistance + colDistance;
|
||||
}
|
||||
};
|
||||
|
||||
std::vector<std::vector<size_t>> buildReverseLevels(const ComputeGraph& graph) {
|
||||
std::vector<size_t> remainingSuccessors(graph.nodes.size(), 0);
|
||||
std::queue<size_t> readySinks;
|
||||
@@ -77,11 +135,16 @@ void verifyOctTableSize(size_t nodeCount, size_t processorCount) {
|
||||
|
||||
} // namespace
|
||||
|
||||
Time getPeftTransferTime(Time transferCost, size_t sourceProcessor, size_t targetProcessor, size_t processorCount) {
|
||||
return MeshModel::infer(processorCount).scaleTransferCost(transferCost, sourceProcessor, targetProcessor);
|
||||
}
|
||||
|
||||
MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftScheduleOptions& options) {
|
||||
const size_t nodeCount = graph.nodes.size();
|
||||
const size_t processorCount = options.processorCount;
|
||||
if (processorCount == 0)
|
||||
llvm::report_fatal_error("PEFT scheduler: processor count must be positive");
|
||||
MeshModel mesh = MeshModel::infer(processorCount);
|
||||
|
||||
verifyOctTableSize(nodeCount, processorCount);
|
||||
std::vector<std::vector<size_t>> reverseLevels = buildReverseLevels(graph);
|
||||
@@ -89,7 +152,6 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
|
||||
// MOCK: Replace this with your actual heterogeneous cost lookup.
|
||||
// If graph.nodes[task] is modified to hold a vector of costs per processor, access it here.
|
||||
auto getComputeCost = [&](size_t task, size_t processor) -> Time { return graph.nodes[task].cost; };
|
||||
|
||||
std::vector<Time> oct(nodeCount * processorCount, 0);
|
||||
std::vector<Time> minOctPlusComp(nodeCount, 0);
|
||||
|
||||
@@ -177,6 +239,7 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
|
||||
Time bestEft = 0;
|
||||
Time bestOeft = std::numeric_limits<Time>::max();
|
||||
unsigned int bestOverlapCount = 0;
|
||||
size_t bestCenterDistance = std::numeric_limits<size_t>::max();
|
||||
bool crossbarRejected = false;
|
||||
|
||||
for (size_t processor = 0; processor < processorCount; ++processor) {
|
||||
@@ -191,7 +254,7 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
|
||||
Time dataReady = 0;
|
||||
for (const auto& [pred, comm] : graph.predecessors[task]) {
|
||||
const ScheduledTask& predSchedule = schedules[pred];
|
||||
Time commPenalty = predSchedule.processor == processor ? 0 : comm;
|
||||
Time commPenalty = getPeftTransferTime(comm, predSchedule.processor, processor, processorCount);
|
||||
dataReady = std::max(dataReady, addOrMax(predSchedule.endTime, commPenalty));
|
||||
}
|
||||
|
||||
@@ -218,6 +281,7 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
|
||||
|
||||
Time eft = addOrMax(est, computeCost);
|
||||
Time oeft = addOrMax(eft, oct[task * processorCount + processor]);
|
||||
size_t centerDistance = mesh.getCenterDistance(processor);
|
||||
|
||||
if (oeft < bestOeft || (oeft == bestOeft && eft < bestEft)
|
||||
|| (oeft == bestOeft && eft == bestEft && est < bestEst)) {
|
||||
@@ -226,13 +290,25 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
|
||||
bestEft = eft;
|
||||
bestOeft = oeft;
|
||||
bestOverlapCount = overlapCount;
|
||||
bestCenterDistance = centerDistance;
|
||||
}
|
||||
else if (oeft == bestOeft && eft == bestEft && est < bestEst && overlapCount < bestOverlapCount) {
|
||||
else if (oeft == bestOeft && eft == bestEft && est == bestEst
|
||||
&& centerDistance < bestCenterDistance) {
|
||||
bestProcessor = processor;
|
||||
bestEst = est;
|
||||
bestEft = eft;
|
||||
bestOeft = oeft;
|
||||
bestOverlapCount = overlapCount;
|
||||
bestCenterDistance = centerDistance;
|
||||
}
|
||||
else if (oeft == bestOeft && eft == bestEft && est == bestEst
|
||||
&& centerDistance == bestCenterDistance && overlapCount < bestOverlapCount) {
|
||||
bestProcessor = processor;
|
||||
bestEst = est;
|
||||
bestEft = eft;
|
||||
bestOeft = oeft;
|
||||
bestOverlapCount = overlapCount;
|
||||
bestCenterDistance = centerDistance;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ struct PeftScheduleOptions {
|
||||
mlir::MLIRContext* context = nullptr;
|
||||
};
|
||||
|
||||
Time getPeftTransferTime(Time transferCost, size_t sourceProcessor, size_t targetProcessor, size_t processorCount);
|
||||
MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftScheduleOptions& options);
|
||||
|
||||
} // namespace spatial
|
||||
|
||||
Reference in New Issue
Block a user