faster DCPAnalysis on partial graph
All checks were successful
Validate Operations / validate-operations (push) Successful in 27m37s

This commit is contained in:
NiccoloN
2026-04-21 18:36:16 +02:00
parent dafc1d15b7
commit 0f13269040
4 changed files with 406 additions and 10 deletions

View File

@@ -47,6 +47,12 @@ llvm::cl::opt<long> coresCount("core-count",
llvm::cl::desc("Number of cores in the chip. `-1` to use the minimum amount of cores."), llvm::cl::desc("Number of cores in the chip. `-1` to use the minimum amount of cores."),
llvm::cl::init(-1)); llvm::cl::init(-1));
llvm::cl::opt<size_t> dcpCriticalWindowSize(
"dcp-critical-window-size",
llvm::cl::desc("Number of lowest-slack virtual nodes considered by each DCP coarsening iteration. "
"Use 0 to run the legacy full-graph DCP analysis."),
llvm::cl::init(1024));
llvm::cl::opt<bool> llvm::cl::opt<bool>
ignoreConcatError("ignore-concat-error", ignoreConcatError("ignore-concat-error",
llvm::cl::desc("Ignore ConcatOp corner case: do not assert and do a simplification"), llvm::cl::desc("Ignore ConcatOp corner case: do not assert and do a simplification"),

View File

@@ -29,6 +29,7 @@ extern llvm::cl::opt<bool> useExperimentalConvImpl;
extern llvm::cl::opt<size_t> crossbarSize; extern llvm::cl::opt<size_t> crossbarSize;
extern llvm::cl::opt<size_t> crossbarCountInCore; extern llvm::cl::opt<size_t> crossbarCountInCore;
extern llvm::cl::opt<long> coresCount; extern llvm::cl::opt<long> coresCount;
extern llvm::cl::opt<size_t> dcpCriticalWindowSize;
// This option, by default set to false, will ignore an error when resolving a // This option, by default set to false, will ignore an error when resolving a
// specific tiles of the operands of a concat. This specific case is when the // specific tiles of the operands of a concat. This specific case is when the

View File

@@ -6,10 +6,18 @@
#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Casting.h" #include "llvm/Support/Casting.h"
#include <algorithm>
#include <iterator> #include <iterator>
#include <map>
#include <numeric>
#include <optional>
#include <set>
#include <utility>
#include <vector>
#include "DCPAnalysis.hpp" #include "DCPAnalysis.hpp"
#include "Graph.hpp" #include "Graph.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Support/TypeUtilities.hpp" #include "src/Support/TypeUtilities.hpp"
namespace onnx_mlir { namespace onnx_mlir {
@@ -17,6 +25,361 @@ namespace spatial {
using namespace mlir; using namespace mlir;
namespace {
struct VirtualNode {
llvm::SmallVector<size_t, 4> originalComputeIndices;
Weight weight = 0;
CrossbarUsage crossbarUsage = 0;
};
struct VirtualGraph {
std::vector<VirtualNode> nodes;
std::vector<IndexedEdge> edges;
};
struct TimingInfo {
std::vector<Time> aest;
std::vector<Time> alst;
std::vector<size_t> topologicalOrder;
bool valid = false;
};
struct WindowScheduleResult {
std::vector<std::vector<size_t>> mergeGroups;
bool usedAllAvailableCpus = false;
};
std::vector<IndexedEdge> aggregateEdges(llvm::ArrayRef<IndexedEdge> edges) {
std::map<std::pair<size_t, size_t>, Weight> edgeWeights;
for (auto [start, end, weight] : edges) {
size_t startIndex = static_cast<size_t>(start);
size_t endIndex = static_cast<size_t>(end);
if (startIndex == endIndex)
continue;
auto key = std::make_pair(startIndex, endIndex);
Weight edgeWeight = static_cast<Weight>(weight);
auto it = edgeWeights.find(key);
if (it == edgeWeights.end())
edgeWeights.insert({key, edgeWeight});
else
it->second = std::max(it->second, edgeWeight);
}
std::vector<IndexedEdge> aggregatedEdges;
aggregatedEdges.reserve(edgeWeights.size());
for (auto [key, weight] : edgeWeights)
aggregatedEdges.push_back(
{static_cast<int64_t>(key.first), static_cast<int64_t>(key.second), static_cast<int64_t>(weight)});
return aggregatedEdges;
}
VirtualGraph buildInitialVirtualGraph(llvm::ArrayRef<SpatWeightedCompute> spatWeightedComputes,
llvm::ArrayRef<IndexedEdge> edges) {
VirtualGraph graph;
graph.nodes.reserve(spatWeightedComputes.size());
for (auto [index, spatWeightedCompute] : llvm::enumerate(spatWeightedComputes)) {
VirtualNode node;
node.originalComputeIndices.push_back(index);
node.weight = getSpatComputeWeight(spatWeightedCompute);
node.crossbarUsage = getSpatComputeCrossbarUsage(spatWeightedCompute);
graph.nodes.push_back(std::move(node));
}
graph.edges = aggregateEdges(edges);
return graph;
}
TimingInfo computeTiming(const VirtualGraph& graph) {
TimingInfo timing;
size_t nodeCount = graph.nodes.size();
timing.aest.assign(nodeCount, 0);
timing.alst.assign(nodeCount, 0);
timing.topologicalOrder.reserve(nodeCount);
std::vector<std::vector<std::pair<size_t, Weight>>> parents(nodeCount);
std::vector<std::vector<std::pair<size_t, Weight>>> children(nodeCount);
std::vector<size_t> incomingEdgeCount(nodeCount, 0);
for (auto [start, end, weight] : graph.edges) {
size_t startIndex = static_cast<size_t>(start);
size_t endIndex = static_cast<size_t>(end);
Weight edgeWeight = static_cast<Weight>(weight);
assert(startIndex < nodeCount && endIndex < nodeCount && "virtual edge endpoint out of range");
children[startIndex].push_back({endIndex, edgeWeight});
parents[endIndex].push_back({startIndex, edgeWeight});
incomingEdgeCount[endIndex]++;
}
std::vector<size_t> readyNodes;
readyNodes.reserve(nodeCount);
for (size_t i = 0; i < nodeCount; ++i)
if (incomingEdgeCount[i] == 0)
readyNodes.push_back(i);
size_t readyIndex = 0;
while (readyIndex != readyNodes.size()) {
size_t current = readyNodes[readyIndex++];
timing.topologicalOrder.push_back(current);
for (auto [child, weight] : children[current]) {
(void) weight;
assert(incomingEdgeCount[child] > 0 && "incoming edge count underflow");
incomingEdgeCount[child]--;
if (incomingEdgeCount[child] == 0)
readyNodes.push_back(child);
}
}
if (timing.topologicalOrder.size() != nodeCount)
return timing;
Time dcpl = 0;
for (size_t nodeIndex : timing.topologicalOrder) {
Time maxParentAest = 0;
for (auto [parent, transferCost] : parents[nodeIndex]) {
maxParentAest =
std::max(maxParentAest, addOrMax(addOrMax(timing.aest[parent], graph.nodes[parent].weight), transferCost));
}
timing.aest[nodeIndex] = maxParentAest;
dcpl = std::max(dcpl, addOrMax(maxParentAest, graph.nodes[nodeIndex].weight));
}
for (size_t nodeIndex : llvm::reverse(timing.topologicalOrder)) {
Time minAlst = std::numeric_limits<Time>::max();
if (children[nodeIndex].empty())
minAlst = subtractOrZero(dcpl, graph.nodes[nodeIndex].weight);
for (auto [child, transferCost] : children[nodeIndex]) {
minAlst =
std::min(minAlst, subtractOrZero(timing.alst[child], addOrMax(graph.nodes[nodeIndex].weight, transferCost)));
}
timing.alst[nodeIndex] = minAlst;
}
timing.valid = true;
return timing;
}
std::vector<size_t> selectCriticalWindow(const TimingInfo& timing, size_t windowSize) {
std::vector<size_t> selected(timing.aest.size());
std::iota(selected.begin(), selected.end(), 0);
std::stable_sort(selected.begin(), selected.end(), [&](size_t lhs, size_t rhs) {
Time lhsSlack = slackOrZero(timing.aest[lhs], timing.alst[lhs]);
Time rhsSlack = slackOrZero(timing.aest[rhs], timing.alst[rhs]);
if (lhsSlack != rhsSlack)
return lhsSlack < rhsSlack;
if (timing.aest[lhs] != timing.aest[rhs])
return timing.aest[lhs] < timing.aest[rhs];
return lhs < rhs;
});
selected.resize(std::min(windowSize, selected.size()));
return selected;
}
std::vector<size_t> getOriginalSignature(const VirtualGraph& graph, llvm::ArrayRef<size_t> selectedNodes) {
std::vector<size_t> signature;
for (size_t nodeIndex : selectedNodes) {
const VirtualNode& node = graph.nodes[nodeIndex];
signature.insert(signature.end(), node.originalComputeIndices.begin(), node.originalComputeIndices.end());
}
std::sort(signature.begin(), signature.end());
return signature;
}
std::vector<IndexedEdge> buildWindowEdges(const VirtualGraph& graph, const std::vector<int64_t>& nodeToWindowIndex) {
std::vector<IndexedEdge> windowEdges;
windowEdges.reserve(graph.edges.size());
for (auto [start, end, weight] : graph.edges) {
int64_t mappedStart = nodeToWindowIndex[static_cast<size_t>(start)];
int64_t mappedEnd = nodeToWindowIndex[static_cast<size_t>(end)];
if (mappedStart == -1 || mappedEnd == -1)
continue;
windowEdges.push_back({mappedStart, mappedEnd, weight});
}
return aggregateEdges(windowEdges);
}
WindowScheduleResult
scheduleWindow(const VirtualGraph& graph, llvm::ArrayRef<size_t> selectedNodes, MLIRContext* context) {
std::vector<Weight> windowWeights;
std::vector<CrossbarUsage> windowCrossbarUsage;
std::vector<int64_t> nodeToWindowIndex(graph.nodes.size(), -1);
windowWeights.reserve(selectedNodes.size());
windowCrossbarUsage.reserve(selectedNodes.size());
for (auto [windowIndex, nodeIndex] : llvm::enumerate(selectedNodes)) {
nodeToWindowIndex[nodeIndex] = static_cast<int64_t>(windowIndex);
windowWeights.push_back(graph.nodes[nodeIndex].weight);
windowCrossbarUsage.push_back(graph.nodes[nodeIndex].crossbarUsage);
}
GraphDCP windowGraph(windowWeights, buildWindowEdges(graph, nodeToWindowIndex), windowCrossbarUsage);
if (coresCount.getValue() > 0)
windowGraph.setMaxCpuCount(static_cast<int>(coresCount.getValue()));
windowGraph.setContext(context);
windowGraph.runDcp();
WindowScheduleResult result;
result.usedAllAvailableCpus = windowGraph.cpuCount() >= windowGraph.getMaxCpuCount();
for (CPU cpu = 0; cpu < windowGraph.cpuCount(); ++cpu) {
auto scheduledTasks = windowGraph.getScheduledTasks(cpu);
if (scheduledTasks.size() < 2)
continue;
std::vector<size_t> mergeGroup;
mergeGroup.reserve(scheduledTasks.size());
for (const auto& task : scheduledTasks)
mergeGroup.push_back(selectedNodes[task.nodeIndex]);
std::sort(mergeGroup.begin(), mergeGroup.end());
result.mergeGroups.push_back(std::move(mergeGroup));
}
return result;
}
bool coarsenGraph(const VirtualGraph& graph,
llvm::ArrayRef<std::vector<size_t>> mergeGroups,
VirtualGraph& coarsenedGraph) {
std::vector<int64_t> nodeToMergeGroup(graph.nodes.size(), -1);
for (auto [groupIndex, mergeGroup] : llvm::enumerate(mergeGroups)) {
if (mergeGroup.size() < 2)
continue;
for (size_t nodeIndex : mergeGroup) {
assert(nodeIndex < graph.nodes.size() && "merge group node out of range");
nodeToMergeGroup[nodeIndex] = static_cast<int64_t>(groupIndex);
}
}
std::vector<std::optional<size_t>> mergeGroupToNewNode(mergeGroups.size());
std::vector<size_t> oldToNewNode(graph.nodes.size(), 0);
bool mergedAny = false;
coarsenedGraph.nodes.clear();
coarsenedGraph.edges.clear();
coarsenedGraph.nodes.reserve(graph.nodes.size());
for (size_t nodeIndex = 0; nodeIndex < graph.nodes.size(); ++nodeIndex) {
int64_t mergeGroupIndex = nodeToMergeGroup[nodeIndex];
if (mergeGroupIndex == -1) {
oldToNewNode[nodeIndex] = coarsenedGraph.nodes.size();
coarsenedGraph.nodes.push_back(graph.nodes[nodeIndex]);
continue;
}
auto& newNodeIndex = mergeGroupToNewNode[static_cast<size_t>(mergeGroupIndex)];
if (newNodeIndex.has_value()) {
oldToNewNode[nodeIndex] = *newNodeIndex;
continue;
}
VirtualNode mergedNode;
for (size_t memberIndex : mergeGroups[static_cast<size_t>(mergeGroupIndex)]) {
const VirtualNode& memberNode = graph.nodes[memberIndex];
mergedNode.originalComputeIndices.append(memberNode.originalComputeIndices.begin(),
memberNode.originalComputeIndices.end());
mergedNode.weight = addOrMax(mergedNode.weight, memberNode.weight);
mergedNode.crossbarUsage = addOrMax(mergedNode.crossbarUsage, memberNode.crossbarUsage);
}
std::sort(mergedNode.originalComputeIndices.begin(), mergedNode.originalComputeIndices.end());
mergedAny = true;
newNodeIndex = coarsenedGraph.nodes.size();
for (size_t memberIndex : mergeGroups[static_cast<size_t>(mergeGroupIndex)])
oldToNewNode[memberIndex] = *newNodeIndex;
coarsenedGraph.nodes.push_back(std::move(mergedNode));
}
if (!mergedAny)
return false;
std::vector<IndexedEdge> remappedEdges;
remappedEdges.reserve(graph.edges.size());
for (auto [start, end, weight] : graph.edges) {
size_t newStart = oldToNewNode[static_cast<size_t>(start)];
size_t newEnd = oldToNewNode[static_cast<size_t>(end)];
if (newStart == newEnd)
continue;
remappedEdges.push_back({static_cast<int64_t>(newStart), static_cast<int64_t>(newEnd), weight});
}
coarsenedGraph.edges = aggregateEdges(remappedEdges);
return computeTiming(coarsenedGraph).valid;
}
bool coarsenGraphWithFallback(const VirtualGraph& graph,
llvm::ArrayRef<std::vector<size_t>> mergeGroups,
VirtualGraph& coarsenedGraph) {
if (coarsenGraph(graph, mergeGroups, coarsenedGraph))
return true;
std::vector<size_t> orderedGroupIndices(mergeGroups.size());
std::iota(orderedGroupIndices.begin(), orderedGroupIndices.end(), 0);
std::stable_sort(orderedGroupIndices.begin(), orderedGroupIndices.end(), [&](size_t lhs, size_t rhs) {
return mergeGroups[lhs].size() > mergeGroups[rhs].size();
});
std::vector<std::vector<size_t>> acceptedMergeGroups;
acceptedMergeGroups.reserve(mergeGroups.size());
for (size_t groupIndex : orderedGroupIndices) {
std::vector<std::vector<size_t>> candidateMergeGroups = acceptedMergeGroups;
candidateMergeGroups.push_back(mergeGroups[groupIndex]);
VirtualGraph candidateGraph;
if (!coarsenGraph(graph, candidateMergeGroups, candidateGraph))
continue;
acceptedMergeGroups = std::move(candidateMergeGroups);
coarsenedGraph = std::move(candidateGraph);
}
return !acceptedMergeGroups.empty();
}
std::vector<size_t> computeOriginalTopologicalOrder(size_t computeCount, llvm::ArrayRef<IndexedEdge> edges) {
VirtualGraph graph;
graph.nodes.resize(computeCount);
graph.edges = aggregateEdges(edges);
TimingInfo timing = computeTiming(graph);
if (timing.valid)
return timing.topologicalOrder;
std::vector<size_t> fallbackOrder(computeCount);
std::iota(fallbackOrder.begin(), fallbackOrder.end(), 0);
return fallbackOrder;
}
DCPAnalysisResult buildResultFromVirtualGraph(const VirtualGraph& graph,
llvm::ArrayRef<SpatWeightedCompute> spatWeightedComputes,
llvm::ArrayRef<IndexedEdge> originalEdges) {
DCPAnalysisResult result;
std::vector<size_t> originalToVirtualNode(spatWeightedComputes.size(), 0);
for (auto [virtualNodeIndex, virtualNode] : llvm::enumerate(graph.nodes))
for (size_t originalIndex : virtualNode.originalComputeIndices)
originalToVirtualNode[originalIndex] = virtualNodeIndex;
auto dominanceOrder = computeOriginalTopologicalOrder(spatWeightedComputes.size(), originalEdges);
result.dominanceOrderCompute.reserve(dominanceOrder.size());
for (size_t originalIndex : dominanceOrder) {
SpatWeightedCompute spatWeightedCompute = spatWeightedComputes[originalIndex];
size_t cpu = originalToVirtualNode[originalIndex];
result.dominanceOrderCompute.push_back(spatWeightedCompute);
result.computeToCpuMap[spatWeightedCompute] = cpu;
result.cpuToLastComputeMap[cpu] = spatWeightedCompute;
}
for (auto [cpu, lastCompute] : result.cpuToLastComputeMap)
result.isLastComputeOfCpu.insert(lastCompute);
return result;
}
DCPAnalysisResult runLegacyDcp(llvm::ArrayRef<SpatWeightedCompute> spatWeightedComputes,
llvm::ArrayRef<IndexedEdge> edges,
MLIRContext* context) {
GraphDCP graphDCP(spatWeightedComputes, edges);
if (coresCount.getValue() > 0)
graphDCP.setMaxCpuCount(static_cast<int>(coresCount.getValue()));
graphDCP.setContext(context);
graphDCP.runDcp();
return graphDCP.getResult();
}
} // namespace
SpatWeightedCompute getOriginalSpatWeightedCompute(Operation* op) { SpatWeightedCompute getOriginalSpatWeightedCompute(Operation* op) {
if (!op) if (!op)
return {}; return {};
@@ -31,8 +394,8 @@ SpatWeightedCompute getOriginalSpatWeightedCompute(Operation* op) {
} }
DCPAnalysisResult DCPAnalysis::run() { DCPAnalysisResult DCPAnalysis::run() {
llvm::SmallVector<SpatWeightedCompute, 10> spatWeightedComputes; SmallVector<SpatWeightedCompute, 10> spatWeightedComputes;
llvm::SmallVector<IndexedEdge, 10> edges; SmallVector<IndexedEdge, 10> edges;
for (auto& region : entryOp->getRegions()) for (auto& region : entryOp->getRegions())
for (SpatWeightedCompute spatWeightedCompute : region.getOps<SpatWeightedCompute>()) for (SpatWeightedCompute spatWeightedCompute : region.getOps<SpatWeightedCompute>())
spatWeightedComputes.push_back(spatWeightedCompute); spatWeightedComputes.push_back(spatWeightedCompute);
@@ -53,10 +416,37 @@ DCPAnalysisResult DCPAnalysis::run() {
} }
} }
} }
GraphDCP graphDCP(spatWeightedComputes, edges);
graphDCP.setContext(entryOp->getContext()); if (dcpCriticalWindowSize.getValue() == 0)
graphDCP.runDcp(); return runLegacyDcp(spatWeightedComputes, edges, entryOp->getContext());
return graphDCP.getResult();
VirtualGraph virtualGraph = buildInitialVirtualGraph(spatWeightedComputes, edges);
std::set<std::vector<size_t>> seenCriticalWindows;
while (virtualGraph.nodes.size() > 1) {
TimingInfo timing = computeTiming(virtualGraph);
if (!timing.valid)
break;
auto selectedNodes = selectCriticalWindow(timing, dcpCriticalWindowSize.getValue());
if (selectedNodes.size() < 2)
break;
if (!seenCriticalWindows.insert(getOriginalSignature(virtualGraph, selectedNodes)).second)
break;
WindowScheduleResult windowSchedule = scheduleWindow(virtualGraph, selectedNodes, entryOp->getContext());
if (windowSchedule.mergeGroups.empty())
break;
VirtualGraph coarsenedGraph;
if (!coarsenGraphWithFallback(virtualGraph, windowSchedule.mergeGroups, coarsenedGraph))
break;
virtualGraph = std::move(coarsenedGraph);
if (windowSchedule.usedAllAvailableCpus)
break;
}
return buildResultFromVirtualGraph(virtualGraph, spatWeightedComputes, edges);
} }
} // namespace spatial } // namespace spatial

