This commit is contained in:
@@ -0,0 +1,402 @@
|
|||||||
|
#include <iostream>
|
||||||
|
#include <vector>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <queue>
|
||||||
|
#include <random>
|
||||||
|
#include <chrono>
|
||||||
|
#include <omp.h>
|
||||||
|
#include <numeric>
|
||||||
|
#include <fstream>
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
|
||||||
|
// ==========================================
|
||||||
|
// 1. DATA STRUCTURES
|
||||||
|
// ==========================================
|
||||||
|
|
||||||
|
// High-performance DAG representation using adjacency lists
|
||||||
|
struct DAG {
|
||||||
|
int num_nodes;
|
||||||
|
int num_processors;
|
||||||
|
|
||||||
|
// Computation cost matrix (Flattened: num_nodes x num_processors)
|
||||||
|
vector<float> comp_costs;
|
||||||
|
|
||||||
|
// Edges: node -> vector of pair<target_node, communication_cost>
|
||||||
|
vector<vector<pair<int, float>>> successors;
|
||||||
|
vector<vector<pair<int, float>>> predecessors;
|
||||||
|
|
||||||
|
DAG(int n, int p) : num_nodes(n), num_processors(p),
|
||||||
|
comp_costs(n * p, 0.0f),
|
||||||
|
successors(n), predecessors(n) {}
|
||||||
|
|
||||||
|
inline float get_comp_cost(int task, int proc) const {
|
||||||
|
return comp_costs[task * num_processors + proc];
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Struct to hold the final schedule for each task
|
||||||
|
struct TaskSchedule {
|
||||||
|
int processor;
|
||||||
|
float start_time;
|
||||||
|
float end_time;
|
||||||
|
};
|
||||||
|
|
||||||
|
// ==========================================
|
||||||
|
// 2. DAG GENERATOR (REAL-WORLD SKEW)
|
||||||
|
// ==========================================
|
||||||
|
DAG generate_dag(int num_nodes, int num_processors, int levels, float ccr) {
|
||||||
|
DAG dag(num_nodes, num_processors);
|
||||||
|
mt19937 gen(42); // Fixed seed for reproducibility
|
||||||
|
uniform_real_distribution<float> comp_dist(10.0f, 100.0f);
|
||||||
|
|
||||||
|
for (int i = 0; i < num_nodes * num_processors; ++i) {
|
||||||
|
dag.comp_costs[i] = comp_dist(gen);
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<vector<int>> nodes_per_level(levels);
|
||||||
|
uniform_int_distribution<int> lvl_dist(0, levels - 1);
|
||||||
|
|
||||||
|
for (int i = 0; i < num_nodes; ++i) {
|
||||||
|
if (i == 0) nodes_per_level[0].push_back(i);
|
||||||
|
else if (i == num_nodes - 1) nodes_per_level[levels - 1].push_back(i);
|
||||||
|
else nodes_per_level[lvl_dist(gen)].push_back(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
float avg_comp = 55.0f;
|
||||||
|
uniform_real_distribution<float> comm_dist(avg_comp * ccr * 0.5f, avg_comp * ccr * 1.5f);
|
||||||
|
uniform_real_distribution<float> prob(0.0, 1.0);
|
||||||
|
|
||||||
|
long long total_edges = 0;
|
||||||
|
|
||||||
|
for (int l = 0; l < levels - 1; ++l) {
|
||||||
|
if (nodes_per_level[l].empty()) continue;
|
||||||
|
|
||||||
|
for (int u : nodes_per_level[l]) {
|
||||||
|
int target_level = l + 1;
|
||||||
|
while (target_level < levels && nodes_per_level[target_level].empty()) target_level++;
|
||||||
|
|
||||||
|
if (target_level < levels) {
|
||||||
|
// 1. BASE DEPENDENCY: Ensure the graph flows forward (no disconnected nodes)
|
||||||
|
int v = nodes_per_level[target_level][gen() % nodes_per_level[target_level].size()];
|
||||||
|
float comm = comm_dist(gen);
|
||||||
|
dag.successors[u].push_back({v, comm});
|
||||||
|
dag.predecessors[v].push_back({u, comm});
|
||||||
|
total_edges++;
|
||||||
|
|
||||||
|
// ----------------------------------------------------
|
||||||
|
// 2. THE REAL-WORLD SKEW LOGIC (Hubs vs Normal Nodes)
|
||||||
|
// ----------------------------------------------------
|
||||||
|
|
||||||
|
// Make 0.5% of nodes act as "Super Hubs" (Broadcast nodes)
|
||||||
|
bool is_super_hub = (prob(gen) < 0.005);
|
||||||
|
int extra_edges = 0;
|
||||||
|
|
||||||
|
if (is_super_hub) {
|
||||||
|
// This node is a Hub! Give it 20,000 children!
|
||||||
|
// (Or as many as the remaining graph size permits)
|
||||||
|
extra_edges = 2000;
|
||||||
|
} else {
|
||||||
|
// Normal node: 20% chance to just have 1 to 3 extra children
|
||||||
|
if (prob(gen) < 0.2) {
|
||||||
|
uniform_int_distribution<int> normal_dist(1, 3);
|
||||||
|
extra_edges = normal_dist(gen);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Randomly connect these extra edges to ANY node in ANY future level
|
||||||
|
for (int e = 0; e < extra_edges; ++e) {
|
||||||
|
uniform_int_distribution<int> future_lvl_dist(target_level, levels - 1);
|
||||||
|
int f_lvl = future_lvl_dist(gen);
|
||||||
|
if (nodes_per_level[f_lvl].empty()) continue;
|
||||||
|
|
||||||
|
int child_v = nodes_per_level[f_lvl][gen() % nodes_per_level[f_lvl].size()];
|
||||||
|
|
||||||
|
float extra_comm = comm_dist(gen);
|
||||||
|
dag.successors[u].push_back({child_v, extra_comm});
|
||||||
|
dag.predecessors[child_v].push_back({u, extra_comm});
|
||||||
|
total_edges++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cout << " [Generator] Created " << total_edges << " total edges (simulating Hubs).\n";
|
||||||
|
return dag;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// ==========================================
|
||||||
|
// 3. PEFT SCHEDULER (O(E*P) Optimized)
|
||||||
|
// ==========================================
|
||||||
|
void run_peft(const DAG& dag, vector<TaskSchedule>& final_schedule) {
|
||||||
|
int N = dag.num_nodes;
|
||||||
|
int P = dag.num_processors;
|
||||||
|
|
||||||
|
// 1. Level sorting for parallel Bottom-Up OCT computation
|
||||||
|
vector<int> out_degree(N, 0);
|
||||||
|
for(int i=0; i<N; ++i) out_degree[i] = dag.successors[i].size();
|
||||||
|
|
||||||
|
vector<vector<int>> reverse_levels;
|
||||||
|
queue<int> q;
|
||||||
|
for(int i=0; i<N; ++i) if(out_degree[i] == 0) q.push(i);
|
||||||
|
|
||||||
|
while(!q.empty()) {
|
||||||
|
int size = q.size();
|
||||||
|
vector<int> current_level;
|
||||||
|
for(int i=0; i<size; ++i) {
|
||||||
|
int u = q.front(); q.pop();
|
||||||
|
current_level.push_back(u);
|
||||||
|
for(auto& edge : dag.predecessors[u]) {
|
||||||
|
int p_node = edge.first;
|
||||||
|
if(--out_degree[p_node] == 0) q.push(p_node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
reverse_levels.push_back(current_level);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// 2. Compute OCT (Optimistic Cost Table)
|
||||||
|
vector<float> oct(N * P, 0.0f);
|
||||||
|
|
||||||
|
// NEW: Cache array to drop complexity to O(E * P)
|
||||||
|
vector<float> min_oct_comp(N, 0.0f);
|
||||||
|
|
||||||
|
for (const auto& level_nodes : reverse_levels) {
|
||||||
|
#pragma omp parallel for schedule(dynamic)
|
||||||
|
for (int idx = 0; idx < level_nodes.size(); ++idx) {
|
||||||
|
int task = level_nodes[idx];
|
||||||
|
|
||||||
|
vector<float> max_vals(P, 0.0f);
|
||||||
|
|
||||||
|
// Loop over successors first (O(E * P) total instead of O(E * P^2))
|
||||||
|
for (auto& edge : dag.successors[task]) {
|
||||||
|
int succ = edge.first;
|
||||||
|
float comm_cost = edge.second;
|
||||||
|
|
||||||
|
// The precalculated global minimum for this successor
|
||||||
|
float val_diff = min_oct_comp[succ] + comm_cost;
|
||||||
|
|
||||||
|
// Cache-friendly, auto-vectorizable O(P) loop
|
||||||
|
for (int p_j = 0; p_j < P; ++p_j) {
|
||||||
|
float val_same = oct[succ * P + p_j] + dag.get_comp_cost(succ, p_j);
|
||||||
|
float min_w = min(val_same, val_diff);
|
||||||
|
max_vals[p_j] = max(max_vals[p_j], min_w);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assign to OCT and precalculate the minimum for the predecessors
|
||||||
|
float task_min_val = 1e9f;
|
||||||
|
for (int p_j = 0; p_j < P; ++p_j) {
|
||||||
|
oct[task * P + p_j] = max_vals[p_j];
|
||||||
|
task_min_val = min(task_min_val, max_vals[p_j] + dag.get_comp_cost(task, p_j));
|
||||||
|
}
|
||||||
|
min_oct_comp[task] = task_min_val;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Compute Rank_OCT and sort tasks (Phase 1)
|
||||||
|
vector<pair<float, int>> rank_oct(N);
|
||||||
|
#pragma omp parallel for
|
||||||
|
for (int i = 0; i < N; ++i) {
|
||||||
|
float avg_oct = 0;
|
||||||
|
for (int p = 0; p < P; ++p) avg_oct += oct[i * P + p];
|
||||||
|
rank_oct[i] = {avg_oct / P, i};
|
||||||
|
}
|
||||||
|
|
||||||
|
sort(rank_oct.rbegin(), rank_oct.rend());
|
||||||
|
|
||||||
|
// 4. Processor Assignment (Phase 2)
|
||||||
|
final_schedule.resize(N);
|
||||||
|
vector<float> avail(P, 0.0f);
|
||||||
|
|
||||||
|
for (int i = 0; i < N; ++i) {
|
||||||
|
int task = rank_oct[i].second;
|
||||||
|
|
||||||
|
int best_p = -1;
|
||||||
|
float min_o_eft = 1e9f;
|
||||||
|
float best_est = 0.0f;
|
||||||
|
float best_eft = 0.0f;
|
||||||
|
|
||||||
|
for (int p = 0; p < P; ++p) {
|
||||||
|
float data_ready_time = 0.0f;
|
||||||
|
for (auto& pred_edge : dag.predecessors[task]) {
|
||||||
|
int pred = pred_edge.first;
|
||||||
|
float comm = pred_edge.second;
|
||||||
|
int pred_p = final_schedule[pred].processor;
|
||||||
|
float comm_penalty = (pred_p == p) ? 0.0f : comm;
|
||||||
|
data_ready_time = max(data_ready_time, final_schedule[pred].end_time + comm_penalty);
|
||||||
|
}
|
||||||
|
|
||||||
|
float est = max(avail[p], data_ready_time);
|
||||||
|
float eft = est + dag.get_comp_cost(task, p);
|
||||||
|
float o_eft = eft + oct[task * P + p];
|
||||||
|
|
||||||
|
if (o_eft < min_o_eft) {
|
||||||
|
min_o_eft = o_eft;
|
||||||
|
best_p = p;
|
||||||
|
best_est = est;
|
||||||
|
best_eft = eft;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
final_schedule[task] = {best_p, best_est, best_eft};
|
||||||
|
avail[best_p] = best_eft;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==========================================
|
||||||
|
// 4. VISUALIZATION EXPORTERS (DOT)
|
||||||
|
// ==========================================
|
||||||
|
|
||||||
|
void export_dag_to_dot(const DAG& dag, const string& filename) {
|
||||||
|
ofstream out(filename);
|
||||||
|
out << "digraph RawDAG {\n";
|
||||||
|
out << " rankdir=TB;\n";
|
||||||
|
out << " node [shape=record, style=filled, fillcolor=lightgrey, fontname=\"Helvetica\"];\n";
|
||||||
|
out << " edge [fontname=\"Helvetica\", fontsize=10];\n\n";
|
||||||
|
|
||||||
|
for (int i = 0; i < dag.num_nodes; ++i) {
|
||||||
|
out << " Task_" << i << " [label=\"Task " << i << "\"];\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
out << "\n";
|
||||||
|
for (int i = 0; i < dag.num_nodes; ++i) {
|
||||||
|
for (const auto& edge : dag.successors[i]) {
|
||||||
|
out << " Task_" << i << " -> Task_" << edge.first
|
||||||
|
<< " [label=\"" << (int)edge.second << "\"];\n";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
out << "}\n";
|
||||||
|
out.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
void export_schedule_to_dot(const DAG& dag, const vector<TaskSchedule>& schedule, const string& filename) {
|
||||||
|
ofstream out(filename);
|
||||||
|
out << "digraph ScheduledDAG {\n";
|
||||||
|
out << " rankdir=TB;\n";
|
||||||
|
// FIX 1: Change shape to standard 'box' to prevent the flat edge warning
|
||||||
|
out << " node [shape=box, fontname=\"Helvetica\", style=filled, fillcolor=white, rounded=true];\n";
|
||||||
|
out << " edge [fontname=\"Helvetica\", fontsize=10];\n\n";
|
||||||
|
|
||||||
|
vector<vector<int>> proc_tasks(dag.num_processors);
|
||||||
|
for (int i = 0; i < dag.num_nodes; ++i) {
|
||||||
|
proc_tasks[schedule[i].processor].push_back(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int p = 0; p < dag.num_processors; ++p) {
|
||||||
|
if (proc_tasks[p].empty()) continue;
|
||||||
|
|
||||||
|
sort(proc_tasks[p].begin(), proc_tasks[p].end(), [&schedule](int a, int b) {
|
||||||
|
return schedule[a].start_time < schedule[b].start_time;
|
||||||
|
});
|
||||||
|
|
||||||
|
out << " subgraph cluster_P" << p << " {\n";
|
||||||
|
out << " label=\"Processor " << p << "\";\n";
|
||||||
|
out << " fontname=\"Helvetica-Bold\";\n";
|
||||||
|
out << " style=rounded;\n";
|
||||||
|
out << " bgcolor=\"#f0f8ff\";\n";
|
||||||
|
out << " color=blue;\n\n";
|
||||||
|
|
||||||
|
for (int task : proc_tasks[p]) {
|
||||||
|
// FIX 1 cont: Use standard \n linebreaks instead of the record | syntax
|
||||||
|
out << " Task_" << task
|
||||||
|
<< " [label=\"Task " << task
|
||||||
|
<< "\\nStart: " << schedule[task].start_time
|
||||||
|
<< "\\nEnd: " << schedule[task].end_time << "\"];\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 0; i < proc_tasks[p].size() - 1; ++i) {
|
||||||
|
out << " Task_" << proc_tasks[p][i] << " -> Task_" << proc_tasks[p][i+1]
|
||||||
|
<< " [style=invis, weight=10];\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
out << " }\n\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < dag.num_nodes; ++i) {
|
||||||
|
for (const auto& edge : dag.successors[i]) {
|
||||||
|
int target = edge.first;
|
||||||
|
bool same_proc = (schedule[i].processor == schedule[target].processor);
|
||||||
|
|
||||||
|
// FIX 2: If on the same processor, add 'constraint=false'.
|
||||||
|
// The invisible edges already handle the layout, so don't let this edge confuse the solver.
|
||||||
|
out << " Task_" << i << " -> Task_" << target
|
||||||
|
<< " [label=\"" << (int)edge.second << "\""
|
||||||
|
<< (same_proc ? ", constraint=false]" : ", color=red, fontcolor=red, style=dashed]")
|
||||||
|
<< ";\n";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
out << "}\n";
|
||||||
|
out.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// ==========================================
|
||||||
|
// 5. MAIN
|
||||||
|
// ==========================================
|
||||||
|
int main() {
|
||||||
|
// Testing with a small graph to ensure DOT generation runs
|
||||||
|
int N = 30000;
|
||||||
|
int P = 1000;
|
||||||
|
|
||||||
|
cout << "Generating DAG with " << N << " nodes and " << P << " processors..." << endl;
|
||||||
|
auto start_gen = chrono::high_resolution_clock::now();
|
||||||
|
DAG dag = generate_dag(N, P, 300, 1.0f); // 10 levels for a small graph
|
||||||
|
auto end_gen = chrono::high_resolution_clock::now();
|
||||||
|
cout << "DAG Generation took: " << chrono::duration<double>(end_gen - start_gen).count() << " s\n";
|
||||||
|
|
||||||
|
vector<TaskSchedule> schedule;
|
||||||
|
|
||||||
|
cout << "Running PEFT Scheduling..." << endl;
|
||||||
|
auto start_sched = chrono::high_resolution_clock::now();
|
||||||
|
run_peft(dag, schedule);
|
||||||
|
auto end_sched = chrono::high_resolution_clock::now();
|
||||||
|
|
||||||
|
cout << "Scheduling took: " << chrono::duration<double>(end_sched - start_sched).count() << " s\n\n";
|
||||||
|
|
||||||
|
// ==========================================
|
||||||
|
// METRICS REPORT
|
||||||
|
// ==========================================
|
||||||
|
float makespan = 0.0f;
|
||||||
|
for (int i = 0; i < N; ++i) makespan = max(makespan, schedule[i].end_time);
|
||||||
|
|
||||||
|
// Calculate Sequential Makespan (If we ran all tasks on the single fastest processor)
|
||||||
|
float best_seq_makespan = 1e9f;
|
||||||
|
int best_seq_processor = -1;
|
||||||
|
for (int p = 0; p < P; ++p) {
|
||||||
|
float current_seq = 0.0f;
|
||||||
|
for (int i = 0; i < N; ++i) current_seq += dag.get_comp_cost(i, p);
|
||||||
|
if (current_seq < best_seq_makespan) {
|
||||||
|
best_seq_makespan = current_seq;
|
||||||
|
best_seq_processor = p;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
float time_gained = best_seq_makespan - makespan;
|
||||||
|
float speedup = best_seq_makespan / makespan;
|
||||||
|
|
||||||
|
cout << "--- METRICS REPORT ---\n";
|
||||||
|
cout << "Sequential Time (CPU " << best_seq_processor << "): " << best_seq_makespan << " units\n";
|
||||||
|
cout << "Parallel PEFT Makespan: " << makespan << " units\n";
|
||||||
|
cout << "Total Time Gained: " << time_gained << " units\n";
|
||||||
|
cout << "Overall Speedup: " << speedup << "x\n";
|
||||||
|
cout << "----------------------\n";
|
||||||
|
|
||||||
|
// Generate visualization only for small graphs
|
||||||
|
if (N <= 50) {
|
||||||
|
cout << "\nGraph size is small (N <= 50). Generating Graphviz DOT files...\n";
|
||||||
|
export_dag_to_dot(dag, "dag_raw.dot");
|
||||||
|
export_schedule_to_dot(dag, schedule, "dag_scheduled.dot");
|
||||||
|
|
||||||
|
cout << "Saved 'dag_raw.dot' and 'dag_scheduled.dot'.\n";
|
||||||
|
cout << "To render images, run the following commands in your terminal:\n";
|
||||||
|
cout << " dot -Tpng dag_raw.dot -o dag_raw.png\n";
|
||||||
|
cout << " dot -Tpng dag_scheduled.dot -o dag_scheduled.png\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
@@ -8,8 +8,11 @@ add_pim_library(SpatialOps
|
|||||||
SpatialOpsVerify.cpp
|
SpatialOpsVerify.cpp
|
||||||
SpatialOpsCanonicalization.cpp
|
SpatialOpsCanonicalization.cpp
|
||||||
Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp
|
Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp
|
||||||
|
Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp
|
||||||
|
Transforms/MergeComputeNodes/PostMergeCompaction.cpp
|
||||||
Transforms/MergeComputeNodes/RegularOpCompaction.cpp
|
Transforms/MergeComputeNodes/RegularOpCompaction.cpp
|
||||||
Transforms/MergeComputeNodes/Scheduling/ComputeGraph.cpp
|
Transforms/MergeComputeNodes/Scheduling/ComputeGraph.cpp
|
||||||
|
Transforms/MergeComputeNodes/Scheduling/ComputeInstanceUtils.cpp
|
||||||
Transforms/MergeComputeNodes/Scheduling/DcpScheduler.cpp
|
Transforms/MergeComputeNodes/Scheduling/DcpScheduler.cpp
|
||||||
Transforms/MergeComputeNodes/Scheduling/MergeSchedulingAnalysis.cpp
|
Transforms/MergeComputeNodes/Scheduling/MergeSchedulingAnalysis.cpp
|
||||||
Transforms/MergeComputeNodes/Scheduling/PeftScheduler.cpp
|
Transforms/MergeComputeNodes/Scheduling/PeftScheduler.cpp
|
||||||
|
|||||||
@@ -1,802 +1,19 @@
|
|||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
||||||
#include "mlir/IR/OpDefinition.h"
|
|
||||||
#include "mlir/IR/Value.h"
|
|
||||||
#include "mlir/IR/ValueRange.h"
|
|
||||||
|
|
||||||
#include "llvm/ADT/DenseMap.h"
|
|
||||||
#include "llvm/ADT/STLExtras.h"
|
|
||||||
#include "llvm/Support/Casting.h"
|
|
||||||
#include "llvm/Support/FormatVariadic.h"
|
|
||||||
#include "llvm/Support/raw_ostream.h"
|
|
||||||
|
|
||||||
#include <algorithm>
|
|
||||||
#include <cstdlib>
|
|
||||||
#include <iterator>
|
|
||||||
#include <numeric>
|
|
||||||
#include <optional>
|
|
||||||
#include <queue>
|
|
||||||
#include <utility>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "DCPAnalysis.hpp"
|
#include "DCPAnalysis.hpp"
|
||||||
#include "Graph.hpp"
|
#include "../Scheduling/ComputeGraph.hpp"
|
||||||
|
#include "../Scheduling/DcpScheduler.hpp"
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
#include "src/Support/TypeUtilities.hpp"
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
namespace spatial {
|
namespace spatial {
|
||||||
|
|
||||||
using namespace mlir;
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
using SpatCompute = onnx_mlir::spatial::SpatCompute;
|
|
||||||
using SpatComputeBatch = onnx_mlir::spatial::SpatComputeBatch;
|
|
||||||
|
|
||||||
bool isDcpCoarsenDebugEnabled() { return std::getenv("DCP_COARSEN_DEBUG") != nullptr; }
|
|
||||||
|
|
||||||
struct VirtualNode {
|
|
||||||
SmallVector<size_t, 4> originalComputeIndices;
|
|
||||||
Weight weight = 0;
|
|
||||||
CrossbarUsage crossbarUsage = 0;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct VirtualGraph {
|
|
||||||
std::vector<VirtualNode> nodes;
|
|
||||||
std::vector<IndexedEdge> edges;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct TimingInfo {
|
|
||||||
std::vector<Time> aest;
|
|
||||||
std::vector<Time> alst;
|
|
||||||
std::vector<size_t> topologicalOrder;
|
|
||||||
bool valid = false;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct WindowScheduleResult {
|
|
||||||
std::vector<std::vector<size_t>> mergeGroups;
|
|
||||||
CPU cpuCount = 0;
|
|
||||||
size_t mergedNodeCount = 0;
|
|
||||||
size_t maxMergeGroupSize = 0;
|
|
||||||
};
|
|
||||||
|
|
||||||
size_t getSchedulingCpuBudget() {
|
|
||||||
if (coresCount.getValue() > 0)
|
|
||||||
return static_cast<size_t>(coresCount.getValue());
|
|
||||||
return std::numeric_limits<size_t>::max();
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t getBatchChunkTargetCount(int32_t laneCount) {
|
|
||||||
assert(laneCount > 0 && "laneCount must be positive");
|
|
||||||
return std::min(static_cast<size_t>(laneCount), std::max<size_t>(1, getSchedulingCpuBudget()));
|
|
||||||
}
|
|
||||||
|
|
||||||
ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex) {
|
|
||||||
size_t totalLanes = static_cast<size_t>(batch.getLaneCount());
|
|
||||||
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount());
|
|
||||||
size_t baseChunkSize = totalLanes / chunkCount;
|
|
||||||
size_t largeChunkCount = totalLanes % chunkCount;
|
|
||||||
|
|
||||||
size_t laneStart = chunkIndex * baseChunkSize + std::min(chunkIndex, largeChunkCount);
|
|
||||||
size_t laneCount = baseChunkSize + (chunkIndex < largeChunkCount ? 1 : 0);
|
|
||||||
return {batch.getOperation(), static_cast<uint32_t>(laneStart), static_cast<uint32_t>(laneCount)};
|
|
||||||
}
|
|
||||||
|
|
||||||
ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane) {
|
|
||||||
size_t totalLanes = static_cast<size_t>(batch.getLaneCount());
|
|
||||||
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount());
|
|
||||||
size_t baseChunkSize = totalLanes / chunkCount;
|
|
||||||
size_t largeChunkCount = totalLanes % chunkCount;
|
|
||||||
size_t largeChunkSpan = largeChunkCount * (baseChunkSize + 1);
|
|
||||||
|
|
||||||
size_t chunkIndex = 0;
|
|
||||||
if (static_cast<size_t>(lane) < largeChunkSpan)
|
|
||||||
chunkIndex = static_cast<size_t>(lane) / (baseChunkSize + 1);
|
|
||||||
else
|
|
||||||
chunkIndex = largeChunkCount + (static_cast<size_t>(lane) - largeChunkSpan) / baseChunkSize;
|
|
||||||
return getBatchChunkForIndex(batch, chunkIndex);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<IndexedEdge> aggregateEdges(ArrayRef<IndexedEdge> edges) {
|
|
||||||
llvm::DenseMap<std::pair<size_t, size_t>, Weight> edgeWeights;
|
|
||||||
for (auto [start, end, weight] : edges) {
|
|
||||||
size_t startIndex = static_cast<size_t>(start);
|
|
||||||
size_t endIndex = static_cast<size_t>(end);
|
|
||||||
if (startIndex == endIndex)
|
|
||||||
continue;
|
|
||||||
auto key = std::make_pair(startIndex, endIndex);
|
|
||||||
Weight edgeWeight = static_cast<Weight>(weight);
|
|
||||||
auto inserted = edgeWeights.try_emplace(key, edgeWeight);
|
|
||||||
if (!inserted.second)
|
|
||||||
inserted.first->second = std::max(inserted.first->second, edgeWeight);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<IndexedEdge> aggregatedEdges;
|
|
||||||
aggregatedEdges.reserve(edgeWeights.size());
|
|
||||||
for (auto [key, weight] : edgeWeights)
|
|
||||||
aggregatedEdges.push_back(
|
|
||||||
{static_cast<int64_t>(key.first), static_cast<int64_t>(key.second), static_cast<int64_t>(weight)});
|
|
||||||
llvm::sort(aggregatedEdges, [](const IndexedEdge& lhs, const IndexedEdge& rhs) {
|
|
||||||
if (std::get<0>(lhs) != std::get<0>(rhs))
|
|
||||||
return std::get<0>(lhs) < std::get<0>(rhs);
|
|
||||||
return std::get<1>(lhs) < std::get<1>(rhs);
|
|
||||||
});
|
|
||||||
return aggregatedEdges;
|
|
||||||
}
|
|
||||||
|
|
||||||
Weight getComputeBodyWeight(Region& body) {
|
|
||||||
constexpr Weight kOperationWeight = 100;
|
|
||||||
Weight numOperations = 0;
|
|
||||||
for (auto& block : body)
|
|
||||||
for ([[maybe_unused]] auto& op : block)
|
|
||||||
numOperations = checkedAdd(numOperations, static_cast<Weight>(1));
|
|
||||||
return checkedMultiply(numOperations, kOperationWeight);
|
|
||||||
}
|
|
||||||
|
|
||||||
CrossbarUsage getComputeBodyCrossbarUsage(Region& body) {
|
|
||||||
CrossbarUsage crossbarUsage = 0;
|
|
||||||
for (auto& block : body)
|
|
||||||
for (auto& op : block)
|
|
||||||
if (isa<SpatVMMOp>(op))
|
|
||||||
crossbarUsage = checkedAdd(crossbarUsage, static_cast<CrossbarUsage>(1));
|
|
||||||
return crossbarUsage;
|
|
||||||
}
|
|
||||||
|
|
||||||
Weight getComputeInstanceWeight(const ComputeInstance& instance) {
|
|
||||||
if (auto spatCompute = dyn_cast<SpatCompute>(instance.op))
|
|
||||||
return getSpatComputeWeight(spatCompute);
|
|
||||||
auto batch = cast<SpatComputeBatch>(instance.op);
|
|
||||||
return checkedMultiply(getComputeBodyWeight(batch.getBody()), static_cast<Weight>(instance.laneCount));
|
|
||||||
}
|
|
||||||
|
|
||||||
CrossbarUsage getComputeInstanceCrossbarUsage(const ComputeInstance& instance) {
|
|
||||||
if (auto spatCompute = dyn_cast<SpatCompute>(instance.op))
|
|
||||||
return getSpatComputeCrossbarUsage(spatCompute);
|
|
||||||
auto batch = cast<SpatComputeBatch>(instance.op);
|
|
||||||
return checkedMultiply(getComputeBodyCrossbarUsage(batch.getBody()), static_cast<CrossbarUsage>(instance.laneCount));
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<Value, 4> getComputeInstanceInputs(const ComputeInstance& instance) {
|
|
||||||
if (auto spatCompute = dyn_cast<SpatCompute>(instance.op))
|
|
||||||
return SmallVector<Value, 4>(spatCompute.getInputs().begin(), spatCompute.getInputs().end());
|
|
||||||
auto batch = cast<SpatComputeBatch>(instance.op);
|
|
||||||
SmallVector<Value, 4> inputs;
|
|
||||||
inputs.reserve(instance.laneCount);
|
|
||||||
for (uint32_t lane = instance.laneStart; lane < instance.laneStart + instance.laneCount; ++lane)
|
|
||||||
inputs.push_back(batch.getInputs()[lane]);
|
|
||||||
return inputs;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::optional<ComputeInstance> getOriginalComputeInstance(Value value) {
|
|
||||||
Operation* op = value.getDefiningOp();
|
|
||||||
if (!op)
|
|
||||||
return std::nullopt;
|
|
||||||
|
|
||||||
while (auto extract = dyn_cast<tensor::ExtractSliceOp>(op)) {
|
|
||||||
value = extract.getSource();
|
|
||||||
op = value.getDefiningOp();
|
|
||||||
if (!op)
|
|
||||||
return std::nullopt;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (auto spatCompute = dyn_cast<SpatCompute>(op))
|
|
||||||
return ComputeInstance {spatCompute.getOperation(), 0, 1};
|
|
||||||
if (auto batch = dyn_cast<SpatComputeBatch>(op))
|
|
||||||
return getBatchChunkForLane(batch, static_cast<uint32_t>(cast<OpResult>(value).getResultNumber()));
|
|
||||||
return std::nullopt;
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<ComputeInstance> collectComputeInstances(Operation* entryOp) {
|
|
||||||
SmallVector<ComputeInstance> instances;
|
|
||||||
auto isUsedAsWeightOnly = [](Operation* producerOp) {
|
|
||||||
if (producerOp->getNumResults() == 0)
|
|
||||||
return false;
|
|
||||||
for (Value result : producerOp->getResults()) {
|
|
||||||
if (result.use_empty())
|
|
||||||
return false;
|
|
||||||
for (Operation* user : result.getUsers()) {
|
|
||||||
if (auto compute = dyn_cast<SpatCompute>(user)) {
|
|
||||||
if (!llvm::is_contained(compute.getWeights(), result))
|
|
||||||
return false;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (auto batch = dyn_cast<SpatComputeBatch>(user)) {
|
|
||||||
if (!llvm::is_contained(batch.getWeights(), result))
|
|
||||||
return false;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
};
|
|
||||||
for (Region& region : entryOp->getRegions()) {
|
|
||||||
for (Block& block : region) {
|
|
||||||
for (Operation& op : block) {
|
|
||||||
if (auto spatCompute = dyn_cast<SpatCompute>(&op)) {
|
|
||||||
if (isUsedAsWeightOnly(spatCompute.getOperation()))
|
|
||||||
continue;
|
|
||||||
instances.push_back({spatCompute.getOperation(), 0, 1});
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (auto batch = dyn_cast<SpatComputeBatch>(&op)) {
|
|
||||||
if (isUsedAsWeightOnly(batch.getOperation()))
|
|
||||||
continue;
|
|
||||||
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount());
|
|
||||||
for (size_t chunkIndex = 0; chunkIndex < chunkCount; ++chunkIndex)
|
|
||||||
instances.push_back(getBatchChunkForIndex(batch, chunkIndex));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return instances;
|
|
||||||
}
|
|
||||||
|
|
||||||
VirtualGraph buildInitialVirtualGraph(ArrayRef<ComputeInstance> computeInstances, ArrayRef<IndexedEdge> edges) {
|
|
||||||
VirtualGraph graph;
|
|
||||||
graph.nodes.reserve(computeInstances.size());
|
|
||||||
for (auto [index, computeInstance] : llvm::enumerate(computeInstances)) {
|
|
||||||
VirtualNode node;
|
|
||||||
node.originalComputeIndices.push_back(index);
|
|
||||||
node.weight = getComputeInstanceWeight(computeInstance);
|
|
||||||
node.crossbarUsage = getComputeInstanceCrossbarUsage(computeInstance);
|
|
||||||
graph.nodes.push_back(std::move(node));
|
|
||||||
}
|
|
||||||
graph.edges = aggregateEdges(edges);
|
|
||||||
return graph;
|
|
||||||
}
|
|
||||||
|
|
||||||
TimingInfo computeTiming(const VirtualGraph& graph) {
|
|
||||||
TimingInfo timing;
|
|
||||||
size_t nodeCount = graph.nodes.size();
|
|
||||||
timing.aest.assign(nodeCount, 0);
|
|
||||||
timing.alst.assign(nodeCount, 0);
|
|
||||||
timing.topologicalOrder.reserve(nodeCount);
|
|
||||||
|
|
||||||
std::vector<std::vector<std::pair<size_t, Weight>>> parents(nodeCount);
|
|
||||||
std::vector<std::vector<std::pair<size_t, Weight>>> children(nodeCount);
|
|
||||||
std::vector<size_t> incomingEdgeCount(nodeCount, 0);
|
|
||||||
|
|
||||||
for (auto [start, end, weight] : graph.edges) {
|
|
||||||
size_t startIndex = static_cast<size_t>(start);
|
|
||||||
size_t endIndex = static_cast<size_t>(end);
|
|
||||||
Weight edgeWeight = static_cast<Weight>(weight);
|
|
||||||
assert(startIndex < nodeCount && endIndex < nodeCount && "virtual edge endpoint out of range");
|
|
||||||
children[startIndex].push_back({endIndex, edgeWeight});
|
|
||||||
parents[endIndex].push_back({startIndex, edgeWeight});
|
|
||||||
incomingEdgeCount[endIndex]++;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto getVirtualNodeOrderKey = [&](size_t nodeIndex) {
|
|
||||||
const VirtualNode& node = graph.nodes[nodeIndex];
|
|
||||||
if (!node.originalComputeIndices.empty())
|
|
||||||
return node.originalComputeIndices.front();
|
|
||||||
return nodeIndex;
|
|
||||||
};
|
|
||||||
auto readyNodeGreater = [&](size_t lhs, size_t rhs) {
|
|
||||||
size_t lhsKey = getVirtualNodeOrderKey(lhs);
|
|
||||||
size_t rhsKey = getVirtualNodeOrderKey(rhs);
|
|
||||||
if (lhsKey != rhsKey)
|
|
||||||
return lhsKey > rhsKey;
|
|
||||||
return lhs > rhs;
|
|
||||||
};
|
|
||||||
std::priority_queue<size_t, std::vector<size_t>, decltype(readyNodeGreater)> readyNodes(readyNodeGreater);
|
|
||||||
for (size_t i = 0; i < nodeCount; ++i)
|
|
||||||
if (incomingEdgeCount[i] == 0)
|
|
||||||
readyNodes.push(i);
|
|
||||||
|
|
||||||
while (!readyNodes.empty()) {
|
|
||||||
size_t current = readyNodes.top();
|
|
||||||
readyNodes.pop();
|
|
||||||
timing.topologicalOrder.push_back(current);
|
|
||||||
for (auto [child, weight] : children[current]) {
|
|
||||||
(void) weight;
|
|
||||||
assert(incomingEdgeCount[child] > 0 && "incoming edge count underflow");
|
|
||||||
incomingEdgeCount[child]--;
|
|
||||||
if (incomingEdgeCount[child] == 0)
|
|
||||||
readyNodes.push(child);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (timing.topologicalOrder.size() != nodeCount)
|
|
||||||
return timing;
|
|
||||||
|
|
||||||
Time dcpl = 0;
|
|
||||||
for (size_t nodeIndex : timing.topologicalOrder) {
|
|
||||||
Time maxParentAest = 0;
|
|
||||||
for (auto [parent, transferCost] : parents[nodeIndex]) {
|
|
||||||
maxParentAest =
|
|
||||||
std::max(maxParentAest, addOrMax(addOrMax(timing.aest[parent], graph.nodes[parent].weight), transferCost));
|
|
||||||
}
|
|
||||||
timing.aest[nodeIndex] = maxParentAest;
|
|
||||||
dcpl = std::max(dcpl, addOrMax(maxParentAest, graph.nodes[nodeIndex].weight));
|
|
||||||
}
|
|
||||||
|
|
||||||
for (size_t nodeIndex : llvm::reverse(timing.topologicalOrder)) {
|
|
||||||
Time minAlst = std::numeric_limits<Time>::max();
|
|
||||||
if (children[nodeIndex].empty())
|
|
||||||
minAlst = subtractOrZero(dcpl, graph.nodes[nodeIndex].weight);
|
|
||||||
for (auto [child, transferCost] : children[nodeIndex]) {
|
|
||||||
minAlst =
|
|
||||||
std::min(minAlst, subtractOrZero(timing.alst[child], addOrMax(graph.nodes[nodeIndex].weight, transferCost)));
|
|
||||||
}
|
|
||||||
timing.alst[nodeIndex] = minAlst;
|
|
||||||
}
|
|
||||||
|
|
||||||
timing.valid = true;
|
|
||||||
return timing;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<std::vector<size_t>> buildUndirectedAdjacency(const VirtualGraph& graph) {
|
|
||||||
std::vector<std::vector<size_t>> adjacency(graph.nodes.size());
|
|
||||||
for (auto [start, end, weight] : graph.edges) {
|
|
||||||
(void) weight;
|
|
||||||
size_t startIndex = static_cast<size_t>(start);
|
|
||||||
size_t endIndex = static_cast<size_t>(end);
|
|
||||||
assert(startIndex < graph.nodes.size() && endIndex < graph.nodes.size() && "virtual edge endpoint out of range");
|
|
||||||
adjacency[startIndex].push_back(endIndex);
|
|
||||||
adjacency[endIndex].push_back(startIndex);
|
|
||||||
}
|
|
||||||
for (auto& neighbours : adjacency) {
|
|
||||||
llvm::sort(neighbours);
|
|
||||||
neighbours.erase(std::unique(neighbours.begin(), neighbours.end()), neighbours.end());
|
|
||||||
}
|
|
||||||
return adjacency;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<size_t> selectCriticalWindow(const VirtualGraph& graph, const TimingInfo& timing, size_t windowSize) {
|
|
||||||
std::vector<size_t> ranked(timing.aest.size());
|
|
||||||
std::iota(ranked.begin(), ranked.end(), 0);
|
|
||||||
auto isHigherPriority = [&](size_t lhs, size_t rhs) {
|
|
||||||
Time lhsSlack = slackOrZero(timing.aest[lhs], timing.alst[lhs]);
|
|
||||||
Time rhsSlack = slackOrZero(timing.aest[rhs], timing.alst[rhs]);
|
|
||||||
if (lhsSlack != rhsSlack)
|
|
||||||
return lhsSlack < rhsSlack;
|
|
||||||
if (timing.aest[lhs] != timing.aest[rhs])
|
|
||||||
return timing.aest[lhs] < timing.aest[rhs];
|
|
||||||
return lhs < rhs;
|
|
||||||
};
|
|
||||||
|
|
||||||
windowSize = std::min(windowSize, ranked.size());
|
|
||||||
if (windowSize == 0)
|
|
||||||
return {};
|
|
||||||
if (windowSize == ranked.size()) {
|
|
||||||
llvm::sort(ranked, isHigherPriority);
|
|
||||||
return ranked;
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t criticalPoolSize = std::min(ranked.size(), std::max(windowSize, windowSize * 2));
|
|
||||||
if (criticalPoolSize < ranked.size())
|
|
||||||
std::nth_element(
|
|
||||||
ranked.begin(), ranked.begin() + static_cast<std::ptrdiff_t>(criticalPoolSize), ranked.end(), isHigherPriority);
|
|
||||||
|
|
||||||
std::vector<char> inCriticalPool(ranked.size(), false);
|
|
||||||
for (size_t i = 0; i < criticalPoolSize; ++i)
|
|
||||||
inCriticalPool[ranked[i]] = true;
|
|
||||||
|
|
||||||
size_t seed = *std::min_element(ranked.begin(), ranked.end(), isHigherPriority);
|
|
||||||
std::vector<std::vector<size_t>> adjacency = buildUndirectedAdjacency(graph);
|
|
||||||
std::vector<size_t> selected;
|
|
||||||
std::vector<char> inWindow(ranked.size(), false);
|
|
||||||
selected.reserve(windowSize);
|
|
||||||
|
|
||||||
struct FrontierEntry {
|
|
||||||
size_t node;
|
|
||||||
};
|
|
||||||
auto frontierCompare = [&](FrontierEntry lhs, FrontierEntry rhs) { return isHigherPriority(rhs.node, lhs.node); };
|
|
||||||
std::priority_queue<FrontierEntry, std::vector<FrontierEntry>, decltype(frontierCompare)> frontier(frontierCompare);
|
|
||||||
|
|
||||||
auto addToWindow = [&](size_t node, const std::vector<char>& eligible) {
|
|
||||||
if (inWindow[node])
|
|
||||||
return;
|
|
||||||
inWindow[node] = true;
|
|
||||||
selected.push_back(node);
|
|
||||||
for (size_t neighbour : adjacency[node])
|
|
||||||
if (!inWindow[neighbour] && eligible[neighbour])
|
|
||||||
frontier.push({neighbour});
|
|
||||||
};
|
|
||||||
|
|
||||||
addToWindow(seed, inCriticalPool);
|
|
||||||
while (!frontier.empty() && selected.size() < windowSize) {
|
|
||||||
size_t node = frontier.top().node;
|
|
||||||
frontier.pop();
|
|
||||||
if (!inWindow[node])
|
|
||||||
addToWindow(node, inCriticalPool);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (selected.size() < windowSize) {
|
|
||||||
std::vector<char> anyNode(ranked.size(), true);
|
|
||||||
for (size_t node : selected)
|
|
||||||
for (size_t neighbour : adjacency[node])
|
|
||||||
if (!inWindow[neighbour])
|
|
||||||
frontier.push({neighbour});
|
|
||||||
while (!frontier.empty() && selected.size() < windowSize) {
|
|
||||||
size_t node = frontier.top().node;
|
|
||||||
frontier.pop();
|
|
||||||
if (!inWindow[node])
|
|
||||||
addToWindow(node, anyNode);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (selected.size() < windowSize) {
|
|
||||||
llvm::sort(ranked, isHigherPriority);
|
|
||||||
for (size_t node : ranked) {
|
|
||||||
if (selected.size() == windowSize)
|
|
||||||
break;
|
|
||||||
if (!inWindow[node]) {
|
|
||||||
inWindow[node] = true;
|
|
||||||
selected.push_back(node);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
llvm::sort(selected, isHigherPriority);
|
|
||||||
return selected;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<IndexedEdge> buildWindowEdges(const VirtualGraph& graph, const std::vector<int64_t>& nodeToWindowIndex) {
|
|
||||||
std::vector<IndexedEdge> windowEdges;
|
|
||||||
windowEdges.reserve(graph.edges.size());
|
|
||||||
for (auto [start, end, weight] : graph.edges) {
|
|
||||||
int64_t mappedStart = nodeToWindowIndex[static_cast<size_t>(start)];
|
|
||||||
int64_t mappedEnd = nodeToWindowIndex[static_cast<size_t>(end)];
|
|
||||||
if (mappedStart == -1 || mappedEnd == -1)
|
|
||||||
continue;
|
|
||||||
windowEdges.push_back({mappedStart, mappedEnd, weight});
|
|
||||||
}
|
|
||||||
return aggregateEdges(windowEdges);
|
|
||||||
}
|
|
||||||
|
|
||||||
WindowScheduleResult scheduleWindow(const VirtualGraph& graph, ArrayRef<size_t> selectedNodes, MLIRContext* context) {
|
|
||||||
std::vector<Weight> windowWeights;
|
|
||||||
std::vector<CrossbarUsage> windowCrossbarUsage;
|
|
||||||
std::vector<int64_t> windowNodeOrderKeys;
|
|
||||||
std::vector<int64_t> nodeToWindowIndex(graph.nodes.size(), -1);
|
|
||||||
windowWeights.reserve(selectedNodes.size());
|
|
||||||
windowCrossbarUsage.reserve(selectedNodes.size());
|
|
||||||
windowNodeOrderKeys.reserve(selectedNodes.size());
|
|
||||||
|
|
||||||
for (auto [windowIndex, nodeIndex] : llvm::enumerate(selectedNodes)) {
|
|
||||||
nodeToWindowIndex[nodeIndex] = static_cast<int64_t>(windowIndex);
|
|
||||||
windowWeights.push_back(graph.nodes[nodeIndex].weight);
|
|
||||||
windowCrossbarUsage.push_back(graph.nodes[nodeIndex].crossbarUsage);
|
|
||||||
windowNodeOrderKeys.push_back(static_cast<int64_t>(nodeIndex));
|
|
||||||
}
|
|
||||||
|
|
||||||
GraphDCP windowGraph(
|
|
||||||
windowWeights, buildWindowEdges(graph, nodeToWindowIndex), windowNodeOrderKeys, windowCrossbarUsage);
|
|
||||||
if (coresCount.getValue() > 0)
|
|
||||||
windowGraph.setMaxCpuCount(static_cast<int>(coresCount.getValue()));
|
|
||||||
windowGraph.setContext(context);
|
|
||||||
windowGraph.runDcp();
|
|
||||||
|
|
||||||
WindowScheduleResult result;
|
|
||||||
result.cpuCount = windowGraph.cpuCount();
|
|
||||||
for (CPU cpu = 0; cpu < windowGraph.cpuCount(); ++cpu) {
|
|
||||||
auto scheduledTasks = windowGraph.getScheduledTasks(cpu);
|
|
||||||
if (scheduledTasks.size() < 2)
|
|
||||||
continue;
|
|
||||||
|
|
||||||
result.mergedNodeCount += scheduledTasks.size();
|
|
||||||
result.maxMergeGroupSize = std::max(result.maxMergeGroupSize, scheduledTasks.size());
|
|
||||||
std::vector<size_t> mergeGroup;
|
|
||||||
mergeGroup.reserve(scheduledTasks.size());
|
|
||||||
for (const auto& task : scheduledTasks)
|
|
||||||
mergeGroup.push_back(selectedNodes[task.nodeIndex]);
|
|
||||||
result.mergeGroups.push_back(std::move(mergeGroup));
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool coarsenGraph(const VirtualGraph& graph,
|
|
||||||
ArrayRef<std::vector<size_t>> mergeGroups,
|
|
||||||
VirtualGraph& coarsenedGraph,
|
|
||||||
std::vector<size_t>& oldToNewNode) {
|
|
||||||
TimingInfo timing = computeTiming(graph);
|
|
||||||
std::vector<size_t> topologicalRank(graph.nodes.size());
|
|
||||||
std::iota(topologicalRank.begin(), topologicalRank.end(), 0);
|
|
||||||
if (timing.valid)
|
|
||||||
for (auto [rank, nodeIndex] : llvm::enumerate(timing.topologicalOrder))
|
|
||||||
topologicalRank[nodeIndex] = rank;
|
|
||||||
|
|
||||||
std::vector<std::vector<size_t>> orderedMergeGroups;
|
|
||||||
orderedMergeGroups.reserve(mergeGroups.size());
|
|
||||||
for (const auto& mergeGroup : mergeGroups) {
|
|
||||||
orderedMergeGroups.emplace_back(mergeGroup.begin(), mergeGroup.end());
|
|
||||||
std::stable_sort(orderedMergeGroups.back().begin(), orderedMergeGroups.back().end(), [&](size_t lhs, size_t rhs) {
|
|
||||||
if (topologicalRank[lhs] != topologicalRank[rhs])
|
|
||||||
return topologicalRank[lhs] < topologicalRank[rhs];
|
|
||||||
return lhs < rhs;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<int64_t> nodeToMergeGroup(graph.nodes.size(), -1);
|
|
||||||
for (auto [groupIndex, mergeGroup] : llvm::enumerate(orderedMergeGroups)) {
|
|
||||||
if (mergeGroup.size() < 2)
|
|
||||||
continue;
|
|
||||||
for (size_t nodeIndex : mergeGroup) {
|
|
||||||
assert(nodeIndex < graph.nodes.size() && "merge group node out of range");
|
|
||||||
nodeToMergeGroup[nodeIndex] = static_cast<int64_t>(groupIndex);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<std::optional<size_t>> mergeGroupToNewNode(orderedMergeGroups.size());
|
|
||||||
std::vector<size_t> newNodeRank;
|
|
||||||
oldToNewNode.assign(graph.nodes.size(), 0);
|
|
||||||
bool mergedAny = false;
|
|
||||||
coarsenedGraph.nodes.clear();
|
|
||||||
coarsenedGraph.edges.clear();
|
|
||||||
coarsenedGraph.nodes.reserve(graph.nodes.size());
|
|
||||||
newNodeRank.reserve(graph.nodes.size());
|
|
||||||
|
|
||||||
for (size_t nodeIndex = 0; nodeIndex < graph.nodes.size(); ++nodeIndex) {
|
|
||||||
int64_t mergeGroupIndex = nodeToMergeGroup[nodeIndex];
|
|
||||||
if (mergeGroupIndex == -1) {
|
|
||||||
oldToNewNode[nodeIndex] = coarsenedGraph.nodes.size();
|
|
||||||
coarsenedGraph.nodes.push_back(graph.nodes[nodeIndex]);
|
|
||||||
newNodeRank.push_back(topologicalRank[nodeIndex]);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto& newNodeIndex = mergeGroupToNewNode[static_cast<size_t>(mergeGroupIndex)];
|
|
||||||
if (newNodeIndex.has_value()) {
|
|
||||||
oldToNewNode[nodeIndex] = *newNodeIndex;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
VirtualNode mergedNode;
|
|
||||||
for (size_t memberIndex : orderedMergeGroups[static_cast<size_t>(mergeGroupIndex)]) {
|
|
||||||
const VirtualNode& memberNode = graph.nodes[memberIndex];
|
|
||||||
mergedNode.originalComputeIndices.append(memberNode.originalComputeIndices.begin(),
|
|
||||||
memberNode.originalComputeIndices.end());
|
|
||||||
mergedNode.weight = addOrMax(mergedNode.weight, memberNode.weight);
|
|
||||||
mergedNode.crossbarUsage = addOrMax(mergedNode.crossbarUsage, memberNode.crossbarUsage);
|
|
||||||
}
|
|
||||||
std::sort(mergedNode.originalComputeIndices.begin(), mergedNode.originalComputeIndices.end());
|
|
||||||
|
|
||||||
mergedAny = true;
|
|
||||||
newNodeIndex = coarsenedGraph.nodes.size();
|
|
||||||
for (size_t memberIndex : orderedMergeGroups[static_cast<size_t>(mergeGroupIndex)])
|
|
||||||
oldToNewNode[memberIndex] = *newNodeIndex;
|
|
||||||
newNodeRank.push_back(topologicalRank[orderedMergeGroups[static_cast<size_t>(mergeGroupIndex)].front()]);
|
|
||||||
coarsenedGraph.nodes.push_back(std::move(mergedNode));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!mergedAny)
|
|
||||||
return false;
|
|
||||||
|
|
||||||
std::vector<IndexedEdge> remappedEdges;
|
|
||||||
remappedEdges.reserve(graph.edges.size());
|
|
||||||
for (auto [start, end, weight] : graph.edges) {
|
|
||||||
size_t newStart = oldToNewNode[static_cast<size_t>(start)];
|
|
||||||
size_t newEnd = oldToNewNode[static_cast<size_t>(end)];
|
|
||||||
if (newStart == newEnd)
|
|
||||||
continue;
|
|
||||||
if (newNodeRank[newStart] >= newNodeRank[newEnd])
|
|
||||||
continue;
|
|
||||||
remappedEdges.push_back({static_cast<int64_t>(newStart), static_cast<int64_t>(newEnd), weight});
|
|
||||||
}
|
|
||||||
coarsenedGraph.edges = aggregateEdges(remappedEdges);
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
CPU getVirtualGraphMaxCpuCount() { return static_cast<CPU>(getSchedulingCpuBudget()); }
|
|
||||||
|
|
||||||
size_t getDcpCoarseningWindowSize(size_t nodeCount) {
|
|
||||||
size_t windowSize = std::min(dcpCriticalWindowSize.getValue(), nodeCount);
|
|
||||||
CPU maxCpuCount = std::max<CPU>(1, getVirtualGraphMaxCpuCount());
|
|
||||||
if (nodeCount > static_cast<size_t>(maxCpuCount))
|
|
||||||
windowSize = std::max(windowSize, std::min(nodeCount, static_cast<size_t>(maxCpuCount) + 1));
|
|
||||||
return windowSize;
|
|
||||||
}
|
|
||||||
|
|
||||||
DCPAnalysisResult buildResultFromVirtualGraph(const VirtualGraph& graph, ArrayRef<ComputeInstance> computeInstances) {
|
|
||||||
DCPAnalysisResult result;
|
|
||||||
|
|
||||||
TimingInfo timing = computeTiming(graph);
|
|
||||||
std::vector<size_t> virtualNodeOrder;
|
|
||||||
if (timing.valid) {
|
|
||||||
virtualNodeOrder = std::move(timing.topologicalOrder);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
virtualNodeOrder.resize(graph.nodes.size());
|
|
||||||
std::iota(virtualNodeOrder.begin(), virtualNodeOrder.end(), 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<size_t> originalComputeToCpu(computeInstances.size(), 0);
|
|
||||||
for (auto [cpu, virtualNodeIndex] : llvm::enumerate(virtualNodeOrder)) {
|
|
||||||
const VirtualNode& virtualNode = graph.nodes[virtualNodeIndex];
|
|
||||||
for (size_t originalIndex : virtualNode.originalComputeIndices)
|
|
||||||
originalComputeToCpu[originalIndex] = cpu;
|
|
||||||
}
|
|
||||||
|
|
||||||
result.dominanceOrderCompute.reserve(computeInstances.size());
|
|
||||||
llvm::DenseMap<size_t, size_t> nextCpuSlot;
|
|
||||||
for (auto [originalIndex, computeInstance] : llvm::enumerate(computeInstances)) {
|
|
||||||
size_t cpu = originalComputeToCpu[originalIndex];
|
|
||||||
result.dominanceOrderCompute.push_back(computeInstance);
|
|
||||||
result.computeToCpuMap[computeInstance] = cpu;
|
|
||||||
result.computeToCpuSlotMap[computeInstance] = nextCpuSlot[cpu]++;
|
|
||||||
result.computeToAestMap[computeInstance] = originalIndex;
|
|
||||||
result.cpuToLastComputeMap[cpu] = computeInstance;
|
|
||||||
}
|
|
||||||
for (const auto& [cpu, lastCompute] : result.cpuToLastComputeMap)
|
|
||||||
result.isLastComputeOfCpu.insert(lastCompute);
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
DCPAnalysisResult buildResultFromScheduledGraph(GraphDCP& graphDCP, ArrayRef<ComputeInstance> computeInstances) {
|
|
||||||
DCPAnalysisResult result;
|
|
||||||
result.dominanceOrderCompute.assign(computeInstances.begin(), computeInstances.end());
|
|
||||||
|
|
||||||
for (CPU cpu = 0; cpu < graphDCP.cpuCount(); ++cpu) {
|
|
||||||
auto scheduledTasks = graphDCP.getScheduledTasks(cpu);
|
|
||||||
if (scheduledTasks.empty())
|
|
||||||
continue;
|
|
||||||
|
|
||||||
for (auto [slot, task] : llvm::enumerate(scheduledTasks)) {
|
|
||||||
ComputeInstance instance = computeInstances[task.nodeIndex];
|
|
||||||
result.computeToCpuMap[instance] = cpu;
|
|
||||||
result.computeToCpuSlotMap[instance] = slot;
|
|
||||||
result.computeToAestMap[instance] = static_cast<uint64_t>(task.aest);
|
|
||||||
}
|
|
||||||
result.cpuToLastComputeMap[cpu] = computeInstances[scheduledTasks.back().nodeIndex];
|
|
||||||
result.isLastComputeOfCpu.insert(computeInstances[scheduledTasks.back().nodeIndex]);
|
|
||||||
}
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
DCPAnalysisResult
|
|
||||||
runLegacyDcp(ArrayRef<ComputeInstance> computeInstances, ArrayRef<IndexedEdge> edges, MLIRContext* context) {
|
|
||||||
SmallVector<Weight> nodeWeights;
|
|
||||||
SmallVector<CrossbarUsage> nodeCrossbarUsage;
|
|
||||||
SmallVector<int64_t> nodeOrderKeys;
|
|
||||||
nodeWeights.reserve(computeInstances.size());
|
|
||||||
nodeCrossbarUsage.reserve(computeInstances.size());
|
|
||||||
nodeOrderKeys.reserve(computeInstances.size());
|
|
||||||
for (auto [index, instance] : llvm::enumerate(computeInstances)) {
|
|
||||||
nodeWeights.push_back(getComputeInstanceWeight(instance));
|
|
||||||
nodeCrossbarUsage.push_back(getComputeInstanceCrossbarUsage(instance));
|
|
||||||
nodeOrderKeys.push_back(static_cast<int64_t>(index));
|
|
||||||
}
|
|
||||||
|
|
||||||
GraphDCP graphDCP(nodeWeights, edges, nodeOrderKeys, nodeCrossbarUsage);
|
|
||||||
if (coresCount.getValue() > 0)
|
|
||||||
graphDCP.setMaxCpuCount(static_cast<int>(coresCount.getValue()));
|
|
||||||
graphDCP.setContext(context);
|
|
||||||
graphDCP.runDcp();
|
|
||||||
return buildResultFromScheduledGraph(graphDCP, computeInstances);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
SpatCompute getOriginalSpatCompute(Operation* op) {
|
|
||||||
if (!op)
|
|
||||||
return {};
|
|
||||||
while (auto extract = dyn_cast<tensor::ExtractSliceOp>(op)) {
|
|
||||||
op = extract.getSource().getDefiningOp();
|
|
||||||
if (!op)
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
if (auto res = dyn_cast<SpatCompute>(op))
|
|
||||||
return res;
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
|
|
||||||
DCPAnalysisResult DCPAnalysis::run() {
|
DCPAnalysisResult DCPAnalysis::run() {
|
||||||
SmallVector<ComputeInstance> computeInstances = collectComputeInstances(entryOp);
|
ComputeGraph graph = buildComputeGraph(entryOp);
|
||||||
SmallVector<IndexedEdge, 10> edges;
|
DcpScheduleOptions options;
|
||||||
|
if (coresCount.getValue() > 0)
|
||||||
llvm::DenseMap<ComputeInstance, size_t> instanceToIndex;
|
options.processorCount = static_cast<size_t>(coresCount.getValue());
|
||||||
instanceToIndex.reserve(computeInstances.size());
|
options.criticalWindowSize = dcpCriticalWindowSize.getValue();
|
||||||
for (auto [index, instance] : llvm::enumerate(computeInstances))
|
options.allowFallbackForAutoCoreCount = true;
|
||||||
instanceToIndex[instance] = index;
|
return runDcpScheduler(graph, options, entryOp->getContext());
|
||||||
|
|
||||||
for (auto [indexEndEdge, computeInstance] : llvm::enumerate(computeInstances)) {
|
|
||||||
for (Value input : getComputeInstanceInputs(computeInstance)) {
|
|
||||||
if (auto producerInstance = getOriginalComputeInstance(input)) {
|
|
||||||
auto producerIt = instanceToIndex.find(*producerInstance);
|
|
||||||
assert(producerIt != instanceToIndex.end());
|
|
||||||
auto indexStartEdge = producerIt->second;
|
|
||||||
edges.push_back({static_cast<int64_t>(indexStartEdge),
|
|
||||||
static_cast<int64_t>(indexEndEdge),
|
|
||||||
static_cast<int64_t>(getSizeInBytes(cast<ShapedType>(input.getType())))});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (coresCount.getValue() > 0) {
|
|
||||||
size_t schedulingCpuBudget = getSchedulingCpuBudget();
|
|
||||||
bool needsExactScheduledBatches = llvm::any_of(computeInstances, [&](const ComputeInstance& instance) {
|
|
||||||
auto batch = dyn_cast<SpatComputeBatch>(instance.op);
|
|
||||||
return batch && static_cast<size_t>(batch.getLaneCount()) > schedulingCpuBudget;
|
|
||||||
});
|
|
||||||
if (needsExactScheduledBatches)
|
|
||||||
return runLegacyDcp(computeInstances, edges, entryOp->getContext());
|
|
||||||
}
|
|
||||||
|
|
||||||
if (dcpCriticalWindowSize.getValue() == 0)
|
|
||||||
return runLegacyDcp(computeInstances, edges, entryOp->getContext());
|
|
||||||
|
|
||||||
VirtualGraph virtualGraph = buildInitialVirtualGraph(computeInstances, edges);
|
|
||||||
size_t iteration = 0;
|
|
||||||
bool debugCoarsening = isDcpCoarsenDebugEnabled();
|
|
||||||
auto tryCoarsenSelectedNodes = [&](ArrayRef<size_t> selectedNodes) {
|
|
||||||
size_t oldNodeCount = virtualGraph.nodes.size();
|
|
||||||
WindowScheduleResult windowSchedule = scheduleWindow(virtualGraph, selectedNodes, entryOp->getContext());
|
|
||||||
if (windowSchedule.mergeGroups.empty()) {
|
|
||||||
if (debugCoarsening && oldNodeCount >= 200)
|
|
||||||
llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} windowCpus={3} "
|
|
||||||
"groups=0 mergedNodes=0 maxGroup=0 new={1} changed=0\n",
|
|
||||||
iteration,
|
|
||||||
oldNodeCount,
|
|
||||||
selectedNodes.size(),
|
|
||||||
windowSchedule.cpuCount);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
VirtualGraph coarsenedGraph;
|
|
||||||
std::vector<size_t> oldToNewNode;
|
|
||||||
if (!coarsenGraph(virtualGraph, windowSchedule.mergeGroups, coarsenedGraph, oldToNewNode))
|
|
||||||
return false;
|
|
||||||
if (debugCoarsening && (oldNodeCount >= 200 || coarsenedGraph.nodes.size() >= 200))
|
|
||||||
llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} windowCpus={3} "
|
|
||||||
"groups={4} mergedNodes={5} maxGroup={6} new={7} changed={8}\n",
|
|
||||||
iteration,
|
|
||||||
oldNodeCount,
|
|
||||||
selectedNodes.size(),
|
|
||||||
windowSchedule.cpuCount,
|
|
||||||
windowSchedule.mergeGroups.size(),
|
|
||||||
windowSchedule.mergedNodeCount,
|
|
||||||
windowSchedule.maxMergeGroupSize,
|
|
||||||
coarsenedGraph.nodes.size(),
|
|
||||||
oldNodeCount - coarsenedGraph.nodes.size());
|
|
||||||
virtualGraph = std::move(coarsenedGraph);
|
|
||||||
return true;
|
|
||||||
};
|
|
||||||
|
|
||||||
while (virtualGraph.nodes.size() > 1) {
|
|
||||||
if (virtualGraph.nodes.size() <= getSchedulingCpuBudget()) {
|
|
||||||
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
|
|
||||||
llvm::errs() << llvm::formatv(
|
|
||||||
"[DCP-COARSEN] iter={0} old={1} stop=cpu-budget\n", iteration, virtualGraph.nodes.size());
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
iteration++;
|
|
||||||
TimingInfo timing = computeTiming(virtualGraph);
|
|
||||||
if (!timing.valid) {
|
|
||||||
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
|
|
||||||
llvm::errs() << llvm::formatv(
|
|
||||||
"[DCP-COARSEN] iter={0} old={1} invalid-timing\n", iteration, virtualGraph.nodes.size());
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<size_t> selectedNodes;
|
|
||||||
auto criticalWindow =
|
|
||||||
selectCriticalWindow(virtualGraph, timing, getDcpCoarseningWindowSize(virtualGraph.nodes.size()));
|
|
||||||
selectedNodes.append(criticalWindow.begin(), criticalWindow.end());
|
|
||||||
|
|
||||||
if (selectedNodes.size() < 2) {
|
|
||||||
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
|
|
||||||
llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} stop=small-window\n",
|
|
||||||
iteration,
|
|
||||||
virtualGraph.nodes.size(),
|
|
||||||
selectedNodes.size());
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (tryCoarsenSelectedNodes(selectedNodes))
|
|
||||||
continue;
|
|
||||||
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
|
|
||||||
llvm::errs() << llvm::formatv(
|
|
||||||
"[DCP-COARSEN] iter={0} old={1} stop=no-merge\n", iteration, virtualGraph.nodes.size());
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
return buildResultFromVirtualGraph(virtualGraph, computeInstances);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace spatial
|
} // namespace spatial
|
||||||
|
|||||||
@@ -0,0 +1,691 @@
|
|||||||
|
#include "MaterializeMergeSchedule.hpp"
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
#include "mlir/IR/IRMapping.h"
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/DenseMap.h"
|
||||||
|
#include "llvm/ADT/DenseSet.h"
|
||||||
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
|
#include <cstddef>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <functional>
|
||||||
|
#include <optional>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "Scheduling/ComputeInstanceUtils.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace spatial {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using SpatCompute = spatial::SpatCompute;
|
||||||
|
using ProducerValueRef = spatial::ProducerValueRef;
|
||||||
|
using spatial::getComputeInstanceInputs;
|
||||||
|
using spatial::getComputeInstanceOutputTypes;
|
||||||
|
using spatial::getComputeInstanceOutputValues;
|
||||||
|
using spatial::getComputeInstanceTemplateBlock;
|
||||||
|
using spatial::getComputeInstanceWeights;
|
||||||
|
using spatial::getProducerValueRef;
|
||||||
|
|
||||||
|
static int32_t getPhysicalCoreId(size_t schedulerCpu) { return static_cast<int32_t>(schedulerCpu + 1); }
|
||||||
|
|
||||||
|
class MergeScheduleMaterializerImpl {
|
||||||
|
public:
|
||||||
|
explicit MergeScheduleMaterializerImpl(func::FuncOp funcOp)
|
||||||
|
: func(funcOp),
|
||||||
|
loc(funcOp.getLoc()),
|
||||||
|
returnOp(cast<func::ReturnOp>(funcOp.getBody().front().getTerminator())) {}
|
||||||
|
|
||||||
|
LogicalResult run(const MergeScheduleResult &scheduleResult, int64_t &nextChannelIdRef) {
|
||||||
|
schedule = &scheduleResult;
|
||||||
|
nextChannelId = &nextChannelIdRef;
|
||||||
|
|
||||||
|
collectScheduledTasks();
|
||||||
|
buildTaskIndex();
|
||||||
|
collectExternalInputsAndWeights();
|
||||||
|
planRemoteChannels();
|
||||||
|
planReceiveReordering();
|
||||||
|
createCpuComputeOps();
|
||||||
|
if (failed(cloneTaskBodies()))
|
||||||
|
return failure();
|
||||||
|
replaceExternalUses();
|
||||||
|
if (failed(eraseOldScheduledOps()))
|
||||||
|
return failure();
|
||||||
|
moveExternalUsersBeforeReturn();
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
struct ScheduledTask {
|
||||||
|
ComputeInstance key;
|
||||||
|
Operation *sourceOp = nullptr;
|
||||||
|
size_t cpu = 0;
|
||||||
|
size_t slot = 0;
|
||||||
|
size_t order = 0;
|
||||||
|
size_t executionOrder = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ChannelInfo {
|
||||||
|
int64_t channelId = -1;
|
||||||
|
int32_t sourceCoreId = -1;
|
||||||
|
int32_t targetCoreId = -1;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct CpuProgram {
|
||||||
|
SpatCompute op;
|
||||||
|
Block *block = nullptr;
|
||||||
|
DenseMap<Value, Value> externalInputMap;
|
||||||
|
DenseMap<Value, size_t> weightToIndex;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct RemoteSendInfo {
|
||||||
|
ChannelInfo channelInfo;
|
||||||
|
ComputeInstance consumer;
|
||||||
|
size_t inputIndex = 0;
|
||||||
|
size_t consumerOrder = 0;
|
||||||
|
size_t sourceOrder = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct RemoteReceiveEntry {
|
||||||
|
ChannelInfo channelInfo;
|
||||||
|
ComputeInstance consumer;
|
||||||
|
size_t inputIndex = 0;
|
||||||
|
size_t sourceOrder = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
static uint64_t getRemoteSendPairKey(const ChannelInfo &channelInfo) {
|
||||||
|
return (static_cast<uint64_t>(static_cast<uint32_t>(channelInfo.sourceCoreId)) << 32)
|
||||||
|
| static_cast<uint32_t>(channelInfo.targetCoreId);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void appendUniqueValue(SmallVectorImpl<Value> &values, DenseSet<Value> &seen, Value value) {
|
||||||
|
if (seen.insert(value).second)
|
||||||
|
values.push_back(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool isInternalInputOp(Operation *op) {
|
||||||
|
auto it = isInternalInputOpCache.find(op);
|
||||||
|
if (it != isInternalInputOpCache.end())
|
||||||
|
return it->second;
|
||||||
|
|
||||||
|
auto extract = dyn_cast_or_null<tensor::ExtractSliceOp>(op);
|
||||||
|
if (!extract)
|
||||||
|
return isInternalInputOpCache[op] = false;
|
||||||
|
|
||||||
|
for (Value result : extract->getResults()) {
|
||||||
|
for (Operation *user : result.getUsers()) {
|
||||||
|
if (toEraseSet.contains(user))
|
||||||
|
continue;
|
||||||
|
if (isInternalInputOp(user))
|
||||||
|
continue;
|
||||||
|
return isInternalInputOpCache[op] = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return isInternalInputOpCache[op] = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void collectInternalInputOps(Value value) {
|
||||||
|
Operation *op = value.getDefiningOp();
|
||||||
|
while (auto extract = dyn_cast_if_present<tensor::ExtractSliceOp>(op)) {
|
||||||
|
if (isInternalInputOp(extract.getOperation()))
|
||||||
|
internalInputOpsToErase.insert(extract.getOperation());
|
||||||
|
value = extract.getSource();
|
||||||
|
op = value.getDefiningOp();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void collectExternalUsers(Operation *op) {
|
||||||
|
if (!externalUsersToMove.insert(op).second)
|
||||||
|
return;
|
||||||
|
for (Value result : op->getResults()) {
|
||||||
|
for (Operation *user : result.getUsers()) {
|
||||||
|
if (toEraseSet.contains(user) || isa<func::ReturnOp>(user))
|
||||||
|
continue;
|
||||||
|
collectExternalUsers(user);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void collectScheduledTasks() {
|
||||||
|
size_t nextOrder = 0;
|
||||||
|
for (ComputeInstance scheduledInstance : schedule->dominanceOrderCompute) {
|
||||||
|
toEraseSet.insert(scheduledInstance.op);
|
||||||
|
scheduledTasks.push_back(
|
||||||
|
{scheduledInstance, scheduledInstance.op, schedule->computeToCpuMap.lookup(scheduledInstance),
|
||||||
|
schedule->computeToCpuSlotMap.lookup(scheduledInstance), nextOrder++});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void buildTaskIndex() {
|
||||||
|
auto markCpuSeen = [&](size_t cpu) {
|
||||||
|
if (seenCpus.insert(cpu).second)
|
||||||
|
orderedCpus.push_back(cpu);
|
||||||
|
};
|
||||||
|
|
||||||
|
for (const ScheduledTask &task : scheduledTasks) {
|
||||||
|
taskByKey[task.key] = task;
|
||||||
|
tasksByCpu[task.cpu].push_back(task);
|
||||||
|
markCpuSeen(task.cpu);
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::sort(orderedCpus);
|
||||||
|
for (size_t cpu : orderedCpus) {
|
||||||
|
llvm::stable_sort(tasksByCpu[cpu], [&](const ScheduledTask &lhs, const ScheduledTask &rhs) {
|
||||||
|
if (lhs.slot != rhs.slot)
|
||||||
|
return lhs.slot < rhs.slot;
|
||||||
|
return lhs.order < rhs.order;
|
||||||
|
});
|
||||||
|
for (auto [executionOrder, task] : llvm::enumerate(tasksByCpu[cpu])) {
|
||||||
|
task.executionOrder = executionOrder;
|
||||||
|
taskByKey[task.key].executionOrder = executionOrder;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void collectExternalInputsAndWeights() {
|
||||||
|
for (size_t cpu : orderedCpus) {
|
||||||
|
for (const ScheduledTask &task : tasksByCpu[cpu]) {
|
||||||
|
auto taskWeights = getComputeInstanceWeights(task.key);
|
||||||
|
for (Value weight : taskWeights)
|
||||||
|
appendUniqueValue(cpuWeights[cpu], seenWeightsByCpu[cpu], weight);
|
||||||
|
|
||||||
|
auto taskInputs = getComputeInstanceInputs(task.key);
|
||||||
|
auto &remoteInputs = remoteInputsByTask[task.key];
|
||||||
|
remoteInputs.resize(taskInputs.size());
|
||||||
|
for (auto [inputIndex, input] : llvm::enumerate(taskInputs)) {
|
||||||
|
auto producerRef = getProducerValueRef(input);
|
||||||
|
if (producerRef) {
|
||||||
|
collectInternalInputOps(input);
|
||||||
|
auto producerIt = taskByKey.find(producerRef->instance);
|
||||||
|
if (producerIt != taskByKey.end()) {
|
||||||
|
if (producerIt->second.cpu != cpu) {
|
||||||
|
ChannelInfo info {
|
||||||
|
(*nextChannelId)++,
|
||||||
|
getPhysicalCoreId(producerIt->second.cpu),
|
||||||
|
getPhysicalCoreId(cpu),
|
||||||
|
};
|
||||||
|
remoteInputs[inputIndex] = info;
|
||||||
|
auto &perResultChannels = remoteSendsByTask[producerRef->instance];
|
||||||
|
if (perResultChannels.empty())
|
||||||
|
perResultChannels.resize(getComputeInstanceOutputTypes(producerIt->second.key).size());
|
||||||
|
perResultChannels[producerRef->resultIndex].push_back(
|
||||||
|
{info, task.key, inputIndex, task.executionOrder, 0});
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
appendUniqueValue(cpuExternalInputs[cpu], seenExternalInputsByCpu[cpu], input);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto taskOutputs = getComputeInstanceOutputValues(task.key);
|
||||||
|
for (auto [resultIndex, output] : llvm::enumerate(taskOutputs)) {
|
||||||
|
bool hasExternalUser = false;
|
||||||
|
for (auto &use : output.getUses()) {
|
||||||
|
Operation *useOwner = use.getOwner();
|
||||||
|
if (toEraseSet.contains(useOwner))
|
||||||
|
continue;
|
||||||
|
hasExternalUser = true;
|
||||||
|
if (!isa<func::ReturnOp>(useOwner))
|
||||||
|
collectExternalUsers(useOwner);
|
||||||
|
}
|
||||||
|
if (hasExternalUser)
|
||||||
|
cpuExternalOutputs[cpu].push_back({task.key, resultIndex});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void planRemoteChannels() {
|
||||||
|
for (size_t cpu : orderedCpus) {
|
||||||
|
DenseMap<uint64_t, size_t> nextSourceOrderByPair;
|
||||||
|
DenseMap<uint64_t, size_t> lastConsumerOrderByPair;
|
||||||
|
for (const ScheduledTask &task : tasksByCpu[cpu]) {
|
||||||
|
auto sendsIt = remoteSendsByTask.find(task.key);
|
||||||
|
if (sendsIt == remoteSendsByTask.end())
|
||||||
|
continue;
|
||||||
|
for (auto &sendInfos : sendsIt->second) {
|
||||||
|
for (RemoteSendInfo &sendInfo : sendInfos) {
|
||||||
|
uint64_t pairKey = getRemoteSendPairKey(sendInfo.channelInfo);
|
||||||
|
sendInfo.sourceOrder = nextSourceOrderByPair[pairKey]++;
|
||||||
|
auto [it, inserted] = lastConsumerOrderByPair.try_emplace(pairKey, sendInfo.consumerOrder);
|
||||||
|
if (!inserted) {
|
||||||
|
if (sendInfo.consumerOrder < it->second)
|
||||||
|
pairsNeedingReceiveReorder.insert(pairKey);
|
||||||
|
it->second = sendInfo.consumerOrder;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void planReceiveReordering() {
|
||||||
|
DenseMap<uint64_t, SmallVector<RemoteSendInfo *>> reorderedSendsByPair;
|
||||||
|
for (auto &taskSends : remoteSendsByTask) {
|
||||||
|
for (auto &sendInfos : taskSends.second) {
|
||||||
|
for (RemoteSendInfo &sendInfo : sendInfos) {
|
||||||
|
uint64_t pairKey = getRemoteSendPairKey(sendInfo.channelInfo);
|
||||||
|
if (pairsNeedingReceiveReorder.contains(pairKey))
|
||||||
|
reorderedSendsByPair[pairKey].push_back(&sendInfo);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto &pairSends : reorderedSendsByPair) {
|
||||||
|
llvm::stable_sort(pairSends.second, [](const RemoteSendInfo *lhs, const RemoteSendInfo *rhs) {
|
||||||
|
if (lhs->sourceOrder != rhs->sourceOrder)
|
||||||
|
return lhs->sourceOrder < rhs->sourceOrder;
|
||||||
|
return lhs->channelInfo.channelId < rhs->channelInfo.channelId;
|
||||||
|
});
|
||||||
|
for (RemoteSendInfo *sendInfo : pairSends.second) {
|
||||||
|
int64_t channelId = (*nextChannelId)++;
|
||||||
|
sendInfo->channelInfo.channelId = channelId;
|
||||||
|
auto remoteInputsIt = remoteInputsByTask.find(sendInfo->consumer);
|
||||||
|
assert(remoteInputsIt != remoteInputsByTask.end() && "missing remote input for reordered send");
|
||||||
|
assert(sendInfo->inputIndex < remoteInputsIt->second.size() && "remote input index out of range");
|
||||||
|
assert(remoteInputsIt->second[sendInfo->inputIndex] && "missing reordered remote input channel");
|
||||||
|
remoteInputsIt->second[sendInfo->inputIndex]->channelId = channelId;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const auto &taskSends : remoteSendsByTask) {
|
||||||
|
for (const auto &sendInfos : taskSends.second) {
|
||||||
|
for (const RemoteSendInfo &sendInfo : sendInfos) {
|
||||||
|
auto remoteInputsIt = remoteInputsByTask.find(sendInfo.consumer);
|
||||||
|
assert(remoteInputsIt != remoteInputsByTask.end() && "missing remote input for send");
|
||||||
|
assert(sendInfo.inputIndex < remoteInputsIt->second.size() && "remote input index out of range");
|
||||||
|
assert(remoteInputsIt->second[sendInfo.inputIndex] && "missing remote input channel");
|
||||||
|
remoteInputsIt->second[sendInfo.inputIndex] = sendInfo.channelInfo;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto &taskSends : remoteSendsByTask) {
|
||||||
|
for (const auto &sendInfos : taskSends.second) {
|
||||||
|
for (const RemoteSendInfo &sendInfo : sendInfos) {
|
||||||
|
uint64_t pairKey = getRemoteSendPairKey(sendInfo.channelInfo);
|
||||||
|
if (!pairsNeedingReceiveReorder.contains(pairKey))
|
||||||
|
continue;
|
||||||
|
size_t targetCpu = static_cast<size_t>(sendInfo.channelInfo.targetCoreId - 1);
|
||||||
|
receiveQueuesByCpu[targetCpu][pairKey].push_back(
|
||||||
|
{sendInfo.channelInfo, sendInfo.consumer, sendInfo.inputIndex, sendInfo.sourceOrder});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto &cpuQueues : receiveQueuesByCpu) {
|
||||||
|
for (auto &pairQueue : cpuQueues.second) {
|
||||||
|
llvm::stable_sort(pairQueue.second, [](const RemoteReceiveEntry &lhs, const RemoteReceiveEntry &rhs) {
|
||||||
|
if (lhs.sourceOrder != rhs.sourceOrder)
|
||||||
|
return lhs.sourceOrder < rhs.sourceOrder;
|
||||||
|
return lhs.channelInfo.channelId < rhs.channelInfo.channelId;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void createCpuComputeOps() {
|
||||||
|
IRRewriter rewriter(func.getContext());
|
||||||
|
for (size_t cpu : orderedCpus) {
|
||||||
|
SmallVector<Value> operands;
|
||||||
|
operands.reserve(cpuWeights[cpu].size() + cpuExternalInputs[cpu].size());
|
||||||
|
llvm::append_range(operands, cpuWeights[cpu]);
|
||||||
|
llvm::append_range(operands, cpuExternalInputs[cpu]);
|
||||||
|
|
||||||
|
SmallVector<Type> resultTypes;
|
||||||
|
resultTypes.reserve(cpuExternalOutputs[cpu].size());
|
||||||
|
for (ProducerValueRef outputRef : cpuExternalOutputs[cpu]) {
|
||||||
|
ScheduledTask task = taskByKey.at(outputRef.instance);
|
||||||
|
resultTypes.push_back(getComputeInstanceOutputTypes(task.key)[outputRef.resultIndex]);
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.setInsertionPoint(returnOp);
|
||||||
|
auto newCompute = SpatCompute::create(rewriter, loc, TypeRange(resultTypes), ValueRange(operands));
|
||||||
|
newCompute.getProperties().setOperandSegmentSizes(
|
||||||
|
{static_cast<int>(cpuWeights[cpu].size()), static_cast<int>(cpuExternalInputs[cpu].size())});
|
||||||
|
newCompute->setAttr(onnx_mlir::kCoreIdAttrName, rewriter.getI32IntegerAttr(getPhysicalCoreId(cpu)));
|
||||||
|
|
||||||
|
SmallVector<Type> blockArgTypes;
|
||||||
|
SmallVector<Location> blockArgLocs;
|
||||||
|
blockArgTypes.reserve(cpuExternalInputs[cpu].size());
|
||||||
|
blockArgLocs.reserve(cpuExternalInputs[cpu].size());
|
||||||
|
for (Value input : cpuExternalInputs[cpu]) {
|
||||||
|
blockArgTypes.push_back(input.getType());
|
||||||
|
blockArgLocs.push_back(loc);
|
||||||
|
}
|
||||||
|
Block *newBlock =
|
||||||
|
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
||||||
|
|
||||||
|
CpuProgram program;
|
||||||
|
program.op = newCompute;
|
||||||
|
program.block = newBlock;
|
||||||
|
for (auto [weightIndex, weight] : llvm::enumerate(cpuWeights[cpu]))
|
||||||
|
program.weightToIndex[weight] = weightIndex;
|
||||||
|
for (auto [inputIndex, input] : llvm::enumerate(cpuExternalInputs[cpu]))
|
||||||
|
program.externalInputMap[input] = newBlock->getArgument(inputIndex);
|
||||||
|
for (auto [resultIndex, outputRef] : llvm::enumerate(cpuExternalOutputs[cpu])) {
|
||||||
|
ScheduledTask task = taskByKey.at(outputRef.instance);
|
||||||
|
oldToNewExternalValueMap[getComputeInstanceOutputValues(task.key)[outputRef.resultIndex]] =
|
||||||
|
newCompute.getResult(resultIndex);
|
||||||
|
}
|
||||||
|
cpuPrograms[cpu] = std::move(program);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
FailureOr<Value> receiveThroughInput(IRRewriter &rewriter,
|
||||||
|
size_t cpu,
|
||||||
|
DenseMap<uint64_t, size_t> &receiveQueueIndices,
|
||||||
|
DenseMap<ComputeInstance, SmallVector<Value>> &preReceivedInputsByTask,
|
||||||
|
const ChannelInfo &requestedChannelInfo,
|
||||||
|
ComputeInstance requestedConsumer,
|
||||||
|
size_t requestedInputIndex) {
|
||||||
|
uint64_t pairKey = getRemoteSendPairKey(requestedChannelInfo);
|
||||||
|
auto cpuQueuesIt = receiveQueuesByCpu.find(cpu);
|
||||||
|
if (cpuQueuesIt == receiveQueuesByCpu.end())
|
||||||
|
return failure();
|
||||||
|
auto queueIt = cpuQueuesIt->second.find(pairKey);
|
||||||
|
if (queueIt == cpuQueuesIt->second.end())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto &queue = queueIt->second;
|
||||||
|
size_t &queueIndex = receiveQueueIndices[pairKey];
|
||||||
|
while (queueIndex < queue.size()) {
|
||||||
|
const RemoteReceiveEntry &entry = queue[queueIndex++];
|
||||||
|
auto consumerTaskIt = taskByKey.find(entry.consumer);
|
||||||
|
if (consumerTaskIt == taskByKey.end())
|
||||||
|
return failure();
|
||||||
|
SmallVector<Value> consumerInputs = getComputeInstanceInputs(consumerTaskIt->second.key);
|
||||||
|
if (consumerInputs.size() <= entry.inputIndex)
|
||||||
|
return failure();
|
||||||
|
Type inputType = consumerInputs[entry.inputIndex].getType();
|
||||||
|
auto receive =
|
||||||
|
spatial::SpatChannelReceiveOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
inputType,
|
||||||
|
rewriter.getI64IntegerAttr(entry.channelInfo.channelId),
|
||||||
|
rewriter.getI32IntegerAttr(entry.channelInfo.sourceCoreId),
|
||||||
|
rewriter.getI32IntegerAttr(entry.channelInfo.targetCoreId));
|
||||||
|
|
||||||
|
auto &receivedInputs = preReceivedInputsByTask[entry.consumer];
|
||||||
|
if (receivedInputs.size() <= entry.inputIndex)
|
||||||
|
receivedInputs.resize(entry.inputIndex + 1);
|
||||||
|
receivedInputs[entry.inputIndex] = receive.getResult();
|
||||||
|
|
||||||
|
if (entry.consumer == requestedConsumer && entry.inputIndex == requestedInputIndex)
|
||||||
|
return receive.getResult();
|
||||||
|
}
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult cloneTaskBodies() {
|
||||||
|
for (size_t cpu : orderedCpus) {
|
||||||
|
CpuProgram &program = cpuPrograms[cpu];
|
||||||
|
IRRewriter rewriter(func.getContext());
|
||||||
|
rewriter.setInsertionPointToEnd(program.block);
|
||||||
|
DenseMap<uint64_t, size_t> receiveQueueIndices;
|
||||||
|
DenseMap<ComputeInstance, SmallVector<Value>> preReceivedInputsByTask;
|
||||||
|
|
||||||
|
auto lookupPreReceivedInput = [&](ComputeInstance consumer, size_t inputIndex) -> std::optional<Value> {
|
||||||
|
auto inputsIt = preReceivedInputsByTask.find(consumer);
|
||||||
|
if (inputsIt == preReceivedInputsByTask.end() || inputsIt->second.size() <= inputIndex)
|
||||||
|
return std::nullopt;
|
||||||
|
Value value = inputsIt->second[inputIndex];
|
||||||
|
if (!value)
|
||||||
|
return std::nullopt;
|
||||||
|
return value;
|
||||||
|
};
|
||||||
|
|
||||||
|
for (const ScheduledTask &task : tasksByCpu[cpu]) {
|
||||||
|
SmallVector<Value> taskInputs = getComputeInstanceInputs(task.key);
|
||||||
|
auto taskWeights = getComputeInstanceWeights(task.key);
|
||||||
|
Block &templateBlock = getComputeInstanceTemplateBlock(task.key);
|
||||||
|
|
||||||
|
SmallVector<Value> resolvedInputs;
|
||||||
|
resolvedInputs.reserve(taskInputs.size());
|
||||||
|
auto remoteInputsIt = remoteInputsByTask.find(task.key);
|
||||||
|
for (auto [inputIndex, input] : llvm::enumerate(taskInputs)) {
|
||||||
|
auto producerRef = getProducerValueRef(input);
|
||||||
|
if (producerRef) {
|
||||||
|
auto producerIt = taskByKey.find(producerRef->instance);
|
||||||
|
if (producerIt != taskByKey.end()) {
|
||||||
|
if (producerIt->second.cpu == cpu) {
|
||||||
|
auto producedIt = producedValuesByTask.find(producerRef->instance);
|
||||||
|
if (producedIt == producedValuesByTask.end() || producedIt->second.size() <= producerRef->resultIndex) {
|
||||||
|
task.sourceOp->emitOpError("missing local producer value during per-cpu merge materialization")
|
||||||
|
<< " consumerCpu=" << cpu << " consumerSlot=" << task.slot
|
||||||
|
<< " producerCpu=" << producerIt->second.cpu << " producerSlot=" << producerIt->second.slot
|
||||||
|
<< " producerLaneStart=" << producerRef->instance.laneStart
|
||||||
|
<< " producerLaneCount=" << producerRef->instance.laneCount;
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
resolvedInputs.push_back(producedIt->second[producerRef->resultIndex]);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
const ChannelInfo &channelInfo = *remoteInputsIt->second[inputIndex];
|
||||||
|
uint64_t pairKey = getRemoteSendPairKey(channelInfo);
|
||||||
|
if (pairsNeedingReceiveReorder.contains(pairKey)) {
|
||||||
|
if (std::optional<Value> preReceived = lookupPreReceivedInput(task.key, inputIndex)) {
|
||||||
|
resolvedInputs.push_back(*preReceived);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
FailureOr<Value> received = receiveThroughInput(
|
||||||
|
rewriter, cpu, receiveQueueIndices, preReceivedInputsByTask, channelInfo, task.key, inputIndex);
|
||||||
|
if (failed(received)) {
|
||||||
|
task.sourceOp->emitOpError("failed to materialize reordered remote receive")
|
||||||
|
<< " consumerCpu=" << cpu << " consumerSlot=" << task.slot
|
||||||
|
<< " sourceCoreId=" << channelInfo.sourceCoreId << " targetCoreId=" << channelInfo.targetCoreId
|
||||||
|
<< " channelId=" << channelInfo.channelId;
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
resolvedInputs.push_back(*received);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto receive =
|
||||||
|
spatial::SpatChannelReceiveOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
input.getType(),
|
||||||
|
rewriter.getI64IntegerAttr(channelInfo.channelId),
|
||||||
|
rewriter.getI32IntegerAttr(channelInfo.sourceCoreId),
|
||||||
|
rewriter.getI32IntegerAttr(channelInfo.targetCoreId));
|
||||||
|
resolvedInputs.push_back(receive.getResult());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
resolvedInputs.push_back(program.externalInputMap.at(input));
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Value> taskYieldValues;
|
||||||
|
rewriter.setInsertionPointToEnd(program.block);
|
||||||
|
if (isa<SpatCompute>(task.sourceOp)) {
|
||||||
|
IRMapping mapper;
|
||||||
|
for (auto [argIndex, oldArg] : llvm::enumerate(templateBlock.getArguments()))
|
||||||
|
mapper.map(oldArg, resolvedInputs[argIndex]);
|
||||||
|
|
||||||
|
for (Operation &op : templateBlock) {
|
||||||
|
if (auto yield = dyn_cast<spatial::SpatYieldOp>(&op)) {
|
||||||
|
for (Value yieldOperand : yield.getOperands())
|
||||||
|
taskYieldValues.push_back(mapper.lookup(yieldOperand));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
Operation *clonedOp = rewriter.clone(op, mapper);
|
||||||
|
if (auto oldWeightedMvmOp = dyn_cast<spatial::SpatMVMOp>(&op)) {
|
||||||
|
auto newWeightedMvmOp = cast<spatial::SpatMVMOp>(clonedOp);
|
||||||
|
Value weight = taskWeights[oldWeightedMvmOp.getWeightIndex()];
|
||||||
|
newWeightedMvmOp.setWeightIndex(program.weightToIndex.at(weight));
|
||||||
|
}
|
||||||
|
if (auto oldWeightedVmmOp = dyn_cast<spatial::SpatVMMOp>(&op)) {
|
||||||
|
auto newWeightedVmmOp = cast<spatial::SpatVMMOp>(clonedOp);
|
||||||
|
Value weight = taskWeights[oldWeightedVmmOp.getWeightIndex()];
|
||||||
|
newWeightedVmmOp.setWeightIndex(program.weightToIndex.at(weight));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (size_t laneOffset = 0; laneOffset < task.key.laneCount; ++laneOffset) {
|
||||||
|
IRMapping mapper;
|
||||||
|
if (templateBlock.getNumArguments() == 1)
|
||||||
|
mapper.map(templateBlock.getArgument(0), resolvedInputs[laneOffset]);
|
||||||
|
|
||||||
|
for (Operation &op : templateBlock) {
|
||||||
|
if (auto yield = dyn_cast<spatial::SpatYieldOp>(&op)) {
|
||||||
|
for (Value yieldOperand : yield.getOperands())
|
||||||
|
taskYieldValues.push_back(mapper.lookup(yieldOperand));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
Operation *clonedOp = rewriter.clone(op, mapper);
|
||||||
|
if (auto oldWeightedMvmOp = dyn_cast<spatial::SpatMVMOp>(&op)) {
|
||||||
|
if (oldWeightedMvmOp.getWeightIndex() != 0) {
|
||||||
|
task.sourceOp->emitOpError(
|
||||||
|
"batched per-cpu merge materialization expects lane-local weight index 0");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
auto newWeightedMvmOp = cast<spatial::SpatMVMOp>(clonedOp);
|
||||||
|
newWeightedMvmOp.setWeightIndex(program.weightToIndex.at(taskWeights[laneOffset]));
|
||||||
|
}
|
||||||
|
if (auto oldWeightedVmmOp = dyn_cast<spatial::SpatVMMOp>(&op)) {
|
||||||
|
if (oldWeightedVmmOp.getWeightIndex() != 0) {
|
||||||
|
task.sourceOp->emitOpError(
|
||||||
|
"batched per-cpu merge materialization expects lane-local weight index 0");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
auto newWeightedVmmOp = cast<spatial::SpatVMMOp>(clonedOp);
|
||||||
|
newWeightedVmmOp.setWeightIndex(program.weightToIndex.at(taskWeights[laneOffset]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
producedValuesByTask[task.key] = taskYieldValues;
|
||||||
|
if (auto sendsIt = remoteSendsByTask.find(task.key); sendsIt != remoteSendsByTask.end()) {
|
||||||
|
for (auto [resultIndex, sendInfos] : llvm::enumerate(sendsIt->second)) {
|
||||||
|
if (sendInfos.empty())
|
||||||
|
continue;
|
||||||
|
Value producedValue = taskYieldValues[resultIndex];
|
||||||
|
for (const RemoteSendInfo &sendInfo : sendInfos) {
|
||||||
|
spatial::SpatChannelSendOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
rewriter.getI64IntegerAttr(sendInfo.channelInfo.channelId),
|
||||||
|
rewriter.getI32IntegerAttr(sendInfo.channelInfo.sourceCoreId),
|
||||||
|
rewriter.getI32IntegerAttr(sendInfo.channelInfo.targetCoreId),
|
||||||
|
producedValue);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Value> yieldValues;
|
||||||
|
yieldValues.reserve(cpuExternalOutputs[cpu].size());
|
||||||
|
for (ProducerValueRef outputRef : cpuExternalOutputs[cpu]) {
|
||||||
|
auto producedIt = producedValuesByTask.find(outputRef.instance);
|
||||||
|
if (producedIt == producedValuesByTask.end() || producedIt->second.size() <= outputRef.resultIndex) {
|
||||||
|
ScheduledTask task = taskByKey.at(outputRef.instance);
|
||||||
|
task.sourceOp->emitOpError("missing yielded external value during per-cpu merge materialization")
|
||||||
|
<< " cpu=" << cpu << " slot=" << task.slot << " laneStart=" << outputRef.instance.laneStart;
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
yieldValues.push_back(producedIt->second[outputRef.resultIndex]);
|
||||||
|
}
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, ValueRange(yieldValues));
|
||||||
|
}
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
void replaceExternalUses() {
|
||||||
|
for (auto [oldValue, newValue] : oldToNewExternalValueMap) {
|
||||||
|
for (auto &use : llvm::make_early_inc_range(oldValue.getUses()))
|
||||||
|
if (!toEraseSet.contains(use.getOwner()))
|
||||||
|
use.assign(newValue);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult eraseOldScheduledOps() {
|
||||||
|
DenseSet<Operation *> allOpsToErase = toEraseSet;
|
||||||
|
for (Operation *op : internalInputOpsToErase)
|
||||||
|
allOpsToErase.insert(op);
|
||||||
|
|
||||||
|
SmallVector<Operation *> orderedOpsToErase;
|
||||||
|
for (Operation &op : func.getBody().front())
|
||||||
|
if (allOpsToErase.contains(&op))
|
||||||
|
orderedOpsToErase.push_back(&op);
|
||||||
|
|
||||||
|
for (Operation *op : llvm::reverse(orderedOpsToErase)) {
|
||||||
|
SmallVector<Operation *> remainingUsers;
|
||||||
|
for (Value result : op->getResults())
|
||||||
|
for (Operation *user : result.getUsers())
|
||||||
|
remainingUsers.push_back(user);
|
||||||
|
if (!remainingUsers.empty()) {
|
||||||
|
InFlightDiagnostic diagnostic = op->emitOpError("still has uses during per-cpu merge cleanup")
|
||||||
|
<< "; erase-set=" << (allOpsToErase.contains(op) ? "yes" : "no");
|
||||||
|
for (Operation *user : remainingUsers) {
|
||||||
|
diagnostic.attachNote(user->getLoc())
|
||||||
|
<< "remaining user " << user->getName() << "; erase-set=" << (allOpsToErase.contains(user) ? "yes" : "no");
|
||||||
|
}
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
op->erase();
|
||||||
|
}
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
void moveExternalUsersBeforeReturn() {
|
||||||
|
SmallVector<Operation *> orderedUsersToMove;
|
||||||
|
for (Operation &op : func.getBody().front()) {
|
||||||
|
if (&op == returnOp.getOperation())
|
||||||
|
break;
|
||||||
|
if (externalUsersToMove.contains(&op))
|
||||||
|
orderedUsersToMove.push_back(&op);
|
||||||
|
}
|
||||||
|
for (Operation *op : orderedUsersToMove)
|
||||||
|
op->moveBefore(returnOp);
|
||||||
|
}
|
||||||
|
|
||||||
|
func::FuncOp func;
|
||||||
|
const MergeScheduleResult *schedule = nullptr;
|
||||||
|
int64_t *nextChannelId = nullptr;
|
||||||
|
Location loc;
|
||||||
|
func::ReturnOp returnOp;
|
||||||
|
|
||||||
|
SmallVector<ScheduledTask> scheduledTasks;
|
||||||
|
DenseSet<Operation *> toEraseSet;
|
||||||
|
DenseMap<ComputeInstance, ScheduledTask> taskByKey;
|
||||||
|
DenseMap<size_t, SmallVector<ScheduledTask>> tasksByCpu;
|
||||||
|
SmallVector<size_t> orderedCpus;
|
||||||
|
DenseSet<size_t> seenCpus;
|
||||||
|
DenseSet<Operation *> internalInputOpsToErase;
|
||||||
|
DenseMap<Operation *, bool> isInternalInputOpCache;
|
||||||
|
DenseSet<Operation *> externalUsersToMove;
|
||||||
|
DenseMap<ComputeInstance, SmallVector<SmallVector<RemoteSendInfo>>> remoteSendsByTask;
|
||||||
|
DenseMap<ComputeInstance, SmallVector<std::optional<ChannelInfo>>> remoteInputsByTask;
|
||||||
|
DenseMap<size_t, SmallVector<Value>> cpuExternalInputs;
|
||||||
|
DenseMap<size_t, SmallVector<Value>> cpuWeights;
|
||||||
|
DenseMap<size_t, SmallVector<ProducerValueRef>> cpuExternalOutputs;
|
||||||
|
DenseMap<size_t, DenseSet<Value>> seenExternalInputsByCpu;
|
||||||
|
DenseMap<size_t, DenseSet<Value>> seenWeightsByCpu;
|
||||||
|
DenseSet<uint64_t> pairsNeedingReceiveReorder;
|
||||||
|
DenseMap<size_t, DenseMap<uint64_t, SmallVector<RemoteReceiveEntry>>> receiveQueuesByCpu;
|
||||||
|
DenseMap<size_t, CpuProgram> cpuPrograms;
|
||||||
|
DenseMap<Value, Value> oldToNewExternalValueMap;
|
||||||
|
DenseMap<ComputeInstance, SmallVector<Value>> producedValuesByTask;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
MergeScheduleMaterializer::run(func::FuncOp func, const MergeScheduleResult &schedule, int64_t &nextChannelId) {
|
||||||
|
return MergeScheduleMaterializerImpl(func).run(schedule, nextChannelId);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace spatial
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,18 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
|
#include "mlir/Support/LogicalResult.h"
|
||||||
|
|
||||||
|
#include "Scheduling/MergeSchedule.hpp"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace spatial {
|
||||||
|
|
||||||
|
class MergeScheduleMaterializer {
|
||||||
|
public:
|
||||||
|
mlir::LogicalResult
|
||||||
|
run(mlir::func::FuncOp func, const MergeScheduleResult &schedule, int64_t &nextChannelId);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace spatial
|
||||||
|
} // namespace onnx_mlir
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,459 @@
|
|||||||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
#include "mlir/IR/IRMapping.h"
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/DenseMap.h"
|
||||||
|
#include "llvm/ADT/DenseSet.h"
|
||||||
|
#include "llvm/ADT/Hashing.h"
|
||||||
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
#include "llvm/Support/FormatVariadic.h"
|
||||||
|
#include "llvm/Support/raw_ostream.h"
|
||||||
|
|
||||||
|
#include <chrono>
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <limits>
|
||||||
|
#include <optional>
|
||||||
|
|
||||||
|
#include "PostMergeCompaction.hpp"
|
||||||
|
#include "RegularOpCompaction.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using SpatCompute = spatial::SpatCompute;
|
||||||
|
using SpatComputeBatch = spatial::SpatComputeBatch;
|
||||||
|
|
||||||
|
bool isMergeProfilingEnabled() { return std::getenv("RAPTOR_PROFILE_MERGE") != nullptr; }
|
||||||
|
|
||||||
|
class ScopedMergePhaseTimer {
|
||||||
|
public:
|
||||||
|
explicit ScopedMergePhaseTimer(StringRef phaseName)
|
||||||
|
: enabled(isMergeProfilingEnabled()), phase(phaseName.str()) {
|
||||||
|
if (enabled)
|
||||||
|
start = std::chrono::steady_clock::now();
|
||||||
|
}
|
||||||
|
|
||||||
|
~ScopedMergePhaseTimer() {
|
||||||
|
if (!enabled)
|
||||||
|
return;
|
||||||
|
auto elapsed = std::chrono::steady_clock::now() - start;
|
||||||
|
double millis = std::chrono::duration<double, std::milli>(elapsed).count();
|
||||||
|
llvm::errs() << "[merge-profile] " << phase << ": " << llvm::formatv("{0:F3}", millis) << " ms\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool enabled = false;
|
||||||
|
std::string phase;
|
||||||
|
std::chrono::steady_clock::time_point start;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::optional<int32_t> getComputeCoreId(SpatCompute compute) {
|
||||||
|
if (auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
|
||||||
|
return static_cast<int32_t>(coreIdAttr.getInt());
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
|
static constexpr StringLiteral kRebatchPhaseAttrName = "_pim_rebatch_phase";
|
||||||
|
|
||||||
|
std::optional<uint64_t> getComputeRebatchPhase(SpatCompute compute) {
|
||||||
|
if (auto phaseAttr = compute->getAttrOfType<IntegerAttr>(kRebatchPhaseAttrName))
|
||||||
|
return static_cast<uint64_t>(phaseAttr.getInt());
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct RebatchKey {
|
||||||
|
unsigned inputCount = 0;
|
||||||
|
unsigned resultCount = 0;
|
||||||
|
unsigned weightCount = 0;
|
||||||
|
uint64_t phase = 0;
|
||||||
|
bool hasPhase = false;
|
||||||
|
uint64_t structureHash = 0;
|
||||||
|
|
||||||
|
bool operator==(const RebatchKey& other) const {
|
||||||
|
return inputCount == other.inputCount && resultCount == other.resultCount && weightCount == other.weightCount
|
||||||
|
&& phase == other.phase && hasPhase == other.hasPhase && structureHash == other.structureHash;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct RebatchKeyInfo {
|
||||||
|
static inline RebatchKey getEmptyKey() { return {std::numeric_limits<unsigned>::max(), 0, 0, 0, false, 0}; }
|
||||||
|
|
||||||
|
static inline RebatchKey getTombstoneKey() { return {std::numeric_limits<unsigned>::max() - 1, 0, 0, 0, false, 0}; }
|
||||||
|
|
||||||
|
static unsigned getHashValue(const RebatchKey& key) {
|
||||||
|
return static_cast<unsigned>(
|
||||||
|
llvm::hash_combine(key.inputCount, key.resultCount, key.weightCount, key.phase, key.hasPhase, key.structureHash));
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool isEqual(const RebatchKey& lhs, const RebatchKey& rhs) { return lhs == rhs; }
|
||||||
|
};
|
||||||
|
|
||||||
|
uint64_t getTypeHash(Type type) { return reinterpret_cast<uintptr_t>(type.getAsOpaquePointer()); }
|
||||||
|
|
||||||
|
uint64_t getValueHash(Value value) { return reinterpret_cast<uintptr_t>(value.getAsOpaquePointer()); }
|
||||||
|
|
||||||
|
uint64_t getAttributeHash(Attribute attr) { return reinterpret_cast<uintptr_t>(attr.getAsOpaquePointer()); }
|
||||||
|
|
||||||
|
RebatchKey computeRebatchKey(SpatCompute compute) {
|
||||||
|
llvm::hash_code structureHash =
|
||||||
|
llvm::hash_combine(compute.getInputs().size(), compute.getResultTypes().size(), compute.getWeights().size());
|
||||||
|
|
||||||
|
for (Value weight : compute.getWeights())
|
||||||
|
structureHash = llvm::hash_combine(structureHash, getValueHash(weight));
|
||||||
|
if (std::optional<uint64_t> phase = getComputeRebatchPhase(compute))
|
||||||
|
structureHash = llvm::hash_combine(structureHash, *phase);
|
||||||
|
|
||||||
|
Block& body = compute.getBody().front();
|
||||||
|
structureHash = llvm::hash_combine(structureHash, body.getNumArguments());
|
||||||
|
for (BlockArgument arg : body.getArguments())
|
||||||
|
structureHash = llvm::hash_combine(structureHash, getTypeHash(arg.getType()));
|
||||||
|
|
||||||
|
for (Operation& op : body) {
|
||||||
|
structureHash = llvm::hash_combine(
|
||||||
|
structureHash, op.getName().getStringRef(), op.getNumOperands(), op.getNumResults(), op.getNumRegions());
|
||||||
|
for (Type type : op.getResultTypes())
|
||||||
|
structureHash = llvm::hash_combine(structureHash, getTypeHash(type));
|
||||||
|
for (NamedAttribute attr : op.getAttrs())
|
||||||
|
structureHash = llvm::hash_combine(structureHash, attr.getName().strref(), getAttributeHash(attr.getValue()));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<uint64_t> phase = getComputeRebatchPhase(compute);
|
||||||
|
return {static_cast<unsigned>(compute.getInputs().size()),
|
||||||
|
static_cast<unsigned>(compute.getResultTypes().size()),
|
||||||
|
static_cast<unsigned>(compute.getWeights().size()),
|
||||||
|
phase.value_or(0),
|
||||||
|
phase.has_value(),
|
||||||
|
static_cast<uint64_t>(structureHash)};
|
||||||
|
}
|
||||||
|
|
||||||
|
bool areEquivalentForRebatch(SpatCompute lhs, SpatCompute rhs) {
|
||||||
|
if (!lhs || !rhs)
|
||||||
|
return false;
|
||||||
|
if (lhs.getInputs().size() != rhs.getInputs().size())
|
||||||
|
return false;
|
||||||
|
if (lhs.getResultTypes() != rhs.getResultTypes())
|
||||||
|
return false;
|
||||||
|
if (lhs.getWeights().size() != rhs.getWeights().size())
|
||||||
|
return false;
|
||||||
|
if (getComputeRebatchPhase(lhs) != getComputeRebatchPhase(rhs))
|
||||||
|
return false;
|
||||||
|
if (!llvm::equal(lhs.getWeights(), rhs.getWeights()))
|
||||||
|
return false;
|
||||||
|
|
||||||
|
auto& lhsBlock = lhs.getBody().front();
|
||||||
|
auto& rhsBlock = rhs.getBody().front();
|
||||||
|
if (lhsBlock.getNumArguments() != rhsBlock.getNumArguments())
|
||||||
|
return false;
|
||||||
|
|
||||||
|
DenseMap<Value, Value> mappedValues;
|
||||||
|
for (auto [lhsArg, rhsArg] : llvm::zip(lhsBlock.getArguments(), rhsBlock.getArguments())) {
|
||||||
|
if (lhsArg.getType() != rhsArg.getType())
|
||||||
|
return false;
|
||||||
|
mappedValues[lhsArg] = rhsArg;
|
||||||
|
}
|
||||||
|
auto lhsIt = lhsBlock.begin();
|
||||||
|
auto rhsIt = rhsBlock.begin();
|
||||||
|
for (; lhsIt != lhsBlock.end() && rhsIt != rhsBlock.end(); ++lhsIt, ++rhsIt) {
|
||||||
|
Operation& lhsOp = *lhsIt;
|
||||||
|
Operation& rhsOp = *rhsIt;
|
||||||
|
|
||||||
|
if (lhsOp.getName() != rhsOp.getName())
|
||||||
|
return false;
|
||||||
|
if (lhsOp.getNumOperands() != rhsOp.getNumOperands())
|
||||||
|
return false;
|
||||||
|
if (lhsOp.getNumResults() != rhsOp.getNumResults())
|
||||||
|
return false;
|
||||||
|
if (lhsOp.getNumRegions() != 0 || rhsOp.getNumRegions() != 0)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
for (auto [lhsOperand, rhsOperand] : llvm::zip(lhsOp.getOperands(), rhsOp.getOperands())) {
|
||||||
|
auto mapped = mappedValues.find(lhsOperand);
|
||||||
|
if (mapped != mappedValues.end()) {
|
||||||
|
if (mapped->second != rhsOperand)
|
||||||
|
return false;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (lhsOperand != rhsOperand)
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto lhsReceive = dyn_cast<spatial::SpatChannelReceiveOp>(lhsOp)) {
|
||||||
|
auto rhsReceive = cast<spatial::SpatChannelReceiveOp>(rhsOp);
|
||||||
|
if (lhsReceive.getOutput().getType() != rhsReceive.getOutput().getType())
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
else if (auto lhsSend = dyn_cast<spatial::SpatChannelSendOp>(lhsOp)) {
|
||||||
|
auto rhsSend = cast<spatial::SpatChannelSendOp>(rhsOp);
|
||||||
|
if (lhsSend.getInput().getType() != rhsSend.getInput().getType())
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
else if (lhsOp.getAttrs() != rhsOp.getAttrs()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (lhsOp.getResultTypes() != rhsOp.getResultTypes())
|
||||||
|
return false;
|
||||||
|
for (auto [lhsResult, rhsResult] : llvm::zip(lhsOp.getResults(), rhsOp.getResults()))
|
||||||
|
mappedValues[lhsResult] = rhsResult;
|
||||||
|
}
|
||||||
|
|
||||||
|
return lhsIt == lhsBlock.end() && rhsIt == rhsBlock.end();
|
||||||
|
}
|
||||||
|
|
||||||
|
void rebatchEquivalentComputes(func::FuncOp funcOp) {
|
||||||
|
IRRewriter rewriter(funcOp.getContext());
|
||||||
|
SmallVector<SpatCompute> computes(funcOp.getOps<SpatCompute>());
|
||||||
|
DenseSet<Operation*> consumed;
|
||||||
|
DenseMap<Operation*, size_t> computeOrder;
|
||||||
|
DenseMap<RebatchKey, SmallVector<SpatCompute>, RebatchKeyInfo> candidatesByKey;
|
||||||
|
|
||||||
|
for (auto [index, compute] : llvm::enumerate(computes)) {
|
||||||
|
computeOrder[compute.getOperation()] = index;
|
||||||
|
if (compute.getInputs().size() <= 1 && compute.getResults().empty())
|
||||||
|
candidatesByKey[computeRebatchKey(compute)].push_back(compute);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t index = 0; index < computes.size(); ++index) {
|
||||||
|
auto anchor = computes[index];
|
||||||
|
if (consumed.contains(anchor))
|
||||||
|
continue;
|
||||||
|
if (anchor.getInputs().size() > 1)
|
||||||
|
continue;
|
||||||
|
if (!anchor.getResults().empty())
|
||||||
|
continue;
|
||||||
|
|
||||||
|
SmallVector<SpatCompute> group {anchor};
|
||||||
|
llvm::SmallDenseSet<int32_t, 8> usedCoreIds;
|
||||||
|
if (auto coreId = getComputeCoreId(anchor))
|
||||||
|
usedCoreIds.insert(*coreId);
|
||||||
|
|
||||||
|
auto bucketIt = candidatesByKey.find(computeRebatchKey(anchor));
|
||||||
|
if (bucketIt == candidatesByKey.end())
|
||||||
|
continue;
|
||||||
|
|
||||||
|
for (auto candidate : bucketIt->second) {
|
||||||
|
if (computeOrder.lookup(candidate.getOperation()) <= index)
|
||||||
|
continue;
|
||||||
|
if (consumed.contains(candidate))
|
||||||
|
continue;
|
||||||
|
if (!areEquivalentForRebatch(anchor, candidate))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
if (auto coreId = getComputeCoreId(candidate))
|
||||||
|
if (!usedCoreIds.insert(*coreId).second)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
group.push_back(candidate);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (group.size() <= 1)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
auto insertionAnchor = group.front();
|
||||||
|
if (llvm::all_of(group, [](SpatCompute compute) { return getComputeCoreId(compute).has_value(); })) {
|
||||||
|
llvm::stable_sort(
|
||||||
|
group, [](SpatCompute lhs, SpatCompute rhs) { return *getComputeCoreId(lhs) < *getComputeCoreId(rhs); });
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Value> weights;
|
||||||
|
weights.reserve(group.size() * anchor.getWeights().size());
|
||||||
|
SmallVector<Value> inputs;
|
||||||
|
inputs.reserve(group.size() * anchor.getInputs().size());
|
||||||
|
SmallVector<int32_t> coreIds;
|
||||||
|
coreIds.reserve(group.size());
|
||||||
|
bool haveAllCoreIds = true;
|
||||||
|
for (auto compute : group) {
|
||||||
|
llvm::append_range(weights, compute.getWeights());
|
||||||
|
llvm::append_range(inputs, compute.getInputs());
|
||||||
|
auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName);
|
||||||
|
if (!coreIdAttr)
|
||||||
|
haveAllCoreIds = false;
|
||||||
|
else if (haveAllCoreIds)
|
||||||
|
coreIds.push_back(static_cast<int32_t>(coreIdAttr.getInt()));
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.setInsertionPoint(insertionAnchor);
|
||||||
|
auto rebatched = SpatComputeBatch::create(rewriter,
|
||||||
|
insertionAnchor.getLoc(),
|
||||||
|
TypeRange {},
|
||||||
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(group.size())),
|
||||||
|
ValueRange(weights),
|
||||||
|
ValueRange(inputs));
|
||||||
|
rebatched.getProperties().setOperandSegmentSizes(
|
||||||
|
{static_cast<int>(weights.size()), static_cast<int>(inputs.size())});
|
||||||
|
if (haveAllCoreIds)
|
||||||
|
rebatched->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds));
|
||||||
|
|
||||||
|
SmallVector<Type> blockArgTypes;
|
||||||
|
SmallVector<Location> blockArgLocs;
|
||||||
|
for (BlockArgument arg : anchor.getBody().front().getArguments()) {
|
||||||
|
blockArgTypes.push_back(arg.getType());
|
||||||
|
blockArgLocs.push_back(arg.getLoc());
|
||||||
|
}
|
||||||
|
auto* newBlock =
|
||||||
|
rewriter.createBlock(&rebatched.getBody(), rebatched.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
||||||
|
rewriter.setInsertionPointToEnd(newBlock);
|
||||||
|
|
||||||
|
IRMapping mapper;
|
||||||
|
auto& anchorBlock = anchor.getBody().front();
|
||||||
|
for (auto [oldArg, newArg] : llvm::zip(anchorBlock.getArguments(), newBlock->getArguments()))
|
||||||
|
mapper.map(oldArg, newArg);
|
||||||
|
auto opIts = llvm::map_to_vector(group, [](SpatCompute compute) { return compute.getBody().front().begin(); });
|
||||||
|
for (Operation& anchorOp : anchorBlock) {
|
||||||
|
if (auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(&anchorOp)) {
|
||||||
|
struct BatchReceiveEntry {
|
||||||
|
uint64_t channelId = 0;
|
||||||
|
uint32_t sourceCoreId = 0;
|
||||||
|
uint32_t targetCoreId = 0;
|
||||||
|
};
|
||||||
|
SmallVector<BatchReceiveEntry> entries;
|
||||||
|
entries.reserve(group.size());
|
||||||
|
for (auto [groupIndex, compute] : llvm::enumerate(group)) {
|
||||||
|
auto groupReceive = cast<spatial::SpatChannelReceiveOp>(&*opIts[groupIndex]);
|
||||||
|
entries.push_back(
|
||||||
|
{groupReceive.getChannelId(), groupReceive.getSourceCoreId(), groupReceive.getTargetCoreId()});
|
||||||
|
++opIts[groupIndex];
|
||||||
|
}
|
||||||
|
SmallVector<int64_t> channelIds;
|
||||||
|
SmallVector<int32_t> sourceCoreIds;
|
||||||
|
SmallVector<int32_t> targetCoreIds;
|
||||||
|
channelIds.reserve(group.size());
|
||||||
|
sourceCoreIds.reserve(group.size());
|
||||||
|
targetCoreIds.reserve(group.size());
|
||||||
|
for (const BatchReceiveEntry& entry : entries) {
|
||||||
|
channelIds.push_back(static_cast<int64_t>(entry.channelId));
|
||||||
|
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
|
||||||
|
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
|
||||||
|
}
|
||||||
|
auto batchReceive = spatial::SpatChannelReceiveBatchOp::create(rewriter,
|
||||||
|
receiveOp.getLoc(),
|
||||||
|
receiveOp.getOutput().getType(),
|
||||||
|
rewriter.getDenseI64ArrayAttr(channelIds),
|
||||||
|
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
||||||
|
rewriter.getDenseI32ArrayAttr(targetCoreIds));
|
||||||
|
mapper.map(receiveOp.getOutput(), batchReceive.getOutput());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto sendOp = dyn_cast<spatial::SpatChannelSendOp>(&anchorOp)) {
|
||||||
|
struct BatchSendEntry {
|
||||||
|
uint64_t channelId = 0;
|
||||||
|
uint32_t sourceCoreId = 0;
|
||||||
|
uint32_t targetCoreId = 0;
|
||||||
|
};
|
||||||
|
SmallVector<BatchSendEntry> entries;
|
||||||
|
entries.reserve(group.size());
|
||||||
|
for (auto [groupIndex, compute] : llvm::enumerate(group)) {
|
||||||
|
auto groupSend = cast<spatial::SpatChannelSendOp>(&*opIts[groupIndex]);
|
||||||
|
entries.push_back({groupSend.getChannelId(), groupSend.getSourceCoreId(), groupSend.getTargetCoreId()});
|
||||||
|
++opIts[groupIndex];
|
||||||
|
}
|
||||||
|
SmallVector<int64_t> channelIds;
|
||||||
|
SmallVector<int32_t> sourceCoreIds;
|
||||||
|
SmallVector<int32_t> targetCoreIds;
|
||||||
|
channelIds.reserve(group.size());
|
||||||
|
sourceCoreIds.reserve(group.size());
|
||||||
|
targetCoreIds.reserve(group.size());
|
||||||
|
for (const BatchSendEntry& entry : entries) {
|
||||||
|
channelIds.push_back(static_cast<int64_t>(entry.channelId));
|
||||||
|
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
|
||||||
|
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
|
||||||
|
}
|
||||||
|
spatial::SpatChannelSendBatchOp::create(rewriter,
|
||||||
|
sendOp.getLoc(),
|
||||||
|
rewriter.getDenseI64ArrayAttr(channelIds),
|
||||||
|
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
||||||
|
rewriter.getDenseI32ArrayAttr(targetCoreIds),
|
||||||
|
mapper.lookup(sendOp.getInput()));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isa<spatial::SpatYieldOp>(anchorOp)) {
|
||||||
|
for (auto& opIt : opIts)
|
||||||
|
++opIt;
|
||||||
|
spatial::SpatYieldOp::create(rewriter, anchorOp.getLoc(), ValueRange {});
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
Operation* cloned = rewriter.clone(anchorOp, mapper);
|
||||||
|
for (auto [originalResult, clonedResult] : llvm::zip(anchorOp.getResults(), cloned->getResults()))
|
||||||
|
mapper.map(originalResult, clonedResult);
|
||||||
|
for (auto& opIt : opIts)
|
||||||
|
++opIt;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto compute : group) {
|
||||||
|
compute->removeAttr(kRebatchPhaseAttrName);
|
||||||
|
consumed.insert(compute);
|
||||||
|
rewriter.eraseOp(compute);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto compute : funcOp.getOps<SpatCompute>())
|
||||||
|
compute->removeAttr(kRebatchPhaseAttrName);
|
||||||
|
}
|
||||||
|
|
||||||
|
void cleanupDeadPackingOps(func::FuncOp funcOp) {
|
||||||
|
auto eraseUnusedOps = [&](auto tag) {
|
||||||
|
using OpTy = decltype(tag);
|
||||||
|
SmallVector<OpTy> ops;
|
||||||
|
funcOp.walk([&](OpTy op) { ops.push_back(op); });
|
||||||
|
for (auto op : llvm::reverse(ops))
|
||||||
|
if (op->use_empty())
|
||||||
|
op.erase();
|
||||||
|
};
|
||||||
|
eraseUnusedOps(tensor::ExtractSliceOp {});
|
||||||
|
eraseUnusedOps(spatial::SpatConcatOp {});
|
||||||
|
eraseUnusedOps(spatial::SpatExtractRowsOp {});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
LogicalResult runPostMergeCompactionPipeline(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||||
|
{
|
||||||
|
ScopedMergePhaseTimer timer("order-bilateral-channel-ops");
|
||||||
|
orderBilateralChannelOps(funcOp);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
ScopedMergePhaseTimer timer("rebatch-equivalent-computes");
|
||||||
|
rebatchEquivalentComputes(funcOp);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
ScopedMergePhaseTimer timer("compact-scalar-channel-runs-1");
|
||||||
|
compactScalarChannelRuns(funcOp, nextChannelId);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
ScopedMergePhaseTimer timer("compact-batch-channel-runs-1");
|
||||||
|
compactBatchChannelRuns(funcOp);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
ScopedMergePhaseTimer timer("compact-regular-op-runs");
|
||||||
|
compactRegularOpRuns(funcOp);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
ScopedMergePhaseTimer timer("compact-row-wise-wvmm-runs");
|
||||||
|
compactRowWiseWvmmRuns(funcOp);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
ScopedMergePhaseTimer timer("compact-scalar-channel-runs-2");
|
||||||
|
compactScalarChannelRuns(funcOp, nextChannelId);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
ScopedMergePhaseTimer timer("compact-batch-channel-runs-2");
|
||||||
|
compactBatchChannelRuns(funcOp);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
ScopedMergePhaseTimer timer("cleanup-dead-packing-ops");
|
||||||
|
cleanupDeadPackingOps(funcOp);
|
||||||
|
}
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,12 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
|
#include "mlir/Support/LogicalResult.h"
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
mlir::LogicalResult runPostMergeCompactionPipeline(mlir::func::FuncOp funcOp, int64_t &nextChannelId);
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -7,6 +7,8 @@
|
|||||||
#include "mlir/IR/Value.h"
|
#include "mlir/IR/Value.h"
|
||||||
#include "mlir/Support/LLVM.h"
|
#include "mlir/Support/LLVM.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/DenseMap.h"
|
||||||
|
#include "llvm/ADT/DenseSet.h"
|
||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/ADT/STLExtras.h"
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
@@ -41,6 +43,47 @@ struct RegularChunk {
|
|||||||
Value output;
|
Value output;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct RegularCompactionResult {
|
||||||
|
bool changed = false;
|
||||||
|
Operation* resumeAfter = nullptr;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename OpTy>
|
||||||
|
struct ConsecutiveRun {
|
||||||
|
SmallVector<OpTy> ops;
|
||||||
|
Block::iterator end;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename OpTy, typename Predicate>
|
||||||
|
static ConsecutiveRun<OpTy>
|
||||||
|
collectConsecutiveRun(Block::iterator start, Block::iterator blockEnd, Predicate predicate) {
|
||||||
|
ConsecutiveRun<OpTy> run;
|
||||||
|
run.end = start;
|
||||||
|
while (run.end != blockEnd) {
|
||||||
|
auto current = dyn_cast<OpTy>(&*run.end);
|
||||||
|
if (!current || !predicate(current))
|
||||||
|
break;
|
||||||
|
run.ops.push_back(current);
|
||||||
|
++run.end;
|
||||||
|
}
|
||||||
|
return run;
|
||||||
|
}
|
||||||
|
|
||||||
|
static uint64_t getEndpointKey(uint32_t sourceCoreId, uint32_t targetCoreId) {
|
||||||
|
return (static_cast<uint64_t>(sourceCoreId) << 32) | static_cast<uint64_t>(targetCoreId);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void appendChannelAttrs(SmallVectorImpl<int64_t>& channelIds,
|
||||||
|
SmallVectorImpl<int32_t>& sourceCoreIds,
|
||||||
|
SmallVectorImpl<int32_t>& targetCoreIds,
|
||||||
|
uint64_t channelId,
|
||||||
|
uint32_t sourceCoreId,
|
||||||
|
uint32_t targetCoreId) {
|
||||||
|
channelIds.push_back(static_cast<int64_t>(channelId));
|
||||||
|
sourceCoreIds.push_back(static_cast<int32_t>(sourceCoreId));
|
||||||
|
targetCoreIds.push_back(static_cast<int32_t>(targetCoreId));
|
||||||
|
}
|
||||||
|
|
||||||
static spatial::SpatConcatOp getContiguousConcatUse(ValueRange values, unsigned& startOperandIndex) {
|
static spatial::SpatConcatOp getContiguousConcatUse(ValueRange values, unsigned& startOperandIndex) {
|
||||||
if (values.empty() || !values.front().hasOneUse())
|
if (values.empty() || !values.front().hasOneUse())
|
||||||
return {};
|
return {};
|
||||||
@@ -212,9 +255,10 @@ static FailureOr<RegularChunk> analyzeRegularChunk(spatial::SpatVMMOp startOp) {
|
|||||||
return chunk;
|
return chunk;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void compactRegularChunkRun(IRRewriter& rewriter, ArrayRef<RegularChunk> run) {
|
static RegularCompactionResult compactRegularChunkRun(IRRewriter& rewriter, ArrayRef<RegularChunk> run) {
|
||||||
assert(!run.empty() && "expected a non-empty regular chunk run");
|
assert(!run.empty() && "expected a non-empty regular chunk run");
|
||||||
const RegularChunk& anchorChunk = run.front();
|
const RegularChunk& anchorChunk = run.front();
|
||||||
|
RegularCompactionResult result;
|
||||||
|
|
||||||
SmallVector<Value> inputs;
|
SmallVector<Value> inputs;
|
||||||
inputs.reserve(run.size());
|
inputs.reserve(run.size());
|
||||||
@@ -224,7 +268,7 @@ static void compactRegularChunkRun(IRRewriter& rewriter, ArrayRef<RegularChunk>
|
|||||||
rewriter.setInsertionPoint(anchorChunk.startOp);
|
rewriter.setInsertionPoint(anchorChunk.startOp);
|
||||||
Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, anchorChunk.startOp->getLoc());
|
Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, anchorChunk.startOp->getLoc());
|
||||||
if (!packedInput)
|
if (!packedInput)
|
||||||
return;
|
return result;
|
||||||
|
|
||||||
auto inputType = cast<RankedTensorType>(anchorChunk.input.getType());
|
auto inputType = cast<RankedTensorType>(anchorChunk.input.getType());
|
||||||
auto outputType = cast<RankedTensorType>(anchorChunk.output.getType());
|
auto outputType = cast<RankedTensorType>(anchorChunk.output.getType());
|
||||||
@@ -327,6 +371,10 @@ static void compactRegularChunkRun(IRRewriter& rewriter, ArrayRef<RegularChunk>
|
|||||||
llvm::append_range(opsToErase, chunk.ops);
|
llvm::append_range(opsToErase, chunk.ops);
|
||||||
for (Operation* op : llvm::reverse(opsToErase))
|
for (Operation* op : llvm::reverse(opsToErase))
|
||||||
rewriter.eraseOp(op);
|
rewriter.eraseOp(op);
|
||||||
|
|
||||||
|
result.changed = true;
|
||||||
|
result.resumeAfter = loop.getOperation()->getNextNode();
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
@@ -340,27 +388,28 @@ void orderBilateralChannelOps(func::FuncOp funcOp) {
|
|||||||
int32_t coreId = static_cast<int32_t>(coreIdAttr.getInt());
|
int32_t coreId = static_cast<int32_t>(coreIdAttr.getInt());
|
||||||
Block& block = compute.getBody().front();
|
Block& block = compute.getBody().front();
|
||||||
SmallVector<std::pair<spatial::SpatChannelReceiveOp, Operation*>> moves;
|
SmallVector<std::pair<spatial::SpatChannelReceiveOp, Operation*>> moves;
|
||||||
|
DenseMap<uint64_t, Operation*> firstForwardedSendByEndpoint;
|
||||||
|
|
||||||
for (Operation& op : block) {
|
for (Operation& op : block) {
|
||||||
|
if (auto sendOp = dyn_cast<spatial::SpatChannelSendOp>(&op)) {
|
||||||
|
if (sendOp.getSourceCoreId() == static_cast<uint32_t>(coreId)
|
||||||
|
&& isForwardedChannelPayload(sendOp.getInput(), block)) {
|
||||||
|
uint64_t key = getEndpointKey(sendOp.getSourceCoreId(), sendOp.getTargetCoreId());
|
||||||
|
firstForwardedSendByEndpoint.try_emplace(key, sendOp.getOperation());
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(&op);
|
auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(&op);
|
||||||
if (!receiveOp || receiveOp.getTargetCoreId() != static_cast<uint32_t>(coreId)
|
if (!receiveOp || receiveOp.getTargetCoreId() != static_cast<uint32_t>(coreId)
|
||||||
|| receiveOp.getSourceCoreId() >= static_cast<uint32_t>(coreId)) {
|
|| receiveOp.getSourceCoreId() >= static_cast<uint32_t>(coreId)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
Operation* firstMatchingSend = nullptr;
|
uint64_t key = getEndpointKey(static_cast<uint32_t>(coreId), receiveOp.getSourceCoreId());
|
||||||
for (Operation* previous = receiveOp->getPrevNode(); previous; previous = previous->getPrevNode()) {
|
auto firstMatchingSend = firstForwardedSendByEndpoint.find(key);
|
||||||
auto sendOp = dyn_cast<spatial::SpatChannelSendOp>(previous);
|
if (firstMatchingSend != firstForwardedSendByEndpoint.end())
|
||||||
if (!sendOp || sendOp.getSourceCoreId() != static_cast<uint32_t>(coreId)
|
moves.push_back({receiveOp, firstMatchingSend->second});
|
||||||
|| sendOp.getTargetCoreId() != receiveOp.getSourceCoreId()
|
|
||||||
|| !isForwardedChannelPayload(sendOp.getInput(), block)) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
firstMatchingSend = sendOp.getOperation();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (firstMatchingSend)
|
|
||||||
moves.push_back({receiveOp, firstMatchingSend});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto [receiveOp, insertionPoint] : moves)
|
for (auto [receiveOp, insertionPoint] : moves)
|
||||||
@@ -373,30 +422,24 @@ void orderBilateralChannelOps(func::FuncOp funcOp) {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<spatial::SpatChannelReceiveOp> run;
|
|
||||||
Type outputType = receiveOp.getOutput().getType();
|
Type outputType = receiveOp.getOutput().getType();
|
||||||
auto runIt = it;
|
auto run = collectConsecutiveRun<spatial::SpatChannelReceiveOp>(
|
||||||
while (runIt != block.end()) {
|
it, block.end(), [&](spatial::SpatChannelReceiveOp current) {
|
||||||
auto current = dyn_cast<spatial::SpatChannelReceiveOp>(&*runIt);
|
return current.getOutput().getType() == outputType
|
||||||
if (!current || current.getOutput().getType() != outputType
|
&& current.getSourceCoreId() < static_cast<uint32_t>(coreId);
|
||||||
|| current.getSourceCoreId() >= static_cast<uint32_t>(coreId)) {
|
});
|
||||||
break;
|
|
||||||
}
|
|
||||||
run.push_back(current);
|
|
||||||
++runIt;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (run.size() > 1) {
|
if (run.ops.size() > 1) {
|
||||||
SmallVector<spatial::SpatChannelReceiveOp> sorted(run);
|
SmallVector<spatial::SpatChannelReceiveOp> sorted(run.ops);
|
||||||
llvm::stable_sort(sorted, [](spatial::SpatChannelReceiveOp lhs, spatial::SpatChannelReceiveOp rhs) {
|
llvm::stable_sort(sorted, [](spatial::SpatChannelReceiveOp lhs, spatial::SpatChannelReceiveOp rhs) {
|
||||||
return lhs.getSourceCoreId() > rhs.getSourceCoreId();
|
return lhs.getSourceCoreId() > rhs.getSourceCoreId();
|
||||||
});
|
});
|
||||||
Block::iterator insertIt = runIt;
|
Block::iterator insertIt = run.end;
|
||||||
for (auto op : sorted)
|
for (auto op : sorted)
|
||||||
op->moveBefore(&block, insertIt);
|
op->moveBefore(&block, insertIt);
|
||||||
}
|
}
|
||||||
|
|
||||||
it = runIt;
|
it = run.end;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -409,29 +452,23 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
|||||||
for (auto it = block.begin(); it != block.end();) {
|
for (auto it = block.begin(); it != block.end();) {
|
||||||
auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(&*it);
|
auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(&*it);
|
||||||
if (receiveOp) {
|
if (receiveOp) {
|
||||||
SmallVector<spatial::SpatChannelReceiveOp> run;
|
|
||||||
Type outputType = receiveOp.getOutput().getType();
|
Type outputType = receiveOp.getOutput().getType();
|
||||||
auto runIt = it;
|
auto run = collectConsecutiveRun<spatial::SpatChannelReceiveOp>(
|
||||||
while (runIt != block.end()) {
|
it, block.end(), [&](spatial::SpatChannelReceiveOp current) {
|
||||||
auto current = dyn_cast<spatial::SpatChannelReceiveOp>(&*runIt);
|
return current.getOutput().getType() == outputType;
|
||||||
if (!current || current.getOutput().getType() != outputType)
|
});
|
||||||
break;
|
|
||||||
run.push_back(current);
|
|
||||||
++runIt;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool hasRepeatedEndpoint = false;
|
bool hasRepeatedEndpoint = false;
|
||||||
for (size_t lhs = 0; lhs < run.size() && !hasRepeatedEndpoint; ++lhs) {
|
DenseSet<uint64_t> seenEndpoints;
|
||||||
for (size_t rhs = lhs + 1; rhs < run.size(); ++rhs) {
|
for (auto op : run.ops) {
|
||||||
if (run[lhs].getSourceCoreId() == run[rhs].getSourceCoreId()
|
uint64_t endpointKey = getEndpointKey(op.getSourceCoreId(), op.getTargetCoreId());
|
||||||
&& run[lhs].getTargetCoreId() == run[rhs].getTargetCoreId()) {
|
if (!seenEndpoints.insert(endpointKey).second) {
|
||||||
hasRepeatedEndpoint = true;
|
hasRepeatedEndpoint = true;
|
||||||
break;
|
break;
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (run.size() > 1 && !hasRepeatedEndpoint) {
|
if (run.ops.size() > 1 && !hasRepeatedEndpoint) {
|
||||||
struct ReceiveEntry {
|
struct ReceiveEntry {
|
||||||
spatial::SpatChannelReceiveOp op;
|
spatial::SpatChannelReceiveOp op;
|
||||||
size_t originalIndex = 0;
|
size_t originalIndex = 0;
|
||||||
@@ -440,8 +477,8 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
|||||||
uint64_t channelId = 0;
|
uint64_t channelId = 0;
|
||||||
};
|
};
|
||||||
SmallVector<ReceiveEntry> sortedEntries;
|
SmallVector<ReceiveEntry> sortedEntries;
|
||||||
sortedEntries.reserve(run.size());
|
sortedEntries.reserve(run.ops.size());
|
||||||
for (auto [originalIndex, op] : llvm::enumerate(run))
|
for (auto [originalIndex, op] : llvm::enumerate(run.ops))
|
||||||
sortedEntries.push_back({op, originalIndex, op.getSourceCoreId(), op.getTargetCoreId(), op.getChannelId()});
|
sortedEntries.push_back({op, originalIndex, op.getSourceCoreId(), op.getTargetCoreId(), op.getChannelId()});
|
||||||
|
|
||||||
SmallVector<int64_t> channelIds;
|
SmallVector<int64_t> channelIds;
|
||||||
@@ -451,12 +488,11 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
|||||||
sourceCoreIds.reserve(sortedEntries.size());
|
sourceCoreIds.reserve(sortedEntries.size());
|
||||||
targetCoreIds.reserve(sortedEntries.size());
|
targetCoreIds.reserve(sortedEntries.size());
|
||||||
for (ReceiveEntry& entry : sortedEntries) {
|
for (ReceiveEntry& entry : sortedEntries) {
|
||||||
channelIds.push_back(static_cast<int64_t>(entry.channelId));
|
appendChannelAttrs(
|
||||||
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
|
channelIds, sourceCoreIds, targetCoreIds, entry.channelId, entry.sourceCoreId, entry.targetCoreId);
|
||||||
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
auto rowType = cast<RankedTensorType>(run.front().getOutput().getType());
|
auto rowType = cast<RankedTensorType>(run.ops.front().getOutput().getType());
|
||||||
auto fallbackPackedType = getPackedTensorType(rowType, static_cast<int64_t>(sortedEntries.size()));
|
auto fallbackPackedType = getPackedTensorType(rowType, static_cast<int64_t>(sortedEntries.size()));
|
||||||
SmallVector<Value> sortedOutputs;
|
SmallVector<Value> sortedOutputs;
|
||||||
sortedOutputs.reserve(sortedEntries.size());
|
sortedOutputs.reserve(sortedEntries.size());
|
||||||
@@ -469,10 +505,10 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
|||||||
concatOp ? getPackedConcatSliceType(concatOp, concatStartIndex, static_cast<unsigned>(sortedOutputs.size()))
|
concatOp ? getPackedConcatSliceType(concatOp, concatStartIndex, static_cast<unsigned>(sortedOutputs.size()))
|
||||||
: RankedTensorType {};
|
: RankedTensorType {};
|
||||||
auto packedType = concatPackedType ? concatPackedType : fallbackPackedType;
|
auto packedType = concatPackedType ? concatPackedType : fallbackPackedType;
|
||||||
rewriter.setInsertionPoint(run.front());
|
rewriter.setInsertionPoint(run.ops.front());
|
||||||
auto compactReceive =
|
auto compactReceive =
|
||||||
spatial::SpatChannelReceiveTensorOp::create(rewriter,
|
spatial::SpatChannelReceiveTensorOp::create(rewriter,
|
||||||
run.front().getLoc(),
|
run.ops.front().getLoc(),
|
||||||
packedType,
|
packedType,
|
||||||
rewriter.getDenseI64ArrayAttr(channelIds),
|
rewriter.getDenseI64ArrayAttr(channelIds),
|
||||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
||||||
@@ -489,7 +525,7 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
|||||||
entry.op.getOutput().replaceAllUsesWith(extractPackedChunk(
|
entry.op.getOutput().replaceAllUsesWith(extractPackedChunk(
|
||||||
compactReceive.getOutput(), rowType, static_cast<unsigned>(sortedIndex), rewriter, entry.op.getLoc()));
|
compactReceive.getOutput(), rowType, static_cast<unsigned>(sortedIndex), rewriter, entry.op.getLoc()));
|
||||||
}
|
}
|
||||||
for (auto op : run)
|
for (auto op : run.ops)
|
||||||
rewriter.eraseOp(op);
|
rewriter.eraseOp(op);
|
||||||
|
|
||||||
it = compactReceive->getIterator();
|
it = compactReceive->getIterator();
|
||||||
@@ -500,18 +536,13 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
|||||||
|
|
||||||
auto sendOp = dyn_cast<spatial::SpatChannelSendOp>(&*it);
|
auto sendOp = dyn_cast<spatial::SpatChannelSendOp>(&*it);
|
||||||
if (sendOp) {
|
if (sendOp) {
|
||||||
SmallVector<spatial::SpatChannelSendOp> run;
|
|
||||||
Type inputType = sendOp.getInput().getType();
|
Type inputType = sendOp.getInput().getType();
|
||||||
auto runIt = it;
|
auto run =
|
||||||
while (runIt != block.end()) {
|
collectConsecutiveRun<spatial::SpatChannelSendOp>(it, block.end(), [&](spatial::SpatChannelSendOp current) {
|
||||||
auto current = dyn_cast<spatial::SpatChannelSendOp>(&*runIt);
|
return current.getInput().getType() == inputType;
|
||||||
if (!current || current.getInput().getType() != inputType)
|
});
|
||||||
break;
|
|
||||||
run.push_back(current);
|
|
||||||
++runIt;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (run.size() > 1) {
|
if (run.ops.size() > 1) {
|
||||||
struct SendEntry {
|
struct SendEntry {
|
||||||
spatial::SpatChannelSendOp op;
|
spatial::SpatChannelSendOp op;
|
||||||
uint32_t sourceCoreId = 0;
|
uint32_t sourceCoreId = 0;
|
||||||
@@ -519,8 +550,8 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
|||||||
uint64_t channelId = 0;
|
uint64_t channelId = 0;
|
||||||
};
|
};
|
||||||
SmallVector<SendEntry> sortedEntries;
|
SmallVector<SendEntry> sortedEntries;
|
||||||
sortedEntries.reserve(run.size());
|
sortedEntries.reserve(run.ops.size());
|
||||||
for (auto op : run)
|
for (auto op : run.ops)
|
||||||
sortedEntries.push_back({op, op.getSourceCoreId(), op.getTargetCoreId(), op.getChannelId()});
|
sortedEntries.push_back({op, op.getSourceCoreId(), op.getTargetCoreId(), op.getChannelId()});
|
||||||
|
|
||||||
SmallVector<int64_t> channelIds;
|
SmallVector<int64_t> channelIds;
|
||||||
@@ -532,25 +563,24 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
|||||||
targetCoreIds.reserve(sortedEntries.size());
|
targetCoreIds.reserve(sortedEntries.size());
|
||||||
inputs.reserve(sortedEntries.size());
|
inputs.reserve(sortedEntries.size());
|
||||||
for (SendEntry& entry : sortedEntries) {
|
for (SendEntry& entry : sortedEntries) {
|
||||||
channelIds.push_back(static_cast<int64_t>(entry.channelId));
|
appendChannelAttrs(
|
||||||
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
|
channelIds, sourceCoreIds, targetCoreIds, entry.channelId, entry.sourceCoreId, entry.targetCoreId);
|
||||||
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
|
|
||||||
inputs.push_back(entry.op.getInput());
|
inputs.push_back(entry.op.getInput());
|
||||||
}
|
}
|
||||||
|
|
||||||
rewriter.setInsertionPoint(run.front());
|
rewriter.setInsertionPoint(run.ops.front());
|
||||||
Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, run.front().getLoc());
|
Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, run.ops.front().getLoc());
|
||||||
if (packedInput) {
|
if (packedInput) {
|
||||||
spatial::SpatChannelSendTensorOp::create(rewriter,
|
spatial::SpatChannelSendTensorOp::create(rewriter,
|
||||||
run.front().getLoc(),
|
run.ops.front().getLoc(),
|
||||||
rewriter.getDenseI64ArrayAttr(channelIds),
|
rewriter.getDenseI64ArrayAttr(channelIds),
|
||||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
||||||
rewriter.getDenseI32ArrayAttr(targetCoreIds),
|
rewriter.getDenseI32ArrayAttr(targetCoreIds),
|
||||||
packedInput);
|
packedInput);
|
||||||
for (auto op : run)
|
for (auto op : run.ops)
|
||||||
rewriter.eraseOp(op);
|
rewriter.eraseOp(op);
|
||||||
|
|
||||||
it = runIt;
|
it = run.end;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -569,32 +599,27 @@ void compactBatchChannelRuns(func::FuncOp funcOp) {
|
|||||||
for (auto it = block.begin(); it != block.end();) {
|
for (auto it = block.begin(); it != block.end();) {
|
||||||
auto receiveOp = dyn_cast<spatial::SpatChannelReceiveBatchOp>(&*it);
|
auto receiveOp = dyn_cast<spatial::SpatChannelReceiveBatchOp>(&*it);
|
||||||
if (receiveOp) {
|
if (receiveOp) {
|
||||||
SmallVector<spatial::SpatChannelReceiveBatchOp> run;
|
|
||||||
Type outputType = receiveOp.getOutput().getType();
|
Type outputType = receiveOp.getOutput().getType();
|
||||||
auto runIt = it;
|
auto run = collectConsecutiveRun<spatial::SpatChannelReceiveBatchOp>(
|
||||||
while (runIt != block.end()) {
|
it, block.end(), [&](spatial::SpatChannelReceiveBatchOp current) {
|
||||||
auto current = dyn_cast<spatial::SpatChannelReceiveBatchOp>(&*runIt);
|
return current.getOutput().getType() == outputType;
|
||||||
if (!current || current.getOutput().getType() != outputType)
|
});
|
||||||
break;
|
|
||||||
run.push_back(current);
|
|
||||||
++runIt;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (run.size() > 1) {
|
if (run.ops.size() > 1) {
|
||||||
SmallVector<int64_t> channelIds;
|
SmallVector<int64_t> channelIds;
|
||||||
SmallVector<int32_t> sourceCoreIds;
|
SmallVector<int32_t> sourceCoreIds;
|
||||||
SmallVector<int32_t> targetCoreIds;
|
SmallVector<int32_t> targetCoreIds;
|
||||||
for (auto op : run) {
|
for (auto op : run.ops) {
|
||||||
llvm::append_range(channelIds, op.getChannelIds());
|
llvm::append_range(channelIds, op.getChannelIds());
|
||||||
llvm::append_range(sourceCoreIds, op.getSourceCoreIds());
|
llvm::append_range(sourceCoreIds, op.getSourceCoreIds());
|
||||||
llvm::append_range(targetCoreIds, op.getTargetCoreIds());
|
llvm::append_range(targetCoreIds, op.getTargetCoreIds());
|
||||||
}
|
}
|
||||||
|
|
||||||
auto rowType = cast<RankedTensorType>(run.front().getOutput().getType());
|
auto rowType = cast<RankedTensorType>(run.ops.front().getOutput().getType());
|
||||||
auto fallbackPackedType = getPackedTensorType(rowType, static_cast<int64_t>(run.size()));
|
auto fallbackPackedType = getPackedTensorType(rowType, static_cast<int64_t>(run.ops.size()));
|
||||||
SmallVector<Value> outputs;
|
SmallVector<Value> outputs;
|
||||||
outputs.reserve(run.size());
|
outputs.reserve(run.ops.size());
|
||||||
for (auto op : run)
|
for (auto op : run.ops)
|
||||||
outputs.push_back(op.getOutput());
|
outputs.push_back(op.getOutput());
|
||||||
|
|
||||||
unsigned concatStartIndex = 0;
|
unsigned concatStartIndex = 0;
|
||||||
@@ -603,10 +628,10 @@ void compactBatchChannelRuns(func::FuncOp funcOp) {
|
|||||||
concatOp ? getPackedConcatSliceType(concatOp, concatStartIndex, static_cast<unsigned>(outputs.size()))
|
concatOp ? getPackedConcatSliceType(concatOp, concatStartIndex, static_cast<unsigned>(outputs.size()))
|
||||||
: RankedTensorType {};
|
: RankedTensorType {};
|
||||||
auto packedType = concatPackedType ? concatPackedType : fallbackPackedType;
|
auto packedType = concatPackedType ? concatPackedType : fallbackPackedType;
|
||||||
rewriter.setInsertionPoint(run.front());
|
rewriter.setInsertionPoint(run.ops.front());
|
||||||
auto compactReceive =
|
auto compactReceive =
|
||||||
spatial::SpatChannelReceiveTensorBatchOp::create(rewriter,
|
spatial::SpatChannelReceiveTensorBatchOp::create(rewriter,
|
||||||
run.front().getLoc(),
|
run.ops.front().getLoc(),
|
||||||
packedType,
|
packedType,
|
||||||
rewriter.getDenseI64ArrayAttr(channelIds),
|
rewriter.getDenseI64ArrayAttr(channelIds),
|
||||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
||||||
@@ -616,11 +641,11 @@ void compactBatchChannelRuns(func::FuncOp funcOp) {
|
|||||||
concatOp, concatStartIndex, static_cast<unsigned>(outputs.size()), compactReceive.getOutput(), rewriter);
|
concatOp, concatStartIndex, static_cast<unsigned>(outputs.size()), compactReceive.getOutput(), rewriter);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
for (auto [index, op] : llvm::enumerate(run))
|
for (auto [index, op] : llvm::enumerate(run.ops))
|
||||||
op.getOutput().replaceAllUsesWith(extractPackedChunk(
|
op.getOutput().replaceAllUsesWith(extractPackedChunk(
|
||||||
compactReceive.getOutput(), rowType, static_cast<unsigned>(index), rewriter, op.getLoc()));
|
compactReceive.getOutput(), rowType, static_cast<unsigned>(index), rewriter, op.getLoc()));
|
||||||
}
|
}
|
||||||
for (auto op : run)
|
for (auto op : run.ops)
|
||||||
rewriter.eraseOp(op);
|
rewriter.eraseOp(op);
|
||||||
|
|
||||||
it = compactReceive->getIterator();
|
it = compactReceive->getIterator();
|
||||||
@@ -631,43 +656,38 @@ void compactBatchChannelRuns(func::FuncOp funcOp) {
|
|||||||
|
|
||||||
auto sendOp = dyn_cast<spatial::SpatChannelSendBatchOp>(&*it);
|
auto sendOp = dyn_cast<spatial::SpatChannelSendBatchOp>(&*it);
|
||||||
if (sendOp) {
|
if (sendOp) {
|
||||||
SmallVector<spatial::SpatChannelSendBatchOp> run;
|
|
||||||
Type inputType = sendOp.getInput().getType();
|
Type inputType = sendOp.getInput().getType();
|
||||||
auto runIt = it;
|
auto run = collectConsecutiveRun<spatial::SpatChannelSendBatchOp>(
|
||||||
while (runIt != block.end()) {
|
it, block.end(), [&](spatial::SpatChannelSendBatchOp current) {
|
||||||
auto current = dyn_cast<spatial::SpatChannelSendBatchOp>(&*runIt);
|
return current.getInput().getType() == inputType;
|
||||||
if (!current || current.getInput().getType() != inputType)
|
});
|
||||||
break;
|
|
||||||
run.push_back(current);
|
|
||||||
++runIt;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (run.size() > 1) {
|
if (run.ops.size() > 1) {
|
||||||
SmallVector<int64_t> channelIds;
|
SmallVector<int64_t> channelIds;
|
||||||
SmallVector<int32_t> sourceCoreIds;
|
SmallVector<int32_t> sourceCoreIds;
|
||||||
SmallVector<int32_t> targetCoreIds;
|
SmallVector<int32_t> targetCoreIds;
|
||||||
SmallVector<Value> inputs;
|
SmallVector<Value> inputs;
|
||||||
inputs.reserve(run.size());
|
inputs.reserve(run.ops.size());
|
||||||
for (auto op : run) {
|
for (auto op : run.ops) {
|
||||||
llvm::append_range(channelIds, op.getChannelIds());
|
llvm::append_range(channelIds, op.getChannelIds());
|
||||||
llvm::append_range(sourceCoreIds, op.getSourceCoreIds());
|
llvm::append_range(sourceCoreIds, op.getSourceCoreIds());
|
||||||
llvm::append_range(targetCoreIds, op.getTargetCoreIds());
|
llvm::append_range(targetCoreIds, op.getTargetCoreIds());
|
||||||
inputs.push_back(op.getInput());
|
inputs.push_back(op.getInput());
|
||||||
}
|
}
|
||||||
|
|
||||||
rewriter.setInsertionPoint(run.front());
|
rewriter.setInsertionPoint(run.ops.front());
|
||||||
Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, run.front().getLoc());
|
Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, run.ops.front().getLoc());
|
||||||
if (packedInput) {
|
if (packedInput) {
|
||||||
spatial::SpatChannelSendTensorBatchOp::create(rewriter,
|
spatial::SpatChannelSendTensorBatchOp::create(rewriter,
|
||||||
run.front().getLoc(),
|
run.ops.front().getLoc(),
|
||||||
rewriter.getDenseI64ArrayAttr(channelIds),
|
rewriter.getDenseI64ArrayAttr(channelIds),
|
||||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
||||||
rewriter.getDenseI32ArrayAttr(targetCoreIds),
|
rewriter.getDenseI32ArrayAttr(targetCoreIds),
|
||||||
packedInput);
|
packedInput);
|
||||||
for (auto op : run)
|
for (auto op : run.ops)
|
||||||
rewriter.eraseOp(op);
|
rewriter.eraseOp(op);
|
||||||
|
|
||||||
it = runIt;
|
it = run.end;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -695,8 +715,9 @@ void compactRegularOpRuns(func::FuncOp funcOp) {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto anchorEndIt = std::next(it, static_cast<std::ptrdiff_t>(anchorChunk->ops.size()));
|
||||||
SmallVector<RegularChunk> run {*anchorChunk};
|
SmallVector<RegularChunk> run {*anchorChunk};
|
||||||
auto runIt = std::next(it, static_cast<std::ptrdiff_t>(anchorChunk->ops.size()));
|
auto runIt = anchorEndIt;
|
||||||
while (runIt != block.end()) {
|
while (runIt != block.end()) {
|
||||||
auto candidateStart = dyn_cast<spatial::SpatVMMOp>(&*runIt);
|
auto candidateStart = dyn_cast<spatial::SpatVMMOp>(&*runIt);
|
||||||
if (!candidateStart)
|
if (!candidateStart)
|
||||||
@@ -711,12 +732,26 @@ void compactRegularOpRuns(func::FuncOp funcOp) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (run.size() <= 1) {
|
if (run.size() <= 1) {
|
||||||
++it;
|
it = anchorEndIt;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
compactRegularChunkRun(rewriter, run);
|
size_t originalOpCount = 0;
|
||||||
it = block.begin();
|
for (const RegularChunk& chunk : run)
|
||||||
|
originalOpCount += chunk.ops.size();
|
||||||
|
|
||||||
|
RegularCompactionResult result = compactRegularChunkRun(rewriter, run);
|
||||||
|
if (result.changed) {
|
||||||
|
assert(originalOpCount > anchorChunk->ops.size() && "successful regular compaction must consume the run");
|
||||||
|
if (!result.resumeAfter) {
|
||||||
|
it = block.end();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
it = result.resumeAfter->getIterator();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
it = anchorEndIt;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -747,37 +782,32 @@ void compactRowWiseWvmmRuns(func::FuncOp funcOp) {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<spatial::SpatVMMOp> run;
|
|
||||||
auto runIt = it;
|
|
||||||
int64_t expectedRow = static_cast<int64_t>(rowResult.getResultNumber());
|
int64_t expectedRow = static_cast<int64_t>(rowResult.getResultNumber());
|
||||||
while (runIt != block.end()) {
|
auto run = collectConsecutiveRun<spatial::SpatVMMOp>(it, block.end(), [&](spatial::SpatVMMOp current) {
|
||||||
auto current = dyn_cast<spatial::SpatVMMOp>(&*runIt);
|
if (current.getWeightIndex() != wvmmOp.getWeightIndex()
|
||||||
if (!current || current.getWeightIndex() != wvmmOp.getWeightIndex()
|
|
||||||
|| current.getInput().getDefiningOp<spatial::SpatExtractRowsOp>() != extractRowsOp
|
|| current.getInput().getDefiningOp<spatial::SpatExtractRowsOp>() != extractRowsOp
|
||||||
|| current.getInput().getType() != wvmmOp.getInput().getType()
|
|| current.getInput().getType() != wvmmOp.getInput().getType()
|
||||||
|| current.getOutput().getType() != wvmmOp.getOutput().getType()) {
|
|| current.getOutput().getType() != wvmmOp.getOutput().getType())
|
||||||
break;
|
return false;
|
||||||
}
|
|
||||||
|
|
||||||
auto currentRow = dyn_cast<OpResult>(current.getInput());
|
auto currentRow = dyn_cast<OpResult>(current.getInput());
|
||||||
if (!currentRow || currentRow.getResultNumber() != static_cast<unsigned>(expectedRow))
|
if (!currentRow || currentRow.getResultNumber() != static_cast<unsigned>(expectedRow))
|
||||||
break;
|
return false;
|
||||||
|
|
||||||
run.push_back(current);
|
|
||||||
++expectedRow;
|
++expectedRow;
|
||||||
++runIt;
|
return true;
|
||||||
}
|
});
|
||||||
|
|
||||||
if (run.size() <= 1) {
|
if (run.ops.size() <= 1) {
|
||||||
++it;
|
++it;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!run.front().getOutput().hasOneUse()) {
|
if (!run.ops.front().getOutput().hasOneUse()) {
|
||||||
++it;
|
++it;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
auto concatUse = run.front().getOutput().getUses().begin();
|
auto concatUse = run.ops.front().getOutput().getUses().begin();
|
||||||
auto concatOp = dyn_cast<spatial::SpatConcatOp>(concatUse->getOwner());
|
auto concatOp = dyn_cast<spatial::SpatConcatOp>(concatUse->getOwner());
|
||||||
if (!concatOp) {
|
if (!concatOp) {
|
||||||
++it;
|
++it;
|
||||||
@@ -786,7 +816,7 @@ void compactRowWiseWvmmRuns(func::FuncOp funcOp) {
|
|||||||
|
|
||||||
unsigned concatStartIndex = concatUse->getOperandNumber();
|
unsigned concatStartIndex = concatUse->getOperandNumber();
|
||||||
bool validConcatRun = true;
|
bool validConcatRun = true;
|
||||||
for (auto [index, op] : llvm::enumerate(run)) {
|
for (auto [index, op] : llvm::enumerate(run.ops)) {
|
||||||
if (!op.getOutput().hasOneUse()) {
|
if (!op.getOutput().hasOneUse()) {
|
||||||
validConcatRun = false;
|
validConcatRun = false;
|
||||||
break;
|
break;
|
||||||
@@ -817,17 +847,17 @@ void compactRowWiseWvmmRuns(func::FuncOp funcOp) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
int64_t firstRow = static_cast<int64_t>(rowResult.getResultNumber());
|
int64_t firstRow = static_cast<int64_t>(rowResult.getResultNumber());
|
||||||
int64_t runLength = static_cast<int64_t>(run.size());
|
int64_t runLength = static_cast<int64_t>(run.ops.size());
|
||||||
auto packedType = RankedTensorType::get({runLength, outputCols}, outputType.getElementType());
|
auto packedType = RankedTensorType::get({runLength, outputCols}, outputType.getElementType());
|
||||||
|
|
||||||
rewriter.setInsertionPoint(run.front());
|
rewriter.setInsertionPoint(run.ops.front());
|
||||||
auto zero = arith::ConstantIndexOp::create(rewriter, run.front().getLoc(), 0);
|
auto zero = arith::ConstantIndexOp::create(rewriter, run.ops.front().getLoc(), 0);
|
||||||
auto upper = arith::ConstantIndexOp::create(rewriter, run.front().getLoc(), runLength);
|
auto upper = arith::ConstantIndexOp::create(rewriter, run.ops.front().getLoc(), runLength);
|
||||||
auto step = arith::ConstantIndexOp::create(rewriter, run.front().getLoc(), 1);
|
auto step = arith::ConstantIndexOp::create(rewriter, run.ops.front().getLoc(), 1);
|
||||||
auto packedInit =
|
auto packedInit =
|
||||||
tensor::EmptyOp::create(rewriter, run.front().getLoc(), packedType.getShape(), packedType.getElementType());
|
tensor::EmptyOp::create(rewriter, run.ops.front().getLoc(), packedType.getShape(), packedType.getElementType());
|
||||||
auto loop =
|
auto loop =
|
||||||
scf::ForOp::create(rewriter, run.front().getLoc(), zero, upper, step, ValueRange {packedInit.getResult()});
|
scf::ForOp::create(rewriter, run.ops.front().getLoc(), zero, upper, step, ValueRange {packedInit.getResult()});
|
||||||
|
|
||||||
{
|
{
|
||||||
OpBuilder::InsertionGuard guard(rewriter);
|
OpBuilder::InsertionGuard guard(rewriter);
|
||||||
@@ -838,41 +868,41 @@ void compactRowWiseWvmmRuns(func::FuncOp funcOp) {
|
|||||||
|
|
||||||
Value sourceRow = iv;
|
Value sourceRow = iv;
|
||||||
if (firstRow != 0) {
|
if (firstRow != 0) {
|
||||||
auto firstRowValue = arith::ConstantIndexOp::create(rewriter, run.front().getLoc(), firstRow);
|
auto firstRowValue = arith::ConstantIndexOp::create(rewriter, run.ops.front().getLoc(), firstRow);
|
||||||
sourceRow = arith::AddIOp::create(rewriter, run.front().getLoc(), iv, firstRowValue);
|
sourceRow = arith::AddIOp::create(rewriter, run.ops.front().getLoc(), iv, firstRowValue);
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<OpFoldResult> extractOffsets = {sourceRow, rewriter.getIndexAttr(0)};
|
SmallVector<OpFoldResult> extractOffsets = {sourceRow, rewriter.getIndexAttr(0)};
|
||||||
SmallVector<OpFoldResult> extractSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(inputCols)};
|
SmallVector<OpFoldResult> extractSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(inputCols)};
|
||||||
SmallVector<OpFoldResult> extractStrides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
SmallVector<OpFoldResult> extractStrides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||||
auto extractedRow = tensor::ExtractSliceOp::create(rewriter,
|
auto extractedRow = tensor::ExtractSliceOp::create(rewriter,
|
||||||
run.front().getLoc(),
|
run.ops.front().getLoc(),
|
||||||
inputType,
|
inputType,
|
||||||
extractRowsOp.getInput(),
|
extractRowsOp.getInput(),
|
||||||
extractOffsets,
|
extractOffsets,
|
||||||
extractSizes,
|
extractSizes,
|
||||||
extractStrides);
|
extractStrides);
|
||||||
auto loopWvmm = spatial::SpatVMMOp::create(
|
auto loopWvmm = spatial::SpatVMMOp::create(
|
||||||
rewriter, run.front().getLoc(), outputType, wvmmOp.getWeightIndex(), extractedRow.getResult());
|
rewriter, run.ops.front().getLoc(), outputType, wvmmOp.getWeightIndex(), extractedRow.getResult());
|
||||||
|
|
||||||
SmallVector<OpFoldResult> insertOffsets = {iv, rewriter.getIndexAttr(0)};
|
SmallVector<OpFoldResult> insertOffsets = {iv, rewriter.getIndexAttr(0)};
|
||||||
SmallVector<OpFoldResult> insertSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(outputCols)};
|
SmallVector<OpFoldResult> insertSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(outputCols)};
|
||||||
SmallVector<OpFoldResult> insertStrides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
SmallVector<OpFoldResult> insertStrides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||||
auto inserted = tensor::InsertSliceOp::create(
|
auto inserted = tensor::InsertSliceOp::create(
|
||||||
rewriter, run.front().getLoc(), loopWvmm.getResult(), acc, insertOffsets, insertSizes, insertStrides);
|
rewriter, run.ops.front().getLoc(), loopWvmm.getResult(), acc, insertOffsets, insertSizes, insertStrides);
|
||||||
scf::YieldOp::create(rewriter, run.front().getLoc(), inserted.getResult());
|
scf::YieldOp::create(rewriter, run.ops.front().getLoc(), inserted.getResult());
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<Value> newConcatInputs;
|
SmallVector<Value> newConcatInputs;
|
||||||
newConcatInputs.reserve(concatOp.getInputs().size() - run.size() + 1);
|
newConcatInputs.reserve(concatOp.getInputs().size() - run.ops.size() + 1);
|
||||||
for (auto [operandIndex, operand] : llvm::enumerate(concatOp.getInputs())) {
|
for (auto [operandIndex, operand] : llvm::enumerate(concatOp.getInputs())) {
|
||||||
if (operandIndex == concatStartIndex)
|
if (operandIndex == concatStartIndex)
|
||||||
newConcatInputs.push_back(loop.getResult(0));
|
newConcatInputs.push_back(loop.getResult(0));
|
||||||
if (operandIndex < concatStartIndex || operandIndex >= concatStartIndex + run.size())
|
if (operandIndex < concatStartIndex || operandIndex >= concatStartIndex + run.ops.size())
|
||||||
newConcatInputs.push_back(operand);
|
newConcatInputs.push_back(operand);
|
||||||
}
|
}
|
||||||
rewriter.modifyOpInPlace(concatOp, [&] { concatOp->setOperands(newConcatInputs); });
|
rewriter.modifyOpInPlace(concatOp, [&] { concatOp->setOperands(newConcatInputs); });
|
||||||
for (auto op : run)
|
for (auto op : run.ops)
|
||||||
rewriter.eraseOp(op);
|
rewriter.eraseOp(op);
|
||||||
|
|
||||||
it = loop->getIterator();
|
it = loop->getIterator();
|
||||||
|
|||||||
@@ -14,7 +14,6 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "ComputeGraph.hpp"
|
#include "ComputeGraph.hpp"
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
|
||||||
#include "src/Support/TypeUtilities.hpp"
|
#include "src/Support/TypeUtilities.hpp"
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
@@ -24,12 +23,6 @@ using namespace mlir;
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
size_t getSchedulingCpuBudget() {
|
|
||||||
if (coresCount.getValue() > 0)
|
|
||||||
return static_cast<size_t>(coresCount.getValue());
|
|
||||||
return std::numeric_limits<size_t>::max();
|
|
||||||
}
|
|
||||||
|
|
||||||
Weight getComputeBodyWeight(Region &body) {
|
Weight getComputeBodyWeight(Region &body) {
|
||||||
constexpr Weight kOperationWeight = 100;
|
constexpr Weight kOperationWeight = 100;
|
||||||
Weight numOperations = 0;
|
Weight numOperations = 0;
|
||||||
@@ -95,41 +88,6 @@ std::vector<ComputeGraphEdge> aggregateEdges(llvm::ArrayRef<ComputeGraphEdge> ed
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
size_t getBatchChunkTargetCount(int32_t laneCount) {
|
|
||||||
assert(laneCount > 0 && "laneCount must be positive");
|
|
||||||
return std::min(static_cast<size_t>(laneCount), std::max<size_t>(1, getSchedulingCpuBudget()));
|
|
||||||
}
|
|
||||||
|
|
||||||
ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex) {
|
|
||||||
size_t totalLanes = static_cast<size_t>(batch.getLaneCount());
|
|
||||||
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount());
|
|
||||||
size_t baseChunkSize = totalLanes / chunkCount;
|
|
||||||
size_t largeChunkCount = totalLanes % chunkCount;
|
|
||||||
|
|
||||||
size_t laneStart = chunkIndex * baseChunkSize + std::min(chunkIndex, largeChunkCount);
|
|
||||||
size_t laneCount = baseChunkSize + (chunkIndex < largeChunkCount ? 1 : 0);
|
|
||||||
return {batch.getOperation(), static_cast<uint32_t>(laneStart), static_cast<uint32_t>(laneCount)};
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane) {
|
|
||||||
size_t totalLanes = static_cast<size_t>(batch.getLaneCount());
|
|
||||||
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount());
|
|
||||||
size_t baseChunkSize = totalLanes / chunkCount;
|
|
||||||
size_t largeChunkCount = totalLanes % chunkCount;
|
|
||||||
size_t largeChunkSpan = largeChunkCount * (baseChunkSize + 1);
|
|
||||||
|
|
||||||
size_t chunkIndex = 0;
|
|
||||||
if (static_cast<size_t>(lane) < largeChunkSpan)
|
|
||||||
chunkIndex = static_cast<size_t>(lane) / (baseChunkSize + 1);
|
|
||||||
else
|
|
||||||
chunkIndex = largeChunkCount + (static_cast<size_t>(lane) - largeChunkSpan) / baseChunkSize;
|
|
||||||
return getBatchChunkForIndex(batch, chunkIndex);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
Weight getComputeInstanceWeight(const ComputeInstance &instance) {
|
Weight getComputeInstanceWeight(const ComputeInstance &instance) {
|
||||||
if (auto spatCompute = dyn_cast<SpatCompute>(instance.op))
|
if (auto spatCompute = dyn_cast<SpatCompute>(instance.op))
|
||||||
return getSpatComputeWeight(spatCompute);
|
return getSpatComputeWeight(spatCompute);
|
||||||
@@ -145,47 +103,6 @@ CrossbarUsage getComputeInstanceCrossbarUsage(const ComputeInstance &instance) {
|
|||||||
static_cast<CrossbarUsage>(instance.laneCount));
|
static_cast<CrossbarUsage>(instance.laneCount));
|
||||||
}
|
}
|
||||||
|
|
||||||
llvm::SmallVector<Value, 4> getComputeInstanceInputs(const ComputeInstance &instance) {
|
|
||||||
if (auto spatCompute = dyn_cast<SpatCompute>(instance.op))
|
|
||||||
return llvm::SmallVector<Value, 4>(spatCompute.getInputs().begin(), spatCompute.getInputs().end());
|
|
||||||
auto batch = cast<SpatComputeBatch>(instance.op);
|
|
||||||
llvm::SmallVector<Value, 4> inputs;
|
|
||||||
inputs.reserve(instance.laneCount);
|
|
||||||
for (uint32_t lane = instance.laneStart; lane < instance.laneStart + instance.laneCount; ++lane)
|
|
||||||
inputs.push_back(batch.getInputs()[lane]);
|
|
||||||
return inputs;
|
|
||||||
}
|
|
||||||
|
|
||||||
llvm::SmallVector<Value, 4> getComputeInstanceWeights(const ComputeInstance &instance) {
|
|
||||||
if (auto spatCompute = dyn_cast<SpatCompute>(instance.op))
|
|
||||||
return llvm::SmallVector<Value, 4>(spatCompute.getWeights().begin(), spatCompute.getWeights().end());
|
|
||||||
auto batch = cast<SpatComputeBatch>(instance.op);
|
|
||||||
llvm::SmallVector<Value, 4> weights;
|
|
||||||
weights.reserve(instance.laneCount);
|
|
||||||
for (uint32_t lane = instance.laneStart; lane < instance.laneStart + instance.laneCount; ++lane)
|
|
||||||
weights.push_back(batch.getWeights()[lane]);
|
|
||||||
return weights;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::optional<ComputeInstance> getComputeProducerInstance(Value value) {
|
|
||||||
Operation *op = value.getDefiningOp();
|
|
||||||
if (!op)
|
|
||||||
return std::nullopt;
|
|
||||||
|
|
||||||
while (auto extract = dyn_cast<tensor::ExtractSliceOp>(op)) {
|
|
||||||
value = extract.getSource();
|
|
||||||
op = value.getDefiningOp();
|
|
||||||
if (!op)
|
|
||||||
return std::nullopt;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (auto spatCompute = dyn_cast<SpatCompute>(op))
|
|
||||||
return ComputeInstance {spatCompute.getOperation(), 0, 1};
|
|
||||||
if (auto batch = dyn_cast<SpatComputeBatch>(op))
|
|
||||||
return getBatchChunkForLane(batch, static_cast<uint32_t>(cast<OpResult>(value).getResultNumber()));
|
|
||||||
return std::nullopt;
|
|
||||||
}
|
|
||||||
|
|
||||||
ComputeGraph buildComputeGraph(Operation *entryOp) {
|
ComputeGraph buildComputeGraph(Operation *entryOp) {
|
||||||
ComputeGraph graph;
|
ComputeGraph graph;
|
||||||
|
|
||||||
|
|||||||
@@ -13,6 +13,7 @@
|
|||||||
|
|
||||||
#include "../DCPGraph/Utils.hpp"
|
#include "../DCPGraph/Utils.hpp"
|
||||||
#include "ComputeInstance.hpp"
|
#include "ComputeInstance.hpp"
|
||||||
|
#include "ComputeInstanceUtils.hpp"
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
namespace spatial {
|
namespace spatial {
|
||||||
@@ -41,11 +42,6 @@ struct ComputeGraph {
|
|||||||
ComputeGraph buildComputeGraph(mlir::Operation *entryOp);
|
ComputeGraph buildComputeGraph(mlir::Operation *entryOp);
|
||||||
bool verifyAcyclic(const ComputeGraph &graph);
|
bool verifyAcyclic(const ComputeGraph &graph);
|
||||||
|
|
||||||
size_t getBatchChunkTargetCount(int32_t laneCount);
|
|
||||||
ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex);
|
|
||||||
std::optional<ComputeInstance> getComputeProducerInstance(mlir::Value value);
|
|
||||||
llvm::SmallVector<mlir::Value, 4> getComputeInstanceInputs(const ComputeInstance &instance);
|
|
||||||
llvm::SmallVector<mlir::Value, 4> getComputeInstanceWeights(const ComputeInstance &instance);
|
|
||||||
Weight getComputeInstanceWeight(const ComputeInstance &instance);
|
Weight getComputeInstanceWeight(const ComputeInstance &instance);
|
||||||
CrossbarUsage getComputeInstanceCrossbarUsage(const ComputeInstance &instance);
|
CrossbarUsage getComputeInstanceCrossbarUsage(const ComputeInstance &instance);
|
||||||
|
|
||||||
|
|||||||
+150
@@ -0,0 +1,150 @@
|
|||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
|
||||||
|
#include <limits>
|
||||||
|
|
||||||
|
#include "ComputeInstanceUtils.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace spatial {
|
||||||
|
|
||||||
|
size_t getSchedulingCpuBudget() {
|
||||||
|
if (coresCount.getValue() > 0)
|
||||||
|
return static_cast<size_t>(coresCount.getValue());
|
||||||
|
return std::numeric_limits<size_t>::max();
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t getBatchChunkTargetCount(int32_t laneCount) {
|
||||||
|
assert(laneCount > 0 && "laneCount must be positive");
|
||||||
|
return std::min(static_cast<size_t>(laneCount), std::max<size_t>(1, getSchedulingCpuBudget()));
|
||||||
|
}
|
||||||
|
|
||||||
|
ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex) {
|
||||||
|
size_t totalLanes = static_cast<size_t>(batch.getLaneCount());
|
||||||
|
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount());
|
||||||
|
size_t baseChunkSize = totalLanes / chunkCount;
|
||||||
|
size_t largeChunkCount = totalLanes % chunkCount;
|
||||||
|
|
||||||
|
size_t laneStart = chunkIndex * baseChunkSize + std::min(chunkIndex, largeChunkCount);
|
||||||
|
size_t laneCount = baseChunkSize + (chunkIndex < largeChunkCount ? 1 : 0);
|
||||||
|
return {batch.getOperation(), static_cast<uint32_t>(laneStart), static_cast<uint32_t>(laneCount)};
|
||||||
|
}
|
||||||
|
|
||||||
|
ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane) {
|
||||||
|
size_t totalLanes = static_cast<size_t>(batch.getLaneCount());
|
||||||
|
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount());
|
||||||
|
size_t baseChunkSize = totalLanes / chunkCount;
|
||||||
|
size_t largeChunkCount = totalLanes % chunkCount;
|
||||||
|
size_t largeChunkSpan = largeChunkCount * (baseChunkSize + 1);
|
||||||
|
|
||||||
|
size_t chunkIndex = 0;
|
||||||
|
if (static_cast<size_t>(lane) < largeChunkSpan)
|
||||||
|
chunkIndex = static_cast<size_t>(lane) / (baseChunkSize + 1);
|
||||||
|
else
|
||||||
|
chunkIndex = largeChunkCount + (static_cast<size_t>(lane) - largeChunkSpan) / baseChunkSize;
|
||||||
|
return getBatchChunkForIndex(batch, chunkIndex);
|
||||||
|
}
|
||||||
|
|
||||||
|
SpatCompute getOriginalSpatCompute(Operation *op) {
|
||||||
|
if (!op)
|
||||||
|
return {};
|
||||||
|
|
||||||
|
while (auto extract = dyn_cast<tensor::ExtractSliceOp>(op)) {
|
||||||
|
op = extract.getSource().getDefiningOp();
|
||||||
|
if (!op)
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
return dyn_cast<SpatCompute>(op);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<ProducerValueRef> getProducerValueRef(Value value) {
|
||||||
|
Operation *op = value.getDefiningOp();
|
||||||
|
if (!op)
|
||||||
|
return std::nullopt;
|
||||||
|
|
||||||
|
while (auto extract = dyn_cast<tensor::ExtractSliceOp>(op)) {
|
||||||
|
value = extract.getSource();
|
||||||
|
op = value.getDefiningOp();
|
||||||
|
if (!op)
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto compute = dyn_cast<SpatCompute>(op)) {
|
||||||
|
return ProducerValueRef {
|
||||||
|
ComputeInstance {compute.getOperation(), 0, 1},
|
||||||
|
static_cast<size_t>(cast<OpResult>(value).getResultNumber())
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto batch = dyn_cast<SpatComputeBatch>(op)) {
|
||||||
|
uint32_t lane = static_cast<uint32_t>(cast<OpResult>(value).getResultNumber());
|
||||||
|
ComputeInstance instance = getBatchChunkForLane(batch, lane);
|
||||||
|
size_t resultIndex = static_cast<size_t>(lane - instance.laneStart);
|
||||||
|
return ProducerValueRef {instance, resultIndex};
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<ComputeInstance> getComputeProducerInstance(Value value) {
|
||||||
|
if (std::optional<ProducerValueRef> producer = getProducerValueRef(value))
|
||||||
|
return producer->instance;
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::SmallVector<Value, 4> getComputeInstanceInputs(const ComputeInstance &instance) {
|
||||||
|
if (auto compute = dyn_cast<SpatCompute>(instance.op))
|
||||||
|
return llvm::SmallVector<Value, 4>(compute.getInputs().begin(), compute.getInputs().end());
|
||||||
|
|
||||||
|
auto batch = cast<SpatComputeBatch>(instance.op);
|
||||||
|
llvm::SmallVector<Value, 4> inputs;
|
||||||
|
inputs.reserve(instance.laneCount);
|
||||||
|
for (uint32_t lane = instance.laneStart; lane < instance.laneStart + instance.laneCount; ++lane)
|
||||||
|
if (!batch.getInputs().empty())
|
||||||
|
inputs.push_back(batch.getInputs()[lane]);
|
||||||
|
return inputs;
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::SmallVector<Value, 4> getComputeInstanceWeights(const ComputeInstance &instance) {
|
||||||
|
if (auto compute = dyn_cast<SpatCompute>(instance.op))
|
||||||
|
return llvm::SmallVector<Value, 4>(compute.getWeights().begin(), compute.getWeights().end());
|
||||||
|
|
||||||
|
auto batch = cast<SpatComputeBatch>(instance.op);
|
||||||
|
llvm::SmallVector<Value, 4> weights;
|
||||||
|
weights.reserve(instance.laneCount);
|
||||||
|
for (uint32_t lane = instance.laneStart; lane < instance.laneStart + instance.laneCount; ++lane)
|
||||||
|
weights.push_back(batch.getWeights()[lane]);
|
||||||
|
return weights;
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::SmallVector<Value, 4> getComputeInstanceOutputValues(const ComputeInstance &instance) {
|
||||||
|
if (auto compute = dyn_cast<SpatCompute>(instance.op))
|
||||||
|
return llvm::SmallVector<Value, 4>(compute.getResults().begin(), compute.getResults().end());
|
||||||
|
|
||||||
|
auto batch = cast<SpatComputeBatch>(instance.op);
|
||||||
|
llvm::SmallVector<Value, 4> outputs;
|
||||||
|
outputs.reserve(instance.laneCount);
|
||||||
|
for (uint32_t lane = instance.laneStart; lane < instance.laneStart + instance.laneCount; ++lane)
|
||||||
|
if (!batch.getOutputs().empty())
|
||||||
|
outputs.push_back(batch.getOutputs()[lane]);
|
||||||
|
return outputs;
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::SmallVector<Type, 4> getComputeInstanceOutputTypes(const ComputeInstance &instance) {
|
||||||
|
llvm::SmallVector<Type, 4> outputTypes;
|
||||||
|
for (Value output : getComputeInstanceOutputValues(instance))
|
||||||
|
outputTypes.push_back(output.getType());
|
||||||
|
return outputTypes;
|
||||||
|
}
|
||||||
|
|
||||||
|
Block &getComputeInstanceTemplateBlock(const ComputeInstance &instance) {
|
||||||
|
if (auto compute = dyn_cast<SpatCompute>(instance.op))
|
||||||
|
return compute.getBody().front();
|
||||||
|
return cast<SpatComputeBatch>(instance.op).getBody().front();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace spatial
|
||||||
|
} // namespace onnx_mlir
|
||||||
+40
@@ -0,0 +1,40 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/IR/Block.h"
|
||||||
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
#include "mlir/IR/Operation.h"
|
||||||
|
#include "mlir/IR/Value.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
|
#include <cstddef>
|
||||||
|
#include <optional>
|
||||||
|
|
||||||
|
#include "ComputeInstance.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace spatial {
|
||||||
|
|
||||||
|
struct ProducerValueRef {
|
||||||
|
ComputeInstance instance;
|
||||||
|
size_t resultIndex = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
size_t getSchedulingCpuBudget();
|
||||||
|
size_t getBatchChunkTargetCount(int32_t laneCount);
|
||||||
|
ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex);
|
||||||
|
ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane);
|
||||||
|
|
||||||
|
SpatCompute getOriginalSpatCompute(mlir::Operation *op);
|
||||||
|
std::optional<ProducerValueRef> getProducerValueRef(mlir::Value value);
|
||||||
|
std::optional<ComputeInstance> getComputeProducerInstance(mlir::Value value);
|
||||||
|
|
||||||
|
llvm::SmallVector<mlir::Value, 4> getComputeInstanceInputs(const ComputeInstance &instance);
|
||||||
|
llvm::SmallVector<mlir::Value, 4> getComputeInstanceWeights(const ComputeInstance &instance);
|
||||||
|
llvm::SmallVector<mlir::Value, 4> getComputeInstanceOutputValues(const ComputeInstance &instance);
|
||||||
|
llvm::SmallVector<mlir::Type, 4> getComputeInstanceOutputTypes(const ComputeInstance &instance);
|
||||||
|
mlir::Block &getComputeInstanceTemplateBlock(const ComputeInstance &instance);
|
||||||
|
|
||||||
|
} // namespace spatial
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -1,13 +1,595 @@
|
|||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
#include "llvm/Support/FormatVariadic.h"
|
||||||
|
#include "llvm/Support/raw_ostream.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <limits>
|
||||||
|
#include <numeric>
|
||||||
|
#include <optional>
|
||||||
|
#include <queue>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "DcpScheduler.hpp"
|
#include "DcpScheduler.hpp"
|
||||||
#include "../DCPGraph/Graph.hpp"
|
#include "../DCPGraph/Graph.hpp"
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
namespace spatial {
|
namespace spatial {
|
||||||
|
|
||||||
MergeScheduleResult runDcpScheduler(const ComputeGraph &graph, mlir::MLIRContext *context) {
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
bool isDcpCoarsenDebugEnabled() { return std::getenv("DCP_COARSEN_DEBUG") != nullptr; }
|
||||||
|
|
||||||
|
struct VirtualNode {
|
||||||
|
llvm::SmallVector<size_t, 4> originalNodeIndices;
|
||||||
|
Weight weight = 0;
|
||||||
|
CrossbarUsage crossbarUsage = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct VirtualGraph {
|
||||||
|
std::vector<VirtualNode> nodes;
|
||||||
|
std::vector<IndexedEdge> edges;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct TimingInfo {
|
||||||
|
std::vector<Time> aest;
|
||||||
|
std::vector<Time> alst;
|
||||||
|
std::vector<size_t> topologicalOrder;
|
||||||
|
bool valid = false;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct WindowScheduleResult {
|
||||||
|
std::vector<std::vector<size_t>> mergeGroups;
|
||||||
|
CPU cpuCount = 0;
|
||||||
|
size_t mergedNodeCount = 0;
|
||||||
|
size_t maxMergeGroupSize = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
size_t getSchedulingCpuBudget(const DcpScheduleOptions &options) {
|
||||||
|
if (options.processorCount > 0)
|
||||||
|
return options.processorCount;
|
||||||
|
return std::numeric_limits<size_t>::max();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<IndexedEdge> aggregateEdges(llvm::ArrayRef<IndexedEdge> edges) {
|
||||||
|
llvm::DenseMap<std::pair<size_t, size_t>, Weight> edgeWeights;
|
||||||
|
for (auto [start, end, weight] : edges) {
|
||||||
|
size_t startIndex = static_cast<size_t>(start);
|
||||||
|
size_t endIndex = static_cast<size_t>(end);
|
||||||
|
if (startIndex == endIndex)
|
||||||
|
continue;
|
||||||
|
auto key = std::make_pair(startIndex, endIndex);
|
||||||
|
Weight edgeWeight = static_cast<Weight>(weight);
|
||||||
|
auto inserted = edgeWeights.try_emplace(key, edgeWeight);
|
||||||
|
if (!inserted.second)
|
||||||
|
inserted.first->second = std::max(inserted.first->second, edgeWeight);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<IndexedEdge> aggregatedEdges;
|
||||||
|
aggregatedEdges.reserve(edgeWeights.size());
|
||||||
|
for (auto [key, weight] : edgeWeights)
|
||||||
|
aggregatedEdges.push_back(
|
||||||
|
{static_cast<int64_t>(key.first), static_cast<int64_t>(key.second), static_cast<int64_t>(weight)});
|
||||||
|
llvm::sort(aggregatedEdges, [](const IndexedEdge &lhs, const IndexedEdge &rhs) {
|
||||||
|
if (std::get<0>(lhs) != std::get<0>(rhs))
|
||||||
|
return std::get<0>(lhs) < std::get<0>(rhs);
|
||||||
|
return std::get<1>(lhs) < std::get<1>(rhs);
|
||||||
|
});
|
||||||
|
return aggregatedEdges;
|
||||||
|
}
|
||||||
|
|
||||||
|
VirtualGraph buildInitialVirtualGraph(const ComputeGraph &graph) {
|
||||||
|
VirtualGraph virtualGraph;
|
||||||
|
virtualGraph.nodes.reserve(graph.nodes.size());
|
||||||
|
for (auto [index, node] : llvm::enumerate(graph.nodes)) {
|
||||||
|
VirtualNode virtualNode;
|
||||||
|
virtualNode.originalNodeIndices.push_back(index);
|
||||||
|
virtualNode.weight = node.weight;
|
||||||
|
virtualNode.crossbarUsage = node.crossbarUsage;
|
||||||
|
virtualGraph.nodes.push_back(std::move(virtualNode));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<IndexedEdge> edges;
|
||||||
|
edges.reserve(graph.edges.size());
|
||||||
|
for (const ComputeGraphEdge &edge : graph.edges)
|
||||||
|
edges.push_back(
|
||||||
|
{static_cast<int64_t>(edge.source), static_cast<int64_t>(edge.target), static_cast<int64_t>(edge.transferCost)});
|
||||||
|
virtualGraph.edges = aggregateEdges(edges);
|
||||||
|
return virtualGraph;
|
||||||
|
}
|
||||||
|
|
||||||
|
TimingInfo computeTiming(const VirtualGraph &graph) {
|
||||||
|
TimingInfo timing;
|
||||||
|
size_t nodeCount = graph.nodes.size();
|
||||||
|
timing.aest.assign(nodeCount, 0);
|
||||||
|
timing.alst.assign(nodeCount, 0);
|
||||||
|
timing.topologicalOrder.reserve(nodeCount);
|
||||||
|
|
||||||
|
std::vector<std::vector<std::pair<size_t, Weight>>> parents(nodeCount);
|
||||||
|
std::vector<std::vector<std::pair<size_t, Weight>>> children(nodeCount);
|
||||||
|
std::vector<size_t> incomingEdgeCount(nodeCount, 0);
|
||||||
|
|
||||||
|
for (auto [start, end, weight] : graph.edges) {
|
||||||
|
size_t startIndex = static_cast<size_t>(start);
|
||||||
|
size_t endIndex = static_cast<size_t>(end);
|
||||||
|
Weight edgeWeight = static_cast<Weight>(weight);
|
||||||
|
assert(startIndex < nodeCount && endIndex < nodeCount && "virtual edge endpoint out of range");
|
||||||
|
children[startIndex].push_back({endIndex, edgeWeight});
|
||||||
|
parents[endIndex].push_back({startIndex, edgeWeight});
|
||||||
|
incomingEdgeCount[endIndex]++;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto getVirtualNodeOrderKey = [&](size_t nodeIndex) {
|
||||||
|
const VirtualNode &node = graph.nodes[nodeIndex];
|
||||||
|
if (!node.originalNodeIndices.empty())
|
||||||
|
return node.originalNodeIndices.front();
|
||||||
|
return nodeIndex;
|
||||||
|
};
|
||||||
|
auto readyNodeGreater = [&](size_t lhs, size_t rhs) {
|
||||||
|
size_t lhsKey = getVirtualNodeOrderKey(lhs);
|
||||||
|
size_t rhsKey = getVirtualNodeOrderKey(rhs);
|
||||||
|
if (lhsKey != rhsKey)
|
||||||
|
return lhsKey > rhsKey;
|
||||||
|
return lhs > rhs;
|
||||||
|
};
|
||||||
|
std::priority_queue<size_t, std::vector<size_t>, decltype(readyNodeGreater)> readyNodes(readyNodeGreater);
|
||||||
|
for (size_t i = 0; i < nodeCount; ++i)
|
||||||
|
if (incomingEdgeCount[i] == 0)
|
||||||
|
readyNodes.push(i);
|
||||||
|
|
||||||
|
while (!readyNodes.empty()) {
|
||||||
|
size_t current = readyNodes.top();
|
||||||
|
readyNodes.pop();
|
||||||
|
timing.topologicalOrder.push_back(current);
|
||||||
|
for (auto [child, weight] : children[current]) {
|
||||||
|
(void) weight;
|
||||||
|
assert(incomingEdgeCount[child] > 0 && "incoming edge count underflow");
|
||||||
|
incomingEdgeCount[child]--;
|
||||||
|
if (incomingEdgeCount[child] == 0)
|
||||||
|
readyNodes.push(child);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (timing.topologicalOrder.size() != nodeCount)
|
||||||
|
return timing;
|
||||||
|
|
||||||
|
Time dcpl = 0;
|
||||||
|
for (size_t nodeIndex : timing.topologicalOrder) {
|
||||||
|
Time maxParentAest = 0;
|
||||||
|
for (auto [parent, transferCost] : parents[nodeIndex]) {
|
||||||
|
maxParentAest =
|
||||||
|
std::max(maxParentAest, addOrMax(addOrMax(timing.aest[parent], graph.nodes[parent].weight), transferCost));
|
||||||
|
}
|
||||||
|
timing.aest[nodeIndex] = maxParentAest;
|
||||||
|
dcpl = std::max(dcpl, addOrMax(maxParentAest, graph.nodes[nodeIndex].weight));
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t nodeIndex : llvm::reverse(timing.topologicalOrder)) {
|
||||||
|
Time minAlst = std::numeric_limits<Time>::max();
|
||||||
|
if (children[nodeIndex].empty())
|
||||||
|
minAlst = subtractOrZero(dcpl, graph.nodes[nodeIndex].weight);
|
||||||
|
for (auto [child, transferCost] : children[nodeIndex]) {
|
||||||
|
minAlst =
|
||||||
|
std::min(minAlst, subtractOrZero(timing.alst[child], addOrMax(graph.nodes[nodeIndex].weight, transferCost)));
|
||||||
|
}
|
||||||
|
timing.alst[nodeIndex] = minAlst;
|
||||||
|
}
|
||||||
|
|
||||||
|
timing.valid = true;
|
||||||
|
return timing;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::vector<size_t>> buildUndirectedAdjacency(const VirtualGraph &graph) {
|
||||||
|
std::vector<std::vector<size_t>> adjacency(graph.nodes.size());
|
||||||
|
for (auto [start, end, weight] : graph.edges) {
|
||||||
|
(void) weight;
|
||||||
|
size_t startIndex = static_cast<size_t>(start);
|
||||||
|
size_t endIndex = static_cast<size_t>(end);
|
||||||
|
assert(startIndex < graph.nodes.size() && endIndex < graph.nodes.size() && "virtual edge endpoint out of range");
|
||||||
|
adjacency[startIndex].push_back(endIndex);
|
||||||
|
adjacency[endIndex].push_back(startIndex);
|
||||||
|
}
|
||||||
|
for (auto &neighbours : adjacency) {
|
||||||
|
llvm::sort(neighbours);
|
||||||
|
neighbours.erase(std::unique(neighbours.begin(), neighbours.end()), neighbours.end());
|
||||||
|
}
|
||||||
|
return adjacency;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<size_t> selectCriticalWindow(const VirtualGraph &graph, const TimingInfo &timing, size_t windowSize) {
|
||||||
|
std::vector<size_t> ranked(timing.aest.size());
|
||||||
|
std::iota(ranked.begin(), ranked.end(), 0);
|
||||||
|
auto isHigherPriority = [&](size_t lhs, size_t rhs) {
|
||||||
|
Time lhsSlack = slackOrZero(timing.aest[lhs], timing.alst[lhs]);
|
||||||
|
Time rhsSlack = slackOrZero(timing.aest[rhs], timing.alst[rhs]);
|
||||||
|
if (lhsSlack != rhsSlack)
|
||||||
|
return lhsSlack < rhsSlack;
|
||||||
|
if (timing.aest[lhs] != timing.aest[rhs])
|
||||||
|
return timing.aest[lhs] < timing.aest[rhs];
|
||||||
|
return lhs < rhs;
|
||||||
|
};
|
||||||
|
|
||||||
|
windowSize = std::min(windowSize, ranked.size());
|
||||||
|
if (windowSize == 0)
|
||||||
|
return {};
|
||||||
|
if (windowSize == ranked.size()) {
|
||||||
|
llvm::sort(ranked, isHigherPriority);
|
||||||
|
return ranked;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t criticalPoolSize = std::min(ranked.size(), std::max(windowSize, windowSize * 2));
|
||||||
|
if (criticalPoolSize < ranked.size())
|
||||||
|
std::nth_element(
|
||||||
|
ranked.begin(), ranked.begin() + static_cast<std::ptrdiff_t>(criticalPoolSize), ranked.end(), isHigherPriority);
|
||||||
|
|
||||||
|
std::vector<char> inCriticalPool(ranked.size(), false);
|
||||||
|
for (size_t i = 0; i < criticalPoolSize; ++i)
|
||||||
|
inCriticalPool[ranked[i]] = true;
|
||||||
|
|
||||||
|
size_t seed = *std::min_element(ranked.begin(), ranked.end(), isHigherPriority);
|
||||||
|
std::vector<std::vector<size_t>> adjacency = buildUndirectedAdjacency(graph);
|
||||||
|
std::vector<size_t> selected;
|
||||||
|
std::vector<char> inWindow(ranked.size(), false);
|
||||||
|
selected.reserve(windowSize);
|
||||||
|
|
||||||
|
struct FrontierEntry {
|
||||||
|
size_t node;
|
||||||
|
};
|
||||||
|
auto frontierCompare = [&](FrontierEntry lhs, FrontierEntry rhs) { return isHigherPriority(rhs.node, lhs.node); };
|
||||||
|
std::priority_queue<FrontierEntry, std::vector<FrontierEntry>, decltype(frontierCompare)> frontier(frontierCompare);
|
||||||
|
|
||||||
|
auto addToWindow = [&](size_t node, const std::vector<char> &eligible) {
|
||||||
|
if (inWindow[node])
|
||||||
|
return;
|
||||||
|
inWindow[node] = true;
|
||||||
|
selected.push_back(node);
|
||||||
|
for (size_t neighbour : adjacency[node])
|
||||||
|
if (!inWindow[neighbour] && eligible[neighbour])
|
||||||
|
frontier.push({neighbour});
|
||||||
|
};
|
||||||
|
|
||||||
|
addToWindow(seed, inCriticalPool);
|
||||||
|
while (!frontier.empty() && selected.size() < windowSize) {
|
||||||
|
size_t node = frontier.top().node;
|
||||||
|
frontier.pop();
|
||||||
|
if (!inWindow[node])
|
||||||
|
addToWindow(node, inCriticalPool);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (selected.size() < windowSize) {
|
||||||
|
std::vector<char> anyNode(ranked.size(), true);
|
||||||
|
for (size_t node : selected)
|
||||||
|
for (size_t neighbour : adjacency[node])
|
||||||
|
if (!inWindow[neighbour])
|
||||||
|
frontier.push({neighbour});
|
||||||
|
while (!frontier.empty() && selected.size() < windowSize) {
|
||||||
|
size_t node = frontier.top().node;
|
||||||
|
frontier.pop();
|
||||||
|
if (!inWindow[node])
|
||||||
|
addToWindow(node, anyNode);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (selected.size() < windowSize) {
|
||||||
|
llvm::sort(ranked, isHigherPriority);
|
||||||
|
for (size_t node : ranked) {
|
||||||
|
if (selected.size() == windowSize)
|
||||||
|
break;
|
||||||
|
if (!inWindow[node]) {
|
||||||
|
inWindow[node] = true;
|
||||||
|
selected.push_back(node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::sort(selected, isHigherPriority);
|
||||||
|
return selected;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<IndexedEdge> buildWindowEdges(const VirtualGraph &graph, const std::vector<int64_t> &nodeToWindowIndex) {
|
||||||
|
std::vector<IndexedEdge> windowEdges;
|
||||||
|
windowEdges.reserve(graph.edges.size());
|
||||||
|
for (auto [start, end, weight] : graph.edges) {
|
||||||
|
int64_t mappedStart = nodeToWindowIndex[static_cast<size_t>(start)];
|
||||||
|
int64_t mappedEnd = nodeToWindowIndex[static_cast<size_t>(end)];
|
||||||
|
if (mappedStart == -1 || mappedEnd == -1)
|
||||||
|
continue;
|
||||||
|
windowEdges.push_back({mappedStart, mappedEnd, weight});
|
||||||
|
}
|
||||||
|
return aggregateEdges(windowEdges);
|
||||||
|
}
|
||||||
|
|
||||||
|
WindowScheduleResult scheduleWindow(const VirtualGraph &graph,
|
||||||
|
llvm::ArrayRef<size_t> selectedNodes,
|
||||||
|
const DcpScheduleOptions &options,
|
||||||
|
mlir::MLIRContext *context) {
|
||||||
|
std::vector<Weight> windowWeights;
|
||||||
|
std::vector<CrossbarUsage> windowCrossbarUsage;
|
||||||
|
std::vector<int64_t> windowNodeOrderKeys;
|
||||||
|
std::vector<int64_t> nodeToWindowIndex(graph.nodes.size(), -1);
|
||||||
|
windowWeights.reserve(selectedNodes.size());
|
||||||
|
windowCrossbarUsage.reserve(selectedNodes.size());
|
||||||
|
windowNodeOrderKeys.reserve(selectedNodes.size());
|
||||||
|
|
||||||
|
for (auto [windowIndex, nodeIndex] : llvm::enumerate(selectedNodes)) {
|
||||||
|
nodeToWindowIndex[nodeIndex] = static_cast<int64_t>(windowIndex);
|
||||||
|
windowWeights.push_back(graph.nodes[nodeIndex].weight);
|
||||||
|
windowCrossbarUsage.push_back(graph.nodes[nodeIndex].crossbarUsage);
|
||||||
|
windowNodeOrderKeys.push_back(static_cast<int64_t>(nodeIndex));
|
||||||
|
}
|
||||||
|
|
||||||
|
GraphDCP windowGraph(
|
||||||
|
windowWeights, buildWindowEdges(graph, nodeToWindowIndex), windowNodeOrderKeys, windowCrossbarUsage);
|
||||||
|
if (options.processorCount > 0)
|
||||||
|
windowGraph.setMaxCpuCount(static_cast<int>(options.processorCount));
|
||||||
|
windowGraph.setContext(context);
|
||||||
|
windowGraph.runDcp();
|
||||||
|
|
||||||
|
WindowScheduleResult result;
|
||||||
|
result.cpuCount = windowGraph.cpuCount();
|
||||||
|
for (CPU cpu = 0; cpu < windowGraph.cpuCount(); ++cpu) {
|
||||||
|
auto scheduledTasks = windowGraph.getScheduledTasks(cpu);
|
||||||
|
if (scheduledTasks.size() < 2)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
result.mergedNodeCount += scheduledTasks.size();
|
||||||
|
result.maxMergeGroupSize = std::max(result.maxMergeGroupSize, scheduledTasks.size());
|
||||||
|
std::vector<size_t> mergeGroup;
|
||||||
|
mergeGroup.reserve(scheduledTasks.size());
|
||||||
|
for (const auto &task : scheduledTasks)
|
||||||
|
mergeGroup.push_back(selectedNodes[task.nodeIndex]);
|
||||||
|
result.mergeGroups.push_back(std::move(mergeGroup));
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool coarsenGraph(const VirtualGraph &graph,
|
||||||
|
llvm::ArrayRef<std::vector<size_t>> mergeGroups,
|
||||||
|
VirtualGraph &coarsenedGraph,
|
||||||
|
std::vector<size_t> &oldToNewNode) {
|
||||||
|
TimingInfo timing = computeTiming(graph);
|
||||||
|
std::vector<size_t> topologicalRank(graph.nodes.size());
|
||||||
|
std::iota(topologicalRank.begin(), topologicalRank.end(), 0);
|
||||||
|
if (timing.valid)
|
||||||
|
for (auto [rank, nodeIndex] : llvm::enumerate(timing.topologicalOrder))
|
||||||
|
topologicalRank[nodeIndex] = rank;
|
||||||
|
|
||||||
|
std::vector<std::vector<size_t>> orderedMergeGroups;
|
||||||
|
orderedMergeGroups.reserve(mergeGroups.size());
|
||||||
|
for (const auto &mergeGroup : mergeGroups) {
|
||||||
|
orderedMergeGroups.emplace_back(mergeGroup.begin(), mergeGroup.end());
|
||||||
|
std::stable_sort(orderedMergeGroups.back().begin(), orderedMergeGroups.back().end(), [&](size_t lhs, size_t rhs) {
|
||||||
|
if (topologicalRank[lhs] != topologicalRank[rhs])
|
||||||
|
return topologicalRank[lhs] < topologicalRank[rhs];
|
||||||
|
return lhs < rhs;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int64_t> nodeToMergeGroup(graph.nodes.size(), -1);
|
||||||
|
for (auto [groupIndex, mergeGroup] : llvm::enumerate(orderedMergeGroups)) {
|
||||||
|
if (mergeGroup.size() < 2)
|
||||||
|
continue;
|
||||||
|
for (size_t nodeIndex : mergeGroup) {
|
||||||
|
assert(nodeIndex < graph.nodes.size() && "merge group node out of range");
|
||||||
|
nodeToMergeGroup[nodeIndex] = static_cast<int64_t>(groupIndex);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::optional<size_t>> mergeGroupToNewNode(orderedMergeGroups.size());
|
||||||
|
std::vector<size_t> newNodeRank;
|
||||||
|
oldToNewNode.assign(graph.nodes.size(), 0);
|
||||||
|
bool mergedAny = false;
|
||||||
|
coarsenedGraph.nodes.clear();
|
||||||
|
coarsenedGraph.edges.clear();
|
||||||
|
coarsenedGraph.nodes.reserve(graph.nodes.size());
|
||||||
|
newNodeRank.reserve(graph.nodes.size());
|
||||||
|
|
||||||
|
for (size_t nodeIndex = 0; nodeIndex < graph.nodes.size(); ++nodeIndex) {
|
||||||
|
int64_t mergeGroupIndex = nodeToMergeGroup[nodeIndex];
|
||||||
|
if (mergeGroupIndex == -1) {
|
||||||
|
oldToNewNode[nodeIndex] = coarsenedGraph.nodes.size();
|
||||||
|
coarsenedGraph.nodes.push_back(graph.nodes[nodeIndex]);
|
||||||
|
newNodeRank.push_back(topologicalRank[nodeIndex]);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto &newNodeIndex = mergeGroupToNewNode[static_cast<size_t>(mergeGroupIndex)];
|
||||||
|
if (newNodeIndex.has_value()) {
|
||||||
|
oldToNewNode[nodeIndex] = *newNodeIndex;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
VirtualNode mergedNode;
|
||||||
|
for (size_t memberIndex : orderedMergeGroups[static_cast<size_t>(mergeGroupIndex)]) {
|
||||||
|
const VirtualNode &memberNode = graph.nodes[memberIndex];
|
||||||
|
mergedNode.originalNodeIndices.append(memberNode.originalNodeIndices.begin(), memberNode.originalNodeIndices.end());
|
||||||
|
mergedNode.weight = addOrMax(mergedNode.weight, memberNode.weight);
|
||||||
|
mergedNode.crossbarUsage = addOrMax(mergedNode.crossbarUsage, memberNode.crossbarUsage);
|
||||||
|
}
|
||||||
|
std::sort(mergedNode.originalNodeIndices.begin(), mergedNode.originalNodeIndices.end());
|
||||||
|
|
||||||
|
mergedAny = true;
|
||||||
|
newNodeIndex = coarsenedGraph.nodes.size();
|
||||||
|
for (size_t memberIndex : orderedMergeGroups[static_cast<size_t>(mergeGroupIndex)])
|
||||||
|
oldToNewNode[memberIndex] = *newNodeIndex;
|
||||||
|
newNodeRank.push_back(topologicalRank[orderedMergeGroups[static_cast<size_t>(mergeGroupIndex)].front()]);
|
||||||
|
coarsenedGraph.nodes.push_back(std::move(mergedNode));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!mergedAny)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
std::vector<IndexedEdge> remappedEdges;
|
||||||
|
remappedEdges.reserve(graph.edges.size());
|
||||||
|
for (auto [start, end, weight] : graph.edges) {
|
||||||
|
size_t newStart = oldToNewNode[static_cast<size_t>(start)];
|
||||||
|
size_t newEnd = oldToNewNode[static_cast<size_t>(end)];
|
||||||
|
if (newStart == newEnd)
|
||||||
|
continue;
|
||||||
|
if (newNodeRank[newStart] >= newNodeRank[newEnd])
|
||||||
|
continue;
|
||||||
|
remappedEdges.push_back({static_cast<int64_t>(newStart), static_cast<int64_t>(newEnd), weight});
|
||||||
|
}
|
||||||
|
coarsenedGraph.edges = aggregateEdges(remappedEdges);
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t getDcpCoarseningWindowSize(size_t nodeCount, const DcpScheduleOptions &options) {
|
||||||
|
size_t windowSize = std::min(options.criticalWindowSize, nodeCount);
|
||||||
|
CPU maxCpuCount = std::max<CPU>(1, static_cast<CPU>(getSchedulingCpuBudget(options)));
|
||||||
|
if (nodeCount > static_cast<size_t>(maxCpuCount))
|
||||||
|
windowSize = std::max(windowSize, std::min(nodeCount, static_cast<size_t>(maxCpuCount) + 1));
|
||||||
|
return windowSize;
|
||||||
|
}
|
||||||
|
|
||||||
|
void assignFeasibleAest(const ComputeGraph &graph, MergeScheduleResult &result) {
|
||||||
|
llvm::DenseMap<ComputeInstance, size_t> nodeIndexByInstance;
|
||||||
|
nodeIndexByInstance.reserve(graph.nodes.size());
|
||||||
|
for (auto [nodeIndex, node] : llvm::enumerate(graph.nodes))
|
||||||
|
nodeIndexByInstance[node.instance] = nodeIndex;
|
||||||
|
|
||||||
|
struct ScheduledEdge {
|
||||||
|
size_t target = 0;
|
||||||
|
Time delay = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<std::vector<ScheduledEdge>> scheduledChildren(graph.nodes.size());
|
||||||
|
std::vector<size_t> incomingEdgeCount(graph.nodes.size(), 0);
|
||||||
|
for (const ComputeGraphEdge &edge : graph.edges) {
|
||||||
|
const ComputeInstance sourceInstance = graph.nodes[edge.source].instance;
|
||||||
|
const ComputeInstance targetInstance = graph.nodes[edge.target].instance;
|
||||||
|
const size_t sourceCpu = result.computeToCpuMap.lookup(sourceInstance);
|
||||||
|
const size_t targetCpu = result.computeToCpuMap.lookup(targetInstance);
|
||||||
|
|
||||||
|
Time delay = graph.nodes[edge.source].weight;
|
||||||
|
if (sourceCpu != targetCpu)
|
||||||
|
delay = addOrMax(delay, edge.transferCost);
|
||||||
|
|
||||||
|
scheduledChildren[edge.source].push_back({edge.target, delay});
|
||||||
|
incomingEdgeCount[edge.target]++;
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::DenseMap<size_t, std::vector<std::pair<size_t, size_t>>> tasksByCpu;
|
||||||
|
for (const ComputeGraphNode &node : graph.nodes) {
|
||||||
|
size_t cpu = result.computeToCpuMap.lookup(node.instance);
|
||||||
|
size_t slot = result.computeToCpuSlotMap.lookup(node.instance);
|
||||||
|
tasksByCpu[cpu].push_back({slot, nodeIndexByInstance.lookup(node.instance)});
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto &entry : tasksByCpu) {
|
||||||
|
auto &scheduledTasks = entry.second;
|
||||||
|
llvm::sort(scheduledTasks, [](const auto &lhs, const auto &rhs) {
|
||||||
|
if (lhs.first != rhs.first)
|
||||||
|
return lhs.first < rhs.first;
|
||||||
|
return lhs.second < rhs.second;
|
||||||
|
});
|
||||||
|
|
||||||
|
for (size_t i = 1; i < scheduledTasks.size(); ++i) {
|
||||||
|
size_t sourceIndex = scheduledTasks[i - 1].second;
|
||||||
|
size_t targetIndex = scheduledTasks[i].second;
|
||||||
|
scheduledChildren[sourceIndex].push_back({targetIndex, graph.nodes[sourceIndex].weight});
|
||||||
|
incomingEdgeCount[targetIndex]++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto readyNodeGreater = [&](size_t lhs, size_t rhs) {
|
||||||
|
if (graph.nodes[lhs].originalOrder != graph.nodes[rhs].originalOrder)
|
||||||
|
return graph.nodes[lhs].originalOrder > graph.nodes[rhs].originalOrder;
|
||||||
|
return lhs > rhs;
|
||||||
|
};
|
||||||
|
std::priority_queue<size_t, std::vector<size_t>, decltype(readyNodeGreater)> readyNodes(readyNodeGreater);
|
||||||
|
for (size_t nodeIndex = 0; nodeIndex < graph.nodes.size(); ++nodeIndex)
|
||||||
|
if (incomingEdgeCount[nodeIndex] == 0)
|
||||||
|
readyNodes.push(nodeIndex);
|
||||||
|
|
||||||
|
std::vector<Time> startTimes(graph.nodes.size(), 0);
|
||||||
|
size_t processedNodeCount = 0;
|
||||||
|
while (!readyNodes.empty()) {
|
||||||
|
size_t sourceIndex = readyNodes.top();
|
||||||
|
readyNodes.pop();
|
||||||
|
processedNodeCount++;
|
||||||
|
|
||||||
|
for (const ScheduledEdge &edge : scheduledChildren[sourceIndex]) {
|
||||||
|
startTimes[edge.target] = std::max(startTimes[edge.target], addOrMax(startTimes[sourceIndex], edge.delay));
|
||||||
|
assert(incomingEdgeCount[edge.target] > 0 && "scheduled incoming edge count underflow");
|
||||||
|
incomingEdgeCount[edge.target]--;
|
||||||
|
if (incomingEdgeCount[edge.target] == 0)
|
||||||
|
readyNodes.push(edge.target);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (processedNodeCount != graph.nodes.size())
|
||||||
|
llvm::report_fatal_error("merge scheduling: coarsened DCP schedule is cyclic");
|
||||||
|
|
||||||
|
for (auto [nodeIndex, node] : llvm::enumerate(graph.nodes))
|
||||||
|
result.computeToAestMap[node.instance] = startTimes[nodeIndex];
|
||||||
|
}
|
||||||
|
|
||||||
|
MergeScheduleResult buildResultFromVirtualGraph(const VirtualGraph &graph, const ComputeGraph &originalGraph) {
|
||||||
|
MergeScheduleResult result;
|
||||||
|
|
||||||
|
TimingInfo timing = computeTiming(graph);
|
||||||
|
std::vector<size_t> virtualNodeOrder;
|
||||||
|
if (timing.valid)
|
||||||
|
virtualNodeOrder = std::move(timing.topologicalOrder);
|
||||||
|
else {
|
||||||
|
virtualNodeOrder.resize(graph.nodes.size());
|
||||||
|
std::iota(virtualNodeOrder.begin(), virtualNodeOrder.end(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<size_t> originalNodeToCpu(originalGraph.nodes.size(), 0);
|
||||||
|
for (auto [cpu, virtualNodeIndex] : llvm::enumerate(virtualNodeOrder)) {
|
||||||
|
const VirtualNode &virtualNode = graph.nodes[virtualNodeIndex];
|
||||||
|
for (size_t originalIndex : virtualNode.originalNodeIndices)
|
||||||
|
originalNodeToCpu[originalIndex] = cpu;
|
||||||
|
}
|
||||||
|
|
||||||
|
result.dominanceOrderCompute.reserve(originalGraph.nodes.size());
|
||||||
|
llvm::DenseMap<size_t, size_t> nextCpuSlot;
|
||||||
|
for (auto [originalIndex, node] : llvm::enumerate(originalGraph.nodes)) {
|
||||||
|
size_t cpu = originalNodeToCpu[originalIndex];
|
||||||
|
result.dominanceOrderCompute.push_back(node.instance);
|
||||||
|
result.computeToCpuMap[node.instance] = cpu;
|
||||||
|
result.computeToCpuSlotMap[node.instance] = nextCpuSlot[cpu]++;
|
||||||
|
result.cpuToLastComputeMap[cpu] = node.instance;
|
||||||
|
}
|
||||||
|
for (const auto &[cpu, lastCompute] : result.cpuToLastComputeMap)
|
||||||
|
result.isLastComputeOfCpu.insert(lastCompute);
|
||||||
|
assignFeasibleAest(originalGraph, result);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
MergeScheduleResult buildResultFromScheduledGraph(GraphDCP &graphDCP, const ComputeGraph &graph) {
|
||||||
|
MergeScheduleResult result;
|
||||||
|
result.dominanceOrderCompute.reserve(graph.nodes.size());
|
||||||
|
for (const ComputeGraphNode &node : graph.nodes)
|
||||||
|
result.dominanceOrderCompute.push_back(node.instance);
|
||||||
|
|
||||||
|
for (CPU cpu = 0; cpu < graphDCP.cpuCount(); ++cpu) {
|
||||||
|
auto scheduledTasks = graphDCP.getScheduledTasks(cpu);
|
||||||
|
if (scheduledTasks.empty())
|
||||||
|
continue;
|
||||||
|
|
||||||
|
for (auto [slot, task] : llvm::enumerate(scheduledTasks)) {
|
||||||
|
const ComputeInstance instance = graph.nodes[task.nodeIndex].instance;
|
||||||
|
result.computeToCpuMap[instance] = cpu;
|
||||||
|
result.computeToCpuSlotMap[instance] = slot;
|
||||||
|
result.computeToAestMap[instance] = static_cast<uint64_t>(task.aest);
|
||||||
|
}
|
||||||
|
|
||||||
|
const ComputeInstance lastInstance = graph.nodes[scheduledTasks.back().nodeIndex].instance;
|
||||||
|
result.cpuToLastComputeMap[cpu] = lastInstance;
|
||||||
|
result.isLastComputeOfCpu.insert(lastInstance);
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
MergeScheduleResult runLegacyDcp(const ComputeGraph &graph, const DcpScheduleOptions &options, mlir::MLIRContext *context) {
|
||||||
llvm::SmallVector<Weight> nodeWeights;
|
llvm::SmallVector<Weight> nodeWeights;
|
||||||
llvm::SmallVector<CrossbarUsage> nodeCrossbarUsage;
|
llvm::SmallVector<CrossbarUsage> nodeCrossbarUsage;
|
||||||
llvm::SmallVector<int64_t> nodeOrderKeys;
|
llvm::SmallVector<int64_t> nodeOrderKeys;
|
||||||
@@ -28,34 +610,110 @@ MergeScheduleResult runDcpScheduler(const ComputeGraph &graph, mlir::MLIRContext
|
|||||||
}
|
}
|
||||||
|
|
||||||
GraphDCP graphDCP(nodeWeights, edges, nodeOrderKeys, nodeCrossbarUsage);
|
GraphDCP graphDCP(nodeWeights, edges, nodeOrderKeys, nodeCrossbarUsage);
|
||||||
if (coresCount.getValue() > 0)
|
if (options.processorCount > 0)
|
||||||
graphDCP.setMaxCpuCount(static_cast<int>(coresCount.getValue()));
|
graphDCP.setMaxCpuCount(static_cast<int>(options.processorCount));
|
||||||
graphDCP.setContext(context);
|
graphDCP.setContext(context);
|
||||||
graphDCP.runDcp();
|
graphDCP.runDcp();
|
||||||
|
return buildResultFromScheduledGraph(graphDCP, graph);
|
||||||
|
}
|
||||||
|
|
||||||
MergeScheduleResult result;
|
bool needsExactScheduledBatches(const ComputeGraph &graph, const DcpScheduleOptions &options) {
|
||||||
result.dominanceOrderCompute.reserve(graph.nodes.size());
|
if (options.processorCount == 0 || !options.allowFallbackForAutoCoreCount)
|
||||||
for (const ComputeGraphNode &node : graph.nodes)
|
return false;
|
||||||
result.dominanceOrderCompute.push_back(node.instance);
|
size_t schedulingCpuBudget = getSchedulingCpuBudget(options);
|
||||||
|
return llvm::any_of(graph.nodes, [&](const ComputeGraphNode &node) {
|
||||||
|
auto batch = dyn_cast<SpatComputeBatch>(node.instance.op);
|
||||||
|
return batch && static_cast<size_t>(batch.getLaneCount()) > schedulingCpuBudget;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
for (CPU cpu = 0; cpu < graphDCP.cpuCount(); ++cpu) {
|
} // namespace
|
||||||
auto scheduledTasks = graphDCP.getScheduledTasks(cpu);
|
|
||||||
if (scheduledTasks.empty())
|
|
||||||
continue;
|
|
||||||
|
|
||||||
for (const auto &[slot, task] : llvm::enumerate(scheduledTasks)) {
|
MergeScheduleResult
|
||||||
const ComputeInstance instance = graph.nodes[task.nodeIndex].instance;
|
runDcpScheduler(const ComputeGraph &graph, const DcpScheduleOptions &options, mlir::MLIRContext *context) {
|
||||||
result.computeToCpuMap[instance] = cpu;
|
if (needsExactScheduledBatches(graph, options))
|
||||||
result.computeToCpuSlotMap[instance] = slot;
|
return runLegacyDcp(graph, options, context);
|
||||||
result.computeToAestMap[instance] = static_cast<uint64_t>(task.aest);
|
|
||||||
|
if (options.criticalWindowSize == 0)
|
||||||
|
return runLegacyDcp(graph, options, context);
|
||||||
|
|
||||||
|
VirtualGraph virtualGraph = buildInitialVirtualGraph(graph);
|
||||||
|
size_t iteration = 0;
|
||||||
|
bool debugCoarsening = isDcpCoarsenDebugEnabled();
|
||||||
|
auto tryCoarsenSelectedNodes = [&](llvm::ArrayRef<size_t> selectedNodes) {
|
||||||
|
size_t oldNodeCount = virtualGraph.nodes.size();
|
||||||
|
WindowScheduleResult windowSchedule = scheduleWindow(virtualGraph, selectedNodes, options, context);
|
||||||
|
if (windowSchedule.mergeGroups.empty()) {
|
||||||
|
if (debugCoarsening && oldNodeCount >= 200)
|
||||||
|
llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} windowCpus={3} "
|
||||||
|
"groups=0 mergedNodes=0 maxGroup=0 new={1} changed=0\n",
|
||||||
|
iteration,
|
||||||
|
oldNodeCount,
|
||||||
|
selectedNodes.size(),
|
||||||
|
windowSchedule.cpuCount);
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
const ComputeInstance lastInstance = graph.nodes[scheduledTasks.back().nodeIndex].instance;
|
VirtualGraph coarsenedGraph;
|
||||||
result.cpuToLastComputeMap[cpu] = lastInstance;
|
std::vector<size_t> oldToNewNode;
|
||||||
result.isLastComputeOfCpu.insert(lastInstance);
|
if (!coarsenGraph(virtualGraph, windowSchedule.mergeGroups, coarsenedGraph, oldToNewNode))
|
||||||
|
return false;
|
||||||
|
if (debugCoarsening && (oldNodeCount >= 200 || coarsenedGraph.nodes.size() >= 200))
|
||||||
|
llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} windowCpus={3} "
|
||||||
|
"groups={4} mergedNodes={5} maxGroup={6} new={7} changed={8}\n",
|
||||||
|
iteration,
|
||||||
|
oldNodeCount,
|
||||||
|
selectedNodes.size(),
|
||||||
|
windowSchedule.cpuCount,
|
||||||
|
windowSchedule.mergeGroups.size(),
|
||||||
|
windowSchedule.mergedNodeCount,
|
||||||
|
windowSchedule.maxMergeGroupSize,
|
||||||
|
coarsenedGraph.nodes.size(),
|
||||||
|
oldNodeCount - coarsenedGraph.nodes.size());
|
||||||
|
virtualGraph = std::move(coarsenedGraph);
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
|
||||||
|
while (virtualGraph.nodes.size() > 1) {
|
||||||
|
if (virtualGraph.nodes.size() <= getSchedulingCpuBudget(options)) {
|
||||||
|
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
|
||||||
|
llvm::errs() << llvm::formatv(
|
||||||
|
"[DCP-COARSEN] iter={0} old={1} stop=cpu-budget\n", iteration, virtualGraph.nodes.size());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
iteration++;
|
||||||
|
TimingInfo timing = computeTiming(virtualGraph);
|
||||||
|
if (!timing.valid) {
|
||||||
|
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
|
||||||
|
llvm::errs() << llvm::formatv(
|
||||||
|
"[DCP-COARSEN] iter={0} old={1} invalid-timing\n", iteration, virtualGraph.nodes.size());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::SmallVector<size_t> selectedNodes;
|
||||||
|
auto criticalWindow =
|
||||||
|
selectCriticalWindow(virtualGraph, timing, getDcpCoarseningWindowSize(virtualGraph.nodes.size(), options));
|
||||||
|
selectedNodes.append(criticalWindow.begin(), criticalWindow.end());
|
||||||
|
|
||||||
|
if (selectedNodes.size() < 2) {
|
||||||
|
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
|
||||||
|
llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} stop=small-window\n",
|
||||||
|
iteration,
|
||||||
|
virtualGraph.nodes.size(),
|
||||||
|
selectedNodes.size());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (tryCoarsenSelectedNodes(selectedNodes))
|
||||||
|
continue;
|
||||||
|
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
|
||||||
|
llvm::errs() << llvm::formatv(
|
||||||
|
"[DCP-COARSEN] iter={0} old={1} stop=no-merge\n", iteration, virtualGraph.nodes.size());
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
return result;
|
return buildResultFromVirtualGraph(virtualGraph, graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace spatial
|
} // namespace spatial
|
||||||
|
|||||||
@@ -8,7 +8,14 @@
|
|||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
namespace spatial {
|
namespace spatial {
|
||||||
|
|
||||||
MergeScheduleResult runDcpScheduler(const ComputeGraph &graph, mlir::MLIRContext *context);
|
struct DcpScheduleOptions {
|
||||||
|
size_t processorCount = 0;
|
||||||
|
size_t criticalWindowSize = 0;
|
||||||
|
bool allowFallbackForAutoCoreCount = true;
|
||||||
|
};
|
||||||
|
|
||||||
|
MergeScheduleResult
|
||||||
|
runDcpScheduler(const ComputeGraph &graph, const DcpScheduleOptions &options, mlir::MLIRContext *context);
|
||||||
|
|
||||||
} // namespace spatial
|
} // namespace spatial
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
+9
-3
@@ -121,11 +121,17 @@ MergeScheduleResult MergeSchedulingAnalysis::run() {
|
|||||||
entryOp->getContext()});
|
entryOp->getContext()});
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
schedule = DCPAnalysis(entryOp).getResult();
|
schedule = runDcpScheduler(
|
||||||
|
graph,
|
||||||
|
DcpScheduleOptions {
|
||||||
|
options.processorCount,
|
||||||
|
dcpCriticalWindowSize.getValue(),
|
||||||
|
options.allowDcpFallbackForAutoCoreCount
|
||||||
|
},
|
||||||
|
entryOp->getContext());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (options.kind == MergeSchedulerKind::Peft)
|
verifySchedule(graph, schedule, static_cast<CrossbarUsage>(crossbarCountInCore.getValue()));
|
||||||
verifySchedule(graph, schedule, static_cast<CrossbarUsage>(crossbarCountInCore.getValue()));
|
|
||||||
return schedule;
|
return schedule;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user