View File

@@ -6,7 +6,7 @@
// consumer land on different CPUs. // consumer land on different CPUs.
// //
// Output: an assignment of every task to a CPU and an order within that CPU, // Output: an assignment of every task to a CPU and an order within that CPU,
// aiming to minimise the overall critical-path length (DCPL). // aiming to minimize the overall critical-path length (DCPL).
// //
// Every task keeps two timing estimates: // Every task keeps two timing estimates:
// AEST - earliest start time, driven by parent completions + transfers. // AEST - earliest start time, driven by parent completions + transfers.
@@ -16,9 +16,9 @@
// Main loop (runDcp): // Main loop (runDcp):
// 1. Build a topological order and seed AEST/ALST from the unscheduled DAG. // 1. Build a topological order and seed AEST/ALST from the unscheduled DAG.
// 2. While there are ready tasks (all dependency parents scheduled): // 2. While there are ready tasks (all dependency parents scheduled):
// a. Pick the candidate with tightest slack (earliest AEST breaks ties). // a. Pick the candidate with the tightest slack (earliest AEST breaks ties).
// b. selectProcessor() tries every candidate CPU and picks the one that // b. selectProcessor() tries every candidate CPU and picks the one that
// minimises a composite cost (own slot + smallest unscheduled child). // minimizes a composite cost (own slot + the smallest unscheduled child).
// c. Commit the placement and refresh AEST/ALST. // c. Commit the placement and refresh AEST/ALST.
// d. Release any child whose dependency parents are now all scheduled. // d. Release any child whose dependency parents are now all scheduled.
// //
@@ -43,7 +43,6 @@
#include <cassert> #include <cassert>
#include <chrono> #include <chrono>
#include <cstdio> #include <cstdio>
#include <cstdlib>
#include <vector> #include <vector>
#include "DCPAnalysis.hpp" #include "DCPAnalysis.hpp"