Compare commits
23 Commits
9f9e7c0892
..
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 5637c861b4 | |||
| 94157a8404 | |||
| 68a3521978 | |||
| e263e05f56 | |||
| 34c29fdec4 | |||
| aa088e2ba5 | |||
| 2836e759ab | |||
| 8071ebab0b | |||
| f1602c0550 | |||
| de0a2f4561 | |||
| 1c4a5bde76 | |||
| 78242e2887 | |||
| fe244d5aa1 | |||
| d09e76c8f9 | |||
| c5e608fa5b | |||
| 43f3ccdd21 | |||
| 8d95c604a6 | |||
| 55eda487dc | |||
| 061139aefb | |||
| ea61540e08 | |||
| 324178cba8 | |||
| e71ba07cd5 | |||
| 64a3805619 |
@@ -114,7 +114,9 @@ Pass these on the `onnx-mlir` command line when compiling for PIM:
|
||||
run only the codegen tail.
|
||||
- `--crossbar-size=<N>` / `--crossbar-count=<N>` — crossbar dimensions and
|
||||
per-core count.
|
||||
- `--core-count=<N>` — number of cores (`-1` picks the minimum).
|
||||
- `--core-count=<N>` — number of cores. Required for PIM compilation.
|
||||
- `--pim-merge-scheduler={peft,dcp}` — scheduler used by the Spatial
|
||||
merge-compute-nodes pass (default: `peft`).
|
||||
- `--dcp-critical-window-size=<N>` — DCP coarsening window (0 = legacy).
|
||||
- `--use-experimental-conv-impl` — alternative convolution lowering.
|
||||
- `--ignore-concat-error` — soft-fail corner case in `ConcatOp`.
|
||||
@@ -129,7 +131,8 @@ Per-operation validation (from `validation/`):
|
||||
```
|
||||
validate.py \
|
||||
--raptor-path ../cmake-build-release/Release/bin/onnx-mlir \
|
||||
--onnx-include-dir ../onnx-mlir/include
|
||||
--onnx-include-dir ../onnx-mlir/include \
|
||||
--core-count 1000
|
||||
```
|
||||
|
||||
End-to-end network validation (example: first 4 layers of YOLOv11n):
|
||||
|
||||
@@ -67,7 +67,7 @@ fn main() -> Result<()> {
|
||||
.lock()
|
||||
.unwrap()
|
||||
.init(executor.cpu().num_core(), args.output.clone());
|
||||
executor.execute();
|
||||
executor.execute()?;
|
||||
dump_memory(executor, &args)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -392,6 +392,7 @@ mod tests {
|
||||
HEADER_SIZE, InstructionRecord, MAGIC, RECORD_SIZE, VERSION, binary_to_instructions,
|
||||
};
|
||||
use crate::{
|
||||
functor_to_name,
|
||||
instruction_set::{InstructionsBuilder, instruction_data::InstructionDataBuilder},
|
||||
json_to_instruction::json_isa::json_to_instruction,
|
||||
};
|
||||
@@ -486,7 +487,10 @@ mod tests {
|
||||
|
||||
assert_eq!(json_instructions.len(), binary_instructions.len());
|
||||
for (json_inst, binary_inst) in json_instructions.iter().zip(binary_instructions.iter()) {
|
||||
assert_eq!(json_inst.functor_name(), binary_inst.functor_name());
|
||||
assert_eq!(
|
||||
functor_to_name(json_inst.functor as usize),
|
||||
functor_to_name(binary_inst.functor as usize)
|
||||
);
|
||||
assert_eq!(json_inst.data, binary_inst.data);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
mod json_isa;
|
||||
pub(crate) mod json_isa;
|
||||
pub mod json_to_executor;
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#![allow(unused)]
|
||||
|
||||
use anyhow::{Result, bail};
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
time::{Duration, SystemTime},
|
||||
@@ -87,6 +88,11 @@ pub struct Executable<'a> {
|
||||
send_recv: SendRecv,
|
||||
}
|
||||
|
||||
struct DeadlockInfo {
|
||||
cycle: String,
|
||||
states: String,
|
||||
}
|
||||
|
||||
fn print_status(core_instructions: &[CoreInstructions]) {
|
||||
let mut tot_instructions = 0;
|
||||
let mut progress = 0;
|
||||
@@ -118,7 +124,7 @@ impl<'a> Executable<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn execute<'b>(&'b mut self)
|
||||
pub fn execute<'b>(&'b mut self) -> Result<()>
|
||||
where
|
||||
'a: 'b,
|
||||
{
|
||||
@@ -153,7 +159,13 @@ impl<'a> Executable<'a> {
|
||||
}
|
||||
if (now.elapsed().unwrap() > Duration::from_secs(5)) {
|
||||
print_status(cores_instructions);
|
||||
check_cycle(cpu, cores_instructions, send_recv);
|
||||
if let Some(deadlock) = detect_deadlock(cores_instructions) {
|
||||
bail!(
|
||||
"Deadlock cycle detected: {} [{}]",
|
||||
deadlock.cycle,
|
||||
deadlock.states
|
||||
);
|
||||
}
|
||||
now = SystemTime::now();
|
||||
}
|
||||
}
|
||||
@@ -178,8 +190,23 @@ impl<'a> Executable<'a> {
|
||||
}
|
||||
print_status(cores_instructions);
|
||||
|
||||
if let Some(deadlock) = detect_deadlock(cores_instructions) {
|
||||
bail!(
|
||||
"Deadlock cycle detected: {} [{}]",
|
||||
deadlock.cycle,
|
||||
deadlock.states
|
||||
);
|
||||
}
|
||||
if cores_instructions
|
||||
.iter()
|
||||
.any(|core_inst| core_inst.program_counter < core_inst.instructions.len())
|
||||
{
|
||||
bail!("Execution stalled with unfinished instructions");
|
||||
}
|
||||
|
||||
#[cfg(feature = "profile_time")]
|
||||
TRACER.lock().unwrap().report();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn cpu(&self) -> &CPU<'a> {
|
||||
@@ -201,11 +228,11 @@ impl<'a> Executable<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
fn check_cycle(cpu: &mut CPU, cores_instructions: &[CoreInstructions], send_recv: &mut SendRecv) {
|
||||
fn detect_deadlock(cores_instructions: &[CoreInstructions]) -> Option<DeadlockInfo> {
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
enum CoreState {
|
||||
SendingTo(i32),
|
||||
ReceivingFrom(i32),
|
||||
SendingTo(i32, i32),
|
||||
ReceivingFrom(i32, i32),
|
||||
Working,
|
||||
Halted,
|
||||
}
|
||||
@@ -223,9 +250,9 @@ fn check_cycle(cpu: &mut CPU, cores_instructions: &[CoreInstructions], send_recv
|
||||
let (this_core, target_core) = data.get_core_immcore();
|
||||
|
||||
if isa_recv(functor_address) {
|
||||
states.insert(this_core, CoreState::ReceivingFrom(target_core));
|
||||
states.insert(this_core, CoreState::ReceivingFrom(target_core, data.imm_len()));
|
||||
} else if isa_send(functor_address) {
|
||||
states.insert(this_core, CoreState::SendingTo(target_core));
|
||||
states.insert(this_core, CoreState::SendingTo(target_core, data.imm_len()));
|
||||
} else {
|
||||
states.insert(this_core, CoreState::Working);
|
||||
}
|
||||
@@ -235,15 +262,15 @@ fn check_cycle(cpu: &mut CPU, cores_instructions: &[CoreInstructions], send_recv
|
||||
|
||||
for (&core_id, state) in states.iter() {
|
||||
match state {
|
||||
CoreState::SendingTo(target_core) => {
|
||||
CoreState::SendingTo(target_core, size) => {
|
||||
let target_state = states.get(target_core).unwrap_or(&CoreState::Halted);
|
||||
if target_state != &CoreState::ReceivingFrom(core_id) {
|
||||
if target_state != &CoreState::ReceivingFrom(core_id, *size) {
|
||||
wait_for.insert(core_id, *target_core);
|
||||
}
|
||||
}
|
||||
CoreState::ReceivingFrom(target_core) => {
|
||||
CoreState::ReceivingFrom(target_core, size) => {
|
||||
let target_state = states.get(target_core).unwrap_or(&CoreState::Halted);
|
||||
if target_state != &CoreState::SendingTo(core_id) {
|
||||
if target_state != &CoreState::SendingTo(core_id, *size) {
|
||||
wait_for.insert(core_id, *target_core);
|
||||
}
|
||||
}
|
||||
@@ -279,11 +306,33 @@ fn check_cycle(cpu: &mut CPU, cores_instructions: &[CoreInstructions], send_recv
|
||||
.collect::<Vec<_>>()
|
||||
.join(" -> ");
|
||||
|
||||
let cycle = cycle
|
||||
.iter()
|
||||
.copied()
|
||||
.chain(std::iter::once(waiting_for))
|
||||
.collect::<Vec<_>>();
|
||||
let cycle_msg = format!("{} -> {}", cycle_str, waiting_for);
|
||||
let states_msg = cycle
|
||||
.iter()
|
||||
.filter_map(|core| {
|
||||
states.get(core).map(|state| match state {
|
||||
CoreState::SendingTo(target, size) => {
|
||||
format!("core {} send {}B -> {}", core, size, target)
|
||||
}
|
||||
CoreState::ReceivingFrom(source, size) => {
|
||||
format!("core {} recv {}B <- {}", core, size, source)
|
||||
}
|
||||
CoreState::Working => format!("core {} working", core),
|
||||
CoreState::Halted => format!("core {} halted", core),
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ");
|
||||
|
||||
println!("Fatal: Deadlock cycle detected: {}", cycle_msg);
|
||||
// bail!("Deadlock detected: {}", cycle_msg);
|
||||
break; // Stop tracing
|
||||
return Some(DeadlockInfo {
|
||||
cycle: cycle_msg,
|
||||
states: states_msg,
|
||||
});
|
||||
}
|
||||
|
||||
// Hit a known branch that didn't result in a cycle
|
||||
@@ -294,6 +343,7 @@ fn check_cycle(cpu: &mut CPU, cores_instructions: &[CoreInstructions], send_recv
|
||||
current_core = waiting_for;
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn handle_wait_sync<'a, 'b, 'c>(
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
use std::path::Path;
|
||||
|
||||
use pimcore::{Executable, cpu::CPU, instruction_set::{InstructionsBuilder, instruction_data::InstructionDataBuilder, isa::*}};
|
||||
use pimcore::{
|
||||
Executable,
|
||||
cpu::crossbar::Crossbar,
|
||||
instruction_set::{InstructionsBuilder, instruction_data::InstructionDataBuilder, isa::*},
|
||||
memory_manager::CoreMemory,
|
||||
};
|
||||
|
||||
fn simple_read(path: &Path) -> Vec<f32> {
|
||||
if !path.exists() {
|
||||
@@ -17,14 +22,12 @@ fn simple_read(path: &Path) -> Vec<f32> {
|
||||
fn mvmul_f32(err: &str)
|
||||
where
|
||||
{
|
||||
let mut cpu = CPU::new(0);
|
||||
cpu.reserve_crossbar(1, 1024 * size_of::<f32>(), 1024);
|
||||
let (memory, crossbars) = cpu.host().get_memory_crossbar();
|
||||
let matrix = simple_read(Path::new("B.txt")) ;
|
||||
|
||||
|
||||
crossbars.get_mut(0).unwrap().execute_store( &matrix).unwrap();
|
||||
let vector = simple_read(Path::new("A.txt"));
|
||||
let matrix = simple_read(Path::new("tests/B.txt"));
|
||||
let mut crossbar = Crossbar::new(1024 * size_of::<f32>(), 1024, CoreMemory::new());
|
||||
crossbar.execute_store(&matrix).unwrap();
|
||||
let mut cpu = pimcore::cpu::CPU::new(0, vec![vec![&crossbar]]);
|
||||
let (memory, _) = cpu.host().get_memory_crossbar();
|
||||
let vector = simple_read(Path::new("tests/A.txt"));
|
||||
memory.execute_store(0, &vector).unwrap();
|
||||
|
||||
let mut inst_builder = InstructionsBuilder::new();
|
||||
@@ -57,7 +60,7 @@ where
|
||||
.cpu_mut()
|
||||
.host()
|
||||
.load::<f32>(1024 * size_of::<f32>(), 1024*size_of::<f32>()).unwrap()[0].iter().zip(
|
||||
simple_read(Path::new("X.txt")) ).all(|(&a,b) : (&f32, f32)| {a-b < 0.001}),
|
||||
simple_read(Path::new("tests/X.txt")) ).all(|(&a,b) : (&f32, f32)| {a-b < 0.001}),
|
||||
"Wrong result for {}",
|
||||
err
|
||||
);
|
||||
@@ -69,5 +72,3 @@ fn mvmul_big_test() {
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
use pimcore::cpu::CPU;
|
||||
|
||||
pub fn empty_cpu(num_cores: usize) -> CPU<'static> {
|
||||
CPU::new(num_cores, vec![Vec::new(); num_cores + 1])
|
||||
}
|
||||
@@ -1,51 +1,103 @@
|
||||
use std::{fs, io::BufReader, path::Path};
|
||||
use std::{
|
||||
fs::{self, File},
|
||||
io::BufReader,
|
||||
path::{Path, PathBuf},
|
||||
};
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use pimcore::json_to_instruction::json_to_executor;
|
||||
use pimcore::{
|
||||
cpu::crossbar::Crossbar,
|
||||
json_to_instruction::json_to_executor,
|
||||
memory_manager::CoreMemory,
|
||||
};
|
||||
use serde_json::Value;
|
||||
|
||||
fn collect_json_from_subfolders<P: AsRef<Path>>(root: P) -> Result<Vec<(Value, Vec<Value>)>> {
|
||||
fn collect_examples<P: AsRef<Path>>(root: P) -> Result<Vec<PathBuf>> {
|
||||
let mut result = Vec::new();
|
||||
for entry in fs::read_dir(root)? {
|
||||
let entry = entry.context("Root not found")?;
|
||||
let path = entry.path();
|
||||
if path.is_dir() {
|
||||
let mut cores = Vec::new();
|
||||
let mut config: Option<Value> = None;
|
||||
for sub_entry in fs::read_dir(&path)
|
||||
.with_context(|| format!("File {} not readable", path.display()))?
|
||||
{
|
||||
let sub_entry =
|
||||
sub_entry.with_context(|| format!("File {} not readable", path.display()))?;
|
||||
let sub_path = sub_entry.path();
|
||||
if sub_path.is_file()
|
||||
&& sub_path.extension().and_then(|s| s.to_str()) == Some("json")
|
||||
{
|
||||
let file = fs::File::open(&sub_path)
|
||||
.with_context(|| format!("Subpath {} not opened", sub_path.display()))?;
|
||||
let reader = BufReader::new(file);
|
||||
let val: Value = serde_json::from_reader(reader).with_context(|| format!(
|
||||
"Serde reader fail for subpath {}",
|
||||
sub_path.display()
|
||||
))?;
|
||||
if sub_path.file_name().unwrap() == "config.json" {
|
||||
config = Some(val);
|
||||
} else {
|
||||
cores.push(val);
|
||||
}
|
||||
}
|
||||
}
|
||||
result.push((config.unwrap(), cores));
|
||||
result.push(path);
|
||||
}
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn core_sort_key(path: &Path) -> i32 {
|
||||
let stem = path.file_stem().unwrap().to_str().unwrap();
|
||||
stem[5..].parse::<i32>().unwrap()
|
||||
}
|
||||
|
||||
fn crossbar_sort_key(path: &Path) -> i32 {
|
||||
let stem = path.file_stem().unwrap().to_str().unwrap();
|
||||
stem[9..].parse::<i32>().unwrap()
|
||||
}
|
||||
|
||||
fn load_crossbars(folder: &Path, config: &Value) -> Result<Vec<Vec<Crossbar>>> {
|
||||
let xbar_size = config.get("xbar_size").unwrap().as_array().unwrap();
|
||||
let rows = xbar_size[0].as_i64().unwrap() as usize;
|
||||
let cols = xbar_size[1].as_i64().unwrap() as usize;
|
||||
let core_cnt = config.get("core_cnt").unwrap().as_i64().unwrap() as usize;
|
||||
let mut owned_crossbars = Vec::with_capacity(core_cnt + 1);
|
||||
owned_crossbars.push(Vec::new());
|
||||
|
||||
for core_idx in 0..core_cnt {
|
||||
let core_folder = folder.join(format!("core_{core_idx}"));
|
||||
let mut core_crossbars = Vec::new();
|
||||
if core_folder.is_dir() {
|
||||
let mut paths: Vec<_> = fs::read_dir(&core_folder)?
|
||||
.map(|entry| entry.map(|entry| entry.path()))
|
||||
.collect::<std::io::Result<Vec<_>>>()?;
|
||||
paths.sort_by_cached_key(|path| crossbar_sort_key(path));
|
||||
for path in paths {
|
||||
if path.extension().and_then(|ext| ext.to_str()) != Some("bin") {
|
||||
continue;
|
||||
}
|
||||
let bytes = fs::read(&path)
|
||||
.with_context(|| format!("failed to read crossbar {}", path.display()))?;
|
||||
let mut crossbar = Crossbar::new(cols * 4, rows, CoreMemory::new());
|
||||
crossbar.execute_store(&bytes)?;
|
||||
core_crossbars.push(crossbar);
|
||||
}
|
||||
}
|
||||
owned_crossbars.push(core_crossbars);
|
||||
}
|
||||
Ok(owned_crossbars)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn json_folder_tester() {
|
||||
let examples = collect_json_from_subfolders("data").unwrap();
|
||||
for example in examples {
|
||||
let (config, cores) = example;
|
||||
json_to_executor::json_to_executor(config, cores.iter()).execute();
|
||||
let examples = collect_examples("tests/data").unwrap();
|
||||
for folder in examples {
|
||||
let config_path = folder.join("config.json");
|
||||
let config_file = File::open(&config_path).unwrap();
|
||||
let config: Value = serde_json::from_reader(BufReader::new(config_file)).unwrap();
|
||||
|
||||
let core_cnt = config.get("core_cnt").unwrap().as_i64().unwrap() as usize;
|
||||
let mut core_paths: Vec<_> = fs::read_dir(&folder)
|
||||
.unwrap()
|
||||
.map(|entry| entry.unwrap().path())
|
||||
.filter(|path| path.extension().and_then(|ext| ext.to_str()) == Some("json"))
|
||||
.filter(|path| path.file_name().unwrap() != "config.json")
|
||||
.collect();
|
||||
core_paths.sort_by_cached_key(|path| core_sort_key(path));
|
||||
assert_eq!(core_paths.len(), core_cnt);
|
||||
|
||||
let mut core_readers: Vec<_> = core_paths
|
||||
.into_iter()
|
||||
.map(|path| BufReader::new(File::open(path).unwrap()))
|
||||
.collect();
|
||||
|
||||
let owned_crossbars = load_crossbars(&folder, &config).unwrap();
|
||||
let crossbars = owned_crossbars
|
||||
.iter()
|
||||
.map(|core_crossbars| core_crossbars.iter().collect())
|
||||
.collect();
|
||||
|
||||
let mut executable = json_to_executor::json_to_executor(config, &mut core_readers, crossbars);
|
||||
let memory = fs::read(folder.join("memory.bin")).unwrap();
|
||||
executable.cpu_mut().host().execute_store(0, &memory).unwrap();
|
||||
executable.execute();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,11 +1,17 @@
|
||||
mod common;
|
||||
|
||||
use pimcore::{Executable, cpu::CPU, instruction_set::{InstructionType, InstructionsBuilder, instruction_data::InstructionDataBuilder, isa::*}};
|
||||
use pimcore::{
|
||||
Executable,
|
||||
instruction_set::{
|
||||
InstructionType, InstructionsBuilder, instruction_data::InstructionDataBuilder, isa::*,
|
||||
},
|
||||
};
|
||||
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Function not found for the requested size") ]
|
||||
fn wrong_size_place_holder() {
|
||||
let cpu = CPU::new(0);
|
||||
let cpu = common::empty_cpu(0);
|
||||
let mut inst_builder = InstructionsBuilder::new();
|
||||
let mut idata_build = InstructionDataBuilder::new();
|
||||
idata_build.set_core_indx(0).fix_core_indx();
|
||||
@@ -30,7 +36,7 @@ fn wrong_size_place_holder() {
|
||||
|
||||
|
||||
fn place_holder(inst : InstructionType) {
|
||||
let mut cpu = CPU::new(0);
|
||||
let mut cpu = common::empty_cpu(0);
|
||||
let mut idata_build = InstructionDataBuilder::new();
|
||||
idata_build.set_core_indx(0).fix_core_indx();
|
||||
inst(&mut cpu, idata_build.build()).unwrap();
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
mod common;
|
||||
|
||||
use pimcore::{
|
||||
Executable,
|
||||
cpu::CPU,
|
||||
cpu::crossbar::Crossbar,
|
||||
instruction_set::{InstructionsBuilder, instruction_data::InstructionDataBuilder, isa::*},
|
||||
memory_manager::{MemoryStorable, type_traits::UpcastDestTraits},
|
||||
memory_manager::{CoreMemory, MemoryStorable, type_traits::UpcastDestTraits},
|
||||
};
|
||||
|
||||
/// VVADD Test
|
||||
@@ -11,7 +13,7 @@ where
|
||||
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
||||
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
|
||||
{
|
||||
let mut cpu = CPU::new(0);
|
||||
let mut cpu = common::empty_cpu(0);
|
||||
let buff: [F; _] = [
|
||||
1.0.into(),
|
||||
2.0.into(),
|
||||
@@ -115,7 +117,7 @@ where
|
||||
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
||||
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
|
||||
{
|
||||
let mut cpu = CPU::new(0);
|
||||
let mut cpu = common::empty_cpu(0);
|
||||
let buff: [F; _] = [
|
||||
1.0.into(),
|
||||
2.0.into(),
|
||||
@@ -219,7 +221,7 @@ where
|
||||
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
||||
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
|
||||
{
|
||||
let mut cpu = CPU::new(0);
|
||||
let mut cpu = common::empty_cpu(0);
|
||||
let buff: [F; _] = [
|
||||
1.0.into(),
|
||||
2.0.into(),
|
||||
@@ -323,7 +325,7 @@ where
|
||||
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
||||
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
|
||||
{
|
||||
let mut cpu = CPU::new(0);
|
||||
let mut cpu = common::empty_cpu(0);
|
||||
let buff: [F; _] = [
|
||||
1.0.into(),
|
||||
2.0.into(),
|
||||
@@ -420,7 +422,7 @@ where
|
||||
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
||||
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
|
||||
{
|
||||
let mut cpu = CPU::new(0);
|
||||
let mut cpu = common::empty_cpu(0);
|
||||
let buff: [F; _] = [
|
||||
9.0.into(),
|
||||
2.0.into(),
|
||||
@@ -524,7 +526,7 @@ where
|
||||
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
||||
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
|
||||
{
|
||||
let mut cpu = CPU::new(0);
|
||||
let mut cpu = common::empty_cpu(0);
|
||||
let buff: [F; _] = [
|
||||
9.0.into(),
|
||||
2.0.into(),
|
||||
@@ -562,6 +564,7 @@ where
|
||||
vavg,
|
||||
idata_build
|
||||
.set_rdr1r2(3, 1, 1)
|
||||
.set_offset_select(1)
|
||||
.set_imm_len(8 * size_of::<F>() as i32)
|
||||
.build(),
|
||||
);
|
||||
@@ -617,7 +620,7 @@ where
|
||||
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
||||
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
|
||||
{
|
||||
let mut cpu = CPU::new(0);
|
||||
let mut cpu = common::empty_cpu(0);
|
||||
let buff: [F; _] = [
|
||||
(-9.0).into(),
|
||||
2.0.into(),
|
||||
@@ -717,7 +720,7 @@ where
|
||||
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
||||
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable + UpcastDestTraits<T>,
|
||||
{
|
||||
let mut cpu = CPU::new(0);
|
||||
let mut cpu = common::empty_cpu(0);
|
||||
let buff: [F; _] = [
|
||||
0.1.into(),
|
||||
0.2.into(),
|
||||
@@ -819,7 +822,7 @@ where
|
||||
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
||||
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable + UpcastDestTraits<T>,
|
||||
{
|
||||
let mut cpu = CPU::new(0);
|
||||
let mut cpu = common::empty_cpu(0);
|
||||
let buff: [F; _] = [
|
||||
0.1.into(),
|
||||
0.2.into(),
|
||||
@@ -923,9 +926,6 @@ where
|
||||
M: From<f32> + std::fmt::Debug + PartialEq<M> + MemoryStorable,
|
||||
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable + UpcastDestTraits<T>,
|
||||
{
|
||||
let mut cpu = CPU::new(0);
|
||||
cpu.reserve_crossbar(1, 4 * size_of::<M>(), 4);
|
||||
let (memory, crossbars) = cpu.host().get_memory_crossbar();
|
||||
let matrix: [M; _] = [
|
||||
1.0.into(),
|
||||
2.0.into(),
|
||||
@@ -944,7 +944,10 @@ where
|
||||
15.0.into(),
|
||||
16.0.into(),
|
||||
];
|
||||
crossbars.get_mut(0).unwrap().execute_store( &matrix).unwrap();
|
||||
let mut crossbar = Crossbar::new(4 * size_of::<M>(), 4, CoreMemory::new());
|
||||
crossbar.execute_store(&matrix).unwrap();
|
||||
let mut cpu = pimcore::cpu::CPU::new(0, vec![vec![&crossbar]]);
|
||||
let (memory, _) = cpu.host().get_memory_crossbar();
|
||||
let vector: [F; _] = [
|
||||
1.0.into(),
|
||||
2.0.into(),
|
||||
@@ -1054,5 +1057,3 @@ fn mvmul_test() {
|
||||
mvmul_test_generic::<f64,f64,f64>("mvmul<f64,f64,f64>",1);
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
mod common;
|
||||
|
||||
use pimcore::{
|
||||
Executable, CoreInstructionsBuilder,
|
||||
cpu::CPU,
|
||||
instruction_set::{InstructionsBuilder, instruction_data::InstructionDataBuilder, isa::*},
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn ld_test() {
|
||||
let mut cpu = CPU::new(1);
|
||||
let mut cpu = common::empty_cpu(1);
|
||||
let mut core_instruction_builder = CoreInstructionsBuilder::new(1);
|
||||
let buff: [f32; _] = [
|
||||
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
|
||||
@@ -41,7 +42,7 @@ fn ld_test() {
|
||||
|
||||
#[test]
|
||||
fn st_test() {
|
||||
let mut cpu = CPU::new(1);
|
||||
let mut cpu = common::empty_cpu(1);
|
||||
let mut core_instruction_builder = CoreInstructionsBuilder::new(1);
|
||||
let buff: [f32; _] = [
|
||||
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
|
||||
@@ -76,7 +77,7 @@ fn st_test() {
|
||||
|
||||
#[test]
|
||||
fn lldi_test() {
|
||||
let cpu = CPU::new(1);
|
||||
let cpu = common::empty_cpu(1);
|
||||
let mut core_instruction_builder = CoreInstructionsBuilder::new(1);
|
||||
let mut inst_builder = InstructionsBuilder::new();
|
||||
let mut idata_build = InstructionDataBuilder::new();
|
||||
@@ -106,7 +107,7 @@ fn lldi_test() {
|
||||
|
||||
#[test]
|
||||
fn lmv_test() {
|
||||
let mut cpu = CPU::new(1);
|
||||
let mut cpu = common::empty_cpu(1);
|
||||
let mut core_instruction_builder = CoreInstructionsBuilder::new(1);
|
||||
let buff: [f32; _] = [
|
||||
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
|
||||
@@ -148,7 +149,7 @@ fn lmv_test() {
|
||||
|
||||
#[test]
|
||||
fn simple_send_recv_test() {
|
||||
let mut cpu = CPU::new(2);
|
||||
let mut cpu = common::empty_cpu(2);
|
||||
let mut core_instruction_builder = CoreInstructionsBuilder::new(2);
|
||||
let buff: [f32; _] = [
|
||||
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
|
||||
@@ -207,7 +208,7 @@ fn simple_send_recv_test() {
|
||||
|
||||
#[test]
|
||||
fn multiple_send_recv_test() {
|
||||
let mut cpu = CPU::new(4);
|
||||
let mut cpu = common::empty_cpu(4);
|
||||
let mut core_instruction_builder = CoreInstructionsBuilder::new(4);
|
||||
let buff: [f32; _] = [
|
||||
1.0, 1.0, 1.0, 1.0, 1.0
|
||||
@@ -226,7 +227,7 @@ fn multiple_send_recv_test() {
|
||||
];
|
||||
cpu.core(4).execute_store(0, &buff).unwrap();
|
||||
|
||||
let send_inst = |cpu :&mut CPU, core_instruction_builder: &mut CoreInstructionsBuilder, inst_builder: &mut InstructionsBuilder, from : i32, to : i32| {
|
||||
let send_inst = |inst_builder: &mut InstructionsBuilder, from: i32, to: i32| {
|
||||
let mut idata_build = InstructionDataBuilder::new();
|
||||
idata_build.set_core_indx(from).fix_core_indx();
|
||||
inst_builder.make_inst(sldi, idata_build.set_rdimm(1, from*size_of::<f32>() as i32).build());
|
||||
@@ -240,7 +241,7 @@ fn multiple_send_recv_test() {
|
||||
);
|
||||
};
|
||||
|
||||
let recv_inst = |cpu :&mut CPU, core_instruction_builder: &mut CoreInstructionsBuilder, mut inst_builder: &mut InstructionsBuilder, to : i32, from : i32| {
|
||||
let recv_inst = |inst_builder: &mut InstructionsBuilder, to: i32, from: i32| {
|
||||
let mut idata_build = InstructionDataBuilder::new();
|
||||
idata_build.set_core_indx(to).fix_core_indx();
|
||||
inst_builder.make_inst(sldi, idata_build.set_rdimm(1, from*size_of::<f32>() as i32).build());
|
||||
@@ -257,26 +258,26 @@ fn multiple_send_recv_test() {
|
||||
|
||||
|
||||
// 1 -> 3
|
||||
send_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,1, 3);
|
||||
send_inst(&mut inst_builder, 1, 3);
|
||||
core_instruction_builder.set_core(1, inst_builder.build());
|
||||
|
||||
// 2 -> 3
|
||||
// 2 <- 4
|
||||
send_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,2, 3);
|
||||
recv_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,2, 4);
|
||||
send_inst(&mut inst_builder, 2, 3);
|
||||
recv_inst(&mut inst_builder, 2, 4);
|
||||
core_instruction_builder.set_core(2, inst_builder.build());
|
||||
|
||||
// 3 <- 2
|
||||
// 3 <- 4
|
||||
// 3 <- 1
|
||||
recv_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,3, 2);
|
||||
recv_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,3, 4);
|
||||
recv_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,3, 1);
|
||||
recv_inst(&mut inst_builder, 3, 2);
|
||||
recv_inst(&mut inst_builder, 3, 4);
|
||||
recv_inst(&mut inst_builder, 3, 1);
|
||||
core_instruction_builder.set_core(3, inst_builder.build());
|
||||
// 4 -> 2
|
||||
// 4 -> 3
|
||||
send_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,4, 2);
|
||||
send_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,4, 3);
|
||||
send_inst(&mut inst_builder, 4, 2);
|
||||
send_inst(&mut inst_builder, 4, 3);
|
||||
core_instruction_builder.set_core(4, inst_builder.build());
|
||||
|
||||
let mut executable = Executable::new(cpu, core_instruction_builder.build());
|
||||
|
||||
@@ -110,6 +110,14 @@ llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticVa
|
||||
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) / static_cast<uint64_t>(*rhs));
|
||||
}
|
||||
|
||||
if (auto minOp = mlir::dyn_cast<mlir::arith::MinUIOp>(definingOp)) {
|
||||
auto lhs = resolveIndexValueImpl(minOp.getLhs(), knowledge);
|
||||
auto rhs = resolveIndexValueImpl(minOp.getRhs(), knowledge);
|
||||
if (failed(lhs) || failed(rhs))
|
||||
return mlir::failure();
|
||||
return static_cast<int64_t>(std::min(static_cast<uint64_t>(*lhs), static_cast<uint64_t>(*rhs)));
|
||||
}
|
||||
|
||||
if (auto remOp = mlir::dyn_cast<mlir::arith::RemUIOp>(definingOp)) {
|
||||
auto lhs = resolveIndexValueImpl(remOp.getLhs(), knowledge);
|
||||
auto rhs = resolveIndexValueImpl(remOp.getRhs(), knowledge);
|
||||
|
||||
@@ -12,6 +12,7 @@ bool isCoreStaticAddressOp(mlir::Operation* op) {
|
||||
mlir::arith::SubIOp,
|
||||
mlir::arith::MulIOp,
|
||||
mlir::arith::DivUIOp,
|
||||
mlir::arith::MinUIOp,
|
||||
mlir::arith::RemUIOp,
|
||||
mlir::arith::IndexCastOp,
|
||||
mlir::memref::AllocOp,
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
|
||||
|
||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
@@ -7,10 +7,34 @@
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <system_error>
|
||||
|
||||
namespace onnx_mlir::pim {
|
||||
|
||||
struct CappedDiagnosticReporter {
|
||||
explicit CappedDiagnosticReporter(int64_t maxReportedFailures = 8) : maxReportedFailures(maxReportedFailures) {}
|
||||
|
||||
template <typename EmitFn>
|
||||
void report(mlir::Operation* op, EmitFn&& emit) {
|
||||
numFailures++;
|
||||
if (numFailures <= maxReportedFailures)
|
||||
emit(op);
|
||||
}
|
||||
|
||||
void emitSuppressedSummary(mlir::Operation* op, llvm::StringRef failureDescription) const {
|
||||
if (numFailures > maxReportedFailures)
|
||||
op->emitError() << "suppressed " << (numFailures - maxReportedFailures) << " additional "
|
||||
<< failureDescription;
|
||||
}
|
||||
|
||||
bool hasFailure() const { return numFailures != 0; }
|
||||
|
||||
private:
|
||||
int64_t maxReportedFailures;
|
||||
int64_t numFailures = 0;
|
||||
};
|
||||
|
||||
/// Emits a consistent diagnostic for target paths that require static shapes.
|
||||
mlir::InFlightDiagnostic emitUnsupportedStaticShapeDiagnostic(mlir::Operation* op, llvm::StringRef valueDescription);
|
||||
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
#include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp"
|
||||
|
||||
#include "llvm/Support/Format.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/Support/FileSystemUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
#pragma once
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <fstream>
|
||||
#include <limits>
|
||||
#include <string>
|
||||
|
||||
@@ -70,9 +70,7 @@ inline void writeUint32LE(llvm::raw_ostream& os, uint32_t value) {
|
||||
os.write(bytes.data(), bytes.size());
|
||||
}
|
||||
|
||||
inline void writeInt32LE(llvm::raw_ostream& os, int32_t value) {
|
||||
writeUint32LE(os, static_cast<uint32_t>(value));
|
||||
}
|
||||
inline void writeInt32LE(llvm::raw_ostream& os, int32_t value) { writeUint32LE(os, static_cast<uint32_t>(value)); }
|
||||
|
||||
inline void writeHeader(llvm::raw_ostream& os) {
|
||||
os.write(kMagic, sizeof(kMagic));
|
||||
@@ -186,39 +184,39 @@ inline Opcode opcodeFromString(llvm::StringRef opName) {
|
||||
|
||||
inline llvm::StringRef opcodeToString(Opcode opcode) {
|
||||
switch (opcode) {
|
||||
case Opcode::nop: return "nop";
|
||||
case Opcode::sldi: return "sldi";
|
||||
case Opcode::sld: return "sld";
|
||||
case Opcode::sadd: return "sadd";
|
||||
case Opcode::ssub: return "ssub";
|
||||
case Opcode::smul: return "smul";
|
||||
case Opcode::saddi: return "saddi";
|
||||
case Opcode::smuli: return "smuli";
|
||||
case Opcode::setbw: return "setbw";
|
||||
case Opcode::mvmul: return "mvmul";
|
||||
case Opcode::vvadd: return "vvadd";
|
||||
case Opcode::vvsub: return "vvsub";
|
||||
case Opcode::vvmul: return "vvmul";
|
||||
case Opcode::vvdmul: return "vvdmul";
|
||||
case Opcode::vvmax: return "vvmax";
|
||||
case Opcode::vvsll: return "vvsll";
|
||||
case Opcode::vvsra: return "vvsra";
|
||||
case Opcode::vavg: return "vavg";
|
||||
case Opcode::vrelu: return "vrelu";
|
||||
case Opcode::vtanh: return "vtanh";
|
||||
case Opcode::vsigm: return "vsigm";
|
||||
case Opcode::nop: return "nop";
|
||||
case Opcode::sldi: return "sldi";
|
||||
case Opcode::sld: return "sld";
|
||||
case Opcode::sadd: return "sadd";
|
||||
case Opcode::ssub: return "ssub";
|
||||
case Opcode::smul: return "smul";
|
||||
case Opcode::saddi: return "saddi";
|
||||
case Opcode::smuli: return "smuli";
|
||||
case Opcode::setbw: return "setbw";
|
||||
case Opcode::mvmul: return "mvmul";
|
||||
case Opcode::vvadd: return "vvadd";
|
||||
case Opcode::vvsub: return "vvsub";
|
||||
case Opcode::vvmul: return "vvmul";
|
||||
case Opcode::vvdmul: return "vvdmul";
|
||||
case Opcode::vvmax: return "vvmax";
|
||||
case Opcode::vvsll: return "vvsll";
|
||||
case Opcode::vvsra: return "vvsra";
|
||||
case Opcode::vavg: return "vavg";
|
||||
case Opcode::vrelu: return "vrelu";
|
||||
case Opcode::vtanh: return "vtanh";
|
||||
case Opcode::vsigm: return "vsigm";
|
||||
case Opcode::vsoftmax: return "vsoftmax";
|
||||
case Opcode::vmv: return "vmv";
|
||||
case Opcode::vrsu: return "vrsu";
|
||||
case Opcode::vrsl: return "vrsl";
|
||||
case Opcode::ld: return "ld";
|
||||
case Opcode::st: return "st";
|
||||
case Opcode::lldi: return "lldi";
|
||||
case Opcode::lmv: return "lmv";
|
||||
case Opcode::send: return "send";
|
||||
case Opcode::recv: return "recv";
|
||||
case Opcode::wait: return "wait";
|
||||
case Opcode::sync: return "sync";
|
||||
case Opcode::vmv: return "vmv";
|
||||
case Opcode::vrsu: return "vrsu";
|
||||
case Opcode::vrsl: return "vrsl";
|
||||
case Opcode::ld: return "ld";
|
||||
case Opcode::st: return "st";
|
||||
case Opcode::lldi: return "lldi";
|
||||
case Opcode::lmv: return "lmv";
|
||||
case Opcode::send: return "send";
|
||||
case Opcode::recv: return "recv";
|
||||
case Opcode::wait: return "wait";
|
||||
case Opcode::sync: return "sync";
|
||||
}
|
||||
llvm_unreachable("Unsupported PIM binary opcode");
|
||||
}
|
||||
@@ -235,9 +233,7 @@ inline InstructionRecord makeInstructionRecord(const llvm::json::Object& instruc
|
||||
case Opcode::sldi:
|
||||
case Opcode::saddi:
|
||||
case Opcode::smuli:
|
||||
case Opcode::lldi:
|
||||
record.r2OrImm = getOptionalInt(instruction, "imm");
|
||||
break;
|
||||
case Opcode::lldi: record.r2OrImm = getOptionalInt(instruction, "imm"); break;
|
||||
case Opcode::mvmul:
|
||||
record.r2OrImm = getOptionalInt(instruction, "mbiw");
|
||||
record.generic1 = getOptionalInt(instruction, "relu");
|
||||
@@ -252,9 +248,7 @@ inline InstructionRecord makeInstructionRecord(const llvm::json::Object& instruc
|
||||
record.r2OrImm = getOptionalInt(instruction, "core");
|
||||
record.generic3 = getOptionalInt(instruction, "size");
|
||||
break;
|
||||
default:
|
||||
record.r2OrImm = getOptionalInt(instruction, "rs2");
|
||||
break;
|
||||
default: record.r2OrImm = getOptionalInt(instruction, "rs2"); break;
|
||||
}
|
||||
|
||||
if (record.opcode != Opcode::mvmul && record.opcode != Opcode::setbw) {
|
||||
@@ -371,8 +365,7 @@ inline llvm::json::Object makeInstructionJson(const InstructionRecord& record) {
|
||||
break;
|
||||
case Opcode::wait:
|
||||
case Opcode::sync:
|
||||
case Opcode::nop:
|
||||
break;
|
||||
case Opcode::nop: break;
|
||||
}
|
||||
|
||||
return instruction;
|
||||
|
||||
@@ -367,7 +367,7 @@ void PimCodeGen::emitMemCopyOp(StringRef opName,
|
||||
instruction.generic1 = 0;
|
||||
instruction.generic2 = 0;
|
||||
instruction.generic3 = static_cast<int32_t>(size);
|
||||
(void)sizeFieldName;
|
||||
(void) sizeFieldName;
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
|
||||
#define DEBUG_TYPE "PimCompilerOptions"
|
||||
|
||||
namespace onnx_mlir {
|
||||
@@ -13,6 +15,14 @@ llvm::cl::opt<PimEmissionTargetType> pimEmissionTarget(
|
||||
llvm::cl::init(EmitPimCodegen),
|
||||
llvm::cl::cat(OnnxMlirOptions));
|
||||
|
||||
llvm::cl::opt<PimMergeSchedulerType> pimMergeScheduler(
|
||||
"pim-merge-scheduler",
|
||||
llvm::cl::desc("Scheduler used by the Spatial merge-compute-nodes pass"),
|
||||
llvm::cl::values(clEnumValN(MergeSchedulerPeft, "peft", "Use PEFT scheduling")),
|
||||
llvm::cl::values(clEnumValN(MergeSchedulerDcp, "dcp", "Use the legacy DCP-inspired scheduler")),
|
||||
llvm::cl::init(MergeSchedulerPeft),
|
||||
llvm::cl::cat(OnnxMlirOptions));
|
||||
|
||||
llvm::cl::opt<bool>
|
||||
pimOnlyCodegen("pim-only-codegen",
|
||||
llvm::cl::desc("Only generate code for PIM (assume input is already in bufferized PIM IR)"),
|
||||
@@ -30,19 +40,19 @@ llvm::cl::opt<bool> pimEmitJson("pim-emit-json",
|
||||
llvm::cl::cat(OnnxMlirOptions));
|
||||
|
||||
llvm::cl::opt<size_t>
|
||||
crossbarSize("crossbar-size", llvm::cl::desc("Width and heigth of a single crossbar"), llvm::cl::init(2));
|
||||
crossbarSize("crossbar-size", llvm::cl::desc("Width and height of a single crossbar"), llvm::cl::init(2));
|
||||
|
||||
llvm::cl::opt<size_t>
|
||||
crossbarCountInCore("crossbar-count", llvm::cl::desc("Number of crossbars in each core"), llvm::cl::init(256));
|
||||
|
||||
llvm::cl::opt<long> coresCount("core-count",
|
||||
llvm::cl::desc("Number of cores in the chip. `-1` to use the minimum amount of cores."),
|
||||
llvm::cl::desc("Number of cores in the chip. Required for PIM compilation."),
|
||||
llvm::cl::init(-1));
|
||||
|
||||
llvm::cl::opt<size_t> dcpCriticalWindowSize(
|
||||
"dcp-critical-window-size",
|
||||
llvm::cl::desc("Number of lowest-slack virtual nodes considered by each DCP coarsening iteration. "
|
||||
"Use 0 to run the legacy full-graph DCP analysis."),
|
||||
"Use 0 to run the legacy full-graph DCP analysis. Only used by the DCP scheduler."),
|
||||
llvm::cl::init(4000));
|
||||
|
||||
llvm::cl::opt<bool>
|
||||
@@ -50,4 +60,13 @@ llvm::cl::opt<bool>
|
||||
llvm::cl::desc("Ignore ConcatOp corner case: do not assert and do a simplification"),
|
||||
llvm::cl::init(false));
|
||||
|
||||
bool hasExplicitPimCoreCount() { return coresCount.getNumOccurrences() != 0; }
|
||||
|
||||
void verifyExplicitPimCoreCount() {
|
||||
if (!hasExplicitPimCoreCount())
|
||||
llvm::report_fatal_error("PIM compilation requires an explicit --core-count=<positive integer>");
|
||||
if (coresCount.getValue() <= 0)
|
||||
llvm::report_fatal_error("PIM compilation requires --core-count to be a positive integer");
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -20,8 +20,14 @@ typedef enum {
|
||||
EmitPimCodegen = 3
|
||||
} PimEmissionTargetType;
|
||||
|
||||
typedef enum {
|
||||
MergeSchedulerPeft = 0,
|
||||
MergeSchedulerDcp = 1,
|
||||
} PimMergeSchedulerType;
|
||||
|
||||
extern llvm::cl::OptionCategory OnnxMlirOptions;
|
||||
extern llvm::cl::opt<PimEmissionTargetType> pimEmissionTarget;
|
||||
extern llvm::cl::opt<PimMergeSchedulerType> pimMergeScheduler;
|
||||
|
||||
extern llvm::cl::opt<bool> pimOnlyCodegen;
|
||||
extern llvm::cl::opt<bool> useExperimentalConvImpl;
|
||||
@@ -32,6 +38,9 @@ extern llvm::cl::opt<size_t> crossbarCountInCore;
|
||||
extern llvm::cl::opt<long> coresCount;
|
||||
extern llvm::cl::opt<size_t> dcpCriticalWindowSize;
|
||||
|
||||
bool hasExplicitPimCoreCount();
|
||||
void verifyExplicitPimCoreCount();
|
||||
|
||||
// This option, by default set to false, will ignore an error when resolving a
|
||||
// specific tiles of the operands of a concat. This specific case is when the
|
||||
// wanted tile is generated by two separate operands of the concat. If this is
|
||||
|
||||
@@ -17,6 +17,7 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
|
||||
PassManager& pm,
|
||||
EmissionTargetType& emissionTarget,
|
||||
std::string outputNameNoExt) {
|
||||
verifyExplicitPimCoreCount();
|
||||
|
||||
if (pimOnlyCodegen) {
|
||||
// Skip all the lowering passes and directly generate code for PIM.
|
||||
|
||||
@@ -33,7 +33,7 @@ struct DenseWeightView {
|
||||
};
|
||||
|
||||
FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value weight) {
|
||||
SmallVector<memref::SubViewOp> subviews;
|
||||
SmallVector<Operation*> viewOps;
|
||||
mlir::Value current = weight;
|
||||
memref::GetGlobalOp getGlobalOp;
|
||||
|
||||
@@ -46,7 +46,7 @@ FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value
|
||||
if (auto subview = dyn_cast<memref::SubViewOp>(defOp)) {
|
||||
if (!hasAllStaticSubviewParts(subview))
|
||||
return failure();
|
||||
subviews.push_back(subview);
|
||||
viewOps.push_back(subview);
|
||||
current = subview.getSource();
|
||||
continue;
|
||||
}
|
||||
@@ -54,6 +54,24 @@ FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value
|
||||
current = cast.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto collapse = dyn_cast<memref::CollapseShapeOp>(defOp)) {
|
||||
auto srcType = dyn_cast<MemRefType>(collapse.getSrc().getType());
|
||||
auto resultType = dyn_cast<MemRefType>(collapse.getResult().getType());
|
||||
if (!srcType || !resultType || !srcType.hasStaticShape() || !resultType.hasStaticShape())
|
||||
return failure();
|
||||
viewOps.push_back(collapse);
|
||||
current = collapse.getSrc();
|
||||
continue;
|
||||
}
|
||||
if (auto expand = dyn_cast<memref::ExpandShapeOp>(defOp)) {
|
||||
auto srcType = dyn_cast<MemRefType>(expand.getSrc().getType());
|
||||
auto resultType = dyn_cast<MemRefType>(expand.getResult().getType());
|
||||
if (!srcType || !resultType || !srcType.hasStaticShape() || !resultType.hasStaticShape())
|
||||
return failure();
|
||||
viewOps.push_back(expand);
|
||||
current = expand.getSrc();
|
||||
continue;
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
@@ -70,16 +88,39 @@ FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value
|
||||
view.shape.assign(denseAttr.getType().getShape().begin(), denseAttr.getType().getShape().end());
|
||||
view.strides = computeRowMajorStrides(view.shape);
|
||||
|
||||
for (memref::SubViewOp subview : llvm::reverse(subviews)) {
|
||||
SmallVector<int64_t> nextStrides;
|
||||
nextStrides.reserve(subview.getStaticStrides().size());
|
||||
for (auto [offset, stride, sourceStride] :
|
||||
llvm::zip_equal(subview.getStaticOffsets(), subview.getStaticStrides(), view.strides)) {
|
||||
view.offset += offset * sourceStride;
|
||||
nextStrides.push_back(stride * sourceStride);
|
||||
for (Operation* viewOp : llvm::reverse(viewOps)) {
|
||||
if (auto subview = dyn_cast<memref::SubViewOp>(viewOp)) {
|
||||
SmallVector<int64_t> nextStrides;
|
||||
nextStrides.reserve(subview.getStaticStrides().size());
|
||||
for (auto [offset, stride, sourceStride] :
|
||||
llvm::zip_equal(subview.getStaticOffsets(), subview.getStaticStrides(), view.strides)) {
|
||||
view.offset += offset * sourceStride;
|
||||
nextStrides.push_back(stride * sourceStride);
|
||||
}
|
||||
view.shape.assign(subview.getStaticSizes().begin(), subview.getStaticSizes().end());
|
||||
view.strides = std::move(nextStrides);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Collapse/expand are accepted only as contiguous static reshapes of a
|
||||
// dense global view, so a row-major stride recomputation preserves layout.
|
||||
if (auto collapse = dyn_cast<memref::CollapseShapeOp>(viewOp)) {
|
||||
if (view.strides != computeRowMajorStrides(view.shape))
|
||||
return failure();
|
||||
auto resultType = cast<MemRefType>(collapse.getResult().getType());
|
||||
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
|
||||
view.strides = computeRowMajorStrides(view.shape);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto expand = dyn_cast<memref::ExpandShapeOp>(viewOp)) {
|
||||
if (view.strides != computeRowMajorStrides(view.shape))
|
||||
return failure();
|
||||
auto resultType = cast<MemRefType>(expand.getResult().getType());
|
||||
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
|
||||
view.strides = computeRowMajorStrides(view.shape);
|
||||
continue;
|
||||
}
|
||||
view.shape.assign(subview.getStaticSizes().begin(), subview.getStaticSizes().end());
|
||||
view.strides = std::move(nextStrides);
|
||||
}
|
||||
|
||||
return view;
|
||||
|
||||
@@ -100,18 +100,27 @@ DenseMap<HSliceId, DenseMap<CoreId, SmallVector<Value>>> tileMatrix(
|
||||
return tiles;
|
||||
}
|
||||
|
||||
tensor::SplatOp
|
||||
broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
Value broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
auto oldType = cast<RankedTensorType>(scalarToBroadcast.getType());
|
||||
Type elementType = oldType.getElementType();
|
||||
int64_t shape[2] = {1, length};
|
||||
Type type = oldType.cloneWith(ArrayRef(shape), elementType);
|
||||
|
||||
auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0).getResult();
|
||||
SmallVector<Value> index(oldType.getRank(), zero);
|
||||
auto elementValue = tensor::ExtractOp::create(rewriter, loc, scalarToBroadcast, index).getResult();
|
||||
auto buildBroadcast = [&](Value input) -> Value {
|
||||
auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0).getResult();
|
||||
SmallVector<Value> index(oldType.getRank(), zero);
|
||||
auto elementValue = tensor::ExtractOp::create(rewriter, loc, input, index).getResult();
|
||||
return tensor::SplatOp::create(rewriter, loc, type, elementValue);
|
||||
};
|
||||
|
||||
return tensor::SplatOp::create(rewriter, loc, type, elementValue);
|
||||
if (isHostFoldableValue(scalarToBroadcast))
|
||||
return buildBroadcast(scalarToBroadcast);
|
||||
|
||||
auto broadcastCompute =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {type}, {}, ValueRange {scalarToBroadcast}, [&](Value input) {
|
||||
spatial::SpatYieldOp::create(rewriter, loc, buildBroadcast(input));
|
||||
});
|
||||
return broadcastCompute.getResult(0);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -136,9 +136,9 @@ tileMatrix(mlir::Value& matrixToTile,
|
||||
mlir::ConversionPatternRewriter& rewriter,
|
||||
mlir::Location& loc);
|
||||
|
||||
mlir::tensor::SplatOp broadcastToVector(mlir::Value scalarToBroadcast,
|
||||
int64_t length,
|
||||
mlir::ConversionPatternRewriter& rewriter,
|
||||
mlir::Location loc);
|
||||
mlir::Value broadcastToVector(mlir::Value scalarToBroadcast,
|
||||
int64_t length,
|
||||
mlir::ConversionPatternRewriter& rewriter,
|
||||
mlir::Location loc);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallBitVector.h"
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
@@ -18,6 +22,11 @@ static bool hasStaticUnitStrides(tensor::ExtractSliceOp extractSliceOp) {
|
||||
return llvm::all_of(extractSliceOp.getStaticStrides(), [](int64_t stride) { return stride == 1; });
|
||||
}
|
||||
|
||||
static bool hasConstantIndices(tensor::ExtractOp extractOp) {
|
||||
return llvm::all_of(extractOp.getIndices(),
|
||||
[](Value index) { return isa_and_nonnull<arith::ConstantIndexOp>(index.getDefiningOp()); });
|
||||
}
|
||||
|
||||
static bool isStaticTensorResult(Operation* op) {
|
||||
return llvm::all_of(op->getResultTypes(), [](Type type) {
|
||||
auto shapedType = dyn_cast<ShapedType>(type);
|
||||
@@ -25,6 +34,167 @@ static bool isStaticTensorResult(Operation* op) {
|
||||
});
|
||||
}
|
||||
|
||||
static SmallVector<int64_t> computeRowMajorStrides(ArrayRef<int64_t> shape) {
|
||||
SmallVector<int64_t> strides(shape.size(), 1);
|
||||
for (int64_t dim = static_cast<int64_t>(shape.size()) - 2; dim >= 0; --dim)
|
||||
strides[dim] = strides[dim + 1] * shape[dim + 1];
|
||||
return strides;
|
||||
}
|
||||
|
||||
static FailureOr<DenseElementsAttr> transposeDenseElements(DenseElementsAttr denseAttr, ArrayRef<int64_t> perms) {
|
||||
auto tensorType = dyn_cast<RankedTensorType>(denseAttr.getType());
|
||||
if (!tensorType)
|
||||
return failure();
|
||||
|
||||
int64_t rank = tensorType.getRank();
|
||||
if (static_cast<int64_t>(perms.size()) != rank)
|
||||
return failure();
|
||||
|
||||
llvm::SmallBitVector seen(rank);
|
||||
SmallVector<int64_t> transposedShape;
|
||||
transposedShape.reserve(rank);
|
||||
for (int64_t perm : perms) {
|
||||
if (perm < 0 || perm >= rank || seen.test(perm))
|
||||
return failure();
|
||||
seen.set(perm);
|
||||
transposedShape.push_back(tensorType.getShape()[perm]);
|
||||
}
|
||||
|
||||
auto transposedType = RankedTensorType::get(transposedShape, tensorType.getElementType(), tensorType.getEncoding());
|
||||
if (denseAttr.isSplat())
|
||||
return DenseElementsAttr::get(transposedType, denseAttr.getSplatValue<Attribute>());
|
||||
|
||||
SmallVector<Attribute> originalValues(denseAttr.getValues<Attribute>());
|
||||
SmallVector<Attribute> transposedValues(originalValues.size());
|
||||
SmallVector<int64_t> originalStrides = computeRowMajorStrides(tensorType.getShape());
|
||||
SmallVector<int64_t> transposedStrides = computeRowMajorStrides(transposedShape);
|
||||
SmallVector<int64_t> originalIndices(rank);
|
||||
|
||||
for (auto [linearIndex, value] : llvm::enumerate(originalValues)) {
|
||||
int64_t remaining = static_cast<int64_t>(linearIndex);
|
||||
for (int64_t dim = 0; dim < rank; ++dim) {
|
||||
originalIndices[dim] = remaining / originalStrides[dim];
|
||||
remaining %= originalStrides[dim];
|
||||
}
|
||||
|
||||
int64_t transposedLinearIndex = 0;
|
||||
for (int64_t dim = 0; dim < rank; ++dim)
|
||||
transposedLinearIndex += originalIndices[perms[dim]] * transposedStrides[dim];
|
||||
|
||||
transposedValues[transposedLinearIndex] = value;
|
||||
}
|
||||
|
||||
return DenseElementsAttr::get(transposedType, transposedValues);
|
||||
}
|
||||
|
||||
static FailureOr<DenseElementsAttr> reshapeDenseElements(DenseElementsAttr denseAttr, RankedTensorType resultType) {
|
||||
auto sourceType = dyn_cast<RankedTensorType>(denseAttr.getType());
|
||||
if (!sourceType || !resultType || sourceType.getNumElements() != resultType.getNumElements())
|
||||
return failure();
|
||||
|
||||
if (denseAttr.isSplat())
|
||||
return DenseElementsAttr::get(resultType, denseAttr.getSplatValue<Attribute>());
|
||||
|
||||
SmallVector<Attribute> values(denseAttr.getValues<Attribute>());
|
||||
return DenseElementsAttr::get(resultType, values);
|
||||
}
|
||||
|
||||
static FailureOr<DenseElementsAttr> extractSliceDenseElements(DenseElementsAttr denseAttr,
|
||||
tensor::ExtractSliceOp extractSliceOp) {
|
||||
auto sourceType = dyn_cast<RankedTensorType>(denseAttr.getType());
|
||||
auto resultType = dyn_cast<RankedTensorType>(extractSliceOp.getType());
|
||||
if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
ArrayRef<int64_t> offsets = extractSliceOp.getStaticOffsets();
|
||||
ArrayRef<int64_t> sizes = extractSliceOp.getStaticSizes();
|
||||
ArrayRef<int64_t> strides = extractSliceOp.getStaticStrides();
|
||||
if (llvm::any_of(offsets, [](int64_t value) { return ShapedType::isDynamic(value); })
|
||||
|| llvm::any_of(sizes, [](int64_t value) { return ShapedType::isDynamic(value); })
|
||||
|| llvm::any_of(strides, [](int64_t stride) { return ShapedType::isDynamic(stride) || stride != 1; }))
|
||||
return failure();
|
||||
|
||||
if (denseAttr.isSplat())
|
||||
return DenseElementsAttr::get(resultType, denseAttr.getSplatValue<Attribute>());
|
||||
|
||||
SmallVector<Attribute> sourceValues(denseAttr.getValues<Attribute>());
|
||||
SmallVector<int64_t> sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
||||
SmallVector<int64_t> resultStrides = computeRowMajorStrides(resultType.getShape());
|
||||
SmallVector<Attribute> resultValues;
|
||||
resultValues.reserve(resultType.getNumElements());
|
||||
|
||||
for (int64_t linearIndex = 0; linearIndex < resultType.getNumElements(); ++linearIndex) {
|
||||
int64_t remaining = linearIndex;
|
||||
int64_t sourceLinearIndex = 0;
|
||||
for (int64_t dim = 0; dim < resultType.getRank(); ++dim) {
|
||||
const int64_t resultIndex = resultStrides.empty() ? 0 : remaining / resultStrides[dim];
|
||||
remaining = resultStrides.empty() ? 0 : remaining % resultStrides[dim];
|
||||
sourceLinearIndex += (offsets[dim] + resultIndex) * sourceStrides[dim];
|
||||
}
|
||||
resultValues.push_back(sourceValues[sourceLinearIndex]);
|
||||
}
|
||||
|
||||
return DenseElementsAttr::get(resultType, resultValues);
|
||||
}
|
||||
|
||||
static DenseElementsAttr getDirectDenseConstantAttr(Value value) {
|
||||
if (auto constantOp = value.getDefiningOp<arith::ConstantOp>())
|
||||
return dyn_cast<DenseElementsAttr>(constantOp.getValue());
|
||||
if (auto constantOp = value.getDefiningOp<ONNXConstantOp>())
|
||||
return dyn_cast_or_null<DenseElementsAttr>(constantOp.getValueAttr());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static DenseElementsAttr getHostFoldableDenseElementsAttrImpl(Value value, llvm::SmallPtrSetImpl<Operation*>& visited) {
|
||||
auto* definingOp = value.getDefiningOp();
|
||||
if (!definingOp || !visited.insert(definingOp).second)
|
||||
return nullptr;
|
||||
|
||||
// Rebuild dense attributes through view-only host-foldable chains so later
|
||||
// lowering stages can still recognize grouped/sliced constants.
|
||||
if (auto denseAttr = getDirectDenseConstantAttr(value))
|
||||
return denseAttr;
|
||||
|
||||
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(definingOp)) {
|
||||
auto inputAttr = getHostFoldableDenseElementsAttrImpl(transposeOp.getData(), visited);
|
||||
if (!inputAttr)
|
||||
return nullptr;
|
||||
|
||||
SmallVector<int64_t> perm;
|
||||
perm.reserve(transposeOp.getPermAttr().size());
|
||||
for (IntegerAttr attr : transposeOp.getPermAttr().getAsRange<IntegerAttr>())
|
||||
perm.push_back(attr.getInt());
|
||||
auto transposedAttr = transposeDenseElements(inputAttr, perm);
|
||||
return succeeded(transposedAttr) ? *transposedAttr : nullptr;
|
||||
}
|
||||
|
||||
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(definingOp)) {
|
||||
auto inputAttr = getHostFoldableDenseElementsAttrImpl(collapseShapeOp.getSrc(), visited);
|
||||
if (!inputAttr)
|
||||
return nullptr;
|
||||
auto reshapedAttr = reshapeDenseElements(inputAttr, cast<RankedTensorType>(collapseShapeOp.getType()));
|
||||
return succeeded(reshapedAttr) ? *reshapedAttr : nullptr;
|
||||
}
|
||||
|
||||
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(definingOp)) {
|
||||
auto inputAttr = getHostFoldableDenseElementsAttrImpl(expandShapeOp.getSrc(), visited);
|
||||
if (!inputAttr)
|
||||
return nullptr;
|
||||
auto reshapedAttr = reshapeDenseElements(inputAttr, cast<RankedTensorType>(expandShapeOp.getType()));
|
||||
return succeeded(reshapedAttr) ? *reshapedAttr : nullptr;
|
||||
}
|
||||
|
||||
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) {
|
||||
auto inputAttr = getHostFoldableDenseElementsAttrImpl(extractSliceOp.getSource(), visited);
|
||||
if (!inputAttr)
|
||||
return nullptr;
|
||||
auto slicedAttr = extractSliceDenseElements(inputAttr, extractSliceOp);
|
||||
return succeeded(slicedAttr) ? *slicedAttr : nullptr;
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static bool isHostFoldableOpImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*>& visited) {
|
||||
if (!op || !visited.insert(op).second)
|
||||
return false;
|
||||
@@ -32,6 +202,9 @@ static bool isHostFoldableOpImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*
|
||||
if (isa<arith::ConstantOp, ONNXConstantOp, ONNXNoneOp>(op))
|
||||
return true;
|
||||
|
||||
if (auto extractOp = dyn_cast<tensor::ExtractOp>(op))
|
||||
return hasConstantIndices(extractOp) && isHostFoldableValue(extractOp.getTensor());
|
||||
|
||||
if (!isStaticTensorResult(op))
|
||||
return false;
|
||||
|
||||
@@ -47,6 +220,9 @@ static bool isHostFoldableOpImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*
|
||||
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op))
|
||||
return hasStaticUnitStrides(extractSliceOp) && isHostFoldableValue(extractSliceOp.getSource());
|
||||
|
||||
if (auto splatOp = dyn_cast<tensor::SplatOp>(op))
|
||||
return isHostFoldableValue(splatOp.getInput());
|
||||
|
||||
if (auto extractRowsOp = dyn_cast<spatial::SpatExtractRowsOp>(op))
|
||||
return isHostFoldableValue(extractRowsOp.getInput());
|
||||
|
||||
@@ -72,4 +248,9 @@ bool isHostFoldableOp(Operation* op) {
|
||||
return isHostFoldableOpImpl(op, visited);
|
||||
}
|
||||
|
||||
DenseElementsAttr getHostFoldableDenseElementsAttr(Value value) {
|
||||
llvm::SmallPtrSet<Operation*, 8> visited;
|
||||
return getHostFoldableDenseElementsAttrImpl(value, visited);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
@@ -9,4 +10,6 @@ bool isHostFoldableValue(mlir::Value value);
|
||||
|
||||
bool isHostFoldableOp(mlir::Operation* op);
|
||||
|
||||
mlir::DenseElementsAttr getHostFoldableDenseElementsAttr(mlir::Value value);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostLegality.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
@@ -11,7 +12,7 @@ using namespace mlir;
|
||||
namespace onnx_mlir {
|
||||
|
||||
LogicalResult verifyONNXToSpatialHostLegality(func::FuncOp funcOp) {
|
||||
bool hasFailure = false;
|
||||
pim::CappedDiagnosticReporter diagnostics;
|
||||
|
||||
for (Operation& op : funcOp.getFunctionBody().front()) {
|
||||
if (isa<func::ReturnOp, spatial::SpatCompute, spatial::SpatComputeBatch>(&op))
|
||||
@@ -19,11 +20,15 @@ LogicalResult verifyONNXToSpatialHostLegality(func::FuncOp funcOp) {
|
||||
if (isHostFoldableOp(&op))
|
||||
continue;
|
||||
|
||||
op.emitOpError("non-foldable top-level runtime op remains after ONNX-to-Spatial; lower it inside spat.compute");
|
||||
hasFailure = true;
|
||||
diagnostics.report(&op, [](Operation* illegalOp) {
|
||||
illegalOp->emitOpError("non-foldable top-level runtime op remains after ONNX-to-Spatial; lower it inside "
|
||||
"spat.compute");
|
||||
});
|
||||
}
|
||||
|
||||
return success(!hasFailure);
|
||||
diagnostics.emitSuppressedSummary(funcOp, "ONNX-to-Spatial host legality failures");
|
||||
|
||||
return success(!diagnostics.hasFailure());
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -5,17 +5,15 @@
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
||||
#include "Common/Common.hpp"
|
||||
#include "Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostLegality.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp"
|
||||
@@ -87,17 +85,68 @@ static void populateEmptyFunction(func::FuncOp funcOp) {
|
||||
returnOp.setOperand(index, computeResult);
|
||||
}
|
||||
|
||||
static void wrapTopLevelRuntimeTransposes(func::FuncOp funcOp) {
|
||||
IRRewriter rewriter(funcOp.getContext());
|
||||
Block& entryBlock = funcOp.getFunctionBody().front();
|
||||
|
||||
for (Operation& op : llvm::make_early_inc_range(entryBlock)) {
|
||||
auto transposeOp = dyn_cast<ONNXTransposeOp>(&op);
|
||||
if (!transposeOp || isHostFoldableOp(transposeOp))
|
||||
continue;
|
||||
|
||||
// Transpose stays globally legal because constant/view-only cases are
|
||||
// allowed on the host. Any residual runtime transpose must be sunk into
|
||||
// spat.compute before the host legality check.
|
||||
auto resultType = transposeOp.getResult().getType();
|
||||
rewriter.setInsertionPoint(transposeOp);
|
||||
auto computeOp = createSpatCompute<1>(
|
||||
rewriter, transposeOp.getLoc(), TypeRange {resultType}, {}, ValueRange {transposeOp.getData()}, [&](Value input) {
|
||||
Value transposed =
|
||||
ONNXTransposeOp::create(rewriter, transposeOp.getLoc(), resultType, input, transposeOp.getPermAttr());
|
||||
spatial::SpatYieldOp::create(rewriter, transposeOp.getLoc(), transposed);
|
||||
});
|
||||
rewriter.replaceOp(transposeOp, computeOp.getResult(0));
|
||||
}
|
||||
}
|
||||
|
||||
void ONNXToSpatialPass::runOnOperation() {
|
||||
ModuleOp moduleOp = getOperation();
|
||||
MLIRContext* ctx = &getContext();
|
||||
|
||||
ConversionTarget preTarget(*ctx);
|
||||
preTarget.addLegalDialect<spatial::SpatialDialect,
|
||||
ONNXDialect,
|
||||
tensor::TensorDialect,
|
||||
arith::ArithDialect,
|
||||
scf::SCFDialect>();
|
||||
preTarget.addIllegalOp<ONNXConstantOp, ONNXFlattenOp>();
|
||||
|
||||
RewritePatternSet prePatterns(ctx);
|
||||
populatePrePatterns(prePatterns, ctx);
|
||||
if (failed(applyPatternsGreedily(moduleOp, std::move(prePatterns))))
|
||||
moduleOp.emitWarning("failed to apply ONNX-to-Spatial pre-patterns; continuing");
|
||||
if (failed(applyPartialConversion(moduleOp, preTarget, std::move(prePatterns)))) {
|
||||
moduleOp.emitError("failed to apply ONNX-to-Spatial pre-rewrites");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
auto entryFunc = getPimEntryFunc(moduleOp);
|
||||
if (failed(entryFunc)) {
|
||||
moduleOp.emitError("failed to locate the PIM entry function during ONNX-to-Spatial lowering");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
RewritePatternSet matmulPatterns(ctx);
|
||||
populateMatMulRewritePatterns(matmulPatterns, ctx);
|
||||
walkAndApplyPatterns(moduleOp, std::move(matmulPatterns));
|
||||
|
||||
bool hasUnloweredMatMul = false;
|
||||
moduleOp.walk([&](ONNXMatMulOp matmulOp) {
|
||||
hasUnloweredMatMul = true;
|
||||
matmulOp.emitOpError("remaining ONNX MatMul before the required ONNX-to-Spatial conversion");
|
||||
});
|
||||
if (hasUnloweredMatMul) {
|
||||
moduleOp.emitError("failed to lower all ONNX MatMul ops before ONNX-to-Spatial conversion");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
@@ -130,31 +179,28 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
RewritePatternSet conversionPatterns(ctx);
|
||||
populateConversionPatterns(conversionPatterns, ctx);
|
||||
if (failed(applyPartialConversion(moduleOp, target, std::move(conversionPatterns)))) {
|
||||
moduleOp.emitError("failed to convert required ONNX ops to Spatial ops");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
ConversionTarget earlyPostTarget(*ctx);
|
||||
earlyPostTarget.addLegalDialect<spatial::SpatialDialect,
|
||||
ONNXDialect,
|
||||
tensor::TensorDialect,
|
||||
arith::ArithDialect,
|
||||
scf::SCFDialect>();
|
||||
earlyPostTarget.addDynamicallyLegalOp<spatial::SpatComputeBatch>(
|
||||
[](spatial::SpatComputeBatch batchOp) { return !requiresEarlyPostRewrite(batchOp); });
|
||||
|
||||
RewritePatternSet earlyPostPatterns(ctx);
|
||||
populateEarlyPostPatterns(earlyPostPatterns, ctx);
|
||||
if (failed(applyPatternsGreedily(*entryFunc, std::move(earlyPostPatterns)))) {
|
||||
if (failed(applyPartialConversion(*entryFunc, earlyPostTarget, std::move(earlyPostPatterns)))) {
|
||||
moduleOp.emitError("failed to normalize single-lane spat.compute_batch ops before core assignment checks");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
if (coresCount != -1) {
|
||||
int computeOpsCount = 0;
|
||||
for (Operation& op : entryFunc->getFunctionBody().front().getOperations())
|
||||
if (isa<spatial::SpatCompute>(op))
|
||||
computeOpsCount++;
|
||||
|
||||
if (computeOpsCount > coresCount) {
|
||||
entryFunc->emitError() << "number of compute ops (" << computeOpsCount << ") exceeds the core count ("
|
||||
<< coresCount << ")";
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
PassManager cleanupPM(ctx);
|
||||
cleanupPM.addPass(createCanonicalizerPass());
|
||||
if (failed(cleanupPM.run(moduleOp)))
|
||||
@@ -162,14 +208,29 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
|
||||
annotateWeightsConstants(*entryFunc);
|
||||
|
||||
ConversionTarget postTarget(*ctx);
|
||||
postTarget.addLegalDialect<spatial::SpatialDialect,
|
||||
ONNXDialect,
|
||||
tensor::TensorDialect,
|
||||
arith::ArithDialect,
|
||||
scf::SCFDialect>();
|
||||
postTarget.addDynamicallyLegalOp<spatial::SpatCompute>(
|
||||
[](spatial::SpatCompute computeOp) { return !requiresPostRewrite(computeOp); });
|
||||
postTarget.addDynamicallyLegalOp<spatial::SpatComputeBatch>(
|
||||
[](spatial::SpatComputeBatch computeOp) { return !requiresPostRewrite(computeOp); });
|
||||
|
||||
RewritePatternSet postPatterns(ctx);
|
||||
populatePostPatterns(postPatterns, ctx);
|
||||
if (failed(applyPatternsGreedily(*entryFunc, std::move(postPatterns)))) {
|
||||
if (failed(applyPartialConversion(*entryFunc, postTarget, std::move(postPatterns)))) {
|
||||
moduleOp.emitError("failed to normalize weight-like Spatial compute operands before Spatial-to-PIM lowering");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
wrapTopLevelRuntimeTransposes(*entryFunc);
|
||||
|
||||
if (failed(verifyONNXToSpatialHostLegality(*entryFunc))) {
|
||||
moduleOp.emitError("ONNX-to-Spatial host legality verification failed");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -27,16 +28,6 @@ struct ConvToGemm : OpConversionPattern<ONNXConvOp> {
|
||||
ConversionPatternRewriter& rewriter) const override;
|
||||
};
|
||||
|
||||
static DenseElementsAttr getDenseConstantAttr(Value value) {
|
||||
if (auto constantOp = value.getDefiningOp<arith::ConstantOp>())
|
||||
return dyn_cast<DenseElementsAttr>(constantOp.getValue());
|
||||
|
||||
if (auto constantOp = value.getDefiningOp<ONNXConstantOp>())
|
||||
return dyn_cast_or_null<DenseElementsAttr>(constantOp.getValueAttr());
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static int64_t getI64FromArrayAttr(ArrayAttr arr, size_t idx) { return cast<IntegerAttr>(arr[idx]).getInt(); }
|
||||
|
||||
static Value expandBiasIfNeeded(Value bias, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
@@ -355,49 +346,22 @@ static Value createCollectedConvOutput(ValueRange gemmRows,
|
||||
return collectComputeOp.getResult(0);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
ONNXConvOpAdaptor convOpAdaptor,
|
||||
ConversionPatternRewriter& rewriter) const {
|
||||
Location loc = convOp.getLoc();
|
||||
Value x = convOpAdaptor.getX();
|
||||
Value w = convOpAdaptor.getW();
|
||||
Value b = convOpAdaptor.getB();
|
||||
|
||||
auto xType = cast<RankedTensorType>(x.getType());
|
||||
auto wType = cast<RankedTensorType>(w.getType());
|
||||
auto outType = cast<RankedTensorType>(convOp.getY().getType());
|
||||
|
||||
if (!xType.hasStaticShape()) {
|
||||
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv input");
|
||||
return failure();
|
||||
}
|
||||
if (!wType.hasStaticShape()) {
|
||||
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv weight");
|
||||
return failure();
|
||||
}
|
||||
if (!outType.hasStaticShape()) {
|
||||
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv result");
|
||||
return failure();
|
||||
}
|
||||
if (xType.getRank() != 4) {
|
||||
pim::emitUnsupportedRankDiagnostic(convOp, "conv input", xType.getRank(), {4});
|
||||
return failure();
|
||||
}
|
||||
if (wType.getRank() != 4) {
|
||||
pim::emitUnsupportedRankDiagnostic(convOp, "conv weight", wType.getRank(), {4});
|
||||
return failure();
|
||||
}
|
||||
if (outType.getRank() != 4) {
|
||||
pim::emitUnsupportedRankDiagnostic(convOp, "conv result", outType.getRank(), {4});
|
||||
return failure();
|
||||
}
|
||||
if (convOp.getGroup() != 1) {
|
||||
convOp.emitOpError("only group=1 convolution is supported for Spatial lowering");
|
||||
return failure();
|
||||
}
|
||||
|
||||
static Value lowerSingleConvGroup(Value x,
|
||||
Value w,
|
||||
Value b,
|
||||
RankedTensorType xType,
|
||||
RankedTensorType wType,
|
||||
RankedTensorType outType,
|
||||
int64_t padHeightBegin,
|
||||
int64_t padHeightEnd,
|
||||
int64_t padWidthBegin,
|
||||
int64_t padWidthEnd,
|
||||
int64_t strideHeight,
|
||||
int64_t strideWidth,
|
||||
int64_t dilationHeight,
|
||||
int64_t dilationWidth,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
const int64_t batchSize = xType.getDimSize(0);
|
||||
const int64_t numChannelsIn = xType.getDimSize(1);
|
||||
const int64_t xHeight = xType.getDimSize(2);
|
||||
@@ -408,71 +372,6 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
const int64_t outHeight = outType.getDimSize(2);
|
||||
const int64_t outWidth = outType.getDimSize(3);
|
||||
|
||||
// Read optional conv attributes (ONNX defaults: stride=1, dilation=1, pad=0)
|
||||
const auto stridesAttr = convOp.getStrides();
|
||||
const auto dilationsAttr = convOp.getDilations();
|
||||
const auto padsAttr = convOp.getPads();
|
||||
|
||||
if (stridesAttr && stridesAttr->size() != 2) {
|
||||
convOp.emitOpError("requires exactly two stride values for Spatial lowering");
|
||||
return failure();
|
||||
}
|
||||
if (dilationsAttr && dilationsAttr->size() != 2) {
|
||||
convOp.emitOpError("requires exactly two dilation values for Spatial lowering");
|
||||
return failure();
|
||||
}
|
||||
if (padsAttr && padsAttr->size() != 4) {
|
||||
convOp.emitOpError("requires exactly four pad values for 2D Spatial lowering");
|
||||
return failure();
|
||||
}
|
||||
|
||||
const int64_t strideHeight = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 0) : 1;
|
||||
const int64_t strideWidth = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 1) : 1;
|
||||
const int64_t dilationHeight = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 0) : 1;
|
||||
const int64_t dilationWidth = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 1) : 1;
|
||||
|
||||
int64_t padHeightBegin = 0;
|
||||
int64_t padHeightEnd = 0;
|
||||
int64_t padWidthBegin = 0;
|
||||
int64_t padWidthEnd = 0;
|
||||
|
||||
if (padsAttr) {
|
||||
padHeightBegin = getI64FromArrayAttr(*padsAttr, 0);
|
||||
padWidthBegin = getI64FromArrayAttr(*padsAttr, 1);
|
||||
padHeightEnd = getI64FromArrayAttr(*padsAttr, 2);
|
||||
padWidthEnd = getI64FromArrayAttr(*padsAttr, 3);
|
||||
}
|
||||
else {
|
||||
// Compute padding from auto_pad attribute
|
||||
const auto autoPad = convOp.getAutoPad();
|
||||
if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
|
||||
const int64_t effectiveKernelH = (wHeight - 1) * dilationHeight + 1;
|
||||
const int64_t effectiveKernelW = (wWidth - 1) * dilationWidth + 1;
|
||||
const int64_t totalPadH =
|
||||
std::max(static_cast<int64_t>(0), (outHeight - 1) * strideHeight + effectiveKernelH - xHeight);
|
||||
const int64_t totalPadW =
|
||||
std::max(static_cast<int64_t>(0), (outWidth - 1) * strideWidth + effectiveKernelW - xWidth);
|
||||
|
||||
if (autoPad == "SAME_UPPER") {
|
||||
padHeightBegin = totalPadH / 2;
|
||||
padHeightEnd = totalPadH - padHeightBegin;
|
||||
padWidthBegin = totalPadW / 2;
|
||||
padWidthEnd = totalPadW - padWidthBegin;
|
||||
}
|
||||
else { // SAME_LOWER
|
||||
padHeightEnd = totalPadH / 2;
|
||||
padHeightBegin = totalPadH - padHeightEnd;
|
||||
padWidthEnd = totalPadW / 2;
|
||||
padWidthBegin = totalPadW - padWidthEnd;
|
||||
}
|
||||
}
|
||||
else if (autoPad != "NOTSET" && autoPad != "VALID") {
|
||||
convOp.emitOpError() << "unsupported auto_pad value `" << autoPad << "` for Spatial lowering";
|
||||
return failure();
|
||||
}
|
||||
// "NOTSET" or "VALID" -> all pads stay 0
|
||||
}
|
||||
|
||||
// im2col layout (flipped with respect to the standard, so filters sit in B = crossbar):
|
||||
// A (im2col): [numPatches, patchSize] -- one row per output spatial position
|
||||
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
|
||||
@@ -492,7 +391,7 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
const int64_t xbarSize = static_cast<int64_t>(crossbarSize.getValue());
|
||||
const int64_t wMaxDim = std::max(patchSize, numChannelsOut);
|
||||
const int64_t maxParallelPixels = std::max<int64_t>(1, xbarSize / wMaxDim);
|
||||
auto wDenseAttr = getDenseConstantAttr(w);
|
||||
auto wDenseAttr = getHostFoldableDenseElementsAttr(w);
|
||||
|
||||
// Prepare weight matrix W for crossbar storage:
|
||||
// W: [numChannelsOut, numChannelsIn, wHeight, wWidth] -> [numChannelsOut, patchSize] -> [patchSize, numChannelsOut]
|
||||
@@ -513,7 +412,7 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
DenseElementsAttr biasDenseAttr;
|
||||
if (hasB) {
|
||||
gemmBias = b;
|
||||
biasDenseAttr = getDenseConstantAttr(b);
|
||||
biasDenseAttr = getHostFoldableDenseElementsAttr(b);
|
||||
biasMatrix = expandBiasIfNeeded(b, rewriter, loc);
|
||||
}
|
||||
const bool canPackWeightsAsConstants = static_cast<bool>(wDenseAttr);
|
||||
@@ -589,17 +488,246 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
rewriter.getBoolAttr(false))
|
||||
.getY();
|
||||
|
||||
rewriter.replaceOp(convOp,
|
||||
createCollectedConvOutput(ValueRange {gemmRows},
|
||||
convOp.getType(),
|
||||
gemmOutType,
|
||||
nhwcType,
|
||||
outType,
|
||||
numPatches,
|
||||
numChannelsOut,
|
||||
effectiveMaxParallelPixels,
|
||||
rewriter,
|
||||
loc));
|
||||
return createCollectedConvOutput(ValueRange {gemmRows},
|
||||
outType,
|
||||
gemmOutType,
|
||||
nhwcType,
|
||||
outType,
|
||||
numPatches,
|
||||
numChannelsOut,
|
||||
effectiveMaxParallelPixels,
|
||||
rewriter,
|
||||
loc);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
ONNXConvOpAdaptor convOpAdaptor,
|
||||
ConversionPatternRewriter& rewriter) const {
|
||||
Location loc = convOp.getLoc();
|
||||
Value x = convOpAdaptor.getX();
|
||||
Value w = convOpAdaptor.getW();
|
||||
Value b = convOpAdaptor.getB();
|
||||
|
||||
auto xType = cast<RankedTensorType>(x.getType());
|
||||
auto wType = cast<RankedTensorType>(w.getType());
|
||||
auto outType = cast<RankedTensorType>(convOp.getY().getType());
|
||||
|
||||
if (!xType.hasStaticShape()) {
|
||||
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv input");
|
||||
return failure();
|
||||
}
|
||||
if (!wType.hasStaticShape()) {
|
||||
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv weight");
|
||||
return failure();
|
||||
}
|
||||
if (!outType.hasStaticShape()) {
|
||||
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv result");
|
||||
return failure();
|
||||
}
|
||||
if (xType.getRank() != 4) {
|
||||
pim::emitUnsupportedRankDiagnostic(convOp, "conv input", xType.getRank(), {4});
|
||||
return failure();
|
||||
}
|
||||
if (wType.getRank() != 4) {
|
||||
pim::emitUnsupportedRankDiagnostic(convOp, "conv weight", wType.getRank(), {4});
|
||||
return failure();
|
||||
}
|
||||
if (outType.getRank() != 4) {
|
||||
pim::emitUnsupportedRankDiagnostic(convOp, "conv result", outType.getRank(), {4});
|
||||
return failure();
|
||||
}
|
||||
if (convOp.getGroup() < 1) {
|
||||
convOp.emitOpError("requires group >= 1 for Spatial lowering");
|
||||
return failure();
|
||||
}
|
||||
|
||||
const int64_t batchSize = xType.getDimSize(0);
|
||||
const int64_t numChannelsIn = xType.getDimSize(1);
|
||||
const int64_t xHeight = xType.getDimSize(2);
|
||||
const int64_t xWidth = xType.getDimSize(3);
|
||||
const int64_t numChannelsOut = wType.getDimSize(0);
|
||||
const int64_t wHeight = wType.getDimSize(2);
|
||||
const int64_t wWidth = wType.getDimSize(3);
|
||||
const int64_t outHeight = outType.getDimSize(2);
|
||||
const int64_t outWidth = outType.getDimSize(3);
|
||||
const int64_t group = convOp.getGroup();
|
||||
const bool hasB = !isa<ONNXNoneOp>(b.getDefiningOp());
|
||||
|
||||
if (numChannelsIn % group != 0) {
|
||||
convOp.emitOpError() << "requires input channels " << numChannelsIn << " to be divisible by group " << group
|
||||
<< " for Spatial lowering";
|
||||
return failure();
|
||||
}
|
||||
if (numChannelsOut % group != 0) {
|
||||
convOp.emitOpError() << "requires output channels " << numChannelsOut << " to be divisible by group " << group
|
||||
<< " for Spatial lowering";
|
||||
return failure();
|
||||
}
|
||||
|
||||
const int64_t numChannelsInPerGroup = numChannelsIn / group;
|
||||
const int64_t numChannelsOutPerGroup = numChannelsOut / group;
|
||||
if (wType.getDimSize(1) != numChannelsInPerGroup) {
|
||||
convOp.emitOpError() << "requires grouped conv weight input channels " << wType.getDimSize(1)
|
||||
<< " to match input channels per group " << numChannelsInPerGroup << " for Spatial lowering";
|
||||
return failure();
|
||||
}
|
||||
if (wType.getDimSize(0) != numChannelsOut) {
|
||||
convOp.emitOpError() << "requires weight output channels " << wType.getDimSize(0) << " to match result channels "
|
||||
<< numChannelsOut << " for Spatial lowering";
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Read optional conv attributes (ONNX defaults: stride=1, dilation=1, pad=0)
|
||||
const auto stridesAttr = convOp.getStrides();
|
||||
const auto dilationsAttr = convOp.getDilations();
|
||||
const auto padsAttr = convOp.getPads();
|
||||
|
||||
if (stridesAttr && stridesAttr->size() != 2) {
|
||||
convOp.emitOpError("requires exactly two stride values for Spatial lowering");
|
||||
return failure();
|
||||
}
|
||||
if (dilationsAttr && dilationsAttr->size() != 2) {
|
||||
convOp.emitOpError("requires exactly two dilation values for Spatial lowering");
|
||||
return failure();
|
||||
}
|
||||
if (padsAttr && padsAttr->size() != 4) {
|
||||
convOp.emitOpError("requires exactly four pad values for 2D Spatial lowering");
|
||||
return failure();
|
||||
}
|
||||
|
||||
const int64_t strideHeight = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 0) : 1;
|
||||
const int64_t strideWidth = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 1) : 1;
|
||||
const int64_t dilationHeight = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 0) : 1;
|
||||
const int64_t dilationWidth = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 1) : 1;
|
||||
|
||||
int64_t padHeightBegin = 0;
|
||||
int64_t padHeightEnd = 0;
|
||||
int64_t padWidthBegin = 0;
|
||||
int64_t padWidthEnd = 0;
|
||||
|
||||
if (padsAttr) {
|
||||
padHeightBegin = getI64FromArrayAttr(*padsAttr, 0);
|
||||
padWidthBegin = getI64FromArrayAttr(*padsAttr, 1);
|
||||
padHeightEnd = getI64FromArrayAttr(*padsAttr, 2);
|
||||
padWidthEnd = getI64FromArrayAttr(*padsAttr, 3);
|
||||
}
|
||||
else {
|
||||
// Compute padding from auto_pad attribute
|
||||
const auto autoPad = convOp.getAutoPad();
|
||||
if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
|
||||
const int64_t effectiveKernelH = (wHeight - 1) * dilationHeight + 1;
|
||||
const int64_t effectiveKernelW = (wWidth - 1) * dilationWidth + 1;
|
||||
const int64_t totalPadH =
|
||||
std::max(static_cast<int64_t>(0), (outHeight - 1) * strideHeight + effectiveKernelH - xHeight);
|
||||
const int64_t totalPadW =
|
||||
std::max(static_cast<int64_t>(0), (outWidth - 1) * strideWidth + effectiveKernelW - xWidth);
|
||||
|
||||
if (autoPad == "SAME_UPPER") {
|
||||
padHeightBegin = totalPadH / 2;
|
||||
padHeightEnd = totalPadH - padHeightBegin;
|
||||
padWidthBegin = totalPadW / 2;
|
||||
padWidthEnd = totalPadW - padWidthBegin;
|
||||
}
|
||||
else { // SAME_LOWER
|
||||
padHeightEnd = totalPadH / 2;
|
||||
padHeightBegin = totalPadH - padHeightEnd;
|
||||
padWidthEnd = totalPadW / 2;
|
||||
padWidthBegin = totalPadW - padWidthEnd;
|
||||
}
|
||||
}
|
||||
else if (autoPad != "NOTSET" && autoPad != "VALID") {
|
||||
convOp.emitOpError() << "unsupported auto_pad value `" << autoPad << "` for Spatial lowering";
|
||||
return failure();
|
||||
}
|
||||
// "NOTSET" or "VALID" -> all pads stay 0
|
||||
}
|
||||
|
||||
if (group == 1) {
|
||||
rewriter.replaceOp(convOp,
|
||||
lowerSingleConvGroup(x,
|
||||
w,
|
||||
b,
|
||||
xType,
|
||||
wType,
|
||||
outType,
|
||||
padHeightBegin,
|
||||
padHeightEnd,
|
||||
padWidthBegin,
|
||||
padWidthEnd,
|
||||
strideHeight,
|
||||
strideWidth,
|
||||
dilationHeight,
|
||||
dilationWidth,
|
||||
rewriter,
|
||||
loc));
|
||||
return success();
|
||||
}
|
||||
|
||||
SmallVector<Value> xSlices = sliceTensor(x, /*axis=*/1, numChannelsInPerGroup, rewriter, loc);
|
||||
SmallVector<Value> wSlices = sliceTensor(w, /*axis=*/0, numChannelsOutPerGroup, rewriter, loc);
|
||||
SmallVector<Value> bSlices;
|
||||
if (hasB) {
|
||||
auto biasType = cast<RankedTensorType>(b.getType());
|
||||
int64_t biasAxis = -1;
|
||||
if (biasType.getRank() == 1)
|
||||
biasAxis = 0;
|
||||
else if (biasType.getRank() == 2)
|
||||
biasAxis = biasType.getDimSize(0) != 1 ? 0 : 1;
|
||||
else {
|
||||
convOp.emitOpError() << "requires rank-1 or rank-2 bias for grouped convolution Spatial lowering, but got rank "
|
||||
<< biasType.getRank();
|
||||
return failure();
|
||||
}
|
||||
bSlices = sliceTensor(b, biasAxis, numChannelsOutPerGroup, rewriter, loc);
|
||||
}
|
||||
|
||||
if (xSlices.size() != static_cast<size_t>(group) || wSlices.size() != static_cast<size_t>(group)
|
||||
|| (hasB && bSlices.size() != static_cast<size_t>(group))) {
|
||||
convOp.emitOpError("failed to partition grouped convolution operands for Spatial lowering");
|
||||
return failure();
|
||||
}
|
||||
|
||||
SmallVector<Value> groupResults;
|
||||
groupResults.reserve(group);
|
||||
auto groupOutType =
|
||||
RankedTensorType::get({batchSize, numChannelsOutPerGroup, outHeight, outWidth}, outType.getElementType());
|
||||
Value noBias = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
|
||||
for (int64_t groupId = 0; groupId < group; groupId++) {
|
||||
Value groupX = xSlices[groupId];
|
||||
Value groupW = wSlices[groupId];
|
||||
Value groupB = hasB ? bSlices[groupId] : noBias;
|
||||
groupResults.push_back(lowerSingleConvGroup(groupX,
|
||||
groupW,
|
||||
groupB,
|
||||
cast<RankedTensorType>(groupX.getType()),
|
||||
cast<RankedTensorType>(groupW.getType()),
|
||||
groupOutType,
|
||||
padHeightBegin,
|
||||
padHeightEnd,
|
||||
padWidthBegin,
|
||||
padWidthEnd,
|
||||
strideHeight,
|
||||
strideWidth,
|
||||
dilationHeight,
|
||||
dilationWidth,
|
||||
rewriter,
|
||||
loc));
|
||||
}
|
||||
|
||||
Value result;
|
||||
if (llvm::all_of(groupResults, isHostFoldableValue)) {
|
||||
result = createSpatConcat(rewriter, loc, /*axis=*/1, groupResults);
|
||||
}
|
||||
else {
|
||||
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {outType}, {}, groupResults, [&](ValueRange args) {
|
||||
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, /*axis=*/1, args));
|
||||
});
|
||||
result = concatCompute.getResult(0);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(convOp, result);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
||||
@@ -502,9 +502,6 @@ LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
}
|
||||
(void) bType;
|
||||
|
||||
if (!isHostFoldableValue(b))
|
||||
return failure();
|
||||
|
||||
Value sharedBias;
|
||||
if (hasC) {
|
||||
auto scaledC = materializeScaledConstantTensor(c, gemmOpAdaptor.getBeta().convertToFloat(), rewriter, loc);
|
||||
|
||||
@@ -2,8 +2,12 @@
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <functional>
|
||||
#include <numeric>
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
@@ -19,6 +23,79 @@ static bool haveStaticPositiveShape(ArrayRef<int64_t> shape) {
|
||||
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
|
||||
}
|
||||
|
||||
static int64_t getStaticShapeElementCount(ArrayRef<int64_t> shape) {
|
||||
return std::accumulate(shape.begin(), shape.end(), int64_t {1}, std::multiplies<int64_t> {});
|
||||
}
|
||||
|
||||
static FailureOr<SmallVector<int64_t>> inferSupportedBatchShape(ArrayRef<int64_t> lhsBatchShape,
|
||||
ArrayRef<int64_t> rhsBatchShape) {
|
||||
if (lhsBatchShape.empty())
|
||||
return SmallVector<int64_t>(rhsBatchShape.begin(), rhsBatchShape.end());
|
||||
if (rhsBatchShape.empty())
|
||||
return SmallVector<int64_t>(lhsBatchShape.begin(), lhsBatchShape.end());
|
||||
if (!llvm::equal(lhsBatchShape, rhsBatchShape))
|
||||
return failure();
|
||||
return SmallVector<int64_t>(lhsBatchShape.begin(), lhsBatchShape.end());
|
||||
}
|
||||
|
||||
static Value collapseBatchDims(Value value,
|
||||
int64_t batchSize,
|
||||
int64_t rows,
|
||||
int64_t cols,
|
||||
PatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
auto type = cast<RankedTensorType>(value.getType());
|
||||
if (type.getRank() == 2 || type.getRank() == 3)
|
||||
return value;
|
||||
|
||||
auto collapsedType =
|
||||
RankedTensorType::get({batchSize, rows, cols}, type.getElementType(), type.getEncoding());
|
||||
SmallVector<ReassociationIndices> reassociation = {
|
||||
ReassociationIndices {},
|
||||
ReassociationIndices {static_cast<int64_t>(type.getRank() - 2)},
|
||||
ReassociationIndices {static_cast<int64_t>(type.getRank() - 1)}
|
||||
};
|
||||
for (int64_t dim = 0; dim < type.getRank() - 2; ++dim)
|
||||
reassociation.front().push_back(dim);
|
||||
|
||||
auto buildCollapsed = [&](Value input) -> Value {
|
||||
return tensor::CollapseShapeOp::create(rewriter, loc, collapsedType, input, reassociation);
|
||||
};
|
||||
|
||||
if (isHostFoldableValue(value))
|
||||
return buildCollapsed(value);
|
||||
|
||||
auto collapseCompute =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {collapsedType}, {}, ValueRange {value}, [&](Value input) {
|
||||
spatial::SpatYieldOp::create(rewriter, loc, buildCollapsed(input));
|
||||
});
|
||||
return collapseCompute.getResult(0);
|
||||
}
|
||||
|
||||
static Value expandBatchDims(Value value,
|
||||
RankedTensorType outputType,
|
||||
size_t batchRank,
|
||||
PatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
if (cast<RankedTensorType>(value.getType()) == outputType)
|
||||
return value;
|
||||
|
||||
SmallVector<ReassociationIndices> reassociation = {
|
||||
ReassociationIndices {},
|
||||
ReassociationIndices {static_cast<int64_t>(batchRank)},
|
||||
ReassociationIndices {static_cast<int64_t>(batchRank + 1)}
|
||||
};
|
||||
for (size_t dim = 0; dim < batchRank; ++dim)
|
||||
reassociation.front().push_back(static_cast<int64_t>(dim));
|
||||
|
||||
auto expandCompute =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {outputType}, {}, ValueRange {value}, [&](Value input) {
|
||||
Value expanded = tensor::ExpandShapeOp::create(rewriter, loc, outputType, input, reassociation);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, expanded);
|
||||
});
|
||||
return expandCompute.getResult(0);
|
||||
}
|
||||
|
||||
static Value extractBatchMatrix(Value value,
|
||||
int64_t batchIndex,
|
||||
int64_t batchSize,
|
||||
@@ -62,13 +139,29 @@ static Value extractBatchMatrix(Value value,
|
||||
static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Location loc) {
|
||||
auto type = cast<RankedTensorType>(value.getType());
|
||||
auto shape = type.getShape();
|
||||
RankedTensorType transposedType;
|
||||
SmallVector<int64_t> perm;
|
||||
if (type.getRank() == 2) {
|
||||
auto transposedType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType());
|
||||
return ONNXTransposeOp::create(rewriter, loc, transposedType, value, rewriter.getI64ArrayAttr({1, 0}));
|
||||
transposedType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType());
|
||||
perm = {1, 0};
|
||||
}
|
||||
else {
|
||||
transposedType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType());
|
||||
perm = {0, 2, 1};
|
||||
}
|
||||
|
||||
auto transposedType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType());
|
||||
return ONNXTransposeOp::create(rewriter, loc, transposedType, value, rewriter.getI64ArrayAttr({0, 2, 1}));
|
||||
auto buildTranspose = [&](Value input) -> Value {
|
||||
return ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm));
|
||||
};
|
||||
|
||||
if (isHostFoldableValue(value))
|
||||
return buildTranspose(value);
|
||||
|
||||
auto transposeCompute =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {transposedType}, {}, ValueRange {value}, [&](Value input) {
|
||||
spatial::SpatYieldOp::create(rewriter, loc, buildTranspose(input));
|
||||
});
|
||||
return transposeCompute.getResult(0);
|
||||
}
|
||||
|
||||
static Value transposeLastTwoDimsInCompute(Value value, PatternRewriter& rewriter, Location loc) {
|
||||
@@ -120,24 +213,25 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||
if (!lhsType || !rhsType || !outType || !lhsType.hasStaticShape() || !rhsType.hasStaticShape()
|
||||
|| !outType.hasStaticShape())
|
||||
return failure();
|
||||
if ((lhsType.getRank() != 2 && lhsType.getRank() != 3) || (rhsType.getRank() != 2 && rhsType.getRank() != 3)
|
||||
|| (outType.getRank() != 2 && outType.getRank() != 3))
|
||||
if (lhsType.getRank() < 2 || rhsType.getRank() < 2 || outType.getRank() < 2)
|
||||
return failure();
|
||||
if (!haveStaticPositiveShape(lhsType.getShape()) || !haveStaticPositiveShape(rhsType.getShape())
|
||||
|| !haveStaticPositiveShape(outType.getShape()))
|
||||
return failure();
|
||||
|
||||
const int64_t lhsBatch = lhsType.getRank() == 3 ? lhsType.getDimSize(0) : 1;
|
||||
const int64_t rhsBatch = rhsType.getRank() == 3 ? rhsType.getDimSize(0) : 1;
|
||||
const int64_t batch = std::max(lhsBatch, rhsBatch);
|
||||
|
||||
if ((lhsBatch != 1 && lhsBatch != batch) || (rhsBatch != 1 && rhsBatch != batch))
|
||||
SmallVector<int64_t> lhsBatchShape(lhsType.getShape().begin(), lhsType.getShape().end() - 2);
|
||||
SmallVector<int64_t> rhsBatchShape(rhsType.getShape().begin(), rhsType.getShape().end() - 2);
|
||||
auto batchShape = inferSupportedBatchShape(lhsBatchShape, rhsBatchShape);
|
||||
if (failed(batchShape))
|
||||
return failure();
|
||||
const int64_t lhsBatch = lhsBatchShape.empty() ? 1 : getStaticShapeElementCount(lhsBatchShape);
|
||||
const int64_t rhsBatch = rhsBatchShape.empty() ? 1 : getStaticShapeElementCount(rhsBatchShape);
|
||||
const int64_t batch = batchShape->empty() ? 1 : getStaticShapeElementCount(*batchShape);
|
||||
|
||||
const int64_t m = lhsType.getRank() == 3 ? lhsType.getDimSize(1) : lhsType.getDimSize(0);
|
||||
const int64_t k = lhsType.getRank() == 3 ? lhsType.getDimSize(2) : lhsType.getDimSize(1);
|
||||
const int64_t rhsK = rhsType.getRank() == 3 ? rhsType.getDimSize(1) : rhsType.getDimSize(0);
|
||||
const int64_t n = rhsType.getRank() == 3 ? rhsType.getDimSize(2) : rhsType.getDimSize(1);
|
||||
const int64_t m = lhsType.getDimSize(lhsType.getRank() - 2);
|
||||
const int64_t k = lhsType.getDimSize(lhsType.getRank() - 1);
|
||||
const int64_t rhsK = rhsType.getDimSize(rhsType.getRank() - 2);
|
||||
const int64_t n = rhsType.getDimSize(rhsType.getRank() - 1);
|
||||
if (k != rhsK)
|
||||
return failure();
|
||||
|
||||
@@ -146,15 +240,17 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||
return failure();
|
||||
}
|
||||
else {
|
||||
if (outType.getDimSize(0) != batch || outType.getDimSize(1) != m || outType.getDimSize(2) != n)
|
||||
SmallVector<int64_t> outBatchShape(outType.getShape().begin(), outType.getShape().end() - 2);
|
||||
if (!llvm::equal(outBatchShape, *batchShape) || outType.getDimSize(outType.getRank() - 2) != m
|
||||
|| outType.getDimSize(outType.getRank() - 1) != n)
|
||||
return failure();
|
||||
}
|
||||
|
||||
Location loc = matmulOp.getLoc();
|
||||
bool useTransposedForm = isHostFoldableValue(matmulOp.getA()) && !isHostFoldableValue(matmulOp.getB());
|
||||
|
||||
Value lhs = matmulOp.getA();
|
||||
Value rhs = matmulOp.getB();
|
||||
Value lhs = collapseBatchDims(matmulOp.getA(), lhsBatch, m, k, rewriter, loc);
|
||||
Value rhs = collapseBatchDims(matmulOp.getB(), rhsBatch, k, n, rewriter, loc);
|
||||
int64_t lhsBatchForGemm = lhsBatch;
|
||||
int64_t rhsBatchForGemm = rhsBatch;
|
||||
int64_t gemmM = m;
|
||||
@@ -239,6 +335,7 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||
}
|
||||
|
||||
Value result = concatValues(batchResults, /*axis=*/0, rewriter, loc);
|
||||
result = expandBatchDims(result, outType, batchShape->size(), rewriter, loc);
|
||||
rewriter.replaceOp(matmulOp, result);
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -22,53 +23,83 @@ static SmallVector<int64_t> permuteShape(ArrayRef<int64_t> shape, ArrayRef<int64
|
||||
return permutedShape;
|
||||
}
|
||||
|
||||
static Value createSoftmaxCompute(Value input, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
static Value buildLoopSoftmaxSlice(Value input,
|
||||
Value accumulator,
|
||||
RankedTensorType inputType,
|
||||
ArrayRef<Value> outerIndices,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
int64_t rank = inputType.getRank();
|
||||
SmallVector<int64_t> sliceShape(static_cast<size_t>(rank - 1), 1);
|
||||
sliceShape.push_back(inputType.getDimSize(rank - 1));
|
||||
auto sliceType = RankedTensorType::get(sliceShape, inputType.getElementType(), inputType.getEncoding());
|
||||
|
||||
SmallVector<OpFoldResult> offsets;
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
|
||||
offsets.reserve(rank);
|
||||
sizes.reserve(rank);
|
||||
|
||||
for (Value outerIndex : outerIndices) {
|
||||
offsets.push_back(outerIndex);
|
||||
sizes.push_back(rewriter.getIndexAttr(1));
|
||||
}
|
||||
offsets.push_back(rewriter.getIndexAttr(0));
|
||||
sizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(rank - 1)));
|
||||
|
||||
Value inputSlice = tensor::ExtractSliceOp::create(rewriter, loc, sliceType, input, offsets, sizes, strides);
|
||||
Value softmaxSlice = spatial::SpatSoftmaxOp::create(rewriter, loc, sliceType, inputSlice).getResult();
|
||||
return tensor::InsertSliceOp::create(rewriter, loc, softmaxSlice, accumulator, offsets, sizes, strides);
|
||||
}
|
||||
|
||||
static Value buildLoopSoftmaxNest(Value input,
|
||||
Value accumulator,
|
||||
RankedTensorType inputType,
|
||||
int64_t axis,
|
||||
SmallVectorImpl<Value>& outerIndices,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
if (axis == inputType.getRank() - 1)
|
||||
return buildLoopSoftmaxSlice(input, accumulator, inputType, outerIndices, rewriter, loc);
|
||||
|
||||
Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
|
||||
Value c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
|
||||
Value cUpper = arith::ConstantIndexOp::create(rewriter, loc, inputType.getDimSize(axis));
|
||||
|
||||
auto loop = scf::ForOp::create(rewriter, loc, c0, cUpper, c1, ValueRange {accumulator});
|
||||
rewriter.setInsertionPointToStart(loop.getBody());
|
||||
|
||||
Value loopIndex = loop.getInductionVar();
|
||||
Value loopAccumulator = loop.getRegionIterArgs().front();
|
||||
outerIndices.push_back(loopIndex);
|
||||
Value updatedAccumulator =
|
||||
buildLoopSoftmaxNest(input, loopAccumulator, inputType, axis + 1, outerIndices, rewriter, loc);
|
||||
outerIndices.pop_back();
|
||||
|
||||
scf::YieldOp::create(rewriter, loc, updatedAccumulator);
|
||||
rewriter.setInsertionPointAfter(loop);
|
||||
return loop.getResult(0);
|
||||
}
|
||||
|
||||
static Value createLoopSoftmaxCompute(Value input, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
auto inputType = cast<RankedTensorType>(input.getType());
|
||||
constexpr size_t numInputs = 1;
|
||||
auto computeOp =
|
||||
createSpatCompute<numInputs>(rewriter, loc, TypeRange {inputType}, {}, ValueRange {input}, [&](Value x) {
|
||||
auto softmaxOp = spatial::SpatSoftmaxOp::create(rewriter, loc, inputType, x);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, softmaxOp.getResult());
|
||||
if (inputType.getRank() == 1) {
|
||||
Value softmax = spatial::SpatSoftmaxOp::create(rewriter, loc, inputType, x).getResult();
|
||||
spatial::SpatYieldOp::create(rewriter, loc, softmax);
|
||||
return;
|
||||
}
|
||||
|
||||
Value outputInit = tensor::EmptyOp::create(rewriter, loc, inputType.getShape(), inputType.getElementType());
|
||||
SmallVector<Value> outerIndices;
|
||||
Value result = buildLoopSoftmaxNest(x, outputInit, inputType, /*axis=*/0, outerIndices, rewriter, loc);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, result);
|
||||
});
|
||||
return computeOp.getResult(0);
|
||||
}
|
||||
|
||||
static Value concatValues(ValueRange inputs, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
auto firstType = cast<RankedTensorType>(inputs.front().getType());
|
||||
SmallVector<int64_t> outputShape(firstType.getShape().begin(), firstType.getShape().end());
|
||||
int64_t concatDimSize = 0;
|
||||
for (Value input : inputs)
|
||||
concatDimSize += cast<RankedTensorType>(input.getType()).getDimSize(axis);
|
||||
outputShape[axis] = concatDimSize;
|
||||
auto resultType = RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
|
||||
|
||||
if (llvm::all_of(inputs, isHostFoldableValue))
|
||||
return createSpatConcat(rewriter, loc, axis, inputs);
|
||||
|
||||
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {resultType}, {}, inputs, [&](ValueRange args) {
|
||||
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, axis, args));
|
||||
});
|
||||
return concatCompute.getResult(0);
|
||||
}
|
||||
|
||||
static Value
|
||||
buildSoftmax(Value input, int64_t softmaxAxis, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
auto inputType = cast<RankedTensorType>(input.getType());
|
||||
if (axis == inputType.getRank())
|
||||
return createSoftmaxCompute(input, rewriter, loc);
|
||||
|
||||
if (axis == softmaxAxis)
|
||||
return buildSoftmax(input, softmaxAxis, axis + 1, rewriter, loc);
|
||||
|
||||
SmallVector<Value> slices = sliceTensor(input, axis, /*sliceSize=*/1, rewriter, loc);
|
||||
SmallVector<Value> rebuiltSlices;
|
||||
rebuiltSlices.reserve(slices.size());
|
||||
for (Value slice : slices)
|
||||
rebuiltSlices.push_back(buildSoftmax(slice, softmaxAxis, axis + 1, rewriter, loc));
|
||||
|
||||
return concatValues(rebuiltSlices, axis, rewriter, loc);
|
||||
}
|
||||
|
||||
struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
@@ -86,7 +117,7 @@ struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
|
||||
Value input = adaptor.getInput();
|
||||
Value result;
|
||||
if (axis == inputType.getRank() - 1) {
|
||||
result = buildSoftmax(input, axis, /*axis=*/0, rewriter, softmaxOp.getLoc());
|
||||
result = createLoopSoftmaxCompute(input, rewriter, softmaxOp.getLoc());
|
||||
}
|
||||
else {
|
||||
SmallVector<int64_t> permutation;
|
||||
@@ -109,8 +140,7 @@ struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
|
||||
spatial::SpatYieldOp::create(rewriter, softmaxOp.getLoc(), transposed);
|
||||
});
|
||||
Value transposedInput = preTransposeCompute.getResult(0);
|
||||
Value transposedResult = buildSoftmax(
|
||||
transposedInput, /*softmaxAxis=*/inputType.getRank() - 1, /*axis=*/0, rewriter, softmaxOp.getLoc());
|
||||
Value transposedResult = createLoopSoftmaxCompute(transposedInput, rewriter, softmaxOp.getLoc());
|
||||
auto postTransposeCompute =
|
||||
createSpatCompute<1>(rewriter, softmaxOp.getLoc(), TypeRange {inputType}, {}, transposedResult, [&](Value x) {
|
||||
Value transposed = ONNXTransposeOp::create(
|
||||
|
||||
@@ -80,6 +80,22 @@ static bool inferExpandReassociation(ArrayRef<int64_t> sourceShape,
|
||||
return sourceIdx == sourceShape.size() && resultIdx == resultShape.size();
|
||||
}
|
||||
|
||||
static SmallVector<ReassociationIndices> getCollapseTo1DReassociation(size_t rank) {
|
||||
SmallVector<ReassociationIndices> reassociation(1);
|
||||
reassociation.front().reserve(rank);
|
||||
for (size_t dim = 0; dim < rank; ++dim)
|
||||
reassociation.front().push_back(static_cast<int64_t>(dim));
|
||||
return reassociation;
|
||||
}
|
||||
|
||||
static SmallVector<ReassociationIndices> getExpandFrom1DReassociation(size_t rank) {
|
||||
SmallVector<ReassociationIndices> reassociation(1);
|
||||
reassociation.front().reserve(rank);
|
||||
for (size_t dim = 0; dim < rank; ++dim)
|
||||
reassociation.front().push_back(static_cast<int64_t>(dim));
|
||||
return reassociation;
|
||||
}
|
||||
|
||||
struct Reshape : OpConversionPattern<ONNXReshapeOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
@@ -126,6 +142,23 @@ struct Reshape : OpConversionPattern<ONNXReshapeOp> {
|
||||
return tensor::ExpandShapeOp::create(rewriter, reshapeOp.getLoc(), resultType, data, reassociation);
|
||||
});
|
||||
|
||||
if (sourceType.getNumElements() != resultType.getNumElements())
|
||||
return failure();
|
||||
|
||||
return replaceWithReshape([&](Value data) -> Value {
|
||||
Value reshaped = data;
|
||||
if (sourceType.getRank() != 1) {
|
||||
auto flatType = RankedTensorType::get({sourceType.getNumElements()}, sourceType.getElementType());
|
||||
reshaped = tensor::CollapseShapeOp::create(
|
||||
rewriter, reshapeOp.getLoc(), flatType, reshaped, getCollapseTo1DReassociation(sourceType.getRank()));
|
||||
}
|
||||
if (resultType.getRank() == 1)
|
||||
return reshaped;
|
||||
return tensor::ExpandShapeOp::create(
|
||||
rewriter, reshapeOp.getLoc(), resultType, reshaped, getExpandFrom1DReassociation(resultType.getRank()))
|
||||
.getResult();
|
||||
});
|
||||
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
@@ -15,42 +15,88 @@ using namespace mlir;
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
static Value
|
||||
extractSliceAt(Value input, int64_t axis, int64_t offset, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
auto inputType = cast<RankedTensorType>(input.getType());
|
||||
SmallVector<OpFoldResult> offsets(inputType.getRank(), rewriter.getIndexAttr(0));
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
SmallVector<OpFoldResult> strides(inputType.getRank(), rewriter.getIndexAttr(1));
|
||||
sizes.reserve(inputType.getRank());
|
||||
for (int64_t dim : inputType.getShape())
|
||||
sizes.push_back(rewriter.getIndexAttr(dim));
|
||||
offsets[axis] = rewriter.getIndexAttr(offset);
|
||||
sizes[axis] = rewriter.getIndexAttr(1);
|
||||
return tensor::ExtractSliceOp::create(rewriter, loc, input, offsets, sizes, strides);
|
||||
static Value buildNearestAsymmetricIndex(
|
||||
Value outputIndex, int64_t inputDim, int64_t outputDim, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
Value cInputDim = arith::ConstantIndexOp::create(rewriter, loc, inputDim);
|
||||
Value cOutputDim = arith::ConstantIndexOp::create(rewriter, loc, outputDim);
|
||||
Value cInputDimLast = arith::ConstantIndexOp::create(rewriter, loc, inputDim - 1);
|
||||
Value scaledIndex = arith::MulIOp::create(rewriter, loc, outputIndex, cInputDim);
|
||||
Value inputIndex = arith::DivUIOp::create(rewriter, loc, scaledIndex, cOutputDim);
|
||||
return arith::MinUIOp::create(rewriter, loc, inputIndex, cInputDimLast);
|
||||
}
|
||||
|
||||
static int64_t nearestAsymmetricIndex(int64_t outputIndex, int64_t inputDim, int64_t outputDim) {
|
||||
return std::min<int64_t>((outputIndex * inputDim) / outputDim, inputDim - 1);
|
||||
}
|
||||
static Value buildNearestResizeLoop(Value input,
|
||||
RankedTensorType inputType,
|
||||
RankedTensorType resultType,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
auto elemType = resultType.getElementType();
|
||||
SmallVector<int64_t> unitShape(resultType.getRank(), 1);
|
||||
auto unitTensorType = RankedTensorType::get(unitShape, elemType);
|
||||
|
||||
static Value buildNearestResize(Value input,
|
||||
ArrayRef<int64_t> inputShape,
|
||||
ArrayRef<int64_t> outputShape,
|
||||
int64_t axis,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
if (axis == static_cast<int64_t>(outputShape.size()))
|
||||
return input;
|
||||
SmallVector<OpFoldResult> unitSizes(resultType.getRank(), rewriter.getIndexAttr(1));
|
||||
SmallVector<OpFoldResult> unitStrides(resultType.getRank(), rewriter.getIndexAttr(1));
|
||||
|
||||
SmallVector<Value> slices;
|
||||
slices.reserve(outputShape[axis]);
|
||||
for (int64_t outputIndex = 0; outputIndex < outputShape[axis]; ++outputIndex) {
|
||||
int64_t inputIndex = nearestAsymmetricIndex(outputIndex, inputShape[axis], outputShape[axis]);
|
||||
Value slice = extractSliceAt(input, axis, inputIndex, rewriter, loc);
|
||||
slices.push_back(buildNearestResize(slice, inputShape, outputShape, axis + 1, rewriter, loc));
|
||||
}
|
||||
Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
|
||||
Value c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
|
||||
Value cOutputN = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(0));
|
||||
Value cOutputC = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(1));
|
||||
Value cOutputH = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(2));
|
||||
Value cOutputW = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(3));
|
||||
|
||||
return createSpatConcat(rewriter, loc, axis, slices);
|
||||
Value outputInit = tensor::EmptyOp::create(rewriter, loc, resultType.getShape(), elemType);
|
||||
|
||||
auto batchLoop = scf::ForOp::create(rewriter, loc, c0, cOutputN, c1, ValueRange {outputInit});
|
||||
rewriter.setInsertionPointToStart(batchLoop.getBody());
|
||||
|
||||
Value outputN = batchLoop.getInductionVar();
|
||||
Value outputBatchAcc = batchLoop.getRegionIterArgs().front();
|
||||
Value inputN = buildNearestAsymmetricIndex(outputN, inputType.getDimSize(0), resultType.getDimSize(0), rewriter, loc);
|
||||
|
||||
auto channelLoop = scf::ForOp::create(rewriter, loc, c0, cOutputC, c1, ValueRange {outputBatchAcc});
|
||||
rewriter.setInsertionPointToStart(channelLoop.getBody());
|
||||
|
||||
Value outputC = channelLoop.getInductionVar();
|
||||
Value outputChannelAcc = channelLoop.getRegionIterArgs().front();
|
||||
Value inputC =
|
||||
buildNearestAsymmetricIndex(outputC, inputType.getDimSize(1), resultType.getDimSize(1), rewriter, loc);
|
||||
|
||||
auto heightLoop = scf::ForOp::create(rewriter, loc, c0, cOutputH, c1, ValueRange {outputChannelAcc});
|
||||
rewriter.setInsertionPointToStart(heightLoop.getBody());
|
||||
|
||||
Value outputH = heightLoop.getInductionVar();
|
||||
Value outputHeightAcc = heightLoop.getRegionIterArgs().front();
|
||||
Value inputH =
|
||||
buildNearestAsymmetricIndex(outputH, inputType.getDimSize(2), resultType.getDimSize(2), rewriter, loc);
|
||||
|
||||
auto widthLoop = scf::ForOp::create(rewriter, loc, c0, cOutputW, c1, ValueRange {outputHeightAcc});
|
||||
rewriter.setInsertionPointToStart(widthLoop.getBody());
|
||||
|
||||
Value outputW = widthLoop.getInductionVar();
|
||||
Value outputWidthAcc = widthLoop.getRegionIterArgs().front();
|
||||
Value inputW =
|
||||
buildNearestAsymmetricIndex(outputW, inputType.getDimSize(3), resultType.getDimSize(3), rewriter, loc);
|
||||
|
||||
SmallVector<OpFoldResult> inputOffsets = {inputN, inputC, inputH, inputW};
|
||||
Value inputSlice =
|
||||
tensor::ExtractSliceOp::create(rewriter, loc, unitTensorType, input, inputOffsets, unitSizes, unitStrides);
|
||||
|
||||
SmallVector<OpFoldResult> outputOffsets = {outputN, outputC, outputH, outputW};
|
||||
Value updatedOutput =
|
||||
tensor::InsertSliceOp::create(rewriter, loc, inputSlice, outputWidthAcc, outputOffsets, unitSizes, unitStrides);
|
||||
scf::YieldOp::create(rewriter, loc, updatedOutput);
|
||||
|
||||
rewriter.setInsertionPointAfter(widthLoop);
|
||||
scf::YieldOp::create(rewriter, loc, widthLoop.getResult(0));
|
||||
|
||||
rewriter.setInsertionPointAfter(heightLoop);
|
||||
scf::YieldOp::create(rewriter, loc, heightLoop.getResult(0));
|
||||
|
||||
rewriter.setInsertionPointAfter(channelLoop);
|
||||
scf::YieldOp::create(rewriter, loc, channelLoop.getResult(0));
|
||||
|
||||
rewriter.setInsertionPointAfter(batchLoop);
|
||||
return batchLoop.getResult(0);
|
||||
}
|
||||
|
||||
struct Resize : OpConversionPattern<ONNXResizeOp> {
|
||||
@@ -62,20 +108,22 @@ struct Resize : OpConversionPattern<ONNXResizeOp> {
|
||||
auto inputType = dyn_cast<RankedTensorType>(adaptor.getX().getType());
|
||||
auto resultType = dyn_cast<RankedTensorType>(resizeOp.getY().getType());
|
||||
if (!inputType || !resultType || !inputType.hasStaticShape() || !resultType.hasStaticShape())
|
||||
return failure();
|
||||
return rewriter.notifyMatchFailure(resizeOp, "resize lowering requires static ranked tensor types.");
|
||||
if (inputType.getRank() != 4 || resultType.getRank() != 4)
|
||||
return rewriter.notifyMatchFailure(resizeOp, "resize lowering currently supports only rank-4 NCHW tensors.");
|
||||
|
||||
if (resizeOp.getMode() != "nearest" || resizeOp.getCoordinateTransformationMode() != "asymmetric"
|
||||
|| resizeOp.getNearestMode() != "floor")
|
||||
return failure();
|
||||
return rewriter.notifyMatchFailure(
|
||||
resizeOp, "resize lowering currently supports only nearest + asymmetric + floor.");
|
||||
|
||||
if (llvm::any_of(inputType.getShape(), [](int64_t dim) { return dim <= 0; })
|
||||
|| llvm::any_of(resultType.getShape(), [](int64_t dim) { return dim <= 0; }))
|
||||
return failure();
|
||||
return rewriter.notifyMatchFailure(resizeOp, "resize lowering requires positive static dimensions.");
|
||||
|
||||
auto computeOp =
|
||||
createSpatCompute<1>(rewriter, resizeOp.getLoc(), TypeRange {resultType}, {}, adaptor.getX(), [&](Value x) {
|
||||
Value result =
|
||||
buildNearestResize(x, inputType.getShape(), resultType.getShape(), /*axis=*/0, rewriter, resizeOp.getLoc());
|
||||
Value result = buildNearestResizeLoop(x, inputType, resultType, rewriter, resizeOp.getLoc());
|
||||
spatial::SpatYieldOp::create(rewriter, resizeOp.getLoc(), result);
|
||||
});
|
||||
rewriter.replaceOp(resizeOp, computeOp.getResults());
|
||||
|
||||
@@ -31,6 +31,21 @@ static bool isDirectConstantValue(Value value) {
|
||||
return isa_and_nonnull<arith::ConstantOp, ONNXConstantOp>(value.getDefiningOp());
|
||||
}
|
||||
|
||||
template <typename ComputeOpTy>
|
||||
static bool hasPromotableWeightLikeInputs(ComputeOpTy compute) {
|
||||
Block& block = compute.getBody().front();
|
||||
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||
if (inputIdx >= block.getNumArguments())
|
||||
continue;
|
||||
if (!isWeightLikeComputeOperand(input))
|
||||
continue;
|
||||
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(block.getArgument(inputIdx)))
|
||||
continue;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Collapses one-lane batches so later phases do not carry batch-only structure unnecessarily.
|
||||
struct FoldSingleLaneComputeBatchPattern : OpRewritePattern<spatial::SpatComputeBatch> {
|
||||
using OpRewritePattern<spatial::SpatComputeBatch>::OpRewritePattern;
|
||||
@@ -262,4 +277,10 @@ void annotateWeightsConstants(func::FuncOp funcOp) {
|
||||
});
|
||||
}
|
||||
|
||||
bool requiresEarlyPostRewrite(spatial::SpatComputeBatch batchOp) { return batchOp.getLaneCount() == 1; }
|
||||
|
||||
bool requiresPostRewrite(spatial::SpatCompute computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
|
||||
|
||||
bool requiresPostRewrite(spatial::SpatComputeBatch computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -3,8 +3,16 @@
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
bool requiresEarlyPostRewrite(spatial::SpatComputeBatch batchOp);
|
||||
|
||||
bool requiresPostRewrite(spatial::SpatCompute computeOp);
|
||||
|
||||
bool requiresPostRewrite(spatial::SpatComputeBatch computeOp);
|
||||
|
||||
void populateEarlyPostPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populatePostPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
@@ -17,9 +17,7 @@ void populatePrePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* c
|
||||
patterns.add<convAddToConvWithBiasLeft>(ctx);
|
||||
patterns.add<convAddToConvWithBiasRight>(ctx);
|
||||
patterns.add<matMulAddToGemm>(ctx);
|
||||
patterns.add<matMulToGemm>(ctx);
|
||||
patterns.add<removeFlattenSameShape>(ctx);
|
||||
populateMatMulRewritePatterns(patterns, ctx);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -202,6 +202,7 @@ void SpatialToGraphvizPass::runOnOperation() {
|
||||
|
||||
auto entryFunc = getPimEntryFunc(module);
|
||||
if (failed(entryFunc)) {
|
||||
module.emitError("failed to locate the PIM entry function for Spatial graph visualization");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -138,12 +138,13 @@ static Value padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, V
|
||||
}
|
||||
|
||||
void SpatialToPimPass::runOnOperation() {
|
||||
coreId = 1;
|
||||
coreId = 0;
|
||||
ModuleOp moduleOp = getOperation();
|
||||
MLIRContext* ctx = moduleOp.getContext();
|
||||
|
||||
auto entryFunc = getPimEntryFunc(moduleOp);
|
||||
if (failed(entryFunc)) {
|
||||
moduleOp.emitError("failed to locate the PIM entry function during Spatial-to-PIM lowering");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
@@ -169,26 +170,22 @@ void SpatialToPimPass::runOnOperation() {
|
||||
spatial::SpatChannelSendTensorBatchOp,
|
||||
spatial::SpatExtractRowsOp>();
|
||||
|
||||
{
|
||||
RewritePatternSet patterns(ctx);
|
||||
populateWithGenerated(patterns);
|
||||
|
||||
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
RewritePatternSet initialPatterns(ctx);
|
||||
populateWithGenerated(initialPatterns);
|
||||
if (failed(applyPartialConversion(moduleOp, target, std::move(initialPatterns)))) {
|
||||
moduleOp.emitError("failed to lower required Spatial ops to the initial PIM form");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
{
|
||||
RewritePatternSet patterns(ctx);
|
||||
populateGlobalTensorMaterializationPatterns(patterns);
|
||||
|
||||
walkAndApplyPatterns(moduleOp, std::move(patterns));
|
||||
}
|
||||
RewritePatternSet globalTensorPatterns(ctx);
|
||||
populateGlobalTensorMaterializationPatterns(globalTensorPatterns);
|
||||
walkAndApplyPatterns(moduleOp, std::move(globalTensorPatterns));
|
||||
auto returnOp = cast<func::ReturnOp>(funcOp.front().getTerminator());
|
||||
|
||||
addReturnOutputBuffers(returnOp, rewriter, outputTensors);
|
||||
if (failed(allocateAndInitializeCoreLocalVariables(funcOp, rewriter))) {
|
||||
funcOp.emitOpError("failed to allocate or initialize core-local tensors during Spatial-to-PIM lowering");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
@@ -197,6 +194,7 @@ void SpatialToPimPass::runOnOperation() {
|
||||
for (auto computeOp : funcOp.getOps<spatial::SpatCompute>()) {
|
||||
markOpToRemove(computeOp);
|
||||
if (failed(lowerComputeOp(computeOp, coreLoweringState, rewriter))) {
|
||||
computeOp.emitOpError("failed to lower spat.compute to pim.core");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
@@ -205,17 +203,16 @@ void SpatialToPimPass::runOnOperation() {
|
||||
for (auto computeBatchOp : funcOp.getOps<spatial::SpatComputeBatch>()) {
|
||||
markOpToRemove(computeBatchOp);
|
||||
if (failed(lowerComputeBatchOp(computeBatchOp, coreLoweringState, rewriter))) {
|
||||
computeBatchOp.emitOpError("failed to lower spat.compute_batch to pim.core_batch");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
RewritePatternSet patterns(ctx);
|
||||
populateTensorPackingPatterns(patterns);
|
||||
walkAndApplyPatterns(funcOp, std::move(patterns));
|
||||
eraseUnusedTensorPackingOps(funcOp, rewriter);
|
||||
}
|
||||
RewritePatternSet initialTensorPackingPatterns(ctx);
|
||||
populateTensorPackingPatterns(initialTensorPackingPatterns);
|
||||
walkAndApplyPatterns(funcOp, std::move(initialTensorPackingPatterns));
|
||||
eraseUnusedTensorPackingOps(funcOp, rewriter);
|
||||
|
||||
SmallVector<spatial::SpatChannelReceiveOp> receiveOps;
|
||||
for (auto op : funcOp.getOps<spatial::SpatChannelReceiveOp>())
|
||||
@@ -229,27 +226,27 @@ void SpatialToPimPass::runOnOperation() {
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
RewritePatternSet coreBodyPatterns(ctx);
|
||||
populateWithGenerated(coreBodyPatterns);
|
||||
FrozenRewritePatternSet frozenCoreBodyPatterns(std::move(coreBodyPatterns));
|
||||
RewritePatternSet coreBodyPatterns(ctx);
|
||||
populateWithGenerated(coreBodyPatterns);
|
||||
FrozenRewritePatternSet frozenCoreBodyPatterns(std::move(coreBodyPatterns));
|
||||
|
||||
SmallVector<pim::PimCoreOp> coreOps;
|
||||
funcOp.walk([&](pim::PimCoreOp coreOp) { coreOps.push_back(coreOp); });
|
||||
for (auto coreOp : coreOps) {
|
||||
if (failed(applyFullConversion(coreOp.getOperation(), target, frozenCoreBodyPatterns))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
SmallVector<pim::PimCoreOp> coreOps;
|
||||
funcOp.walk([&](pim::PimCoreOp coreOp) { coreOps.push_back(coreOp); });
|
||||
for (auto coreOp : coreOps) {
|
||||
if (failed(applyFullConversion(coreOp.getOperation(), target, frozenCoreBodyPatterns))) {
|
||||
coreOp.emitOpError("failed to convert nested Spatial ops inside pim.core");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<pim::PimCoreBatchOp> coreBatchOps;
|
||||
funcOp.walk([&](pim::PimCoreBatchOp coreBatchOp) { coreBatchOps.push_back(coreBatchOp); });
|
||||
for (auto coreBatchOp : coreBatchOps) {
|
||||
if (failed(applyFullConversion(coreBatchOp.getOperation(), target, frozenCoreBodyPatterns))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
SmallVector<pim::PimCoreBatchOp> coreBatchOps;
|
||||
funcOp.walk([&](pim::PimCoreBatchOp coreBatchOp) { coreBatchOps.push_back(coreBatchOp); });
|
||||
for (auto coreBatchOp : coreBatchOps) {
|
||||
if (failed(applyFullConversion(coreBatchOp.getOperation(), target, frozenCoreBodyPatterns))) {
|
||||
coreBatchOp.emitOpError("failed to convert nested Spatial ops inside pim.core_batch");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -259,44 +256,43 @@ void SpatialToPimPass::runOnOperation() {
|
||||
|
||||
SmallVector<Operation*> pendingRemovals(operationsToRemove.begin(), operationsToRemove.end());
|
||||
if (failed(erasePendingOps(pendingRemovals, rewriter))) {
|
||||
funcOp.emitOpError("failed to erase obsolete Spatial ops after lowering to PIM");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
{
|
||||
RewritePatternSet patterns(ctx);
|
||||
populateTensorPackingPatterns(patterns);
|
||||
walkAndApplyPatterns(funcOp, std::move(patterns));
|
||||
eraseUnusedTensorPackingOps(funcOp, rewriter);
|
||||
}
|
||||
RewritePatternSet finalTensorPackingPatterns(ctx);
|
||||
populateTensorPackingPatterns(finalTensorPackingPatterns);
|
||||
walkAndApplyPatterns(funcOp, std::move(finalTensorPackingPatterns));
|
||||
eraseUnusedTensorPackingOps(funcOp, rewriter);
|
||||
|
||||
{
|
||||
ConversionTarget communicationTarget(*ctx);
|
||||
communicationTarget.addLegalDialect<PimDialect,
|
||||
tensor::TensorDialect,
|
||||
arith::ArithDialect,
|
||||
bufferization::BufferizationDialect,
|
||||
func::FuncDialect,
|
||||
memref::MemRefDialect,
|
||||
scf::SCFDialect,
|
||||
BuiltinDialect>();
|
||||
communicationTarget.addLegalOp<ModuleOp>();
|
||||
communicationTarget.addIllegalOp<spatial::SpatConcatOp,
|
||||
spatial::SpatChannelReceiveOp,
|
||||
spatial::SpatChannelReceiveTensorOp,
|
||||
spatial::SpatChannelSendOp,
|
||||
spatial::SpatChannelSendTensorOp,
|
||||
spatial::SpatExtractRowsOp>();
|
||||
ConversionTarget communicationTarget(*ctx);
|
||||
communicationTarget.addLegalDialect<PimDialect,
|
||||
tensor::TensorDialect,
|
||||
arith::ArithDialect,
|
||||
bufferization::BufferizationDialect,
|
||||
func::FuncDialect,
|
||||
memref::MemRefDialect,
|
||||
scf::SCFDialect,
|
||||
BuiltinDialect>();
|
||||
communicationTarget.addLegalOp<ModuleOp>();
|
||||
communicationTarget.addIllegalOp<spatial::SpatConcatOp,
|
||||
spatial::SpatChannelReceiveOp,
|
||||
spatial::SpatChannelReceiveTensorOp,
|
||||
spatial::SpatChannelSendOp,
|
||||
spatial::SpatChannelSendTensorOp,
|
||||
spatial::SpatExtractRowsOp>();
|
||||
|
||||
RewritePatternSet communicationPatterns(ctx);
|
||||
populateChannelLoweringPatterns(communicationPatterns);
|
||||
if (failed(applyFullConversion(funcOp, communicationTarget, std::move(communicationPatterns)))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
RewritePatternSet communicationPatterns(ctx);
|
||||
populateChannelLoweringPatterns(communicationPatterns);
|
||||
if (failed(applyFullConversion(funcOp, communicationTarget, std::move(communicationPatterns)))) {
|
||||
funcOp.emitOpError("failed to lower Spatial communication ops to PIM communication ops");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
if (failed(verifySpatialToPimBoundary(moduleOp))) {
|
||||
moduleOp.emitError("Spatial-to-PIM boundary verification failed");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp"
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
@@ -75,16 +74,14 @@ struct PackSpatialConcatInputsPattern final : OpRewritePattern<spatial::SpatConc
|
||||
return failure();
|
||||
|
||||
auto outputType = cast<ShapedType>(concatOp.getOutput().getType());
|
||||
auto newConcat = pim::PimConcatOp::create(rewriter,
|
||||
concatOp.getLoc(),
|
||||
concatOp.getOutput().getType(),
|
||||
concatOp.getAxisAttr(),
|
||||
ValueRange(packedInputs),
|
||||
tensor::EmptyOp::create(rewriter,
|
||||
concatOp.getLoc(),
|
||||
outputType.getShape(),
|
||||
outputType.getElementType())
|
||||
.getResult());
|
||||
auto newConcat = pim::PimConcatOp::create(
|
||||
rewriter,
|
||||
concatOp.getLoc(),
|
||||
concatOp.getOutput().getType(),
|
||||
concatOp.getAxisAttr(),
|
||||
ValueRange(packedInputs),
|
||||
tensor::EmptyOp::create(rewriter, concatOp.getLoc(), outputType.getShape(), outputType.getElementType())
|
||||
.getResult());
|
||||
rewriter.replaceOp(concatOp, newConcat.getOutput());
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
@@ -79,6 +79,7 @@ void PimBufferizationPass::runOnOperation() {
|
||||
return WalkResult::skip();
|
||||
});
|
||||
if (hasFailed) {
|
||||
moduleOp.emitError("failed to lower memref.copy-like ops inside PIM core bodies during bufferization");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescing.hpp"
|
||||
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
|
||||
#include <limits>
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescing.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
@@ -29,9 +30,8 @@ static uint64_t getTypeSizeBytes(MemRefType type) {
|
||||
return static_cast<uint64_t>(type.getNumElements() * type.getElementTypeBitWidth() / 8);
|
||||
}
|
||||
|
||||
static FailureOr<uint64_t> getLastUseInstruction(memref::AllocOp allocOp,
|
||||
Block& body,
|
||||
const DenseMap<Operation*, uint64_t>& opOrder) {
|
||||
static FailureOr<uint64_t>
|
||||
getLastUseInstruction(memref::AllocOp allocOp, Block& body, const DenseMap<Operation*, uint64_t>& opOrder) {
|
||||
uint64_t endInstruction = opOrder.lookup(allocOp);
|
||||
SmallPtrSet<Operation*, 16> visited;
|
||||
SmallVector<Value> pendingValues;
|
||||
@@ -45,9 +45,15 @@ static FailureOr<uint64_t> getLastUseInstruction(memref::AllocOp allocOp,
|
||||
if (!visited.insert(user).second)
|
||||
continue;
|
||||
|
||||
if (isSupportedAliasOp(user)) {
|
||||
if (isSupportedAliasOp(user))
|
||||
for (Value result : user->getResults())
|
||||
pendingValues.push_back(result);
|
||||
|
||||
if (auto forOp = dyn_cast<scf::ForOp>(user)) {
|
||||
for (auto [index, initArg] : llvm::enumerate(forOp.getInitArgs())) {
|
||||
if (initArg == value)
|
||||
pendingValues.push_back(forOp.getResult(index));
|
||||
}
|
||||
}
|
||||
|
||||
if (auto dpsOp = dyn_cast<DestinationStyleOpInterface>(user)) {
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
|
||||
+20
-19
@@ -45,9 +45,7 @@ struct CoalescingReportEntry {
|
||||
CoalescingReportRow row;
|
||||
};
|
||||
|
||||
static std::string formatMemory(uint64_t bytes) {
|
||||
return formatReportMemory(bytes);
|
||||
}
|
||||
static std::string formatMemory(uint64_t bytes) { return formatReportMemory(bytes); }
|
||||
|
||||
static SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) {
|
||||
auto coreIdsAttr = coreBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName);
|
||||
@@ -58,9 +56,10 @@ static SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) {
|
||||
static void printReportRow(raw_ostream& os, const CoalescingReportRow& row) {
|
||||
llvm::SmallVector<ReportField, 4> fields = {
|
||||
{"Number of candidates", std::to_string(row.numCandidates)},
|
||||
{"Skipped allocations", std::to_string(row.numSkipped)},
|
||||
{"Removed allocations", std::to_string(row.numRemoved)},
|
||||
{"Saved memory", formatMemory(row.savedBytes)}};
|
||||
{"Skipped allocations", std::to_string(row.numSkipped) },
|
||||
{"Removed allocations", std::to_string(row.numRemoved) },
|
||||
{"Saved memory", formatMemory(row.savedBytes) }
|
||||
};
|
||||
printReportFlatFields(os, fields);
|
||||
}
|
||||
|
||||
@@ -87,10 +86,12 @@ static void emitReport(ArrayRef<CoalescingReportEntry> entries) {
|
||||
totalRow.savedBytes += entryTotal.savedBytes;
|
||||
}
|
||||
|
||||
llvm::SmallVector<ReportField, 4> totalFields = {{"Number of candidates", std::to_string(totalRow.numCandidates)},
|
||||
{"Skipped allocations", std::to_string(totalRow.numSkipped)},
|
||||
{"Removed allocations", std::to_string(totalRow.numRemoved)},
|
||||
{"Saved memory", formatMemory(totalRow.savedBytes)}};
|
||||
llvm::SmallVector<ReportField, 4> totalFields = {
|
||||
{"Number of candidates", std::to_string(totalRow.numCandidates)},
|
||||
{"Skipped allocations", std::to_string(totalRow.numSkipped) },
|
||||
{"Removed allocations", std::to_string(totalRow.numRemoved) },
|
||||
{"Saved memory", formatMemory(totalRow.savedBytes) }
|
||||
};
|
||||
printReportTotalsBlock(os, totalFields);
|
||||
if (!entries.empty())
|
||||
os << "\n";
|
||||
@@ -127,15 +128,17 @@ static void emitReport(ArrayRef<CoalescingReportEntry> entries) {
|
||||
if (sortedEntries[index].kind == CoalescingReportEntry::Kind::Batch) {
|
||||
llvm::SmallVector<ReportField, 4> perCoreFields = {
|
||||
{"Number of candidates", std::to_string(sortedEntries[index].row.numCandidates)},
|
||||
{"Skipped allocations", std::to_string(sortedEntries[index].row.numSkipped)},
|
||||
{"Removed allocations", std::to_string(sortedEntries[index].row.numRemoved)},
|
||||
{"Saved memory", formatMemory(sortedEntries[index].row.savedBytes)}};
|
||||
{"Skipped allocations", std::to_string(sortedEntries[index].row.numSkipped) },
|
||||
{"Removed allocations", std::to_string(sortedEntries[index].row.numRemoved) },
|
||||
{"Saved memory", formatMemory(sortedEntries[index].row.savedBytes) }
|
||||
};
|
||||
CoalescingReportRow totalRow = getTotalRow(sortedEntries[index]);
|
||||
llvm::SmallVector<ReportField, 4> totalFields = {
|
||||
{"Number of candidates", std::to_string(totalRow.numCandidates)},
|
||||
{"Skipped allocations", std::to_string(totalRow.numSkipped)},
|
||||
{"Removed allocations", std::to_string(totalRow.numRemoved)},
|
||||
{"Saved memory", formatMemory(totalRow.savedBytes)}};
|
||||
{"Skipped allocations", std::to_string(totalRow.numSkipped) },
|
||||
{"Removed allocations", std::to_string(totalRow.numRemoved) },
|
||||
{"Saved memory", formatMemory(totalRow.savedBytes) }
|
||||
};
|
||||
printReportPerCoreAndTotalFields(os, perCoreFields, totalFields);
|
||||
}
|
||||
else {
|
||||
@@ -196,8 +199,6 @@ struct StaticMemoryCoalescingPass : PassWrapper<StaticMemoryCoalescingPass, Oper
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> createPimStaticMemoryCoalescingPass() {
|
||||
return std::make_unique<StaticMemoryCoalescingPass>();
|
||||
}
|
||||
std::unique_ptr<Pass> createPimStaticMemoryCoalescingPass() { return std::make_unique<StaticMemoryCoalescingPass>(); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -8,7 +8,14 @@ add_pim_library(SpatialOps
|
||||
SpatialOpsVerify.cpp
|
||||
SpatialOpsCanonicalization.cpp
|
||||
Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp
|
||||
Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp
|
||||
Transforms/MergeComputeNodes/PostMergeCompaction.cpp
|
||||
Transforms/MergeComputeNodes/RegularOpCompaction.cpp
|
||||
Transforms/MergeComputeNodes/Scheduling/ComputeGraph.cpp
|
||||
Transforms/MergeComputeNodes/Scheduling/ComputeInstanceUtils.cpp
|
||||
Transforms/MergeComputeNodes/Scheduling/DcpScheduler.cpp
|
||||
Transforms/MergeComputeNodes/Scheduling/MergeSchedulingAnalysis.cpp
|
||||
Transforms/MergeComputeNodes/Scheduling/PeftScheduler.cpp
|
||||
Transforms/MergeComputeNodes/DCPGraph/Graph.cpp
|
||||
Transforms/MergeComputeNodes/DCPGraph/GraphDebug.cpp
|
||||
Transforms/MergeComputeNodes/DCPGraph/GraphSupport.cpp
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "llvm/Support/LogicalResult.h"
|
||||
@@ -338,6 +339,19 @@ LogicalResult SpatConcatOp::verify() {
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult verifyComputeResultsUses(Operation* op) {
|
||||
if (!isa<SpatCompute, SpatComputeBatch>(op))
|
||||
return op->emitError("verifyComputeResultUses: Op is not a SpatCompute/SpatComputeBatch operation");
|
||||
if (!llvm::all_of(op->getResults(), [](Value result) {
|
||||
return llvm::all_of(result.getUsers(), [](Operation* op) {
|
||||
return !(op->getParentOfType<SpatCompute>() || op->getParentOfType<SpatComputeBatch>());
|
||||
});
|
||||
})) {
|
||||
return op->emitError("ComputeResult used directly inside another Compute" );
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult SpatCompute::verify() {
|
||||
auto& block = getBody().front();
|
||||
if (block.mightHaveTerminator()) {
|
||||
@@ -375,7 +389,8 @@ LogicalResult SpatCompute::verify() {
|
||||
for (auto arg : block.getArguments())
|
||||
if (arg.use_empty())
|
||||
return emitError("ComputeOp block argument is not used");
|
||||
|
||||
if (failed(verifyComputeResultsUses(this->getOperation())))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -465,8 +480,8 @@ LogicalResult SpatComputeBatch::verify() {
|
||||
return emitError("compute_batch coreIds attribute must be a dense i32 array");
|
||||
if (coreIdsAttr.size() != static_cast<int64_t>(laneCountSz))
|
||||
return emitError("compute_batch coreIds array length must match laneCount");
|
||||
if (llvm::any_of(coreIdsAttr.asArrayRef(), [](int32_t coreId) { return coreId <= 0; }))
|
||||
return emitError("compute_batch coreIds values must be positive");
|
||||
if (llvm::any_of(coreIdsAttr.asArrayRef(), [](int32_t coreId) { return coreId < 0; }))
|
||||
return emitError("compute_batch coreIds values must be non-negative");
|
||||
llvm::SmallDenseSet<int32_t, 8> seenCoreIds;
|
||||
for (int32_t coreId : coreIdsAttr.asArrayRef())
|
||||
if (!seenCoreIds.insert(coreId).second)
|
||||
@@ -485,6 +500,8 @@ LogicalResult SpatComputeBatch::verify() {
|
||||
return emitError("body block argument type must match input type");
|
||||
}
|
||||
|
||||
if (failed(verifyComputeResultsUses(this->getOperation())))
|
||||
return failure();
|
||||
return verifyBatchBody(getOperation(), block, getResultTypes(), weightsPerLane);
|
||||
}
|
||||
|
||||
|
||||
@@ -1,802 +1,19 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/IR/ValueRange.h"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdlib>
|
||||
#include <iterator>
|
||||
#include <numeric>
|
||||
#include <optional>
|
||||
#include <queue>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "DCPAnalysis.hpp"
|
||||
#include "Graph.hpp"
|
||||
#include "../Scheduling/ComputeGraph.hpp"
|
||||
#include "../Scheduling/DcpScheduler.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Support/TypeUtilities.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
using SpatCompute = onnx_mlir::spatial::SpatCompute;
|
||||
using SpatComputeBatch = onnx_mlir::spatial::SpatComputeBatch;
|
||||
|
||||
bool isDcpCoarsenDebugEnabled() { return std::getenv("DCP_COARSEN_DEBUG") != nullptr; }
|
||||
|
||||
struct VirtualNode {
|
||||
SmallVector<size_t, 4> originalComputeIndices;
|
||||
Weight weight = 0;
|
||||
CrossbarUsage crossbarUsage = 0;
|
||||
};
|
||||
|
||||
struct VirtualGraph {
|
||||
std::vector<VirtualNode> nodes;
|
||||
std::vector<IndexedEdge> edges;
|
||||
};
|
||||
|
||||
struct TimingInfo {
|
||||
std::vector<Time> aest;
|
||||
std::vector<Time> alst;
|
||||
std::vector<size_t> topologicalOrder;
|
||||
bool valid = false;
|
||||
};
|
||||
|
||||
struct WindowScheduleResult {
|
||||
std::vector<std::vector<size_t>> mergeGroups;
|
||||
CPU cpuCount = 0;
|
||||
size_t mergedNodeCount = 0;
|
||||
size_t maxMergeGroupSize = 0;
|
||||
};
|
||||
|
||||
size_t getSchedulingCpuBudget() {
|
||||
if (coresCount.getValue() > 0)
|
||||
return static_cast<size_t>(coresCount.getValue());
|
||||
return std::numeric_limits<size_t>::max();
|
||||
}
|
||||
|
||||
size_t getBatchChunkTargetCount(int32_t laneCount) {
|
||||
assert(laneCount > 0 && "laneCount must be positive");
|
||||
return std::min(static_cast<size_t>(laneCount), std::max<size_t>(1, getSchedulingCpuBudget()));
|
||||
}
|
||||
|
||||
ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex) {
|
||||
size_t totalLanes = static_cast<size_t>(batch.getLaneCount());
|
||||
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount());
|
||||
size_t baseChunkSize = totalLanes / chunkCount;
|
||||
size_t largeChunkCount = totalLanes % chunkCount;
|
||||
|
||||
size_t laneStart = chunkIndex * baseChunkSize + std::min(chunkIndex, largeChunkCount);
|
||||
size_t laneCount = baseChunkSize + (chunkIndex < largeChunkCount ? 1 : 0);
|
||||
return {batch.getOperation(), static_cast<uint32_t>(laneStart), static_cast<uint32_t>(laneCount)};
|
||||
}
|
||||
|
||||
ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane) {
|
||||
size_t totalLanes = static_cast<size_t>(batch.getLaneCount());
|
||||
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount());
|
||||
size_t baseChunkSize = totalLanes / chunkCount;
|
||||
size_t largeChunkCount = totalLanes % chunkCount;
|
||||
size_t largeChunkSpan = largeChunkCount * (baseChunkSize + 1);
|
||||
|
||||
size_t chunkIndex = 0;
|
||||
if (static_cast<size_t>(lane) < largeChunkSpan)
|
||||
chunkIndex = static_cast<size_t>(lane) / (baseChunkSize + 1);
|
||||
else
|
||||
chunkIndex = largeChunkCount + (static_cast<size_t>(lane) - largeChunkSpan) / baseChunkSize;
|
||||
return getBatchChunkForIndex(batch, chunkIndex);
|
||||
}
|
||||
|
||||
std::vector<IndexedEdge> aggregateEdges(ArrayRef<IndexedEdge> edges) {
|
||||
llvm::DenseMap<std::pair<size_t, size_t>, Weight> edgeWeights;
|
||||
for (auto [start, end, weight] : edges) {
|
||||
size_t startIndex = static_cast<size_t>(start);
|
||||
size_t endIndex = static_cast<size_t>(end);
|
||||
if (startIndex == endIndex)
|
||||
continue;
|
||||
auto key = std::make_pair(startIndex, endIndex);
|
||||
Weight edgeWeight = static_cast<Weight>(weight);
|
||||
auto inserted = edgeWeights.try_emplace(key, edgeWeight);
|
||||
if (!inserted.second)
|
||||
inserted.first->second = std::max(inserted.first->second, edgeWeight);
|
||||
}
|
||||
|
||||
std::vector<IndexedEdge> aggregatedEdges;
|
||||
aggregatedEdges.reserve(edgeWeights.size());
|
||||
for (auto [key, weight] : edgeWeights)
|
||||
aggregatedEdges.push_back(
|
||||
{static_cast<int64_t>(key.first), static_cast<int64_t>(key.second), static_cast<int64_t>(weight)});
|
||||
llvm::sort(aggregatedEdges, [](const IndexedEdge& lhs, const IndexedEdge& rhs) {
|
||||
if (std::get<0>(lhs) != std::get<0>(rhs))
|
||||
return std::get<0>(lhs) < std::get<0>(rhs);
|
||||
return std::get<1>(lhs) < std::get<1>(rhs);
|
||||
});
|
||||
return aggregatedEdges;
|
||||
}
|
||||
|
||||
Weight getComputeBodyWeight(Region& body) {
|
||||
constexpr Weight kOperationWeight = 100;
|
||||
Weight numOperations = 0;
|
||||
for (auto& block : body)
|
||||
for ([[maybe_unused]] auto& op : block)
|
||||
numOperations = checkedAdd(numOperations, static_cast<Weight>(1));
|
||||
return checkedMultiply(numOperations, kOperationWeight);
|
||||
}
|
||||
|
||||
CrossbarUsage getComputeBodyCrossbarUsage(Region& body) {
|
||||
CrossbarUsage crossbarUsage = 0;
|
||||
for (auto& block : body)
|
||||
for (auto& op : block)
|
||||
if (isa<SpatVMMOp>(op))
|
||||
crossbarUsage = checkedAdd(crossbarUsage, static_cast<CrossbarUsage>(1));
|
||||
return crossbarUsage;
|
||||
}
|
||||
|
||||
Weight getComputeInstanceWeight(const ComputeInstance& instance) {
|
||||
if (auto spatCompute = dyn_cast<SpatCompute>(instance.op))
|
||||
return getSpatComputeWeight(spatCompute);
|
||||
auto batch = cast<SpatComputeBatch>(instance.op);
|
||||
return checkedMultiply(getComputeBodyWeight(batch.getBody()), static_cast<Weight>(instance.laneCount));
|
||||
}
|
||||
|
||||
CrossbarUsage getComputeInstanceCrossbarUsage(const ComputeInstance& instance) {
|
||||
if (auto spatCompute = dyn_cast<SpatCompute>(instance.op))
|
||||
return getSpatComputeCrossbarUsage(spatCompute);
|
||||
auto batch = cast<SpatComputeBatch>(instance.op);
|
||||
return checkedMultiply(getComputeBodyCrossbarUsage(batch.getBody()), static_cast<CrossbarUsage>(instance.laneCount));
|
||||
}
|
||||
|
||||
SmallVector<Value, 4> getComputeInstanceInputs(const ComputeInstance& instance) {
|
||||
if (auto spatCompute = dyn_cast<SpatCompute>(instance.op))
|
||||
return SmallVector<Value, 4>(spatCompute.getInputs().begin(), spatCompute.getInputs().end());
|
||||
auto batch = cast<SpatComputeBatch>(instance.op);
|
||||
SmallVector<Value, 4> inputs;
|
||||
inputs.reserve(instance.laneCount);
|
||||
for (uint32_t lane = instance.laneStart; lane < instance.laneStart + instance.laneCount; ++lane)
|
||||
inputs.push_back(batch.getInputs()[lane]);
|
||||
return inputs;
|
||||
}
|
||||
|
||||
std::optional<ComputeInstance> getOriginalComputeInstance(Value value) {
|
||||
Operation* op = value.getDefiningOp();
|
||||
if (!op)
|
||||
return std::nullopt;
|
||||
|
||||
while (auto extract = dyn_cast<tensor::ExtractSliceOp>(op)) {
|
||||
value = extract.getSource();
|
||||
op = value.getDefiningOp();
|
||||
if (!op)
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
if (auto spatCompute = dyn_cast<SpatCompute>(op))
|
||||
return ComputeInstance {spatCompute.getOperation(), 0, 1};
|
||||
if (auto batch = dyn_cast<SpatComputeBatch>(op))
|
||||
return getBatchChunkForLane(batch, static_cast<uint32_t>(cast<OpResult>(value).getResultNumber()));
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
SmallVector<ComputeInstance> collectComputeInstances(Operation* entryOp) {
|
||||
SmallVector<ComputeInstance> instances;
|
||||
auto isUsedAsWeightOnly = [](Operation* producerOp) {
|
||||
if (producerOp->getNumResults() == 0)
|
||||
return false;
|
||||
for (Value result : producerOp->getResults()) {
|
||||
if (result.use_empty())
|
||||
return false;
|
||||
for (Operation* user : result.getUsers()) {
|
||||
if (auto compute = dyn_cast<SpatCompute>(user)) {
|
||||
if (!llvm::is_contained(compute.getWeights(), result))
|
||||
return false;
|
||||
continue;
|
||||
}
|
||||
if (auto batch = dyn_cast<SpatComputeBatch>(user)) {
|
||||
if (!llvm::is_contained(batch.getWeights(), result))
|
||||
return false;
|
||||
continue;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
};
|
||||
for (Region& region : entryOp->getRegions()) {
|
||||
for (Block& block : region) {
|
||||
for (Operation& op : block) {
|
||||
if (auto spatCompute = dyn_cast<SpatCompute>(&op)) {
|
||||
if (isUsedAsWeightOnly(spatCompute.getOperation()))
|
||||
continue;
|
||||
instances.push_back({spatCompute.getOperation(), 0, 1});
|
||||
continue;
|
||||
}
|
||||
if (auto batch = dyn_cast<SpatComputeBatch>(&op)) {
|
||||
if (isUsedAsWeightOnly(batch.getOperation()))
|
||||
continue;
|
||||
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount());
|
||||
for (size_t chunkIndex = 0; chunkIndex < chunkCount; ++chunkIndex)
|
||||
instances.push_back(getBatchChunkForIndex(batch, chunkIndex));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return instances;
|
||||
}
|
||||
|
||||
VirtualGraph buildInitialVirtualGraph(ArrayRef<ComputeInstance> computeInstances, ArrayRef<IndexedEdge> edges) {
|
||||
VirtualGraph graph;
|
||||
graph.nodes.reserve(computeInstances.size());
|
||||
for (auto [index, computeInstance] : llvm::enumerate(computeInstances)) {
|
||||
VirtualNode node;
|
||||
node.originalComputeIndices.push_back(index);
|
||||
node.weight = getComputeInstanceWeight(computeInstance);
|
||||
node.crossbarUsage = getComputeInstanceCrossbarUsage(computeInstance);
|
||||
graph.nodes.push_back(std::move(node));
|
||||
}
|
||||
graph.edges = aggregateEdges(edges);
|
||||
return graph;
|
||||
}
|
||||
|
||||
TimingInfo computeTiming(const VirtualGraph& graph) {
|
||||
TimingInfo timing;
|
||||
size_t nodeCount = graph.nodes.size();
|
||||
timing.aest.assign(nodeCount, 0);
|
||||
timing.alst.assign(nodeCount, 0);
|
||||
timing.topologicalOrder.reserve(nodeCount);
|
||||
|
||||
std::vector<std::vector<std::pair<size_t, Weight>>> parents(nodeCount);
|
||||
std::vector<std::vector<std::pair<size_t, Weight>>> children(nodeCount);
|
||||
std::vector<size_t> incomingEdgeCount(nodeCount, 0);
|
||||
|
||||
for (auto [start, end, weight] : graph.edges) {
|
||||
size_t startIndex = static_cast<size_t>(start);
|
||||
size_t endIndex = static_cast<size_t>(end);
|
||||
Weight edgeWeight = static_cast<Weight>(weight);
|
||||
assert(startIndex < nodeCount && endIndex < nodeCount && "virtual edge endpoint out of range");
|
||||
children[startIndex].push_back({endIndex, edgeWeight});
|
||||
parents[endIndex].push_back({startIndex, edgeWeight});
|
||||
incomingEdgeCount[endIndex]++;
|
||||
}
|
||||
|
||||
auto getVirtualNodeOrderKey = [&](size_t nodeIndex) {
|
||||
const VirtualNode& node = graph.nodes[nodeIndex];
|
||||
if (!node.originalComputeIndices.empty())
|
||||
return node.originalComputeIndices.front();
|
||||
return nodeIndex;
|
||||
};
|
||||
auto readyNodeGreater = [&](size_t lhs, size_t rhs) {
|
||||
size_t lhsKey = getVirtualNodeOrderKey(lhs);
|
||||
size_t rhsKey = getVirtualNodeOrderKey(rhs);
|
||||
if (lhsKey != rhsKey)
|
||||
return lhsKey > rhsKey;
|
||||
return lhs > rhs;
|
||||
};
|
||||
std::priority_queue<size_t, std::vector<size_t>, decltype(readyNodeGreater)> readyNodes(readyNodeGreater);
|
||||
for (size_t i = 0; i < nodeCount; ++i)
|
||||
if (incomingEdgeCount[i] == 0)
|
||||
readyNodes.push(i);
|
||||
|
||||
while (!readyNodes.empty()) {
|
||||
size_t current = readyNodes.top();
|
||||
readyNodes.pop();
|
||||
timing.topologicalOrder.push_back(current);
|
||||
for (auto [child, weight] : children[current]) {
|
||||
(void) weight;
|
||||
assert(incomingEdgeCount[child] > 0 && "incoming edge count underflow");
|
||||
incomingEdgeCount[child]--;
|
||||
if (incomingEdgeCount[child] == 0)
|
||||
readyNodes.push(child);
|
||||
}
|
||||
}
|
||||
|
||||
if (timing.topologicalOrder.size() != nodeCount)
|
||||
return timing;
|
||||
|
||||
Time dcpl = 0;
|
||||
for (size_t nodeIndex : timing.topologicalOrder) {
|
||||
Time maxParentAest = 0;
|
||||
for (auto [parent, transferCost] : parents[nodeIndex]) {
|
||||
maxParentAest =
|
||||
std::max(maxParentAest, addOrMax(addOrMax(timing.aest[parent], graph.nodes[parent].weight), transferCost));
|
||||
}
|
||||
timing.aest[nodeIndex] = maxParentAest;
|
||||
dcpl = std::max(dcpl, addOrMax(maxParentAest, graph.nodes[nodeIndex].weight));
|
||||
}
|
||||
|
||||
for (size_t nodeIndex : llvm::reverse(timing.topologicalOrder)) {
|
||||
Time minAlst = std::numeric_limits<Time>::max();
|
||||
if (children[nodeIndex].empty())
|
||||
minAlst = subtractOrZero(dcpl, graph.nodes[nodeIndex].weight);
|
||||
for (auto [child, transferCost] : children[nodeIndex]) {
|
||||
minAlst =
|
||||
std::min(minAlst, subtractOrZero(timing.alst[child], addOrMax(graph.nodes[nodeIndex].weight, transferCost)));
|
||||
}
|
||||
timing.alst[nodeIndex] = minAlst;
|
||||
}
|
||||
|
||||
timing.valid = true;
|
||||
return timing;
|
||||
}
|
||||
|
||||
std::vector<std::vector<size_t>> buildUndirectedAdjacency(const VirtualGraph& graph) {
|
||||
std::vector<std::vector<size_t>> adjacency(graph.nodes.size());
|
||||
for (auto [start, end, weight] : graph.edges) {
|
||||
(void) weight;
|
||||
size_t startIndex = static_cast<size_t>(start);
|
||||
size_t endIndex = static_cast<size_t>(end);
|
||||
assert(startIndex < graph.nodes.size() && endIndex < graph.nodes.size() && "virtual edge endpoint out of range");
|
||||
adjacency[startIndex].push_back(endIndex);
|
||||
adjacency[endIndex].push_back(startIndex);
|
||||
}
|
||||
for (auto& neighbours : adjacency) {
|
||||
llvm::sort(neighbours);
|
||||
neighbours.erase(std::unique(neighbours.begin(), neighbours.end()), neighbours.end());
|
||||
}
|
||||
return adjacency;
|
||||
}
|
||||
|
||||
std::vector<size_t> selectCriticalWindow(const VirtualGraph& graph, const TimingInfo& timing, size_t windowSize) {
|
||||
std::vector<size_t> ranked(timing.aest.size());
|
||||
std::iota(ranked.begin(), ranked.end(), 0);
|
||||
auto isHigherPriority = [&](size_t lhs, size_t rhs) {
|
||||
Time lhsSlack = slackOrZero(timing.aest[lhs], timing.alst[lhs]);
|
||||
Time rhsSlack = slackOrZero(timing.aest[rhs], timing.alst[rhs]);
|
||||
if (lhsSlack != rhsSlack)
|
||||
return lhsSlack < rhsSlack;
|
||||
if (timing.aest[lhs] != timing.aest[rhs])
|
||||
return timing.aest[lhs] < timing.aest[rhs];
|
||||
return lhs < rhs;
|
||||
};
|
||||
|
||||
windowSize = std::min(windowSize, ranked.size());
|
||||
if (windowSize == 0)
|
||||
return {};
|
||||
if (windowSize == ranked.size()) {
|
||||
llvm::sort(ranked, isHigherPriority);
|
||||
return ranked;
|
||||
}
|
||||
|
||||
size_t criticalPoolSize = std::min(ranked.size(), std::max(windowSize, windowSize * 2));
|
||||
if (criticalPoolSize < ranked.size())
|
||||
std::nth_element(
|
||||
ranked.begin(), ranked.begin() + static_cast<std::ptrdiff_t>(criticalPoolSize), ranked.end(), isHigherPriority);
|
||||
|
||||
std::vector<char> inCriticalPool(ranked.size(), false);
|
||||
for (size_t i = 0; i < criticalPoolSize; ++i)
|
||||
inCriticalPool[ranked[i]] = true;
|
||||
|
||||
size_t seed = *std::min_element(ranked.begin(), ranked.end(), isHigherPriority);
|
||||
std::vector<std::vector<size_t>> adjacency = buildUndirectedAdjacency(graph);
|
||||
std::vector<size_t> selected;
|
||||
std::vector<char> inWindow(ranked.size(), false);
|
||||
selected.reserve(windowSize);
|
||||
|
||||
struct FrontierEntry {
|
||||
size_t node;
|
||||
};
|
||||
auto frontierCompare = [&](FrontierEntry lhs, FrontierEntry rhs) { return isHigherPriority(rhs.node, lhs.node); };
|
||||
std::priority_queue<FrontierEntry, std::vector<FrontierEntry>, decltype(frontierCompare)> frontier(frontierCompare);
|
||||
|
||||
auto addToWindow = [&](size_t node, const std::vector<char>& eligible) {
|
||||
if (inWindow[node])
|
||||
return;
|
||||
inWindow[node] = true;
|
||||
selected.push_back(node);
|
||||
for (size_t neighbour : adjacency[node])
|
||||
if (!inWindow[neighbour] && eligible[neighbour])
|
||||
frontier.push({neighbour});
|
||||
};
|
||||
|
||||
addToWindow(seed, inCriticalPool);
|
||||
while (!frontier.empty() && selected.size() < windowSize) {
|
||||
size_t node = frontier.top().node;
|
||||
frontier.pop();
|
||||
if (!inWindow[node])
|
||||
addToWindow(node, inCriticalPool);
|
||||
}
|
||||
|
||||
if (selected.size() < windowSize) {
|
||||
std::vector<char> anyNode(ranked.size(), true);
|
||||
for (size_t node : selected)
|
||||
for (size_t neighbour : adjacency[node])
|
||||
if (!inWindow[neighbour])
|
||||
frontier.push({neighbour});
|
||||
while (!frontier.empty() && selected.size() < windowSize) {
|
||||
size_t node = frontier.top().node;
|
||||
frontier.pop();
|
||||
if (!inWindow[node])
|
||||
addToWindow(node, anyNode);
|
||||
}
|
||||
}
|
||||
|
||||
if (selected.size() < windowSize) {
|
||||
llvm::sort(ranked, isHigherPriority);
|
||||
for (size_t node : ranked) {
|
||||
if (selected.size() == windowSize)
|
||||
break;
|
||||
if (!inWindow[node]) {
|
||||
inWindow[node] = true;
|
||||
selected.push_back(node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
llvm::sort(selected, isHigherPriority);
|
||||
return selected;
|
||||
}
|
||||
|
||||
std::vector<IndexedEdge> buildWindowEdges(const VirtualGraph& graph, const std::vector<int64_t>& nodeToWindowIndex) {
|
||||
std::vector<IndexedEdge> windowEdges;
|
||||
windowEdges.reserve(graph.edges.size());
|
||||
for (auto [start, end, weight] : graph.edges) {
|
||||
int64_t mappedStart = nodeToWindowIndex[static_cast<size_t>(start)];
|
||||
int64_t mappedEnd = nodeToWindowIndex[static_cast<size_t>(end)];
|
||||
if (mappedStart == -1 || mappedEnd == -1)
|
||||
continue;
|
||||
windowEdges.push_back({mappedStart, mappedEnd, weight});
|
||||
}
|
||||
return aggregateEdges(windowEdges);
|
||||
}
|
||||
|
||||
WindowScheduleResult scheduleWindow(const VirtualGraph& graph, ArrayRef<size_t> selectedNodes, MLIRContext* context) {
|
||||
std::vector<Weight> windowWeights;
|
||||
std::vector<CrossbarUsage> windowCrossbarUsage;
|
||||
std::vector<int64_t> windowNodeOrderKeys;
|
||||
std::vector<int64_t> nodeToWindowIndex(graph.nodes.size(), -1);
|
||||
windowWeights.reserve(selectedNodes.size());
|
||||
windowCrossbarUsage.reserve(selectedNodes.size());
|
||||
windowNodeOrderKeys.reserve(selectedNodes.size());
|
||||
|
||||
for (auto [windowIndex, nodeIndex] : llvm::enumerate(selectedNodes)) {
|
||||
nodeToWindowIndex[nodeIndex] = static_cast<int64_t>(windowIndex);
|
||||
windowWeights.push_back(graph.nodes[nodeIndex].weight);
|
||||
windowCrossbarUsage.push_back(graph.nodes[nodeIndex].crossbarUsage);
|
||||
windowNodeOrderKeys.push_back(static_cast<int64_t>(nodeIndex));
|
||||
}
|
||||
|
||||
GraphDCP windowGraph(
|
||||
windowWeights, buildWindowEdges(graph, nodeToWindowIndex), windowNodeOrderKeys, windowCrossbarUsage);
|
||||
if (coresCount.getValue() > 0)
|
||||
windowGraph.setMaxCpuCount(static_cast<int>(coresCount.getValue()));
|
||||
windowGraph.setContext(context);
|
||||
windowGraph.runDcp();
|
||||
|
||||
WindowScheduleResult result;
|
||||
result.cpuCount = windowGraph.cpuCount();
|
||||
for (CPU cpu = 0; cpu < windowGraph.cpuCount(); ++cpu) {
|
||||
auto scheduledTasks = windowGraph.getScheduledTasks(cpu);
|
||||
if (scheduledTasks.size() < 2)
|
||||
continue;
|
||||
|
||||
result.mergedNodeCount += scheduledTasks.size();
|
||||
result.maxMergeGroupSize = std::max(result.maxMergeGroupSize, scheduledTasks.size());
|
||||
std::vector<size_t> mergeGroup;
|
||||
mergeGroup.reserve(scheduledTasks.size());
|
||||
for (const auto& task : scheduledTasks)
|
||||
mergeGroup.push_back(selectedNodes[task.nodeIndex]);
|
||||
result.mergeGroups.push_back(std::move(mergeGroup));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
bool coarsenGraph(const VirtualGraph& graph,
|
||||
ArrayRef<std::vector<size_t>> mergeGroups,
|
||||
VirtualGraph& coarsenedGraph,
|
||||
std::vector<size_t>& oldToNewNode) {
|
||||
TimingInfo timing = computeTiming(graph);
|
||||
std::vector<size_t> topologicalRank(graph.nodes.size());
|
||||
std::iota(topologicalRank.begin(), topologicalRank.end(), 0);
|
||||
if (timing.valid)
|
||||
for (auto [rank, nodeIndex] : llvm::enumerate(timing.topologicalOrder))
|
||||
topologicalRank[nodeIndex] = rank;
|
||||
|
||||
std::vector<std::vector<size_t>> orderedMergeGroups;
|
||||
orderedMergeGroups.reserve(mergeGroups.size());
|
||||
for (const auto& mergeGroup : mergeGroups) {
|
||||
orderedMergeGroups.emplace_back(mergeGroup.begin(), mergeGroup.end());
|
||||
std::stable_sort(orderedMergeGroups.back().begin(), orderedMergeGroups.back().end(), [&](size_t lhs, size_t rhs) {
|
||||
if (topologicalRank[lhs] != topologicalRank[rhs])
|
||||
return topologicalRank[lhs] < topologicalRank[rhs];
|
||||
return lhs < rhs;
|
||||
});
|
||||
}
|
||||
|
||||
std::vector<int64_t> nodeToMergeGroup(graph.nodes.size(), -1);
|
||||
for (auto [groupIndex, mergeGroup] : llvm::enumerate(orderedMergeGroups)) {
|
||||
if (mergeGroup.size() < 2)
|
||||
continue;
|
||||
for (size_t nodeIndex : mergeGroup) {
|
||||
assert(nodeIndex < graph.nodes.size() && "merge group node out of range");
|
||||
nodeToMergeGroup[nodeIndex] = static_cast<int64_t>(groupIndex);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::optional<size_t>> mergeGroupToNewNode(orderedMergeGroups.size());
|
||||
std::vector<size_t> newNodeRank;
|
||||
oldToNewNode.assign(graph.nodes.size(), 0);
|
||||
bool mergedAny = false;
|
||||
coarsenedGraph.nodes.clear();
|
||||
coarsenedGraph.edges.clear();
|
||||
coarsenedGraph.nodes.reserve(graph.nodes.size());
|
||||
newNodeRank.reserve(graph.nodes.size());
|
||||
|
||||
for (size_t nodeIndex = 0; nodeIndex < graph.nodes.size(); ++nodeIndex) {
|
||||
int64_t mergeGroupIndex = nodeToMergeGroup[nodeIndex];
|
||||
if (mergeGroupIndex == -1) {
|
||||
oldToNewNode[nodeIndex] = coarsenedGraph.nodes.size();
|
||||
coarsenedGraph.nodes.push_back(graph.nodes[nodeIndex]);
|
||||
newNodeRank.push_back(topologicalRank[nodeIndex]);
|
||||
continue;
|
||||
}
|
||||
|
||||
auto& newNodeIndex = mergeGroupToNewNode[static_cast<size_t>(mergeGroupIndex)];
|
||||
if (newNodeIndex.has_value()) {
|
||||
oldToNewNode[nodeIndex] = *newNodeIndex;
|
||||
continue;
|
||||
}
|
||||
|
||||
VirtualNode mergedNode;
|
||||
for (size_t memberIndex : orderedMergeGroups[static_cast<size_t>(mergeGroupIndex)]) {
|
||||
const VirtualNode& memberNode = graph.nodes[memberIndex];
|
||||
mergedNode.originalComputeIndices.append(memberNode.originalComputeIndices.begin(),
|
||||
memberNode.originalComputeIndices.end());
|
||||
mergedNode.weight = addOrMax(mergedNode.weight, memberNode.weight);
|
||||
mergedNode.crossbarUsage = addOrMax(mergedNode.crossbarUsage, memberNode.crossbarUsage);
|
||||
}
|
||||
std::sort(mergedNode.originalComputeIndices.begin(), mergedNode.originalComputeIndices.end());
|
||||
|
||||
mergedAny = true;
|
||||
newNodeIndex = coarsenedGraph.nodes.size();
|
||||
for (size_t memberIndex : orderedMergeGroups[static_cast<size_t>(mergeGroupIndex)])
|
||||
oldToNewNode[memberIndex] = *newNodeIndex;
|
||||
newNodeRank.push_back(topologicalRank[orderedMergeGroups[static_cast<size_t>(mergeGroupIndex)].front()]);
|
||||
coarsenedGraph.nodes.push_back(std::move(mergedNode));
|
||||
}
|
||||
|
||||
if (!mergedAny)
|
||||
return false;
|
||||
|
||||
std::vector<IndexedEdge> remappedEdges;
|
||||
remappedEdges.reserve(graph.edges.size());
|
||||
for (auto [start, end, weight] : graph.edges) {
|
||||
size_t newStart = oldToNewNode[static_cast<size_t>(start)];
|
||||
size_t newEnd = oldToNewNode[static_cast<size_t>(end)];
|
||||
if (newStart == newEnd)
|
||||
continue;
|
||||
if (newNodeRank[newStart] >= newNodeRank[newEnd])
|
||||
continue;
|
||||
remappedEdges.push_back({static_cast<int64_t>(newStart), static_cast<int64_t>(newEnd), weight});
|
||||
}
|
||||
coarsenedGraph.edges = aggregateEdges(remappedEdges);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
CPU getVirtualGraphMaxCpuCount() { return static_cast<CPU>(getSchedulingCpuBudget()); }
|
||||
|
||||
size_t getDcpCoarseningWindowSize(size_t nodeCount) {
|
||||
size_t windowSize = std::min(dcpCriticalWindowSize.getValue(), nodeCount);
|
||||
CPU maxCpuCount = std::max<CPU>(1, getVirtualGraphMaxCpuCount());
|
||||
if (nodeCount > static_cast<size_t>(maxCpuCount))
|
||||
windowSize = std::max(windowSize, std::min(nodeCount, static_cast<size_t>(maxCpuCount) + 1));
|
||||
return windowSize;
|
||||
}
|
||||
|
||||
DCPAnalysisResult buildResultFromVirtualGraph(const VirtualGraph& graph, ArrayRef<ComputeInstance> computeInstances) {
|
||||
DCPAnalysisResult result;
|
||||
|
||||
TimingInfo timing = computeTiming(graph);
|
||||
std::vector<size_t> virtualNodeOrder;
|
||||
if (timing.valid) {
|
||||
virtualNodeOrder = std::move(timing.topologicalOrder);
|
||||
}
|
||||
else {
|
||||
virtualNodeOrder.resize(graph.nodes.size());
|
||||
std::iota(virtualNodeOrder.begin(), virtualNodeOrder.end(), 0);
|
||||
}
|
||||
|
||||
std::vector<size_t> originalComputeToCpu(computeInstances.size(), 0);
|
||||
for (auto [cpu, virtualNodeIndex] : llvm::enumerate(virtualNodeOrder)) {
|
||||
const VirtualNode& virtualNode = graph.nodes[virtualNodeIndex];
|
||||
for (size_t originalIndex : virtualNode.originalComputeIndices)
|
||||
originalComputeToCpu[originalIndex] = cpu;
|
||||
}
|
||||
|
||||
result.dominanceOrderCompute.reserve(computeInstances.size());
|
||||
llvm::DenseMap<size_t, size_t> nextCpuSlot;
|
||||
for (auto [originalIndex, computeInstance] : llvm::enumerate(computeInstances)) {
|
||||
size_t cpu = originalComputeToCpu[originalIndex];
|
||||
result.dominanceOrderCompute.push_back(computeInstance);
|
||||
result.computeToCpuMap[computeInstance] = cpu;
|
||||
result.computeToCpuSlotMap[computeInstance] = nextCpuSlot[cpu]++;
|
||||
result.computeToAestMap[computeInstance] = originalIndex;
|
||||
result.cpuToLastComputeMap[cpu] = computeInstance;
|
||||
}
|
||||
for (const auto& [cpu, lastCompute] : result.cpuToLastComputeMap)
|
||||
result.isLastComputeOfCpu.insert(lastCompute);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
DCPAnalysisResult buildResultFromScheduledGraph(GraphDCP& graphDCP, ArrayRef<ComputeInstance> computeInstances) {
|
||||
DCPAnalysisResult result;
|
||||
result.dominanceOrderCompute.assign(computeInstances.begin(), computeInstances.end());
|
||||
|
||||
for (CPU cpu = 0; cpu < graphDCP.cpuCount(); ++cpu) {
|
||||
auto scheduledTasks = graphDCP.getScheduledTasks(cpu);
|
||||
if (scheduledTasks.empty())
|
||||
continue;
|
||||
|
||||
for (auto [slot, task] : llvm::enumerate(scheduledTasks)) {
|
||||
ComputeInstance instance = computeInstances[task.nodeIndex];
|
||||
result.computeToCpuMap[instance] = cpu;
|
||||
result.computeToCpuSlotMap[instance] = slot;
|
||||
result.computeToAestMap[instance] = static_cast<uint64_t>(task.aest);
|
||||
}
|
||||
result.cpuToLastComputeMap[cpu] = computeInstances[scheduledTasks.back().nodeIndex];
|
||||
result.isLastComputeOfCpu.insert(computeInstances[scheduledTasks.back().nodeIndex]);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
DCPAnalysisResult
|
||||
runLegacyDcp(ArrayRef<ComputeInstance> computeInstances, ArrayRef<IndexedEdge> edges, MLIRContext* context) {
|
||||
SmallVector<Weight> nodeWeights;
|
||||
SmallVector<CrossbarUsage> nodeCrossbarUsage;
|
||||
SmallVector<int64_t> nodeOrderKeys;
|
||||
nodeWeights.reserve(computeInstances.size());
|
||||
nodeCrossbarUsage.reserve(computeInstances.size());
|
||||
nodeOrderKeys.reserve(computeInstances.size());
|
||||
for (auto [index, instance] : llvm::enumerate(computeInstances)) {
|
||||
nodeWeights.push_back(getComputeInstanceWeight(instance));
|
||||
nodeCrossbarUsage.push_back(getComputeInstanceCrossbarUsage(instance));
|
||||
nodeOrderKeys.push_back(static_cast<int64_t>(index));
|
||||
}
|
||||
|
||||
GraphDCP graphDCP(nodeWeights, edges, nodeOrderKeys, nodeCrossbarUsage);
|
||||
if (coresCount.getValue() > 0)
|
||||
graphDCP.setMaxCpuCount(static_cast<int>(coresCount.getValue()));
|
||||
graphDCP.setContext(context);
|
||||
graphDCP.runDcp();
|
||||
return buildResultFromScheduledGraph(graphDCP, computeInstances);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
SpatCompute getOriginalSpatCompute(Operation* op) {
|
||||
if (!op)
|
||||
return {};
|
||||
while (auto extract = dyn_cast<tensor::ExtractSliceOp>(op)) {
|
||||
op = extract.getSource().getDefiningOp();
|
||||
if (!op)
|
||||
return {};
|
||||
}
|
||||
if (auto res = dyn_cast<SpatCompute>(op))
|
||||
return res;
|
||||
return {};
|
||||
}
|
||||
|
||||
DCPAnalysisResult DCPAnalysis::run() {
|
||||
SmallVector<ComputeInstance> computeInstances = collectComputeInstances(entryOp);
|
||||
SmallVector<IndexedEdge, 10> edges;
|
||||
|
||||
llvm::DenseMap<ComputeInstance, size_t> instanceToIndex;
|
||||
instanceToIndex.reserve(computeInstances.size());
|
||||
for (auto [index, instance] : llvm::enumerate(computeInstances))
|
||||
instanceToIndex[instance] = index;
|
||||
|
||||
for (auto [indexEndEdge, computeInstance] : llvm::enumerate(computeInstances)) {
|
||||
for (Value input : getComputeInstanceInputs(computeInstance)) {
|
||||
if (auto producerInstance = getOriginalComputeInstance(input)) {
|
||||
auto producerIt = instanceToIndex.find(*producerInstance);
|
||||
assert(producerIt != instanceToIndex.end());
|
||||
auto indexStartEdge = producerIt->second;
|
||||
edges.push_back({static_cast<int64_t>(indexStartEdge),
|
||||
static_cast<int64_t>(indexEndEdge),
|
||||
static_cast<int64_t>(getSizeInBytes(cast<ShapedType>(input.getType())))});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (coresCount.getValue() > 0) {
|
||||
size_t schedulingCpuBudget = getSchedulingCpuBudget();
|
||||
bool needsExactScheduledBatches = llvm::any_of(computeInstances, [&](const ComputeInstance& instance) {
|
||||
auto batch = dyn_cast<SpatComputeBatch>(instance.op);
|
||||
return batch && static_cast<size_t>(batch.getLaneCount()) > schedulingCpuBudget;
|
||||
});
|
||||
if (needsExactScheduledBatches)
|
||||
return runLegacyDcp(computeInstances, edges, entryOp->getContext());
|
||||
}
|
||||
|
||||
if (dcpCriticalWindowSize.getValue() == 0)
|
||||
return runLegacyDcp(computeInstances, edges, entryOp->getContext());
|
||||
|
||||
VirtualGraph virtualGraph = buildInitialVirtualGraph(computeInstances, edges);
|
||||
size_t iteration = 0;
|
||||
bool debugCoarsening = isDcpCoarsenDebugEnabled();
|
||||
auto tryCoarsenSelectedNodes = [&](ArrayRef<size_t> selectedNodes) {
|
||||
size_t oldNodeCount = virtualGraph.nodes.size();
|
||||
WindowScheduleResult windowSchedule = scheduleWindow(virtualGraph, selectedNodes, entryOp->getContext());
|
||||
if (windowSchedule.mergeGroups.empty()) {
|
||||
if (debugCoarsening && oldNodeCount >= 200)
|
||||
llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} windowCpus={3} "
|
||||
"groups=0 mergedNodes=0 maxGroup=0 new={1} changed=0\n",
|
||||
iteration,
|
||||
oldNodeCount,
|
||||
selectedNodes.size(),
|
||||
windowSchedule.cpuCount);
|
||||
return false;
|
||||
}
|
||||
|
||||
VirtualGraph coarsenedGraph;
|
||||
std::vector<size_t> oldToNewNode;
|
||||
if (!coarsenGraph(virtualGraph, windowSchedule.mergeGroups, coarsenedGraph, oldToNewNode))
|
||||
return false;
|
||||
if (debugCoarsening && (oldNodeCount >= 200 || coarsenedGraph.nodes.size() >= 200))
|
||||
llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} windowCpus={3} "
|
||||
"groups={4} mergedNodes={5} maxGroup={6} new={7} changed={8}\n",
|
||||
iteration,
|
||||
oldNodeCount,
|
||||
selectedNodes.size(),
|
||||
windowSchedule.cpuCount,
|
||||
windowSchedule.mergeGroups.size(),
|
||||
windowSchedule.mergedNodeCount,
|
||||
windowSchedule.maxMergeGroupSize,
|
||||
coarsenedGraph.nodes.size(),
|
||||
oldNodeCount - coarsenedGraph.nodes.size());
|
||||
virtualGraph = std::move(coarsenedGraph);
|
||||
return true;
|
||||
};
|
||||
|
||||
while (virtualGraph.nodes.size() > 1) {
|
||||
if (virtualGraph.nodes.size() <= getSchedulingCpuBudget()) {
|
||||
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
|
||||
llvm::errs() << llvm::formatv(
|
||||
"[DCP-COARSEN] iter={0} old={1} stop=cpu-budget\n", iteration, virtualGraph.nodes.size());
|
||||
break;
|
||||
}
|
||||
|
||||
iteration++;
|
||||
TimingInfo timing = computeTiming(virtualGraph);
|
||||
if (!timing.valid) {
|
||||
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
|
||||
llvm::errs() << llvm::formatv(
|
||||
"[DCP-COARSEN] iter={0} old={1} invalid-timing\n", iteration, virtualGraph.nodes.size());
|
||||
break;
|
||||
}
|
||||
|
||||
SmallVector<size_t> selectedNodes;
|
||||
auto criticalWindow =
|
||||
selectCriticalWindow(virtualGraph, timing, getDcpCoarseningWindowSize(virtualGraph.nodes.size()));
|
||||
selectedNodes.append(criticalWindow.begin(), criticalWindow.end());
|
||||
|
||||
if (selectedNodes.size() < 2) {
|
||||
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
|
||||
llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} stop=small-window\n",
|
||||
iteration,
|
||||
virtualGraph.nodes.size(),
|
||||
selectedNodes.size());
|
||||
break;
|
||||
}
|
||||
|
||||
if (tryCoarsenSelectedNodes(selectedNodes))
|
||||
continue;
|
||||
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
|
||||
llvm::errs() << llvm::formatv(
|
||||
"[DCP-COARSEN] iter={0} old={1} stop=no-merge\n", iteration, virtualGraph.nodes.size());
|
||||
break;
|
||||
}
|
||||
|
||||
return buildResultFromVirtualGraph(virtualGraph, computeInstances);
|
||||
ComputeGraph graph = buildComputeGraph(entryOp);
|
||||
DcpScheduleOptions options;
|
||||
if (coresCount.getValue() > 0)
|
||||
options.processorCount = static_cast<size_t>(coresCount.getValue());
|
||||
options.criticalWindowSize = dcpCriticalWindowSize.getValue();
|
||||
options.allowFallbackForAutoCoreCount = true;
|
||||
return runDcpScheduler(graph, options, entryOp->getContext());
|
||||
}
|
||||
|
||||
} // namespace spatial
|
||||
|
||||
@@ -2,64 +2,27 @@
|
||||
|
||||
#include "mlir/IR/Operation.h"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
// A scheduling identity that covers both spat.compute and scheduled shards of
|
||||
// spat.compute_batch.
|
||||
struct ComputeInstance {
|
||||
mlir::Operation* op = nullptr;
|
||||
uint32_t laneStart = 0;
|
||||
uint32_t laneCount = 1;
|
||||
|
||||
bool operator==(const ComputeInstance& other) const {
|
||||
return op == other.op && laneStart == other.laneStart && laneCount == other.laneCount;
|
||||
}
|
||||
};
|
||||
|
||||
struct DCPAnalysisResult {
|
||||
std::vector<ComputeInstance> dominanceOrderCompute;
|
||||
llvm::DenseMap<ComputeInstance, size_t> computeToCpuMap;
|
||||
llvm::DenseMap<ComputeInstance, size_t> computeToCpuSlotMap;
|
||||
llvm::DenseMap<ComputeInstance, uint64_t> computeToAestMap;
|
||||
llvm::DenseSet<ComputeInstance> isLastComputeOfCpu;
|
||||
llvm::DenseMap<size_t, ComputeInstance> cpuToLastComputeMap;
|
||||
};
|
||||
#include "../Scheduling/MergeSchedule.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
using DCPAnalysisResult = MergeScheduleResult;
|
||||
|
||||
struct DCPAnalysis {
|
||||
private:
|
||||
DCPAnalysisResult result;
|
||||
mlir::Operation* entryOp;
|
||||
mlir::Operation *entryOp;
|
||||
DCPAnalysisResult run();
|
||||
|
||||
public:
|
||||
DCPAnalysis(mlir::Operation* op)
|
||||
DCPAnalysis(mlir::Operation *op)
|
||||
: entryOp(op) {
|
||||
result = run();
|
||||
}
|
||||
DCPAnalysisResult& getResult() { return result; }
|
||||
DCPAnalysisResult &getResult() { return result; }
|
||||
};
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
|
||||
namespace llvm {
|
||||
template <>
|
||||
struct DenseMapInfo<ComputeInstance> {
|
||||
static ComputeInstance getEmptyKey() {
|
||||
return {DenseMapInfo<mlir::Operation*>::getEmptyKey(), UINT32_MAX, UINT32_MAX};
|
||||
}
|
||||
static ComputeInstance getTombstoneKey() {
|
||||
return {DenseMapInfo<mlir::Operation*>::getTombstoneKey(), UINT32_MAX, UINT32_MAX};
|
||||
}
|
||||
static unsigned getHashValue(const ComputeInstance& v) { return llvm::hash_combine(v.op, v.laneStart, v.laneCount); }
|
||||
static bool isEqual(const ComputeInstance& a, const ComputeInstance& b) { return a == b; }
|
||||
};
|
||||
} // namespace llvm
|
||||
using DCPAnalysisResult = onnx_mlir::spatial::DCPAnalysisResult;
|
||||
|
||||
@@ -0,0 +1,636 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <optional>
|
||||
#include <utility>
|
||||
|
||||
#include "MaterializeMergeSchedule.hpp"
|
||||
#include "Scheduling/ComputeInstanceUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
namespace {
|
||||
|
||||
using SpatCompute = spatial::SpatCompute;
|
||||
using ProducerValueRef = spatial::ProducerValueRef;
|
||||
using spatial::getComputeInstanceInputs;
|
||||
using spatial::getComputeInstanceOutputTypes;
|
||||
using spatial::getComputeInstanceOutputValues;
|
||||
using spatial::getComputeInstanceTemplateBlock;
|
||||
using spatial::getComputeInstanceWeights;
|
||||
using spatial::getProducerValueRef;
|
||||
|
||||
class MergeScheduleMaterializerImpl {
|
||||
public:
|
||||
explicit MergeScheduleMaterializerImpl(func::FuncOp funcOp)
|
||||
: func(funcOp), loc(funcOp.getLoc()), returnOp(cast<func::ReturnOp>(funcOp.getBody().front().getTerminator())) {}
|
||||
|
||||
LogicalResult run(const MergeScheduleResult& scheduleResult, int64_t& nextChannelIdRef) {
|
||||
schedule = &scheduleResult;
|
||||
nextChannelId = &nextChannelIdRef;
|
||||
|
||||
collectScheduledTasks();
|
||||
buildTaskIndex();
|
||||
collectExternalInputsAndWeights();
|
||||
planRemoteChannels();
|
||||
planReceiveReordering();
|
||||
createCpuComputeOps();
|
||||
if (failed(cloneTaskBodies()))
|
||||
return failure();
|
||||
replaceExternalUses();
|
||||
if (failed(eraseOldScheduledOps()))
|
||||
return failure();
|
||||
moveExternalUsersBeforeReturn();
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
struct ScheduledTask {
|
||||
ComputeInstance computeInstance;
|
||||
size_t cpu = 0;
|
||||
size_t orderWithinCpu = 0;
|
||||
};
|
||||
|
||||
struct ChannelInfo {
|
||||
int64_t channelId = -1;
|
||||
int32_t sourceCoreId = -1;
|
||||
int32_t targetCoreId = -1;
|
||||
};
|
||||
|
||||
struct CpuProgram {
|
||||
SpatCompute op;
|
||||
DenseMap<Value, Value> externalInputMap;
|
||||
DenseMap<Value, size_t> weightToIndex;
|
||||
};
|
||||
|
||||
struct RemoteSendInfo {
|
||||
ChannelInfo channelInfo;
|
||||
ComputeInstance consumer;
|
||||
size_t inputIndex = 0;
|
||||
size_t consumerOrder = 0;
|
||||
size_t sourceOrder = 0;
|
||||
};
|
||||
|
||||
struct RemoteReceiveEntry {
|
||||
ChannelInfo channelInfo;
|
||||
ComputeInstance consumer;
|
||||
size_t inputIndex = 0;
|
||||
size_t sourceOrder = 0;
|
||||
};
|
||||
|
||||
static uint64_t getRemoteSendPairKey(const ChannelInfo& channelInfo) {
|
||||
return (static_cast<uint64_t>(static_cast<uint32_t>(channelInfo.sourceCoreId)) << 32)
|
||||
| static_cast<uint32_t>(channelInfo.targetCoreId);
|
||||
}
|
||||
|
||||
void collectExternalUsers(Operation* op) {
|
||||
if (!externalUsersToMove.insert(op).second)
|
||||
return;
|
||||
for (Value result : op->getResults()) {
|
||||
for (Operation* user : result.getUsers()) {
|
||||
if (oldComputeOps.contains(user) || isa<func::ReturnOp>(user))
|
||||
continue;
|
||||
collectExternalUsers(user);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void collectScheduledTasks() {
|
||||
for (ComputeInstance scheduledInstance : schedule->dominanceOrderCompute) {
|
||||
oldComputeOps.insert(scheduledInstance.op);
|
||||
scheduledTasks.push_back({scheduledInstance,
|
||||
schedule->computeToCpuMap.lookup(scheduledInstance),
|
||||
schedule->computeToCpuSlotMap.lookup(scheduledInstance)});
|
||||
}
|
||||
}
|
||||
|
||||
void buildTaskIndex() {
|
||||
auto markCpuSeen = [&](size_t cpu) {
|
||||
if (seenCpus.insert(cpu).second)
|
||||
orderedCpus.push_back(cpu);
|
||||
};
|
||||
|
||||
for (const ScheduledTask& task : scheduledTasks) {
|
||||
taskByComputeInstance[task.computeInstance] = task;
|
||||
tasksByCpu[task.cpu].push_back(task);
|
||||
markCpuSeen(task.cpu);
|
||||
}
|
||||
|
||||
llvm::sort(orderedCpus);
|
||||
for (size_t cpu : orderedCpus)
|
||||
llvm::stable_sort(tasksByCpu[cpu], [&](const ScheduledTask& lhs, const ScheduledTask& rhs) {
|
||||
return lhs.orderWithinCpu < rhs.orderWithinCpu;
|
||||
});
|
||||
}
|
||||
|
||||
void collectExternalInputsAndWeights() {
|
||||
for (size_t cpu : orderedCpus) {
|
||||
for (const ScheduledTask& task : tasksByCpu[cpu]) {
|
||||
auto& thisCpuWeights = cpuWeights[cpu];
|
||||
auto& thisSeenWeights = seenWeightsByCpu[cpu];
|
||||
auto taskWeights = getComputeInstanceWeights(task.computeInstance);
|
||||
for (Value weight : taskWeights)
|
||||
if (thisSeenWeights.insert(weight).second)
|
||||
thisCpuWeights.push_back(weight);
|
||||
|
||||
auto taskInputs = getComputeInstanceInputs(task.computeInstance);
|
||||
auto& remoteInputs = remoteInputsByTask[task.computeInstance];
|
||||
remoteInputs.resize(taskInputs.size());
|
||||
for (auto [inputIndex, input] : llvm::enumerate(taskInputs)) {
|
||||
auto producerRef = getProducerValueRef(input);
|
||||
if (producerRef) {
|
||||
auto producerIt = taskByComputeInstance.find(producerRef->instance);
|
||||
if (producerIt != taskByComputeInstance.end()) {
|
||||
if (producerIt->second.cpu != cpu) {
|
||||
ChannelInfo info {
|
||||
(*nextChannelId)++,
|
||||
static_cast<int32_t>(producerIt->second.cpu),
|
||||
static_cast<int32_t>(cpu),
|
||||
};
|
||||
remoteInputs[inputIndex] = info;
|
||||
auto& perResultChannels = remoteSendsByTask[producerRef->instance];
|
||||
if (perResultChannels.empty())
|
||||
perResultChannels.resize(getComputeInstanceOutputTypes(producerIt->second.computeInstance).size());
|
||||
perResultChannels[producerRef->resultIndex].push_back(
|
||||
{info, task.computeInstance, inputIndex, task.orderWithinCpu, 0});
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if (seenExternalInputsByCpu[cpu].insert(input).second)
|
||||
cpuExternalInputs[cpu].push_back(input);
|
||||
}
|
||||
|
||||
auto taskOutputs = getComputeInstanceOutputValues(task.computeInstance);
|
||||
for (auto [resultIndex, output] : llvm::enumerate(taskOutputs)) {
|
||||
bool hasExternalUser = false;
|
||||
for (auto& use : output.getUses()) {
|
||||
Operation* useOwner = use.getOwner();
|
||||
if (oldComputeOps.contains(useOwner))
|
||||
continue;
|
||||
hasExternalUser = true;
|
||||
if (!isa<func::ReturnOp>(useOwner))
|
||||
collectExternalUsers(useOwner);
|
||||
}
|
||||
if (hasExternalUser)
|
||||
cpuExternalOutputs[cpu].push_back({task.computeInstance, resultIndex});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void planRemoteChannels() {
|
||||
for (size_t cpu : orderedCpus) {
|
||||
DenseMap<uint64_t, size_t> nextSourceOrderByPair;
|
||||
DenseMap<uint64_t, size_t> lastConsumerOrderByPair;
|
||||
for (const ScheduledTask& task : tasksByCpu[cpu]) {
|
||||
auto sendsIt = remoteSendsByTask.find(task.computeInstance);
|
||||
if (sendsIt == remoteSendsByTask.end())
|
||||
continue;
|
||||
for (auto& sendInfos : sendsIt->second) {
|
||||
for (RemoteSendInfo& sendInfo : sendInfos) {
|
||||
uint64_t pairKey = getRemoteSendPairKey(sendInfo.channelInfo);
|
||||
sendInfo.sourceOrder = nextSourceOrderByPair[pairKey]++;
|
||||
auto [it, inserted] = lastConsumerOrderByPair.try_emplace(pairKey, sendInfo.consumerOrder);
|
||||
if (!inserted) {
|
||||
if (sendInfo.consumerOrder < it->second)
|
||||
pairsNeedingReceiveReorder.insert(pairKey);
|
||||
it->second = sendInfo.consumerOrder;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void planReceiveReordering() {
|
||||
DenseMap<uint64_t, SmallVector<RemoteSendInfo*>> reorderedSendsByPair;
|
||||
for (auto& taskSends : remoteSendsByTask) {
|
||||
for (auto& sendInfos : taskSends.second) {
|
||||
for (RemoteSendInfo& sendInfo : sendInfos) {
|
||||
uint64_t pairKey = getRemoteSendPairKey(sendInfo.channelInfo);
|
||||
if (pairsNeedingReceiveReorder.contains(pairKey))
|
||||
reorderedSendsByPair[pairKey].push_back(&sendInfo);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& pairSends : reorderedSendsByPair) {
|
||||
llvm::stable_sort(pairSends.second, [](const RemoteSendInfo* lhs, const RemoteSendInfo* rhs) {
|
||||
if (lhs->sourceOrder != rhs->sourceOrder)
|
||||
return lhs->sourceOrder < rhs->sourceOrder;
|
||||
return lhs->channelInfo.channelId < rhs->channelInfo.channelId;
|
||||
});
|
||||
for (RemoteSendInfo* sendInfo : pairSends.second) {
|
||||
int64_t channelId = (*nextChannelId)++;
|
||||
sendInfo->channelInfo.channelId = channelId;
|
||||
auto remoteInputsIt = remoteInputsByTask.find(sendInfo->consumer);
|
||||
assert(remoteInputsIt != remoteInputsByTask.end() && "missing remote input for reordered send");
|
||||
assert(sendInfo->inputIndex < remoteInputsIt->second.size() && "remote input index out of range");
|
||||
assert(remoteInputsIt->second[sendInfo->inputIndex] && "missing reordered remote input channel");
|
||||
remoteInputsIt->second[sendInfo->inputIndex]->channelId = channelId;
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto& taskSends : remoteSendsByTask) {
|
||||
for (const auto& sendInfos : taskSends.second) {
|
||||
for (const RemoteSendInfo& sendInfo : sendInfos) {
|
||||
auto remoteInputsIt = remoteInputsByTask.find(sendInfo.consumer);
|
||||
assert(remoteInputsIt != remoteInputsByTask.end() && "missing remote input for send");
|
||||
assert(sendInfo.inputIndex < remoteInputsIt->second.size() && "remote input index out of range");
|
||||
assert(remoteInputsIt->second[sendInfo.inputIndex] && "missing remote input channel");
|
||||
remoteInputsIt->second[sendInfo.inputIndex] = sendInfo.channelInfo;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& taskSends : remoteSendsByTask) {
|
||||
for (const auto& sendInfos : taskSends.second) {
|
||||
for (const RemoteSendInfo& sendInfo : sendInfos) {
|
||||
uint64_t pairKey = getRemoteSendPairKey(sendInfo.channelInfo);
|
||||
if (!pairsNeedingReceiveReorder.contains(pairKey))
|
||||
continue;
|
||||
size_t targetCpu = static_cast<size_t>(sendInfo.channelInfo.targetCoreId);
|
||||
receiveQueuesByCpu[targetCpu][pairKey].push_back(
|
||||
{sendInfo.channelInfo, sendInfo.consumer, sendInfo.inputIndex, sendInfo.sourceOrder});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& cpuQueues : receiveQueuesByCpu) {
|
||||
for (auto& pairQueue : cpuQueues.second) {
|
||||
llvm::stable_sort(pairQueue.second, [](const RemoteReceiveEntry& lhs, const RemoteReceiveEntry& rhs) {
|
||||
if (lhs.sourceOrder != rhs.sourceOrder)
|
||||
return lhs.sourceOrder < rhs.sourceOrder;
|
||||
return lhs.channelInfo.channelId < rhs.channelInfo.channelId;
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void createCpuComputeOps() {
|
||||
IRRewriter rewriter(func.getContext());
|
||||
for (size_t cpu : orderedCpus) {
|
||||
SmallVector<Value> operands;
|
||||
operands.reserve(cpuWeights[cpu].size() + cpuExternalInputs[cpu].size());
|
||||
llvm::append_range(operands, cpuWeights[cpu]);
|
||||
llvm::append_range(operands, cpuExternalInputs[cpu]);
|
||||
|
||||
SmallVector<Type> resultTypes;
|
||||
resultTypes.reserve(cpuExternalOutputs[cpu].size());
|
||||
for (ProducerValueRef outputRef : cpuExternalOutputs[cpu]) {
|
||||
ScheduledTask task = taskByComputeInstance.at(outputRef.instance);
|
||||
resultTypes.push_back(getComputeInstanceOutputTypes(task.computeInstance)[outputRef.resultIndex]);
|
||||
}
|
||||
|
||||
rewriter.setInsertionPoint(returnOp);
|
||||
auto newCompute = SpatCompute::create(rewriter, loc, TypeRange(resultTypes), ValueRange(operands));
|
||||
newCompute.getProperties().setOperandSegmentSizes(
|
||||
{static_cast<int>(cpuWeights[cpu].size()), static_cast<int>(cpuExternalInputs[cpu].size())});
|
||||
newCompute->setAttr(onnx_mlir::kCoreIdAttrName, rewriter.getI32IntegerAttr(static_cast<int32_t>(cpu)));
|
||||
|
||||
SmallVector<Type> blockArgTypes;
|
||||
SmallVector<Location> blockArgLocs;
|
||||
blockArgTypes.reserve(cpuExternalInputs[cpu].size());
|
||||
blockArgLocs.reserve(cpuExternalInputs[cpu].size());
|
||||
for (Value input : cpuExternalInputs[cpu]) {
|
||||
blockArgTypes.push_back(input.getType());
|
||||
blockArgLocs.push_back(loc);
|
||||
}
|
||||
Block* newBlock =
|
||||
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
||||
|
||||
CpuProgram program;
|
||||
program.op = newCompute;
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(cpuWeights[cpu]))
|
||||
program.weightToIndex[weight] = weightIndex;
|
||||
for (auto [inputIndex, input] : llvm::enumerate(cpuExternalInputs[cpu]))
|
||||
program.externalInputMap[input] = newBlock->getArgument(inputIndex);
|
||||
for (auto [resultIndex, outputRef] : llvm::enumerate(cpuExternalOutputs[cpu])) {
|
||||
ScheduledTask task = taskByComputeInstance.at(outputRef.instance);
|
||||
oldToNewExternalValueMap[getComputeInstanceOutputValues(task.computeInstance)[outputRef.resultIndex]] =
|
||||
newCompute.getResult(resultIndex);
|
||||
}
|
||||
cpuPrograms[cpu] = std::move(program);
|
||||
}
|
||||
}
|
||||
|
||||
FailureOr<Value> receiveThroughInput(IRRewriter& rewriter,
|
||||
size_t cpu,
|
||||
DenseMap<uint64_t, size_t>& receiveQueueIndices,
|
||||
DenseMap<ComputeInstance, SmallVector<Value>>& preReceivedInputsByTask,
|
||||
const ChannelInfo& requestedChannelInfo,
|
||||
ComputeInstance requestedConsumer,
|
||||
size_t requestedInputIndex) {
|
||||
uint64_t pairKey = getRemoteSendPairKey(requestedChannelInfo);
|
||||
auto cpuQueuesIt = receiveQueuesByCpu.find(cpu);
|
||||
if (cpuQueuesIt == receiveQueuesByCpu.end())
|
||||
return failure();
|
||||
auto queueIt = cpuQueuesIt->second.find(pairKey);
|
||||
if (queueIt == cpuQueuesIt->second.end())
|
||||
return failure();
|
||||
|
||||
auto& queue = queueIt->second;
|
||||
size_t& queueIndex = receiveQueueIndices[pairKey];
|
||||
while (queueIndex < queue.size()) {
|
||||
const RemoteReceiveEntry& entry = queue[queueIndex++];
|
||||
auto consumerTaskIt = taskByComputeInstance.find(entry.consumer);
|
||||
if (consumerTaskIt == taskByComputeInstance.end())
|
||||
return failure();
|
||||
SmallVector<Value> consumerInputs = getComputeInstanceInputs(consumerTaskIt->second.computeInstance);
|
||||
if (consumerInputs.size() <= entry.inputIndex)
|
||||
return failure();
|
||||
Type inputType = consumerInputs[entry.inputIndex].getType();
|
||||
auto receive = spatial::SpatChannelReceiveOp::create(rewriter,
|
||||
loc,
|
||||
inputType,
|
||||
rewriter.getI64IntegerAttr(entry.channelInfo.channelId),
|
||||
rewriter.getI32IntegerAttr(entry.channelInfo.sourceCoreId),
|
||||
rewriter.getI32IntegerAttr(entry.channelInfo.targetCoreId));
|
||||
|
||||
auto& receivedInputs = preReceivedInputsByTask[entry.consumer];
|
||||
if (receivedInputs.size() <= entry.inputIndex)
|
||||
receivedInputs.resize(entry.inputIndex + 1);
|
||||
receivedInputs[entry.inputIndex] = receive.getResult();
|
||||
|
||||
if (entry.consumer == requestedConsumer && entry.inputIndex == requestedInputIndex)
|
||||
return receive.getResult();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
LogicalResult cloneTaskBodies() {
|
||||
for (size_t cpu : orderedCpus) {
|
||||
CpuProgram& program = cpuPrograms[cpu];
|
||||
IRRewriter rewriter(func.getContext());
|
||||
rewriter.setInsertionPointToEnd(&program.op.getBody().front());
|
||||
DenseMap<uint64_t, size_t> receiveQueueIndices;
|
||||
DenseMap<ComputeInstance, SmallVector<Value>> preReceivedInputsByTask;
|
||||
|
||||
auto lookupPreReceivedInput = [&](ComputeInstance consumer, size_t inputIndex) -> std::optional<Value> {
|
||||
auto inputsIt = preReceivedInputsByTask.find(consumer);
|
||||
if (inputsIt == preReceivedInputsByTask.end() || inputsIt->second.size() <= inputIndex)
|
||||
return std::nullopt;
|
||||
Value value = inputsIt->second[inputIndex];
|
||||
if (!value)
|
||||
return std::nullopt;
|
||||
return value;
|
||||
};
|
||||
|
||||
for (const ScheduledTask& task : tasksByCpu[cpu]) {
|
||||
SmallVector<Value> taskInputs = getComputeInstanceInputs(task.computeInstance);
|
||||
auto taskWeights = getComputeInstanceWeights(task.computeInstance);
|
||||
Block& templateBlock = getComputeInstanceTemplateBlock(task.computeInstance);
|
||||
|
||||
SmallVector<Value> resolvedInputs;
|
||||
resolvedInputs.reserve(taskInputs.size());
|
||||
auto remoteInputsIt = remoteInputsByTask.find(task.computeInstance);
|
||||
for (auto [inputIndex, input] : llvm::enumerate(taskInputs)) {
|
||||
auto producerRef = getProducerValueRef(input);
|
||||
if (producerRef) {
|
||||
auto producerIt = taskByComputeInstance.find(producerRef->instance);
|
||||
if (producerIt != taskByComputeInstance.end()) {
|
||||
if (producerIt->second.cpu == cpu) {
|
||||
auto producedIt = producedValuesByTask.find(producerRef->instance);
|
||||
if (producedIt == producedValuesByTask.end() || producedIt->second.size() <= producerRef->resultIndex) {
|
||||
task.computeInstance.op->emitOpError("missing local producer value during per-cpu merge materialization")
|
||||
<< " consumerCpu=" << cpu << " producerCpu=" << producerIt->second.cpu
|
||||
<< " producerLaneStart=" << producerRef->instance.laneStart
|
||||
<< " producerLaneCount=" << producerRef->instance.laneCount;
|
||||
return failure();
|
||||
}
|
||||
resolvedInputs.push_back(producedIt->second[producerRef->resultIndex]);
|
||||
continue;
|
||||
}
|
||||
const ChannelInfo& channelInfo = *remoteInputsIt->second[inputIndex];
|
||||
uint64_t pairKey = getRemoteSendPairKey(channelInfo);
|
||||
if (pairsNeedingReceiveReorder.contains(pairKey)) {
|
||||
if (std::optional<Value> preReceived = lookupPreReceivedInput(task.computeInstance, inputIndex)) {
|
||||
resolvedInputs.push_back(*preReceived);
|
||||
continue;
|
||||
}
|
||||
FailureOr<Value> received = receiveThroughInput(rewriter,
|
||||
cpu,
|
||||
receiveQueueIndices,
|
||||
preReceivedInputsByTask,
|
||||
channelInfo,
|
||||
task.computeInstance,
|
||||
inputIndex);
|
||||
if (failed(received)) {
|
||||
task.computeInstance.op->emitOpError("failed to materialize reordered remote receive")
|
||||
<< " consumerCpu=" << cpu << " sourceCoreId=" << channelInfo.sourceCoreId
|
||||
<< " targetCoreId=" << channelInfo.targetCoreId << " channelId=" << channelInfo.channelId;
|
||||
return failure();
|
||||
}
|
||||
resolvedInputs.push_back(*received);
|
||||
continue;
|
||||
}
|
||||
auto receive =
|
||||
spatial::SpatChannelReceiveOp::create(rewriter,
|
||||
loc,
|
||||
input.getType(),
|
||||
rewriter.getI64IntegerAttr(channelInfo.channelId),
|
||||
rewriter.getI32IntegerAttr(channelInfo.sourceCoreId),
|
||||
rewriter.getI32IntegerAttr(channelInfo.targetCoreId));
|
||||
resolvedInputs.push_back(receive.getResult());
|
||||
continue;
|
||||
}
|
||||
}
|
||||
resolvedInputs.push_back(program.externalInputMap.at(input));
|
||||
}
|
||||
|
||||
SmallVector<Value> taskYieldValues;
|
||||
rewriter.setInsertionPointToEnd(&program.op.getBody().front());
|
||||
if (isa<SpatCompute>(task.computeInstance.op)) {
|
||||
IRMapping mapper;
|
||||
for (auto [argIndex, oldArg] : llvm::enumerate(templateBlock.getArguments()))
|
||||
mapper.map(oldArg, resolvedInputs[argIndex]);
|
||||
|
||||
for (Operation& op : templateBlock) {
|
||||
if (auto yield = dyn_cast<spatial::SpatYieldOp>(&op)) {
|
||||
for (Value yieldOperand : yield.getOperands())
|
||||
taskYieldValues.push_back(mapper.lookup(yieldOperand));
|
||||
continue;
|
||||
}
|
||||
|
||||
Operation* clonedOp = rewriter.clone(op, mapper);
|
||||
if (auto oldWeightedMvmOp = dyn_cast<spatial::SpatMVMOp>(&op)) {
|
||||
auto newWeightedMvmOp = cast<spatial::SpatMVMOp>(clonedOp);
|
||||
Value weight = taskWeights[oldWeightedMvmOp.getWeightIndex()];
|
||||
newWeightedMvmOp.setWeightIndex(program.weightToIndex.at(weight));
|
||||
}
|
||||
if (auto oldWeightedVmmOp = dyn_cast<spatial::SpatVMMOp>(&op)) {
|
||||
auto newWeightedVmmOp = cast<spatial::SpatVMMOp>(clonedOp);
|
||||
Value weight = taskWeights[oldWeightedVmmOp.getWeightIndex()];
|
||||
newWeightedVmmOp.setWeightIndex(program.weightToIndex.at(weight));
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
for (size_t laneOffset = 0; laneOffset < task.computeInstance.laneCount; ++laneOffset) {
|
||||
IRMapping mapper;
|
||||
if (templateBlock.getNumArguments() == 1)
|
||||
mapper.map(templateBlock.getArgument(0), resolvedInputs[laneOffset]);
|
||||
|
||||
for (Operation& op : templateBlock) {
|
||||
if (auto yield = dyn_cast<spatial::SpatYieldOp>(&op)) {
|
||||
for (Value yieldOperand : yield.getOperands())
|
||||
taskYieldValues.push_back(mapper.lookup(yieldOperand));
|
||||
continue;
|
||||
}
|
||||
|
||||
Operation* clonedOp = rewriter.clone(op, mapper);
|
||||
if (auto oldWeightedMvmOp = dyn_cast<spatial::SpatMVMOp>(&op)) {
|
||||
if (oldWeightedMvmOp.getWeightIndex() != 0) {
|
||||
task.computeInstance.op->emitOpError(
|
||||
"batched per-cpu merge materialization expects lane-local weight index 0");
|
||||
return failure();
|
||||
}
|
||||
auto newWeightedMvmOp = cast<spatial::SpatMVMOp>(clonedOp);
|
||||
newWeightedMvmOp.setWeightIndex(program.weightToIndex.at(taskWeights[laneOffset]));
|
||||
}
|
||||
if (auto oldWeightedVmmOp = dyn_cast<spatial::SpatVMMOp>(&op)) {
|
||||
if (oldWeightedVmmOp.getWeightIndex() != 0) {
|
||||
task.computeInstance.op->emitOpError(
|
||||
"batched per-cpu merge materialization expects lane-local weight index 0");
|
||||
return failure();
|
||||
}
|
||||
auto newWeightedVmmOp = cast<spatial::SpatVMMOp>(clonedOp);
|
||||
newWeightedVmmOp.setWeightIndex(program.weightToIndex.at(taskWeights[laneOffset]));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
producedValuesByTask[task.computeInstance] = taskYieldValues;
|
||||
if (auto sendsIt = remoteSendsByTask.find(task.computeInstance); sendsIt != remoteSendsByTask.end()) {
|
||||
for (auto [resultIndex, sendInfos] : llvm::enumerate(sendsIt->second)) {
|
||||
if (sendInfos.empty())
|
||||
continue;
|
||||
Value producedValue = taskYieldValues[resultIndex];
|
||||
for (const RemoteSendInfo& sendInfo : sendInfos) {
|
||||
spatial::SpatChannelSendOp::create(rewriter,
|
||||
loc,
|
||||
rewriter.getI64IntegerAttr(sendInfo.channelInfo.channelId),
|
||||
rewriter.getI32IntegerAttr(sendInfo.channelInfo.sourceCoreId),
|
||||
rewriter.getI32IntegerAttr(sendInfo.channelInfo.targetCoreId),
|
||||
producedValue);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<Value> yieldValues;
|
||||
yieldValues.reserve(cpuExternalOutputs[cpu].size());
|
||||
for (ProducerValueRef outputRef : cpuExternalOutputs[cpu]) {
|
||||
auto producedIt = producedValuesByTask.find(outputRef.instance);
|
||||
if (producedIt == producedValuesByTask.end() || producedIt->second.size() <= outputRef.resultIndex) {
|
||||
ScheduledTask task = taskByComputeInstance.at(outputRef.instance);
|
||||
task.computeInstance.op->emitOpError("missing yielded external value during per-cpu merge materialization")
|
||||
<< " cpu=" << cpu << " laneStart=" << outputRef.instance.laneStart;
|
||||
return failure();
|
||||
}
|
||||
yieldValues.push_back(producedIt->second[outputRef.resultIndex]);
|
||||
}
|
||||
spatial::SpatYieldOp::create(rewriter, loc, ValueRange(yieldValues));
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
void replaceExternalUses() {
|
||||
for (auto [oldValue, newValue] : oldToNewExternalValueMap) {
|
||||
for (auto& use : llvm::make_early_inc_range(oldValue.getUses()))
|
||||
if (!oldComputeOps.contains(use.getOwner()))
|
||||
use.assign(newValue);
|
||||
}
|
||||
}
|
||||
|
||||
LogicalResult eraseOldScheduledOps() {
|
||||
SmallVector<Operation*> orderedOpsToErase;
|
||||
for (Operation& op : func.getBody().front())
|
||||
if (oldComputeOps.contains(&op))
|
||||
orderedOpsToErase.push_back(&op);
|
||||
|
||||
for (Operation* op : llvm::reverse(orderedOpsToErase)) {
|
||||
SmallVector<Operation*> remainingUsers;
|
||||
for (Value result : op->getResults())
|
||||
for (Operation* user : result.getUsers())
|
||||
remainingUsers.push_back(user);
|
||||
if (!remainingUsers.empty()) {
|
||||
InFlightDiagnostic diagnostic = op->emitOpError("still has uses during per-cpu merge cleanup")
|
||||
<< "; erase-set=" << (oldComputeOps.contains(op) ? "yes" : "no");
|
||||
for (Operation* user : remainingUsers) {
|
||||
diagnostic.attachNote(user->getLoc())
|
||||
<< "remaining user " << user->getName() << "; erase-set=" << (oldComputeOps.contains(user) ? "yes" : "no");
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
op->erase();
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
void moveExternalUsersBeforeReturn() {
|
||||
SmallVector<Operation*> orderedUsersToMove;
|
||||
for (Operation& op : func.getBody().front()) {
|
||||
if (&op == returnOp.getOperation())
|
||||
break;
|
||||
if (externalUsersToMove.contains(&op))
|
||||
orderedUsersToMove.push_back(&op);
|
||||
}
|
||||
for (Operation* op : orderedUsersToMove)
|
||||
op->moveBefore(returnOp);
|
||||
}
|
||||
|
||||
func::FuncOp func;
|
||||
const MergeScheduleResult* schedule = nullptr;
|
||||
int64_t* nextChannelId = nullptr;
|
||||
Location loc;
|
||||
func::ReturnOp returnOp;
|
||||
|
||||
SmallVector<ScheduledTask> scheduledTasks;
|
||||
DenseSet<Operation*> oldComputeOps;
|
||||
DenseMap<ComputeInstance, ScheduledTask> taskByComputeInstance;
|
||||
DenseMap<size_t, SmallVector<ScheduledTask>> tasksByCpu;
|
||||
SmallVector<size_t> orderedCpus;
|
||||
DenseSet<size_t> seenCpus;
|
||||
DenseSet<Operation*> externalUsersToMove;
|
||||
DenseMap<ComputeInstance, SmallVector<SmallVector<RemoteSendInfo>>> remoteSendsByTask;
|
||||
DenseMap<ComputeInstance, SmallVector<std::optional<ChannelInfo>>> remoteInputsByTask;
|
||||
DenseMap<size_t, SmallVector<Value>> cpuExternalInputs;
|
||||
DenseMap<size_t, SmallVector<Value>> cpuWeights;
|
||||
DenseMap<size_t, SmallVector<ProducerValueRef>> cpuExternalOutputs;
|
||||
DenseMap<size_t, DenseSet<Value>> seenExternalInputsByCpu;
|
||||
DenseMap<size_t, DenseSet<Value>> seenWeightsByCpu;
|
||||
DenseSet<uint64_t> pairsNeedingReceiveReorder;
|
||||
DenseMap<size_t, DenseMap<uint64_t, SmallVector<RemoteReceiveEntry>>> receiveQueuesByCpu;
|
||||
DenseMap<size_t, CpuProgram> cpuPrograms;
|
||||
DenseMap<Value, Value> oldToNewExternalValueMap;
|
||||
DenseMap<ComputeInstance, SmallVector<Value>> producedValuesByTask;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult
|
||||
MergeScheduleMaterializer::run(func::FuncOp func, const MergeScheduleResult& schedule, int64_t& nextChannelId) {
|
||||
return MergeScheduleMaterializerImpl(func).run(schedule, nextChannelId);
|
||||
}
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,18 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
|
||||
#include "Scheduling/MergeSchedule.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
class MergeScheduleMaterializer {
|
||||
public:
|
||||
mlir::LogicalResult
|
||||
run(mlir::func::FuncOp func, const MergeScheduleResult &schedule, int64_t &nextChannelId);
|
||||
};
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,459 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "llvm/ADT/Hashing.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include <chrono>
|
||||
#include <cstdlib>
|
||||
#include <limits>
|
||||
#include <optional>
|
||||
|
||||
#include "PostMergeCompaction.hpp"
|
||||
#include "RegularOpCompaction.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
using SpatCompute = spatial::SpatCompute;
|
||||
using SpatComputeBatch = spatial::SpatComputeBatch;
|
||||
|
||||
bool isMergeProfilingEnabled() { return std::getenv("RAPTOR_PROFILE_MERGE") != nullptr; }
|
||||
|
||||
class ScopedMergePhaseTimer {
|
||||
public:
|
||||
explicit ScopedMergePhaseTimer(StringRef phaseName)
|
||||
: enabled(isMergeProfilingEnabled()), phase(phaseName.str()) {
|
||||
if (enabled)
|
||||
start = std::chrono::steady_clock::now();
|
||||
}
|
||||
|
||||
~ScopedMergePhaseTimer() {
|
||||
if (!enabled)
|
||||
return;
|
||||
auto elapsed = std::chrono::steady_clock::now() - start;
|
||||
double millis = std::chrono::duration<double, std::milli>(elapsed).count();
|
||||
llvm::errs() << "[merge-profile] " << phase << ": " << llvm::formatv("{0:F3}", millis) << " ms\n";
|
||||
}
|
||||
|
||||
private:
|
||||
bool enabled = false;
|
||||
std::string phase;
|
||||
std::chrono::steady_clock::time_point start;
|
||||
};
|
||||
|
||||
std::optional<int32_t> getComputeCoreId(SpatCompute compute) {
|
||||
if (auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
|
||||
return static_cast<int32_t>(coreIdAttr.getInt());
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
static constexpr StringLiteral kRebatchPhaseAttrName = "_pim_rebatch_phase";
|
||||
|
||||
std::optional<uint64_t> getComputeRebatchPhase(SpatCompute compute) {
|
||||
if (auto phaseAttr = compute->getAttrOfType<IntegerAttr>(kRebatchPhaseAttrName))
|
||||
return static_cast<uint64_t>(phaseAttr.getInt());
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
struct RebatchKey {
|
||||
unsigned inputCount = 0;
|
||||
unsigned resultCount = 0;
|
||||
unsigned weightCount = 0;
|
||||
uint64_t phase = 0;
|
||||
bool hasPhase = false;
|
||||
uint64_t structureHash = 0;
|
||||
|
||||
bool operator==(const RebatchKey& other) const {
|
||||
return inputCount == other.inputCount && resultCount == other.resultCount && weightCount == other.weightCount
|
||||
&& phase == other.phase && hasPhase == other.hasPhase && structureHash == other.structureHash;
|
||||
}
|
||||
};
|
||||
|
||||
struct RebatchKeyInfo {
|
||||
static inline RebatchKey getEmptyKey() { return {std::numeric_limits<unsigned>::max(), 0, 0, 0, false, 0}; }
|
||||
|
||||
static inline RebatchKey getTombstoneKey() { return {std::numeric_limits<unsigned>::max() - 1, 0, 0, 0, false, 0}; }
|
||||
|
||||
static unsigned getHashValue(const RebatchKey& key) {
|
||||
return static_cast<unsigned>(
|
||||
llvm::hash_combine(key.inputCount, key.resultCount, key.weightCount, key.phase, key.hasPhase, key.structureHash));
|
||||
}
|
||||
|
||||
static bool isEqual(const RebatchKey& lhs, const RebatchKey& rhs) { return lhs == rhs; }
|
||||
};
|
||||
|
||||
uint64_t getTypeHash(Type type) { return reinterpret_cast<uintptr_t>(type.getAsOpaquePointer()); }
|
||||
|
||||
uint64_t getValueHash(Value value) { return reinterpret_cast<uintptr_t>(value.getAsOpaquePointer()); }
|
||||
|
||||
uint64_t getAttributeHash(Attribute attr) { return reinterpret_cast<uintptr_t>(attr.getAsOpaquePointer()); }
|
||||
|
||||
RebatchKey computeRebatchKey(SpatCompute compute) {
|
||||
llvm::hash_code structureHash =
|
||||
llvm::hash_combine(compute.getInputs().size(), compute.getResultTypes().size(), compute.getWeights().size());
|
||||
|
||||
for (Value weight : compute.getWeights())
|
||||
structureHash = llvm::hash_combine(structureHash, getValueHash(weight));
|
||||
if (std::optional<uint64_t> phase = getComputeRebatchPhase(compute))
|
||||
structureHash = llvm::hash_combine(structureHash, *phase);
|
||||
|
||||
Block& body = compute.getBody().front();
|
||||
structureHash = llvm::hash_combine(structureHash, body.getNumArguments());
|
||||
for (BlockArgument arg : body.getArguments())
|
||||
structureHash = llvm::hash_combine(structureHash, getTypeHash(arg.getType()));
|
||||
|
||||
for (Operation& op : body) {
|
||||
structureHash = llvm::hash_combine(
|
||||
structureHash, op.getName().getStringRef(), op.getNumOperands(), op.getNumResults(), op.getNumRegions());
|
||||
for (Type type : op.getResultTypes())
|
||||
structureHash = llvm::hash_combine(structureHash, getTypeHash(type));
|
||||
for (NamedAttribute attr : op.getAttrs())
|
||||
structureHash = llvm::hash_combine(structureHash, attr.getName().strref(), getAttributeHash(attr.getValue()));
|
||||
}
|
||||
|
||||
std::optional<uint64_t> phase = getComputeRebatchPhase(compute);
|
||||
return {static_cast<unsigned>(compute.getInputs().size()),
|
||||
static_cast<unsigned>(compute.getResultTypes().size()),
|
||||
static_cast<unsigned>(compute.getWeights().size()),
|
||||
phase.value_or(0),
|
||||
phase.has_value(),
|
||||
static_cast<uint64_t>(structureHash)};
|
||||
}
|
||||
|
||||
bool areEquivalentForRebatch(SpatCompute lhs, SpatCompute rhs) {
|
||||
if (!lhs || !rhs)
|
||||
return false;
|
||||
if (lhs.getInputs().size() != rhs.getInputs().size())
|
||||
return false;
|
||||
if (lhs.getResultTypes() != rhs.getResultTypes())
|
||||
return false;
|
||||
if (lhs.getWeights().size() != rhs.getWeights().size())
|
||||
return false;
|
||||
if (getComputeRebatchPhase(lhs) != getComputeRebatchPhase(rhs))
|
||||
return false;
|
||||
if (!llvm::equal(lhs.getWeights(), rhs.getWeights()))
|
||||
return false;
|
||||
|
||||
auto& lhsBlock = lhs.getBody().front();
|
||||
auto& rhsBlock = rhs.getBody().front();
|
||||
if (lhsBlock.getNumArguments() != rhsBlock.getNumArguments())
|
||||
return false;
|
||||
|
||||
DenseMap<Value, Value> mappedValues;
|
||||
for (auto [lhsArg, rhsArg] : llvm::zip(lhsBlock.getArguments(), rhsBlock.getArguments())) {
|
||||
if (lhsArg.getType() != rhsArg.getType())
|
||||
return false;
|
||||
mappedValues[lhsArg] = rhsArg;
|
||||
}
|
||||
auto lhsIt = lhsBlock.begin();
|
||||
auto rhsIt = rhsBlock.begin();
|
||||
for (; lhsIt != lhsBlock.end() && rhsIt != rhsBlock.end(); ++lhsIt, ++rhsIt) {
|
||||
Operation& lhsOp = *lhsIt;
|
||||
Operation& rhsOp = *rhsIt;
|
||||
|
||||
if (lhsOp.getName() != rhsOp.getName())
|
||||
return false;
|
||||
if (lhsOp.getNumOperands() != rhsOp.getNumOperands())
|
||||
return false;
|
||||
if (lhsOp.getNumResults() != rhsOp.getNumResults())
|
||||
return false;
|
||||
if (lhsOp.getNumRegions() != 0 || rhsOp.getNumRegions() != 0)
|
||||
return false;
|
||||
|
||||
for (auto [lhsOperand, rhsOperand] : llvm::zip(lhsOp.getOperands(), rhsOp.getOperands())) {
|
||||
auto mapped = mappedValues.find(lhsOperand);
|
||||
if (mapped != mappedValues.end()) {
|
||||
if (mapped->second != rhsOperand)
|
||||
return false;
|
||||
continue;
|
||||
}
|
||||
if (lhsOperand != rhsOperand)
|
||||
return false;
|
||||
}
|
||||
|
||||
if (auto lhsReceive = dyn_cast<spatial::SpatChannelReceiveOp>(lhsOp)) {
|
||||
auto rhsReceive = cast<spatial::SpatChannelReceiveOp>(rhsOp);
|
||||
if (lhsReceive.getOutput().getType() != rhsReceive.getOutput().getType())
|
||||
return false;
|
||||
}
|
||||
else if (auto lhsSend = dyn_cast<spatial::SpatChannelSendOp>(lhsOp)) {
|
||||
auto rhsSend = cast<spatial::SpatChannelSendOp>(rhsOp);
|
||||
if (lhsSend.getInput().getType() != rhsSend.getInput().getType())
|
||||
return false;
|
||||
}
|
||||
else if (lhsOp.getAttrs() != rhsOp.getAttrs()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (lhsOp.getResultTypes() != rhsOp.getResultTypes())
|
||||
return false;
|
||||
for (auto [lhsResult, rhsResult] : llvm::zip(lhsOp.getResults(), rhsOp.getResults()))
|
||||
mappedValues[lhsResult] = rhsResult;
|
||||
}
|
||||
|
||||
return lhsIt == lhsBlock.end() && rhsIt == rhsBlock.end();
|
||||
}
|
||||
|
||||
void rebatchEquivalentComputes(func::FuncOp funcOp) {
|
||||
IRRewriter rewriter(funcOp.getContext());
|
||||
SmallVector<SpatCompute> computes(funcOp.getOps<SpatCompute>());
|
||||
DenseSet<Operation*> consumed;
|
||||
DenseMap<Operation*, size_t> computeOrder;
|
||||
DenseMap<RebatchKey, SmallVector<SpatCompute>, RebatchKeyInfo> candidatesByKey;
|
||||
|
||||
for (auto [index, compute] : llvm::enumerate(computes)) {
|
||||
computeOrder[compute.getOperation()] = index;
|
||||
if (compute.getInputs().size() <= 1 && compute.getResults().empty())
|
||||
candidatesByKey[computeRebatchKey(compute)].push_back(compute);
|
||||
}
|
||||
|
||||
for (size_t index = 0; index < computes.size(); ++index) {
|
||||
auto anchor = computes[index];
|
||||
if (consumed.contains(anchor))
|
||||
continue;
|
||||
if (anchor.getInputs().size() > 1)
|
||||
continue;
|
||||
if (!anchor.getResults().empty())
|
||||
continue;
|
||||
|
||||
SmallVector<SpatCompute> group {anchor};
|
||||
llvm::SmallDenseSet<int32_t, 8> usedCoreIds;
|
||||
if (auto coreId = getComputeCoreId(anchor))
|
||||
usedCoreIds.insert(*coreId);
|
||||
|
||||
auto bucketIt = candidatesByKey.find(computeRebatchKey(anchor));
|
||||
if (bucketIt == candidatesByKey.end())
|
||||
continue;
|
||||
|
||||
for (auto candidate : bucketIt->second) {
|
||||
if (computeOrder.lookup(candidate.getOperation()) <= index)
|
||||
continue;
|
||||
if (consumed.contains(candidate))
|
||||
continue;
|
||||
if (!areEquivalentForRebatch(anchor, candidate))
|
||||
continue;
|
||||
|
||||
if (auto coreId = getComputeCoreId(candidate))
|
||||
if (!usedCoreIds.insert(*coreId).second)
|
||||
continue;
|
||||
|
||||
group.push_back(candidate);
|
||||
}
|
||||
|
||||
if (group.size() <= 1)
|
||||
continue;
|
||||
|
||||
auto insertionAnchor = group.front();
|
||||
if (llvm::all_of(group, [](SpatCompute compute) { return getComputeCoreId(compute).has_value(); })) {
|
||||
llvm::stable_sort(
|
||||
group, [](SpatCompute lhs, SpatCompute rhs) { return *getComputeCoreId(lhs) < *getComputeCoreId(rhs); });
|
||||
}
|
||||
|
||||
SmallVector<Value> weights;
|
||||
weights.reserve(group.size() * anchor.getWeights().size());
|
||||
SmallVector<Value> inputs;
|
||||
inputs.reserve(group.size() * anchor.getInputs().size());
|
||||
SmallVector<int32_t> coreIds;
|
||||
coreIds.reserve(group.size());
|
||||
bool haveAllCoreIds = true;
|
||||
for (auto compute : group) {
|
||||
llvm::append_range(weights, compute.getWeights());
|
||||
llvm::append_range(inputs, compute.getInputs());
|
||||
auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName);
|
||||
if (!coreIdAttr)
|
||||
haveAllCoreIds = false;
|
||||
else if (haveAllCoreIds)
|
||||
coreIds.push_back(static_cast<int32_t>(coreIdAttr.getInt()));
|
||||
}
|
||||
|
||||
rewriter.setInsertionPoint(insertionAnchor);
|
||||
auto rebatched = SpatComputeBatch::create(rewriter,
|
||||
insertionAnchor.getLoc(),
|
||||
TypeRange {},
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(group.size())),
|
||||
ValueRange(weights),
|
||||
ValueRange(inputs));
|
||||
rebatched.getProperties().setOperandSegmentSizes(
|
||||
{static_cast<int>(weights.size()), static_cast<int>(inputs.size())});
|
||||
if (haveAllCoreIds)
|
||||
rebatched->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds));
|
||||
|
||||
SmallVector<Type> blockArgTypes;
|
||||
SmallVector<Location> blockArgLocs;
|
||||
for (BlockArgument arg : anchor.getBody().front().getArguments()) {
|
||||
blockArgTypes.push_back(arg.getType());
|
||||
blockArgLocs.push_back(arg.getLoc());
|
||||
}
|
||||
auto* newBlock =
|
||||
rewriter.createBlock(&rebatched.getBody(), rebatched.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
||||
rewriter.setInsertionPointToEnd(newBlock);
|
||||
|
||||
IRMapping mapper;
|
||||
auto& anchorBlock = anchor.getBody().front();
|
||||
for (auto [oldArg, newArg] : llvm::zip(anchorBlock.getArguments(), newBlock->getArguments()))
|
||||
mapper.map(oldArg, newArg);
|
||||
auto opIts = llvm::map_to_vector(group, [](SpatCompute compute) { return compute.getBody().front().begin(); });
|
||||
for (Operation& anchorOp : anchorBlock) {
|
||||
if (auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(&anchorOp)) {
|
||||
struct BatchReceiveEntry {
|
||||
uint64_t channelId = 0;
|
||||
uint32_t sourceCoreId = 0;
|
||||
uint32_t targetCoreId = 0;
|
||||
};
|
||||
SmallVector<BatchReceiveEntry> entries;
|
||||
entries.reserve(group.size());
|
||||
for (auto [groupIndex, compute] : llvm::enumerate(group)) {
|
||||
auto groupReceive = cast<spatial::SpatChannelReceiveOp>(&*opIts[groupIndex]);
|
||||
entries.push_back(
|
||||
{groupReceive.getChannelId(), groupReceive.getSourceCoreId(), groupReceive.getTargetCoreId()});
|
||||
++opIts[groupIndex];
|
||||
}
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
channelIds.reserve(group.size());
|
||||
sourceCoreIds.reserve(group.size());
|
||||
targetCoreIds.reserve(group.size());
|
||||
for (const BatchReceiveEntry& entry : entries) {
|
||||
channelIds.push_back(static_cast<int64_t>(entry.channelId));
|
||||
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
|
||||
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
|
||||
}
|
||||
auto batchReceive = spatial::SpatChannelReceiveBatchOp::create(rewriter,
|
||||
receiveOp.getLoc(),
|
||||
receiveOp.getOutput().getType(),
|
||||
rewriter.getDenseI64ArrayAttr(channelIds),
|
||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
||||
rewriter.getDenseI32ArrayAttr(targetCoreIds));
|
||||
mapper.map(receiveOp.getOutput(), batchReceive.getOutput());
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto sendOp = dyn_cast<spatial::SpatChannelSendOp>(&anchorOp)) {
|
||||
struct BatchSendEntry {
|
||||
uint64_t channelId = 0;
|
||||
uint32_t sourceCoreId = 0;
|
||||
uint32_t targetCoreId = 0;
|
||||
};
|
||||
SmallVector<BatchSendEntry> entries;
|
||||
entries.reserve(group.size());
|
||||
for (auto [groupIndex, compute] : llvm::enumerate(group)) {
|
||||
auto groupSend = cast<spatial::SpatChannelSendOp>(&*opIts[groupIndex]);
|
||||
entries.push_back({groupSend.getChannelId(), groupSend.getSourceCoreId(), groupSend.getTargetCoreId()});
|
||||
++opIts[groupIndex];
|
||||
}
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
channelIds.reserve(group.size());
|
||||
sourceCoreIds.reserve(group.size());
|
||||
targetCoreIds.reserve(group.size());
|
||||
for (const BatchSendEntry& entry : entries) {
|
||||
channelIds.push_back(static_cast<int64_t>(entry.channelId));
|
||||
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
|
||||
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
|
||||
}
|
||||
spatial::SpatChannelSendBatchOp::create(rewriter,
|
||||
sendOp.getLoc(),
|
||||
rewriter.getDenseI64ArrayAttr(channelIds),
|
||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
||||
rewriter.getDenseI32ArrayAttr(targetCoreIds),
|
||||
mapper.lookup(sendOp.getInput()));
|
||||
continue;
|
||||
}
|
||||
|
||||
if (isa<spatial::SpatYieldOp>(anchorOp)) {
|
||||
for (auto& opIt : opIts)
|
||||
++opIt;
|
||||
spatial::SpatYieldOp::create(rewriter, anchorOp.getLoc(), ValueRange {});
|
||||
continue;
|
||||
}
|
||||
|
||||
Operation* cloned = rewriter.clone(anchorOp, mapper);
|
||||
for (auto [originalResult, clonedResult] : llvm::zip(anchorOp.getResults(), cloned->getResults()))
|
||||
mapper.map(originalResult, clonedResult);
|
||||
for (auto& opIt : opIts)
|
||||
++opIt;
|
||||
}
|
||||
|
||||
for (auto compute : group) {
|
||||
compute->removeAttr(kRebatchPhaseAttrName);
|
||||
consumed.insert(compute);
|
||||
rewriter.eraseOp(compute);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto compute : funcOp.getOps<SpatCompute>())
|
||||
compute->removeAttr(kRebatchPhaseAttrName);
|
||||
}
|
||||
|
||||
void cleanupDeadPackingOps(func::FuncOp funcOp) {
|
||||
auto eraseUnusedOps = [&](auto tag) {
|
||||
using OpTy = decltype(tag);
|
||||
SmallVector<OpTy> ops;
|
||||
funcOp.walk([&](OpTy op) { ops.push_back(op); });
|
||||
for (auto op : llvm::reverse(ops))
|
||||
if (op->use_empty())
|
||||
op.erase();
|
||||
};
|
||||
eraseUnusedOps(tensor::ExtractSliceOp {});
|
||||
eraseUnusedOps(spatial::SpatConcatOp {});
|
||||
eraseUnusedOps(spatial::SpatExtractRowsOp {});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult runPostMergeCompactionPipeline(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||
{
|
||||
ScopedMergePhaseTimer timer("order-bilateral-channel-ops");
|
||||
orderBilateralChannelOps(funcOp);
|
||||
}
|
||||
{
|
||||
ScopedMergePhaseTimer timer("rebatch-equivalent-computes");
|
||||
rebatchEquivalentComputes(funcOp);
|
||||
}
|
||||
{
|
||||
ScopedMergePhaseTimer timer("compact-scalar-channel-runs-1");
|
||||
compactScalarChannelRuns(funcOp, nextChannelId);
|
||||
}
|
||||
{
|
||||
ScopedMergePhaseTimer timer("compact-batch-channel-runs-1");
|
||||
compactBatchChannelRuns(funcOp);
|
||||
}
|
||||
{
|
||||
ScopedMergePhaseTimer timer("compact-regular-op-runs");
|
||||
compactRegularOpRuns(funcOp);
|
||||
}
|
||||
{
|
||||
ScopedMergePhaseTimer timer("compact-row-wise-wvmm-runs");
|
||||
compactRowWiseWvmmRuns(funcOp);
|
||||
}
|
||||
{
|
||||
ScopedMergePhaseTimer timer("compact-scalar-channel-runs-2");
|
||||
compactScalarChannelRuns(funcOp, nextChannelId);
|
||||
}
|
||||
{
|
||||
ScopedMergePhaseTimer timer("compact-batch-channel-runs-2");
|
||||
compactBatchChannelRuns(funcOp);
|
||||
}
|
||||
{
|
||||
ScopedMergePhaseTimer timer("cleanup-dead-packing-ops");
|
||||
cleanupDeadPackingOps(funcOp);
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,12 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
mlir::LogicalResult runPostMergeCompactionPipeline(mlir::func::FuncOp funcOp, int64_t &nextChannelId);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -7,12 +7,13 @@
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <tuple>
|
||||
|
||||
#include "RegularOpCompaction.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
@@ -42,6 +43,47 @@ struct RegularChunk {
|
||||
Value output;
|
||||
};
|
||||
|
||||
struct RegularCompactionResult {
|
||||
bool changed = false;
|
||||
Operation* resumeAfter = nullptr;
|
||||
};
|
||||
|
||||
template <typename OpTy>
|
||||
struct ConsecutiveRun {
|
||||
SmallVector<OpTy> ops;
|
||||
Block::iterator end;
|
||||
};
|
||||
|
||||
template <typename OpTy, typename Predicate>
|
||||
static ConsecutiveRun<OpTy>
|
||||
collectConsecutiveRun(Block::iterator start, Block::iterator blockEnd, Predicate predicate) {
|
||||
ConsecutiveRun<OpTy> run;
|
||||
run.end = start;
|
||||
while (run.end != blockEnd) {
|
||||
auto current = dyn_cast<OpTy>(&*run.end);
|
||||
if (!current || !predicate(current))
|
||||
break;
|
||||
run.ops.push_back(current);
|
||||
++run.end;
|
||||
}
|
||||
return run;
|
||||
}
|
||||
|
||||
static uint64_t getEndpointKey(uint32_t sourceCoreId, uint32_t targetCoreId) {
|
||||
return (static_cast<uint64_t>(sourceCoreId) << 32) | static_cast<uint64_t>(targetCoreId);
|
||||
}
|
||||
|
||||
static void appendChannelAttrs(SmallVectorImpl<int64_t>& channelIds,
|
||||
SmallVectorImpl<int32_t>& sourceCoreIds,
|
||||
SmallVectorImpl<int32_t>& targetCoreIds,
|
||||
uint64_t channelId,
|
||||
uint32_t sourceCoreId,
|
||||
uint32_t targetCoreId) {
|
||||
channelIds.push_back(static_cast<int64_t>(channelId));
|
||||
sourceCoreIds.push_back(static_cast<int32_t>(sourceCoreId));
|
||||
targetCoreIds.push_back(static_cast<int32_t>(targetCoreId));
|
||||
}
|
||||
|
||||
static spatial::SpatConcatOp getContiguousConcatUse(ValueRange values, unsigned& startOperandIndex) {
|
||||
if (values.empty() || !values.front().hasOneUse())
|
||||
return {};
|
||||
@@ -168,6 +210,17 @@ static bool areEquivalentRegularChunks(const RegularChunk& lhs, const RegularChu
|
||||
[](auto pair) { return areEquivalentRegularSteps(std::get<0>(pair), std::get<1>(pair)); });
|
||||
}
|
||||
|
||||
static bool isForwardedChannelPayload(Value value, Block& block) {
|
||||
Operation* op = value.getDefiningOp();
|
||||
if (!op || op->getBlock() != &block)
|
||||
return true;
|
||||
|
||||
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op))
|
||||
return isForwardedChannelPayload(extractSliceOp.getSource(), block);
|
||||
|
||||
return isa<spatial::SpatChannelReceiveOp, spatial::SpatChannelReceiveTensorOp>(op);
|
||||
}
|
||||
|
||||
static FailureOr<RegularChunk> analyzeRegularChunk(spatial::SpatVMMOp startOp) {
|
||||
RegularChunk chunk;
|
||||
chunk.startOp = startOp.getOperation();
|
||||
@@ -202,9 +255,10 @@ static FailureOr<RegularChunk> analyzeRegularChunk(spatial::SpatVMMOp startOp) {
|
||||
return chunk;
|
||||
}
|
||||
|
||||
static void compactRegularChunkRun(IRRewriter& rewriter, ArrayRef<RegularChunk> run) {
|
||||
static RegularCompactionResult compactRegularChunkRun(IRRewriter& rewriter, ArrayRef<RegularChunk> run) {
|
||||
assert(!run.empty() && "expected a non-empty regular chunk run");
|
||||
const RegularChunk& anchorChunk = run.front();
|
||||
RegularCompactionResult result;
|
||||
|
||||
SmallVector<Value> inputs;
|
||||
inputs.reserve(run.size());
|
||||
@@ -214,7 +268,7 @@ static void compactRegularChunkRun(IRRewriter& rewriter, ArrayRef<RegularChunk>
|
||||
rewriter.setInsertionPoint(anchorChunk.startOp);
|
||||
Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, anchorChunk.startOp->getLoc());
|
||||
if (!packedInput)
|
||||
return;
|
||||
return result;
|
||||
|
||||
auto inputType = cast<RankedTensorType>(anchorChunk.input.getType());
|
||||
auto outputType = cast<RankedTensorType>(anchorChunk.output.getType());
|
||||
@@ -317,10 +371,79 @@ static void compactRegularChunkRun(IRRewriter& rewriter, ArrayRef<RegularChunk>
|
||||
llvm::append_range(opsToErase, chunk.ops);
|
||||
for (Operation* op : llvm::reverse(opsToErase))
|
||||
rewriter.eraseOp(op);
|
||||
|
||||
result.changed = true;
|
||||
result.resumeAfter = loop.getOperation()->getNextNode();
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void orderBilateralChannelOps(func::FuncOp funcOp) {
|
||||
for (auto compute : funcOp.getOps<spatial::SpatCompute>()) {
|
||||
auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName);
|
||||
if (!coreIdAttr)
|
||||
continue;
|
||||
|
||||
int32_t coreId = static_cast<int32_t>(coreIdAttr.getInt());
|
||||
Block& block = compute.getBody().front();
|
||||
SmallVector<std::pair<spatial::SpatChannelReceiveOp, Operation*>> moves;
|
||||
DenseMap<uint64_t, Operation*> firstForwardedSendByEndpoint;
|
||||
|
||||
for (Operation& op : block) {
|
||||
if (auto sendOp = dyn_cast<spatial::SpatChannelSendOp>(&op)) {
|
||||
if (sendOp.getSourceCoreId() == static_cast<uint32_t>(coreId)
|
||||
&& isForwardedChannelPayload(sendOp.getInput(), block)) {
|
||||
uint64_t key = getEndpointKey(sendOp.getSourceCoreId(), sendOp.getTargetCoreId());
|
||||
firstForwardedSendByEndpoint.try_emplace(key, sendOp.getOperation());
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(&op);
|
||||
if (!receiveOp || receiveOp.getTargetCoreId() != static_cast<uint32_t>(coreId)
|
||||
|| receiveOp.getSourceCoreId() >= static_cast<uint32_t>(coreId)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
uint64_t key = getEndpointKey(static_cast<uint32_t>(coreId), receiveOp.getSourceCoreId());
|
||||
auto firstMatchingSend = firstForwardedSendByEndpoint.find(key);
|
||||
if (firstMatchingSend != firstForwardedSendByEndpoint.end())
|
||||
moves.push_back({receiveOp, firstMatchingSend->second});
|
||||
}
|
||||
|
||||
for (auto [receiveOp, insertionPoint] : moves)
|
||||
receiveOp->moveBefore(insertionPoint);
|
||||
|
||||
for (auto it = block.begin(); it != block.end();) {
|
||||
auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(&*it);
|
||||
if (!receiveOp || receiveOp.getSourceCoreId() >= static_cast<uint32_t>(coreId)) {
|
||||
++it;
|
||||
continue;
|
||||
}
|
||||
|
||||
Type outputType = receiveOp.getOutput().getType();
|
||||
auto run = collectConsecutiveRun<spatial::SpatChannelReceiveOp>(
|
||||
it, block.end(), [&](spatial::SpatChannelReceiveOp current) {
|
||||
return current.getOutput().getType() == outputType
|
||||
&& current.getSourceCoreId() < static_cast<uint32_t>(coreId);
|
||||
});
|
||||
|
||||
if (run.ops.size() > 1) {
|
||||
SmallVector<spatial::SpatChannelReceiveOp> sorted(run.ops);
|
||||
llvm::stable_sort(sorted, [](spatial::SpatChannelReceiveOp lhs, spatial::SpatChannelReceiveOp rhs) {
|
||||
return lhs.getSourceCoreId() > rhs.getSourceCoreId();
|
||||
});
|
||||
Block::iterator insertIt = run.end;
|
||||
for (auto op : sorted)
|
||||
op->moveBefore(&block, insertIt);
|
||||
}
|
||||
|
||||
it = run.end;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||
IRRewriter rewriter(funcOp.getContext());
|
||||
|
||||
@@ -329,18 +452,23 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||
for (auto it = block.begin(); it != block.end();) {
|
||||
auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(&*it);
|
||||
if (receiveOp) {
|
||||
SmallVector<spatial::SpatChannelReceiveOp> run;
|
||||
Type outputType = receiveOp.getOutput().getType();
|
||||
auto runIt = it;
|
||||
while (runIt != block.end()) {
|
||||
auto current = dyn_cast<spatial::SpatChannelReceiveOp>(&*runIt);
|
||||
if (!current || current.getOutput().getType() != outputType)
|
||||
auto run = collectConsecutiveRun<spatial::SpatChannelReceiveOp>(
|
||||
it, block.end(), [&](spatial::SpatChannelReceiveOp current) {
|
||||
return current.getOutput().getType() == outputType;
|
||||
});
|
||||
|
||||
bool hasRepeatedEndpoint = false;
|
||||
DenseSet<uint64_t> seenEndpoints;
|
||||
for (auto op : run.ops) {
|
||||
uint64_t endpointKey = getEndpointKey(op.getSourceCoreId(), op.getTargetCoreId());
|
||||
if (!seenEndpoints.insert(endpointKey).second) {
|
||||
hasRepeatedEndpoint = true;
|
||||
break;
|
||||
run.push_back(current);
|
||||
++runIt;
|
||||
}
|
||||
}
|
||||
|
||||
if (run.size() > 1) {
|
||||
if (run.ops.size() > 1 && !hasRepeatedEndpoint) {
|
||||
struct ReceiveEntry {
|
||||
spatial::SpatChannelReceiveOp op;
|
||||
size_t originalIndex = 0;
|
||||
@@ -349,13 +477,9 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||
uint64_t channelId = 0;
|
||||
};
|
||||
SmallVector<ReceiveEntry> sortedEntries;
|
||||
sortedEntries.reserve(run.size());
|
||||
for (auto [originalIndex, op] : llvm::enumerate(run))
|
||||
sortedEntries.reserve(run.ops.size());
|
||||
for (auto [originalIndex, op] : llvm::enumerate(run.ops))
|
||||
sortedEntries.push_back({op, originalIndex, op.getSourceCoreId(), op.getTargetCoreId(), op.getChannelId()});
|
||||
llvm::stable_sort(sortedEntries, [](const ReceiveEntry& lhs, const ReceiveEntry& rhs) {
|
||||
return std::tuple(lhs.sourceCoreId, lhs.targetCoreId, lhs.channelId)
|
||||
< std::tuple(rhs.sourceCoreId, rhs.targetCoreId, rhs.channelId);
|
||||
});
|
||||
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
@@ -364,13 +488,11 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||
sourceCoreIds.reserve(sortedEntries.size());
|
||||
targetCoreIds.reserve(sortedEntries.size());
|
||||
for (ReceiveEntry& entry : sortedEntries) {
|
||||
(void) entry;
|
||||
channelIds.push_back(nextChannelId++);
|
||||
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
|
||||
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
|
||||
appendChannelAttrs(
|
||||
channelIds, sourceCoreIds, targetCoreIds, entry.channelId, entry.sourceCoreId, entry.targetCoreId);
|
||||
}
|
||||
|
||||
auto rowType = cast<RankedTensorType>(run.front().getOutput().getType());
|
||||
auto rowType = cast<RankedTensorType>(run.ops.front().getOutput().getType());
|
||||
auto fallbackPackedType = getPackedTensorType(rowType, static_cast<int64_t>(sortedEntries.size()));
|
||||
SmallVector<Value> sortedOutputs;
|
||||
sortedOutputs.reserve(sortedEntries.size());
|
||||
@@ -383,10 +505,10 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||
concatOp ? getPackedConcatSliceType(concatOp, concatStartIndex, static_cast<unsigned>(sortedOutputs.size()))
|
||||
: RankedTensorType {};
|
||||
auto packedType = concatPackedType ? concatPackedType : fallbackPackedType;
|
||||
rewriter.setInsertionPoint(run.front());
|
||||
rewriter.setInsertionPoint(run.ops.front());
|
||||
auto compactReceive =
|
||||
spatial::SpatChannelReceiveTensorOp::create(rewriter,
|
||||
run.front().getLoc(),
|
||||
run.ops.front().getLoc(),
|
||||
packedType,
|
||||
rewriter.getDenseI64ArrayAttr(channelIds),
|
||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
||||
@@ -403,7 +525,7 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||
entry.op.getOutput().replaceAllUsesWith(extractPackedChunk(
|
||||
compactReceive.getOutput(), rowType, static_cast<unsigned>(sortedIndex), rewriter, entry.op.getLoc()));
|
||||
}
|
||||
for (auto op : run)
|
||||
for (auto op : run.ops)
|
||||
rewriter.eraseOp(op);
|
||||
|
||||
it = compactReceive->getIterator();
|
||||
@@ -414,18 +536,13 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||
|
||||
auto sendOp = dyn_cast<spatial::SpatChannelSendOp>(&*it);
|
||||
if (sendOp) {
|
||||
SmallVector<spatial::SpatChannelSendOp> run;
|
||||
Type inputType = sendOp.getInput().getType();
|
||||
auto runIt = it;
|
||||
while (runIt != block.end()) {
|
||||
auto current = dyn_cast<spatial::SpatChannelSendOp>(&*runIt);
|
||||
if (!current || current.getInput().getType() != inputType)
|
||||
break;
|
||||
run.push_back(current);
|
||||
++runIt;
|
||||
}
|
||||
auto run =
|
||||
collectConsecutiveRun<spatial::SpatChannelSendOp>(it, block.end(), [&](spatial::SpatChannelSendOp current) {
|
||||
return current.getInput().getType() == inputType;
|
||||
});
|
||||
|
||||
if (run.size() > 1) {
|
||||
if (run.ops.size() > 1) {
|
||||
struct SendEntry {
|
||||
spatial::SpatChannelSendOp op;
|
||||
uint32_t sourceCoreId = 0;
|
||||
@@ -433,13 +550,9 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||
uint64_t channelId = 0;
|
||||
};
|
||||
SmallVector<SendEntry> sortedEntries;
|
||||
sortedEntries.reserve(run.size());
|
||||
for (auto op : run)
|
||||
sortedEntries.reserve(run.ops.size());
|
||||
for (auto op : run.ops)
|
||||
sortedEntries.push_back({op, op.getSourceCoreId(), op.getTargetCoreId(), op.getChannelId()});
|
||||
llvm::stable_sort(sortedEntries, [](const SendEntry& lhs, const SendEntry& rhs) {
|
||||
return std::tuple(lhs.sourceCoreId, lhs.targetCoreId, lhs.channelId)
|
||||
< std::tuple(rhs.sourceCoreId, rhs.targetCoreId, rhs.channelId);
|
||||
});
|
||||
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
@@ -450,26 +563,24 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||
targetCoreIds.reserve(sortedEntries.size());
|
||||
inputs.reserve(sortedEntries.size());
|
||||
for (SendEntry& entry : sortedEntries) {
|
||||
(void) entry;
|
||||
channelIds.push_back(nextChannelId++);
|
||||
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
|
||||
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
|
||||
appendChannelAttrs(
|
||||
channelIds, sourceCoreIds, targetCoreIds, entry.channelId, entry.sourceCoreId, entry.targetCoreId);
|
||||
inputs.push_back(entry.op.getInput());
|
||||
}
|
||||
|
||||
rewriter.setInsertionPoint(run.front());
|
||||
Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, run.front().getLoc());
|
||||
rewriter.setInsertionPoint(run.ops.front());
|
||||
Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, run.ops.front().getLoc());
|
||||
if (packedInput) {
|
||||
spatial::SpatChannelSendTensorOp::create(rewriter,
|
||||
run.front().getLoc(),
|
||||
run.ops.front().getLoc(),
|
||||
rewriter.getDenseI64ArrayAttr(channelIds),
|
||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
||||
rewriter.getDenseI32ArrayAttr(targetCoreIds),
|
||||
packedInput);
|
||||
for (auto op : run)
|
||||
for (auto op : run.ops)
|
||||
rewriter.eraseOp(op);
|
||||
|
||||
it = runIt;
|
||||
it = run.end;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
@@ -488,32 +599,27 @@ void compactBatchChannelRuns(func::FuncOp funcOp) {
|
||||
for (auto it = block.begin(); it != block.end();) {
|
||||
auto receiveOp = dyn_cast<spatial::SpatChannelReceiveBatchOp>(&*it);
|
||||
if (receiveOp) {
|
||||
SmallVector<spatial::SpatChannelReceiveBatchOp> run;
|
||||
Type outputType = receiveOp.getOutput().getType();
|
||||
auto runIt = it;
|
||||
while (runIt != block.end()) {
|
||||
auto current = dyn_cast<spatial::SpatChannelReceiveBatchOp>(&*runIt);
|
||||
if (!current || current.getOutput().getType() != outputType)
|
||||
break;
|
||||
run.push_back(current);
|
||||
++runIt;
|
||||
}
|
||||
auto run = collectConsecutiveRun<spatial::SpatChannelReceiveBatchOp>(
|
||||
it, block.end(), [&](spatial::SpatChannelReceiveBatchOp current) {
|
||||
return current.getOutput().getType() == outputType;
|
||||
});
|
||||
|
||||
if (run.size() > 1) {
|
||||
if (run.ops.size() > 1) {
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
for (auto op : run) {
|
||||
for (auto op : run.ops) {
|
||||
llvm::append_range(channelIds, op.getChannelIds());
|
||||
llvm::append_range(sourceCoreIds, op.getSourceCoreIds());
|
||||
llvm::append_range(targetCoreIds, op.getTargetCoreIds());
|
||||
}
|
||||
|
||||
auto rowType = cast<RankedTensorType>(run.front().getOutput().getType());
|
||||
auto fallbackPackedType = getPackedTensorType(rowType, static_cast<int64_t>(run.size()));
|
||||
auto rowType = cast<RankedTensorType>(run.ops.front().getOutput().getType());
|
||||
auto fallbackPackedType = getPackedTensorType(rowType, static_cast<int64_t>(run.ops.size()));
|
||||
SmallVector<Value> outputs;
|
||||
outputs.reserve(run.size());
|
||||
for (auto op : run)
|
||||
outputs.reserve(run.ops.size());
|
||||
for (auto op : run.ops)
|
||||
outputs.push_back(op.getOutput());
|
||||
|
||||
unsigned concatStartIndex = 0;
|
||||
@@ -522,10 +628,10 @@ void compactBatchChannelRuns(func::FuncOp funcOp) {
|
||||
concatOp ? getPackedConcatSliceType(concatOp, concatStartIndex, static_cast<unsigned>(outputs.size()))
|
||||
: RankedTensorType {};
|
||||
auto packedType = concatPackedType ? concatPackedType : fallbackPackedType;
|
||||
rewriter.setInsertionPoint(run.front());
|
||||
rewriter.setInsertionPoint(run.ops.front());
|
||||
auto compactReceive =
|
||||
spatial::SpatChannelReceiveTensorBatchOp::create(rewriter,
|
||||
run.front().getLoc(),
|
||||
run.ops.front().getLoc(),
|
||||
packedType,
|
||||
rewriter.getDenseI64ArrayAttr(channelIds),
|
||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
||||
@@ -535,11 +641,11 @@ void compactBatchChannelRuns(func::FuncOp funcOp) {
|
||||
concatOp, concatStartIndex, static_cast<unsigned>(outputs.size()), compactReceive.getOutput(), rewriter);
|
||||
}
|
||||
else {
|
||||
for (auto [index, op] : llvm::enumerate(run))
|
||||
for (auto [index, op] : llvm::enumerate(run.ops))
|
||||
op.getOutput().replaceAllUsesWith(extractPackedChunk(
|
||||
compactReceive.getOutput(), rowType, static_cast<unsigned>(index), rewriter, op.getLoc()));
|
||||
}
|
||||
for (auto op : run)
|
||||
for (auto op : run.ops)
|
||||
rewriter.eraseOp(op);
|
||||
|
||||
it = compactReceive->getIterator();
|
||||
@@ -550,43 +656,38 @@ void compactBatchChannelRuns(func::FuncOp funcOp) {
|
||||
|
||||
auto sendOp = dyn_cast<spatial::SpatChannelSendBatchOp>(&*it);
|
||||
if (sendOp) {
|
||||
SmallVector<spatial::SpatChannelSendBatchOp> run;
|
||||
Type inputType = sendOp.getInput().getType();
|
||||
auto runIt = it;
|
||||
while (runIt != block.end()) {
|
||||
auto current = dyn_cast<spatial::SpatChannelSendBatchOp>(&*runIt);
|
||||
if (!current || current.getInput().getType() != inputType)
|
||||
break;
|
||||
run.push_back(current);
|
||||
++runIt;
|
||||
}
|
||||
auto run = collectConsecutiveRun<spatial::SpatChannelSendBatchOp>(
|
||||
it, block.end(), [&](spatial::SpatChannelSendBatchOp current) {
|
||||
return current.getInput().getType() == inputType;
|
||||
});
|
||||
|
||||
if (run.size() > 1) {
|
||||
if (run.ops.size() > 1) {
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
SmallVector<Value> inputs;
|
||||
inputs.reserve(run.size());
|
||||
for (auto op : run) {
|
||||
inputs.reserve(run.ops.size());
|
||||
for (auto op : run.ops) {
|
||||
llvm::append_range(channelIds, op.getChannelIds());
|
||||
llvm::append_range(sourceCoreIds, op.getSourceCoreIds());
|
||||
llvm::append_range(targetCoreIds, op.getTargetCoreIds());
|
||||
inputs.push_back(op.getInput());
|
||||
}
|
||||
|
||||
rewriter.setInsertionPoint(run.front());
|
||||
Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, run.front().getLoc());
|
||||
rewriter.setInsertionPoint(run.ops.front());
|
||||
Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, run.ops.front().getLoc());
|
||||
if (packedInput) {
|
||||
spatial::SpatChannelSendTensorBatchOp::create(rewriter,
|
||||
run.front().getLoc(),
|
||||
run.ops.front().getLoc(),
|
||||
rewriter.getDenseI64ArrayAttr(channelIds),
|
||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
||||
rewriter.getDenseI32ArrayAttr(targetCoreIds),
|
||||
packedInput);
|
||||
for (auto op : run)
|
||||
for (auto op : run.ops)
|
||||
rewriter.eraseOp(op);
|
||||
|
||||
it = runIt;
|
||||
it = run.end;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
@@ -614,8 +715,9 @@ void compactRegularOpRuns(func::FuncOp funcOp) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto anchorEndIt = std::next(it, static_cast<std::ptrdiff_t>(anchorChunk->ops.size()));
|
||||
SmallVector<RegularChunk> run {*anchorChunk};
|
||||
auto runIt = std::next(it, static_cast<std::ptrdiff_t>(anchorChunk->ops.size()));
|
||||
auto runIt = anchorEndIt;
|
||||
while (runIt != block.end()) {
|
||||
auto candidateStart = dyn_cast<spatial::SpatVMMOp>(&*runIt);
|
||||
if (!candidateStart)
|
||||
@@ -630,12 +732,26 @@ void compactRegularOpRuns(func::FuncOp funcOp) {
|
||||
}
|
||||
|
||||
if (run.size() <= 1) {
|
||||
++it;
|
||||
it = anchorEndIt;
|
||||
continue;
|
||||
}
|
||||
|
||||
compactRegularChunkRun(rewriter, run);
|
||||
it = runIt;
|
||||
size_t originalOpCount = 0;
|
||||
for (const RegularChunk& chunk : run)
|
||||
originalOpCount += chunk.ops.size();
|
||||
|
||||
RegularCompactionResult result = compactRegularChunkRun(rewriter, run);
|
||||
if (result.changed) {
|
||||
assert(originalOpCount > anchorChunk->ops.size() && "successful regular compaction must consume the run");
|
||||
if (!result.resumeAfter) {
|
||||
it = block.end();
|
||||
continue;
|
||||
}
|
||||
it = result.resumeAfter->getIterator();
|
||||
continue;
|
||||
}
|
||||
|
||||
it = anchorEndIt;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -666,37 +782,32 @@ void compactRowWiseWvmmRuns(func::FuncOp funcOp) {
|
||||
continue;
|
||||
}
|
||||
|
||||
SmallVector<spatial::SpatVMMOp> run;
|
||||
auto runIt = it;
|
||||
int64_t expectedRow = static_cast<int64_t>(rowResult.getResultNumber());
|
||||
while (runIt != block.end()) {
|
||||
auto current = dyn_cast<spatial::SpatVMMOp>(&*runIt);
|
||||
if (!current || current.getWeightIndex() != wvmmOp.getWeightIndex()
|
||||
auto run = collectConsecutiveRun<spatial::SpatVMMOp>(it, block.end(), [&](spatial::SpatVMMOp current) {
|
||||
if (current.getWeightIndex() != wvmmOp.getWeightIndex()
|
||||
|| current.getInput().getDefiningOp<spatial::SpatExtractRowsOp>() != extractRowsOp
|
||||
|| current.getInput().getType() != wvmmOp.getInput().getType()
|
||||
|| current.getOutput().getType() != wvmmOp.getOutput().getType()) {
|
||||
break;
|
||||
}
|
||||
|| current.getOutput().getType() != wvmmOp.getOutput().getType())
|
||||
return false;
|
||||
|
||||
auto currentRow = dyn_cast<OpResult>(current.getInput());
|
||||
if (!currentRow || currentRow.getResultNumber() != static_cast<unsigned>(expectedRow))
|
||||
break;
|
||||
return false;
|
||||
|
||||
run.push_back(current);
|
||||
++expectedRow;
|
||||
++runIt;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
|
||||
if (run.size() <= 1) {
|
||||
if (run.ops.size() <= 1) {
|
||||
++it;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!run.front().getOutput().hasOneUse()) {
|
||||
if (!run.ops.front().getOutput().hasOneUse()) {
|
||||
++it;
|
||||
continue;
|
||||
}
|
||||
auto concatUse = run.front().getOutput().getUses().begin();
|
||||
auto concatUse = run.ops.front().getOutput().getUses().begin();
|
||||
auto concatOp = dyn_cast<spatial::SpatConcatOp>(concatUse->getOwner());
|
||||
if (!concatOp) {
|
||||
++it;
|
||||
@@ -705,7 +816,7 @@ void compactRowWiseWvmmRuns(func::FuncOp funcOp) {
|
||||
|
||||
unsigned concatStartIndex = concatUse->getOperandNumber();
|
||||
bool validConcatRun = true;
|
||||
for (auto [index, op] : llvm::enumerate(run)) {
|
||||
for (auto [index, op] : llvm::enumerate(run.ops)) {
|
||||
if (!op.getOutput().hasOneUse()) {
|
||||
validConcatRun = false;
|
||||
break;
|
||||
@@ -736,17 +847,17 @@ void compactRowWiseWvmmRuns(func::FuncOp funcOp) {
|
||||
}
|
||||
|
||||
int64_t firstRow = static_cast<int64_t>(rowResult.getResultNumber());
|
||||
int64_t runLength = static_cast<int64_t>(run.size());
|
||||
int64_t runLength = static_cast<int64_t>(run.ops.size());
|
||||
auto packedType = RankedTensorType::get({runLength, outputCols}, outputType.getElementType());
|
||||
|
||||
rewriter.setInsertionPoint(run.front());
|
||||
auto zero = arith::ConstantIndexOp::create(rewriter, run.front().getLoc(), 0);
|
||||
auto upper = arith::ConstantIndexOp::create(rewriter, run.front().getLoc(), runLength);
|
||||
auto step = arith::ConstantIndexOp::create(rewriter, run.front().getLoc(), 1);
|
||||
rewriter.setInsertionPoint(run.ops.front());
|
||||
auto zero = arith::ConstantIndexOp::create(rewriter, run.ops.front().getLoc(), 0);
|
||||
auto upper = arith::ConstantIndexOp::create(rewriter, run.ops.front().getLoc(), runLength);
|
||||
auto step = arith::ConstantIndexOp::create(rewriter, run.ops.front().getLoc(), 1);
|
||||
auto packedInit =
|
||||
tensor::EmptyOp::create(rewriter, run.front().getLoc(), packedType.getShape(), packedType.getElementType());
|
||||
tensor::EmptyOp::create(rewriter, run.ops.front().getLoc(), packedType.getShape(), packedType.getElementType());
|
||||
auto loop =
|
||||
scf::ForOp::create(rewriter, run.front().getLoc(), zero, upper, step, ValueRange {packedInit.getResult()});
|
||||
scf::ForOp::create(rewriter, run.ops.front().getLoc(), zero, upper, step, ValueRange {packedInit.getResult()});
|
||||
|
||||
{
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
@@ -757,41 +868,41 @@ void compactRowWiseWvmmRuns(func::FuncOp funcOp) {
|
||||
|
||||
Value sourceRow = iv;
|
||||
if (firstRow != 0) {
|
||||
auto firstRowValue = arith::ConstantIndexOp::create(rewriter, run.front().getLoc(), firstRow);
|
||||
sourceRow = arith::AddIOp::create(rewriter, run.front().getLoc(), iv, firstRowValue);
|
||||
auto firstRowValue = arith::ConstantIndexOp::create(rewriter, run.ops.front().getLoc(), firstRow);
|
||||
sourceRow = arith::AddIOp::create(rewriter, run.ops.front().getLoc(), iv, firstRowValue);
|
||||
}
|
||||
|
||||
SmallVector<OpFoldResult> extractOffsets = {sourceRow, rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> extractSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(inputCols)};
|
||||
SmallVector<OpFoldResult> extractStrides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
auto extractedRow = tensor::ExtractSliceOp::create(rewriter,
|
||||
run.front().getLoc(),
|
||||
run.ops.front().getLoc(),
|
||||
inputType,
|
||||
extractRowsOp.getInput(),
|
||||
extractOffsets,
|
||||
extractSizes,
|
||||
extractStrides);
|
||||
auto loopWvmm = spatial::SpatVMMOp::create(
|
||||
rewriter, run.front().getLoc(), outputType, wvmmOp.getWeightIndex(), extractedRow.getResult());
|
||||
rewriter, run.ops.front().getLoc(), outputType, wvmmOp.getWeightIndex(), extractedRow.getResult());
|
||||
|
||||
SmallVector<OpFoldResult> insertOffsets = {iv, rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> insertSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(outputCols)};
|
||||
SmallVector<OpFoldResult> insertStrides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
auto inserted = tensor::InsertSliceOp::create(
|
||||
rewriter, run.front().getLoc(), loopWvmm.getResult(), acc, insertOffsets, insertSizes, insertStrides);
|
||||
scf::YieldOp::create(rewriter, run.front().getLoc(), inserted.getResult());
|
||||
rewriter, run.ops.front().getLoc(), loopWvmm.getResult(), acc, insertOffsets, insertSizes, insertStrides);
|
||||
scf::YieldOp::create(rewriter, run.ops.front().getLoc(), inserted.getResult());
|
||||
}
|
||||
|
||||
SmallVector<Value> newConcatInputs;
|
||||
newConcatInputs.reserve(concatOp.getInputs().size() - run.size() + 1);
|
||||
newConcatInputs.reserve(concatOp.getInputs().size() - run.ops.size() + 1);
|
||||
for (auto [operandIndex, operand] : llvm::enumerate(concatOp.getInputs())) {
|
||||
if (operandIndex == concatStartIndex)
|
||||
newConcatInputs.push_back(loop.getResult(0));
|
||||
if (operandIndex < concatStartIndex || operandIndex >= concatStartIndex + run.size())
|
||||
if (operandIndex < concatStartIndex || operandIndex >= concatStartIndex + run.ops.size())
|
||||
newConcatInputs.push_back(operand);
|
||||
}
|
||||
rewriter.modifyOpInPlace(concatOp, [&] { concatOp->setOperands(newConcatInputs); });
|
||||
for (auto op : run)
|
||||
for (auto op : run.ops)
|
||||
rewriter.eraseOp(op);
|
||||
|
||||
it = loop->getIterator();
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
void orderBilateralChannelOps(mlir::func::FuncOp funcOp);
|
||||
void compactScalarChannelRuns(mlir::func::FuncOp funcOp, int64_t& nextChannelId);
|
||||
void compactBatchChannelRuns(mlir::func::FuncOp funcOp);
|
||||
void compactRegularOpRuns(mlir::func::FuncOp funcOp);
|
||||
|
||||
@@ -0,0 +1,189 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
#include <optional>
|
||||
#include <queue>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "ComputeGraph.hpp"
|
||||
#include "src/Support/TypeUtilities.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
|
||||
Weight getComputeBodyWeight(Region &body) {
|
||||
constexpr Weight kOperationWeight = 100;
|
||||
Weight numOperations = 0;
|
||||
for (auto &block : body)
|
||||
for ([[maybe_unused]] auto &op : block)
|
||||
numOperations = checkedAdd(numOperations, static_cast<Weight>(1));
|
||||
return checkedMultiply(numOperations, kOperationWeight);
|
||||
}
|
||||
|
||||
CrossbarUsage getComputeBodyCrossbarUsage(Region &body) {
|
||||
CrossbarUsage crossbarUsage = 0;
|
||||
for (auto &block : body)
|
||||
for (auto &op : block)
|
||||
if (isa<SpatVMMOp>(op))
|
||||
crossbarUsage = checkedAdd(crossbarUsage, static_cast<CrossbarUsage>(1));
|
||||
return crossbarUsage;
|
||||
}
|
||||
|
||||
bool isUsedAsWeightOnly(Operation *producerOp) {
|
||||
if (producerOp->getNumResults() == 0)
|
||||
return false;
|
||||
for (Value result : producerOp->getResults()) {
|
||||
if (result.use_empty())
|
||||
return false;
|
||||
for (Operation *user : result.getUsers()) {
|
||||
if (auto compute = dyn_cast<SpatCompute>(user)) {
|
||||
if (!llvm::is_contained(compute.getWeights(), result))
|
||||
return false;
|
||||
continue;
|
||||
}
|
||||
if (auto batch = dyn_cast<SpatComputeBatch>(user)) {
|
||||
if (!llvm::is_contained(batch.getWeights(), result))
|
||||
return false;
|
||||
continue;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<ComputeGraphEdge> aggregateEdges(llvm::ArrayRef<ComputeGraphEdge> edges) {
|
||||
llvm::DenseMap<std::pair<size_t, size_t>, Weight> edgeWeights;
|
||||
for (const ComputeGraphEdge &edge : edges) {
|
||||
if (edge.source == edge.target)
|
||||
continue;
|
||||
auto inserted = edgeWeights.try_emplace({edge.source, edge.target}, edge.transferCost);
|
||||
if (!inserted.second)
|
||||
inserted.first->second = std::max(inserted.first->second, edge.transferCost);
|
||||
}
|
||||
|
||||
std::vector<ComputeGraphEdge> aggregatedEdges;
|
||||
aggregatedEdges.reserve(edgeWeights.size());
|
||||
for (const auto &[key, weight] : edgeWeights)
|
||||
aggregatedEdges.push_back({key.first, key.second, weight});
|
||||
llvm::sort(aggregatedEdges, [](const ComputeGraphEdge &lhs, const ComputeGraphEdge &rhs) {
|
||||
if (lhs.source != rhs.source)
|
||||
return lhs.source < rhs.source;
|
||||
return lhs.target < rhs.target;
|
||||
});
|
||||
return aggregatedEdges;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Weight getComputeInstanceWeight(const ComputeInstance &instance) {
|
||||
if (auto spatCompute = dyn_cast<SpatCompute>(instance.op))
|
||||
return getSpatComputeWeight(spatCompute);
|
||||
auto batch = cast<SpatComputeBatch>(instance.op);
|
||||
return checkedMultiply(getComputeBodyWeight(batch.getBody()), static_cast<Weight>(instance.laneCount));
|
||||
}
|
||||
|
||||
CrossbarUsage getComputeInstanceCrossbarUsage(const ComputeInstance &instance) {
|
||||
if (auto spatCompute = dyn_cast<SpatCompute>(instance.op))
|
||||
return getSpatComputeCrossbarUsage(spatCompute);
|
||||
auto batch = cast<SpatComputeBatch>(instance.op);
|
||||
return checkedMultiply(getComputeBodyCrossbarUsage(batch.getBody()),
|
||||
static_cast<CrossbarUsage>(instance.laneCount));
|
||||
}
|
||||
|
||||
ComputeGraph buildComputeGraph(Operation *entryOp) {
|
||||
ComputeGraph graph;
|
||||
|
||||
for (Region ®ion : entryOp->getRegions()) {
|
||||
for (Block &block : region) {
|
||||
for (Operation &op : block) {
|
||||
if (auto spatCompute = dyn_cast<SpatCompute>(&op)) {
|
||||
if (isUsedAsWeightOnly(spatCompute.getOperation()))
|
||||
continue;
|
||||
ComputeInstance instance {spatCompute.getOperation(), 0, 1};
|
||||
size_t index = graph.nodes.size();
|
||||
graph.nodes.push_back({instance, getComputeInstanceWeight(instance), getComputeInstanceCrossbarUsage(instance), index});
|
||||
graph.instanceToIndex[instance] = index;
|
||||
continue;
|
||||
}
|
||||
if (auto batch = dyn_cast<SpatComputeBatch>(&op)) {
|
||||
if (isUsedAsWeightOnly(batch.getOperation()))
|
||||
continue;
|
||||
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount());
|
||||
for (size_t chunkIndex = 0; chunkIndex < chunkCount; ++chunkIndex) {
|
||||
ComputeInstance instance = getBatchChunkForIndex(batch, chunkIndex);
|
||||
size_t index = graph.nodes.size();
|
||||
graph.nodes.push_back(
|
||||
{instance, getComputeInstanceWeight(instance), getComputeInstanceCrossbarUsage(instance), index});
|
||||
graph.instanceToIndex[instance] = index;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
llvm::SmallVector<ComputeGraphEdge, 16> rawEdges;
|
||||
for (const auto &[targetIndex, node] : llvm::enumerate(graph.nodes)) {
|
||||
for (Value input : getComputeInstanceInputs(node.instance)) {
|
||||
auto producerInstance = getComputeProducerInstance(input);
|
||||
if (!producerInstance)
|
||||
continue;
|
||||
auto producerIt = graph.instanceToIndex.find(*producerInstance);
|
||||
if (producerIt == graph.instanceToIndex.end())
|
||||
continue;
|
||||
rawEdges.push_back(
|
||||
{producerIt->second, targetIndex, static_cast<Weight>(getSizeInBytes(cast<ShapedType>(input.getType())))});
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<ComputeGraphEdge> aggregatedEdges = aggregateEdges(rawEdges);
|
||||
graph.edges.append(aggregatedEdges.begin(), aggregatedEdges.end());
|
||||
graph.successors.assign(graph.nodes.size(), {});
|
||||
graph.predecessors.assign(graph.nodes.size(), {});
|
||||
for (const ComputeGraphEdge &edge : graph.edges) {
|
||||
graph.successors[edge.source].push_back({edge.target, edge.transferCost});
|
||||
graph.predecessors[edge.target].push_back({edge.source, edge.transferCost});
|
||||
}
|
||||
|
||||
return graph;
|
||||
}
|
||||
|
||||
bool verifyAcyclic(const ComputeGraph &graph) {
|
||||
std::vector<size_t> remainingParents(graph.nodes.size(), 0);
|
||||
std::queue<size_t> readyNodes;
|
||||
for (size_t node = 0; node < graph.nodes.size(); ++node) {
|
||||
remainingParents[node] = graph.predecessors[node].size();
|
||||
if (remainingParents[node] == 0)
|
||||
readyNodes.push(node);
|
||||
}
|
||||
|
||||
size_t visited = 0;
|
||||
while (!readyNodes.empty()) {
|
||||
size_t node = readyNodes.front();
|
||||
readyNodes.pop();
|
||||
++visited;
|
||||
for (const auto &[child, weight] : graph.successors[node]) {
|
||||
(void) weight;
|
||||
assert(remainingParents[child] > 0 && "remaining parent count underflow");
|
||||
if (--remainingParents[child] == 0)
|
||||
readyNodes.push(child);
|
||||
}
|
||||
}
|
||||
|
||||
return visited == graph.nodes.size();
|
||||
}
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,49 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <cstddef>
|
||||
#include <optional>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "../DCPGraph/Utils.hpp"
|
||||
#include "ComputeInstance.hpp"
|
||||
#include "ComputeInstanceUtils.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
struct ComputeGraphNode {
|
||||
ComputeInstance instance;
|
||||
Weight weight = 0;
|
||||
CrossbarUsage crossbarUsage = 0;
|
||||
size_t originalOrder = 0;
|
||||
};
|
||||
|
||||
struct ComputeGraphEdge {
|
||||
size_t source = 0;
|
||||
size_t target = 0;
|
||||
Weight transferCost = 0;
|
||||
};
|
||||
|
||||
struct ComputeGraph {
|
||||
llvm::SmallVector<ComputeGraphNode> nodes;
|
||||
llvm::SmallVector<ComputeGraphEdge> edges;
|
||||
std::vector<std::vector<std::pair<size_t, Weight>>> successors;
|
||||
std::vector<std::vector<std::pair<size_t, Weight>>> predecessors;
|
||||
llvm::DenseMap<ComputeInstance, size_t> instanceToIndex;
|
||||
};
|
||||
|
||||
ComputeGraph buildComputeGraph(mlir::Operation *entryOp);
|
||||
bool verifyAcyclic(const ComputeGraph &graph);
|
||||
|
||||
Weight getComputeInstanceWeight(const ComputeInstance &instance);
|
||||
CrossbarUsage getComputeInstanceCrossbarUsage(const ComputeInstance &instance);
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,45 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/Operation.h"
|
||||
|
||||
#include "llvm/ADT/DenseMapInfo.h"
|
||||
#include "llvm/ADT/Hashing.h"
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
struct ComputeInstance {
|
||||
mlir::Operation *op = nullptr;
|
||||
uint32_t laneStart = 0;
|
||||
uint32_t laneCount = 1;
|
||||
|
||||
bool operator==(const ComputeInstance &other) const {
|
||||
return op == other.op && laneStart == other.laneStart && laneCount == other.laneCount;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
|
||||
using ComputeInstance = onnx_mlir::spatial::ComputeInstance;
|
||||
|
||||
namespace llvm {
|
||||
template <>
|
||||
struct DenseMapInfo<onnx_mlir::spatial::ComputeInstance> {
|
||||
static onnx_mlir::spatial::ComputeInstance getEmptyKey() {
|
||||
return {DenseMapInfo<mlir::Operation *>::getEmptyKey(), UINT32_MAX, UINT32_MAX};
|
||||
}
|
||||
static onnx_mlir::spatial::ComputeInstance getTombstoneKey() {
|
||||
return {DenseMapInfo<mlir::Operation *>::getTombstoneKey(), UINT32_MAX, UINT32_MAX};
|
||||
}
|
||||
static unsigned getHashValue(const onnx_mlir::spatial::ComputeInstance &value) {
|
||||
return llvm::hash_combine(value.op, value.laneStart, value.laneCount);
|
||||
}
|
||||
static bool isEqual(const onnx_mlir::spatial::ComputeInstance &lhs,
|
||||
const onnx_mlir::spatial::ComputeInstance &rhs) {
|
||||
return lhs == rhs;
|
||||
}
|
||||
};
|
||||
} // namespace llvm
|
||||
+151
@@ -0,0 +1,151 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
||||
#include <limits>
|
||||
|
||||
#include "ComputeInstanceUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
size_t getSchedulingCpuBudget() {
|
||||
if (coresCount.getValue() > 0)
|
||||
return static_cast<size_t>(coresCount.getValue());
|
||||
return std::numeric_limits<size_t>::max();
|
||||
}
|
||||
|
||||
size_t getBatchChunkTargetCount(int32_t laneCount) {
|
||||
assert(laneCount > 0 && "laneCount must be positive");
|
||||
return std::min(static_cast<size_t>(laneCount), std::max<size_t>(1, getSchedulingCpuBudget()));
|
||||
}
|
||||
|
||||
ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex) {
|
||||
size_t totalLanes = static_cast<size_t>(batch.getLaneCount());
|
||||
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount());
|
||||
size_t baseChunkSize = totalLanes / chunkCount;
|
||||
size_t largeChunkCount = totalLanes % chunkCount;
|
||||
|
||||
size_t laneStart = chunkIndex * baseChunkSize + std::min(chunkIndex, largeChunkCount);
|
||||
size_t laneCount = baseChunkSize + (chunkIndex < largeChunkCount ? 1 : 0);
|
||||
return {batch.getOperation(), static_cast<uint32_t>(laneStart), static_cast<uint32_t>(laneCount)};
|
||||
}
|
||||
|
||||
ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane) {
|
||||
size_t totalLanes = static_cast<size_t>(batch.getLaneCount());
|
||||
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount());
|
||||
size_t baseChunkSize = totalLanes / chunkCount;
|
||||
size_t largeChunkCount = totalLanes % chunkCount;
|
||||
size_t largeChunkSpan = largeChunkCount * (baseChunkSize + 1);
|
||||
|
||||
size_t chunkIndex = 0;
|
||||
if (static_cast<size_t>(lane) < largeChunkSpan)
|
||||
chunkIndex = static_cast<size_t>(lane) / (baseChunkSize + 1);
|
||||
else
|
||||
chunkIndex = largeChunkCount + (static_cast<size_t>(lane) - largeChunkSpan) / baseChunkSize;
|
||||
return getBatchChunkForIndex(batch, chunkIndex);
|
||||
}
|
||||
|
||||
SpatCompute getOriginalSpatCompute(Operation *op) {
|
||||
if (!op)
|
||||
return {};
|
||||
|
||||
while (auto extract = dyn_cast<tensor::ExtractSliceOp>(op)) {
|
||||
op = extract.getSource().getDefiningOp();
|
||||
if (!op)
|
||||
return {};
|
||||
}
|
||||
|
||||
return dyn_cast<SpatCompute>(op);
|
||||
}
|
||||
|
||||
std::optional<ProducerValueRef> getProducerValueRef(Value value) {
|
||||
Operation *op = value.getDefiningOp();
|
||||
if (!op)
|
||||
return std::nullopt;
|
||||
|
||||
//TODO Extract Slice is not the only global non compute operation. There are other legal op
|
||||
while (auto extract = dyn_cast<tensor::ExtractSliceOp>(op)) {
|
||||
value = extract.getSource();
|
||||
op = value.getDefiningOp();
|
||||
if (!op)
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
if (auto compute = dyn_cast<SpatCompute>(op)) {
|
||||
return ProducerValueRef {
|
||||
ComputeInstance {compute.getOperation(), 0, 1},
|
||||
static_cast<size_t>(cast<OpResult>(value).getResultNumber())
|
||||
};
|
||||
}
|
||||
|
||||
if (auto batch = dyn_cast<SpatComputeBatch>(op)) {
|
||||
uint32_t lane = static_cast<uint32_t>(cast<OpResult>(value).getResultNumber());
|
||||
ComputeInstance instance = getBatchChunkForLane(batch, lane);
|
||||
size_t resultIndex = static_cast<size_t>(lane - instance.laneStart);
|
||||
return ProducerValueRef {instance, resultIndex};
|
||||
}
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
std::optional<ComputeInstance> getComputeProducerInstance(Value value) {
|
||||
if (std::optional<ProducerValueRef> producer = getProducerValueRef(value))
|
||||
return producer->instance;
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
llvm::SmallVector<Value, 4> getComputeInstanceInputs(const ComputeInstance &instance) {
|
||||
if (auto compute = dyn_cast<SpatCompute>(instance.op))
|
||||
return llvm::SmallVector<Value, 4>(compute.getInputs().begin(), compute.getInputs().end());
|
||||
|
||||
auto batch = cast<SpatComputeBatch>(instance.op);
|
||||
llvm::SmallVector<Value, 4> inputs;
|
||||
inputs.reserve(instance.laneCount);
|
||||
for (uint32_t lane = instance.laneStart; lane < instance.laneStart + instance.laneCount; ++lane)
|
||||
if (!batch.getInputs().empty())
|
||||
inputs.push_back(batch.getInputs()[lane]);
|
||||
return inputs;
|
||||
}
|
||||
|
||||
llvm::SmallVector<Value, 4> getComputeInstanceWeights(const ComputeInstance &instance) {
|
||||
if (auto compute = dyn_cast<SpatCompute>(instance.op))
|
||||
return llvm::SmallVector<Value, 4>(compute.getWeights().begin(), compute.getWeights().end());
|
||||
|
||||
auto batch = cast<SpatComputeBatch>(instance.op);
|
||||
llvm::SmallVector<Value, 4> weights;
|
||||
weights.reserve(instance.laneCount);
|
||||
for (uint32_t lane = instance.laneStart; lane < instance.laneStart + instance.laneCount; ++lane)
|
||||
weights.push_back(batch.getWeights()[lane]);
|
||||
return weights;
|
||||
}
|
||||
|
||||
llvm::SmallVector<Value, 4> getComputeInstanceOutputValues(const ComputeInstance &instance) {
|
||||
if (auto compute = dyn_cast<SpatCompute>(instance.op))
|
||||
return llvm::SmallVector<Value, 4>(compute.getResults().begin(), compute.getResults().end());
|
||||
|
||||
auto batch = cast<SpatComputeBatch>(instance.op);
|
||||
llvm::SmallVector<Value, 4> outputs;
|
||||
outputs.reserve(instance.laneCount);
|
||||
for (uint32_t lane = instance.laneStart; lane < instance.laneStart + instance.laneCount; ++lane)
|
||||
if (!batch.getOutputs().empty())
|
||||
outputs.push_back(batch.getOutputs()[lane]);
|
||||
return outputs;
|
||||
}
|
||||
|
||||
llvm::SmallVector<Type, 4> getComputeInstanceOutputTypes(const ComputeInstance &instance) {
|
||||
llvm::SmallVector<Type, 4> outputTypes;
|
||||
for (Value output : getComputeInstanceOutputValues(instance))
|
||||
outputTypes.push_back(output.getType());
|
||||
return outputTypes;
|
||||
}
|
||||
|
||||
Block &getComputeInstanceTemplateBlock(const ComputeInstance &instance) {
|
||||
if (auto compute = dyn_cast<SpatCompute>(instance.op))
|
||||
return compute.getBody().front();
|
||||
return cast<SpatComputeBatch>(instance.op).getBody().front();
|
||||
}
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
+40
@@ -0,0 +1,40 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/Block.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <cstddef>
|
||||
#include <optional>
|
||||
|
||||
#include "ComputeInstance.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
struct ProducerValueRef {
|
||||
ComputeInstance instance;
|
||||
size_t resultIndex = 0;
|
||||
};
|
||||
|
||||
size_t getSchedulingCpuBudget();
|
||||
size_t getBatchChunkTargetCount(int32_t laneCount);
|
||||
ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex);
|
||||
ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane);
|
||||
|
||||
SpatCompute getOriginalSpatCompute(mlir::Operation *op);
|
||||
std::optional<ProducerValueRef> getProducerValueRef(mlir::Value value);
|
||||
std::optional<ComputeInstance> getComputeProducerInstance(mlir::Value value);
|
||||
|
||||
llvm::SmallVector<mlir::Value, 4> getComputeInstanceInputs(const ComputeInstance &instance);
|
||||
llvm::SmallVector<mlir::Value, 4> getComputeInstanceWeights(const ComputeInstance &instance);
|
||||
llvm::SmallVector<mlir::Value, 4> getComputeInstanceOutputValues(const ComputeInstance &instance);
|
||||
llvm::SmallVector<mlir::Type, 4> getComputeInstanceOutputTypes(const ComputeInstance &instance);
|
||||
mlir::Block &getComputeInstanceTemplateBlock(const ComputeInstance &instance);
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,720 @@
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdlib>
|
||||
#include <limits>
|
||||
#include <numeric>
|
||||
#include <optional>
|
||||
#include <queue>
|
||||
#include <vector>
|
||||
|
||||
#include "DcpScheduler.hpp"
|
||||
#include "../DCPGraph/Graph.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
|
||||
bool isDcpCoarsenDebugEnabled() { return std::getenv("DCP_COARSEN_DEBUG") != nullptr; }
|
||||
|
||||
struct VirtualNode {
|
||||
llvm::SmallVector<size_t, 4> originalNodeIndices;
|
||||
Weight weight = 0;
|
||||
CrossbarUsage crossbarUsage = 0;
|
||||
};
|
||||
|
||||
struct VirtualGraph {
|
||||
std::vector<VirtualNode> nodes;
|
||||
std::vector<IndexedEdge> edges;
|
||||
};
|
||||
|
||||
struct TimingInfo {
|
||||
std::vector<Time> aest;
|
||||
std::vector<Time> alst;
|
||||
std::vector<size_t> topologicalOrder;
|
||||
bool valid = false;
|
||||
};
|
||||
|
||||
struct WindowScheduleResult {
|
||||
std::vector<std::vector<size_t>> mergeGroups;
|
||||
CPU cpuCount = 0;
|
||||
size_t mergedNodeCount = 0;
|
||||
size_t maxMergeGroupSize = 0;
|
||||
};
|
||||
|
||||
size_t getSchedulingCpuBudget(const DcpScheduleOptions &options) {
|
||||
if (options.processorCount > 0)
|
||||
return options.processorCount;
|
||||
return std::numeric_limits<size_t>::max();
|
||||
}
|
||||
|
||||
std::vector<IndexedEdge> aggregateEdges(llvm::ArrayRef<IndexedEdge> edges) {
|
||||
llvm::DenseMap<std::pair<size_t, size_t>, Weight> edgeWeights;
|
||||
for (auto [start, end, weight] : edges) {
|
||||
size_t startIndex = static_cast<size_t>(start);
|
||||
size_t endIndex = static_cast<size_t>(end);
|
||||
if (startIndex == endIndex)
|
||||
continue;
|
||||
auto key = std::make_pair(startIndex, endIndex);
|
||||
Weight edgeWeight = static_cast<Weight>(weight);
|
||||
auto inserted = edgeWeights.try_emplace(key, edgeWeight);
|
||||
if (!inserted.second)
|
||||
inserted.first->second = std::max(inserted.first->second, edgeWeight);
|
||||
}
|
||||
|
||||
std::vector<IndexedEdge> aggregatedEdges;
|
||||
aggregatedEdges.reserve(edgeWeights.size());
|
||||
for (auto [key, weight] : edgeWeights)
|
||||
aggregatedEdges.push_back(
|
||||
{static_cast<int64_t>(key.first), static_cast<int64_t>(key.second), static_cast<int64_t>(weight)});
|
||||
llvm::sort(aggregatedEdges, [](const IndexedEdge &lhs, const IndexedEdge &rhs) {
|
||||
if (std::get<0>(lhs) != std::get<0>(rhs))
|
||||
return std::get<0>(lhs) < std::get<0>(rhs);
|
||||
return std::get<1>(lhs) < std::get<1>(rhs);
|
||||
});
|
||||
return aggregatedEdges;
|
||||
}
|
||||
|
||||
VirtualGraph buildInitialVirtualGraph(const ComputeGraph &graph) {
|
||||
VirtualGraph virtualGraph;
|
||||
virtualGraph.nodes.reserve(graph.nodes.size());
|
||||
for (auto [index, node] : llvm::enumerate(graph.nodes)) {
|
||||
VirtualNode virtualNode;
|
||||
virtualNode.originalNodeIndices.push_back(index);
|
||||
virtualNode.weight = node.weight;
|
||||
virtualNode.crossbarUsage = node.crossbarUsage;
|
||||
virtualGraph.nodes.push_back(std::move(virtualNode));
|
||||
}
|
||||
|
||||
std::vector<IndexedEdge> edges;
|
||||
edges.reserve(graph.edges.size());
|
||||
for (const ComputeGraphEdge &edge : graph.edges)
|
||||
edges.push_back(
|
||||
{static_cast<int64_t>(edge.source), static_cast<int64_t>(edge.target), static_cast<int64_t>(edge.transferCost)});
|
||||
virtualGraph.edges = aggregateEdges(edges);
|
||||
return virtualGraph;
|
||||
}
|
||||
|
||||
TimingInfo computeTiming(const VirtualGraph &graph) {
|
||||
TimingInfo timing;
|
||||
size_t nodeCount = graph.nodes.size();
|
||||
timing.aest.assign(nodeCount, 0);
|
||||
timing.alst.assign(nodeCount, 0);
|
||||
timing.topologicalOrder.reserve(nodeCount);
|
||||
|
||||
std::vector<std::vector<std::pair<size_t, Weight>>> parents(nodeCount);
|
||||
std::vector<std::vector<std::pair<size_t, Weight>>> children(nodeCount);
|
||||
std::vector<size_t> incomingEdgeCount(nodeCount, 0);
|
||||
|
||||
for (auto [start, end, weight] : graph.edges) {
|
||||
size_t startIndex = static_cast<size_t>(start);
|
||||
size_t endIndex = static_cast<size_t>(end);
|
||||
Weight edgeWeight = static_cast<Weight>(weight);
|
||||
assert(startIndex < nodeCount && endIndex < nodeCount && "virtual edge endpoint out of range");
|
||||
children[startIndex].push_back({endIndex, edgeWeight});
|
||||
parents[endIndex].push_back({startIndex, edgeWeight});
|
||||
incomingEdgeCount[endIndex]++;
|
||||
}
|
||||
|
||||
auto getVirtualNodeOrderKey = [&](size_t nodeIndex) {
|
||||
const VirtualNode &node = graph.nodes[nodeIndex];
|
||||
if (!node.originalNodeIndices.empty())
|
||||
return node.originalNodeIndices.front();
|
||||
return nodeIndex;
|
||||
};
|
||||
auto readyNodeGreater = [&](size_t lhs, size_t rhs) {
|
||||
size_t lhsKey = getVirtualNodeOrderKey(lhs);
|
||||
size_t rhsKey = getVirtualNodeOrderKey(rhs);
|
||||
if (lhsKey != rhsKey)
|
||||
return lhsKey > rhsKey;
|
||||
return lhs > rhs;
|
||||
};
|
||||
std::priority_queue<size_t, std::vector<size_t>, decltype(readyNodeGreater)> readyNodes(readyNodeGreater);
|
||||
for (size_t i = 0; i < nodeCount; ++i)
|
||||
if (incomingEdgeCount[i] == 0)
|
||||
readyNodes.push(i);
|
||||
|
||||
while (!readyNodes.empty()) {
|
||||
size_t current = readyNodes.top();
|
||||
readyNodes.pop();
|
||||
timing.topologicalOrder.push_back(current);
|
||||
for (auto [child, weight] : children[current]) {
|
||||
(void) weight;
|
||||
assert(incomingEdgeCount[child] > 0 && "incoming edge count underflow");
|
||||
incomingEdgeCount[child]--;
|
||||
if (incomingEdgeCount[child] == 0)
|
||||
readyNodes.push(child);
|
||||
}
|
||||
}
|
||||
|
||||
if (timing.topologicalOrder.size() != nodeCount)
|
||||
return timing;
|
||||
|
||||
Time dcpl = 0;
|
||||
for (size_t nodeIndex : timing.topologicalOrder) {
|
||||
Time maxParentAest = 0;
|
||||
for (auto [parent, transferCost] : parents[nodeIndex]) {
|
||||
maxParentAest =
|
||||
std::max(maxParentAest, addOrMax(addOrMax(timing.aest[parent], graph.nodes[parent].weight), transferCost));
|
||||
}
|
||||
timing.aest[nodeIndex] = maxParentAest;
|
||||
dcpl = std::max(dcpl, addOrMax(maxParentAest, graph.nodes[nodeIndex].weight));
|
||||
}
|
||||
|
||||
for (size_t nodeIndex : llvm::reverse(timing.topologicalOrder)) {
|
||||
Time minAlst = std::numeric_limits<Time>::max();
|
||||
if (children[nodeIndex].empty())
|
||||
minAlst = subtractOrZero(dcpl, graph.nodes[nodeIndex].weight);
|
||||
for (auto [child, transferCost] : children[nodeIndex]) {
|
||||
minAlst =
|
||||
std::min(minAlst, subtractOrZero(timing.alst[child], addOrMax(graph.nodes[nodeIndex].weight, transferCost)));
|
||||
}
|
||||
timing.alst[nodeIndex] = minAlst;
|
||||
}
|
||||
|
||||
timing.valid = true;
|
||||
return timing;
|
||||
}
|
||||
|
||||
std::vector<std::vector<size_t>> buildUndirectedAdjacency(const VirtualGraph &graph) {
|
||||
std::vector<std::vector<size_t>> adjacency(graph.nodes.size());
|
||||
for (auto [start, end, weight] : graph.edges) {
|
||||
(void) weight;
|
||||
size_t startIndex = static_cast<size_t>(start);
|
||||
size_t endIndex = static_cast<size_t>(end);
|
||||
assert(startIndex < graph.nodes.size() && endIndex < graph.nodes.size() && "virtual edge endpoint out of range");
|
||||
adjacency[startIndex].push_back(endIndex);
|
||||
adjacency[endIndex].push_back(startIndex);
|
||||
}
|
||||
for (auto &neighbours : adjacency) {
|
||||
llvm::sort(neighbours);
|
||||
neighbours.erase(std::unique(neighbours.begin(), neighbours.end()), neighbours.end());
|
||||
}
|
||||
return adjacency;
|
||||
}
|
||||
|
||||
std::vector<size_t> selectCriticalWindow(const VirtualGraph &graph, const TimingInfo &timing, size_t windowSize) {
|
||||
std::vector<size_t> ranked(timing.aest.size());
|
||||
std::iota(ranked.begin(), ranked.end(), 0);
|
||||
auto isHigherPriority = [&](size_t lhs, size_t rhs) {
|
||||
Time lhsSlack = slackOrZero(timing.aest[lhs], timing.alst[lhs]);
|
||||
Time rhsSlack = slackOrZero(timing.aest[rhs], timing.alst[rhs]);
|
||||
if (lhsSlack != rhsSlack)
|
||||
return lhsSlack < rhsSlack;
|
||||
if (timing.aest[lhs] != timing.aest[rhs])
|
||||
return timing.aest[lhs] < timing.aest[rhs];
|
||||
return lhs < rhs;
|
||||
};
|
||||
|
||||
windowSize = std::min(windowSize, ranked.size());
|
||||
if (windowSize == 0)
|
||||
return {};
|
||||
if (windowSize == ranked.size()) {
|
||||
llvm::sort(ranked, isHigherPriority);
|
||||
return ranked;
|
||||
}
|
||||
|
||||
size_t criticalPoolSize = std::min(ranked.size(), std::max(windowSize, windowSize * 2));
|
||||
if (criticalPoolSize < ranked.size())
|
||||
std::nth_element(
|
||||
ranked.begin(), ranked.begin() + static_cast<std::ptrdiff_t>(criticalPoolSize), ranked.end(), isHigherPriority);
|
||||
|
||||
std::vector<char> inCriticalPool(ranked.size(), false);
|
||||
for (size_t i = 0; i < criticalPoolSize; ++i)
|
||||
inCriticalPool[ranked[i]] = true;
|
||||
|
||||
size_t seed = *std::min_element(ranked.begin(), ranked.end(), isHigherPriority);
|
||||
std::vector<std::vector<size_t>> adjacency = buildUndirectedAdjacency(graph);
|
||||
std::vector<size_t> selected;
|
||||
std::vector<char> inWindow(ranked.size(), false);
|
||||
selected.reserve(windowSize);
|
||||
|
||||
struct FrontierEntry {
|
||||
size_t node;
|
||||
};
|
||||
auto frontierCompare = [&](FrontierEntry lhs, FrontierEntry rhs) { return isHigherPriority(rhs.node, lhs.node); };
|
||||
std::priority_queue<FrontierEntry, std::vector<FrontierEntry>, decltype(frontierCompare)> frontier(frontierCompare);
|
||||
|
||||
auto addToWindow = [&](size_t node, const std::vector<char> &eligible) {
|
||||
if (inWindow[node])
|
||||
return;
|
||||
inWindow[node] = true;
|
||||
selected.push_back(node);
|
||||
for (size_t neighbour : adjacency[node])
|
||||
if (!inWindow[neighbour] && eligible[neighbour])
|
||||
frontier.push({neighbour});
|
||||
};
|
||||
|
||||
addToWindow(seed, inCriticalPool);
|
||||
while (!frontier.empty() && selected.size() < windowSize) {
|
||||
size_t node = frontier.top().node;
|
||||
frontier.pop();
|
||||
if (!inWindow[node])
|
||||
addToWindow(node, inCriticalPool);
|
||||
}
|
||||
|
||||
if (selected.size() < windowSize) {
|
||||
std::vector<char> anyNode(ranked.size(), true);
|
||||
for (size_t node : selected)
|
||||
for (size_t neighbour : adjacency[node])
|
||||
if (!inWindow[neighbour])
|
||||
frontier.push({neighbour});
|
||||
while (!frontier.empty() && selected.size() < windowSize) {
|
||||
size_t node = frontier.top().node;
|
||||
frontier.pop();
|
||||
if (!inWindow[node])
|
||||
addToWindow(node, anyNode);
|
||||
}
|
||||
}
|
||||
|
||||
if (selected.size() < windowSize) {
|
||||
llvm::sort(ranked, isHigherPriority);
|
||||
for (size_t node : ranked) {
|
||||
if (selected.size() == windowSize)
|
||||
break;
|
||||
if (!inWindow[node]) {
|
||||
inWindow[node] = true;
|
||||
selected.push_back(node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
llvm::sort(selected, isHigherPriority);
|
||||
return selected;
|
||||
}
|
||||
|
||||
std::vector<IndexedEdge> buildWindowEdges(const VirtualGraph &graph, const std::vector<int64_t> &nodeToWindowIndex) {
|
||||
std::vector<IndexedEdge> windowEdges;
|
||||
windowEdges.reserve(graph.edges.size());
|
||||
for (auto [start, end, weight] : graph.edges) {
|
||||
int64_t mappedStart = nodeToWindowIndex[static_cast<size_t>(start)];
|
||||
int64_t mappedEnd = nodeToWindowIndex[static_cast<size_t>(end)];
|
||||
if (mappedStart == -1 || mappedEnd == -1)
|
||||
continue;
|
||||
windowEdges.push_back({mappedStart, mappedEnd, weight});
|
||||
}
|
||||
return aggregateEdges(windowEdges);
|
||||
}
|
||||
|
||||
WindowScheduleResult scheduleWindow(const VirtualGraph &graph,
|
||||
llvm::ArrayRef<size_t> selectedNodes,
|
||||
const DcpScheduleOptions &options,
|
||||
mlir::MLIRContext *context) {
|
||||
std::vector<Weight> windowWeights;
|
||||
std::vector<CrossbarUsage> windowCrossbarUsage;
|
||||
std::vector<int64_t> windowNodeOrderKeys;
|
||||
std::vector<int64_t> nodeToWindowIndex(graph.nodes.size(), -1);
|
||||
windowWeights.reserve(selectedNodes.size());
|
||||
windowCrossbarUsage.reserve(selectedNodes.size());
|
||||
windowNodeOrderKeys.reserve(selectedNodes.size());
|
||||
|
||||
for (auto [windowIndex, nodeIndex] : llvm::enumerate(selectedNodes)) {
|
||||
nodeToWindowIndex[nodeIndex] = static_cast<int64_t>(windowIndex);
|
||||
windowWeights.push_back(graph.nodes[nodeIndex].weight);
|
||||
windowCrossbarUsage.push_back(graph.nodes[nodeIndex].crossbarUsage);
|
||||
windowNodeOrderKeys.push_back(static_cast<int64_t>(nodeIndex));
|
||||
}
|
||||
|
||||
GraphDCP windowGraph(
|
||||
windowWeights, buildWindowEdges(graph, nodeToWindowIndex), windowNodeOrderKeys, windowCrossbarUsage);
|
||||
if (options.processorCount > 0)
|
||||
windowGraph.setMaxCpuCount(static_cast<int>(options.processorCount));
|
||||
windowGraph.setContext(context);
|
||||
windowGraph.runDcp();
|
||||
|
||||
WindowScheduleResult result;
|
||||
result.cpuCount = windowGraph.cpuCount();
|
||||
for (CPU cpu = 0; cpu < windowGraph.cpuCount(); ++cpu) {
|
||||
auto scheduledTasks = windowGraph.getScheduledTasks(cpu);
|
||||
if (scheduledTasks.size() < 2)
|
||||
continue;
|
||||
|
||||
result.mergedNodeCount += scheduledTasks.size();
|
||||
result.maxMergeGroupSize = std::max(result.maxMergeGroupSize, scheduledTasks.size());
|
||||
std::vector<size_t> mergeGroup;
|
||||
mergeGroup.reserve(scheduledTasks.size());
|
||||
for (const auto &task : scheduledTasks)
|
||||
mergeGroup.push_back(selectedNodes[task.nodeIndex]);
|
||||
result.mergeGroups.push_back(std::move(mergeGroup));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
bool coarsenGraph(const VirtualGraph &graph,
|
||||
llvm::ArrayRef<std::vector<size_t>> mergeGroups,
|
||||
VirtualGraph &coarsenedGraph,
|
||||
std::vector<size_t> &oldToNewNode) {
|
||||
TimingInfo timing = computeTiming(graph);
|
||||
std::vector<size_t> topologicalRank(graph.nodes.size());
|
||||
std::iota(topologicalRank.begin(), topologicalRank.end(), 0);
|
||||
if (timing.valid)
|
||||
for (auto [rank, nodeIndex] : llvm::enumerate(timing.topologicalOrder))
|
||||
topologicalRank[nodeIndex] = rank;
|
||||
|
||||
std::vector<std::vector<size_t>> orderedMergeGroups;
|
||||
orderedMergeGroups.reserve(mergeGroups.size());
|
||||
for (const auto &mergeGroup : mergeGroups) {
|
||||
orderedMergeGroups.emplace_back(mergeGroup.begin(), mergeGroup.end());
|
||||
std::stable_sort(orderedMergeGroups.back().begin(), orderedMergeGroups.back().end(), [&](size_t lhs, size_t rhs) {
|
||||
if (topologicalRank[lhs] != topologicalRank[rhs])
|
||||
return topologicalRank[lhs] < topologicalRank[rhs];
|
||||
return lhs < rhs;
|
||||
});
|
||||
}
|
||||
|
||||
std::vector<int64_t> nodeToMergeGroup(graph.nodes.size(), -1);
|
||||
for (auto [groupIndex, mergeGroup] : llvm::enumerate(orderedMergeGroups)) {
|
||||
if (mergeGroup.size() < 2)
|
||||
continue;
|
||||
for (size_t nodeIndex : mergeGroup) {
|
||||
assert(nodeIndex < graph.nodes.size() && "merge group node out of range");
|
||||
nodeToMergeGroup[nodeIndex] = static_cast<int64_t>(groupIndex);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::optional<size_t>> mergeGroupToNewNode(orderedMergeGroups.size());
|
||||
std::vector<size_t> newNodeRank;
|
||||
oldToNewNode.assign(graph.nodes.size(), 0);
|
||||
bool mergedAny = false;
|
||||
coarsenedGraph.nodes.clear();
|
||||
coarsenedGraph.edges.clear();
|
||||
coarsenedGraph.nodes.reserve(graph.nodes.size());
|
||||
newNodeRank.reserve(graph.nodes.size());
|
||||
|
||||
for (size_t nodeIndex = 0; nodeIndex < graph.nodes.size(); ++nodeIndex) {
|
||||
int64_t mergeGroupIndex = nodeToMergeGroup[nodeIndex];
|
||||
if (mergeGroupIndex == -1) {
|
||||
oldToNewNode[nodeIndex] = coarsenedGraph.nodes.size();
|
||||
coarsenedGraph.nodes.push_back(graph.nodes[nodeIndex]);
|
||||
newNodeRank.push_back(topologicalRank[nodeIndex]);
|
||||
continue;
|
||||
}
|
||||
|
||||
auto &newNodeIndex = mergeGroupToNewNode[static_cast<size_t>(mergeGroupIndex)];
|
||||
if (newNodeIndex.has_value()) {
|
||||
oldToNewNode[nodeIndex] = *newNodeIndex;
|
||||
continue;
|
||||
}
|
||||
|
||||
VirtualNode mergedNode;
|
||||
for (size_t memberIndex : orderedMergeGroups[static_cast<size_t>(mergeGroupIndex)]) {
|
||||
const VirtualNode &memberNode = graph.nodes[memberIndex];
|
||||
mergedNode.originalNodeIndices.append(memberNode.originalNodeIndices.begin(), memberNode.originalNodeIndices.end());
|
||||
mergedNode.weight = addOrMax(mergedNode.weight, memberNode.weight);
|
||||
mergedNode.crossbarUsage = addOrMax(mergedNode.crossbarUsage, memberNode.crossbarUsage);
|
||||
}
|
||||
std::sort(mergedNode.originalNodeIndices.begin(), mergedNode.originalNodeIndices.end());
|
||||
|
||||
mergedAny = true;
|
||||
newNodeIndex = coarsenedGraph.nodes.size();
|
||||
for (size_t memberIndex : orderedMergeGroups[static_cast<size_t>(mergeGroupIndex)])
|
||||
oldToNewNode[memberIndex] = *newNodeIndex;
|
||||
newNodeRank.push_back(topologicalRank[orderedMergeGroups[static_cast<size_t>(mergeGroupIndex)].front()]);
|
||||
coarsenedGraph.nodes.push_back(std::move(mergedNode));
|
||||
}
|
||||
|
||||
if (!mergedAny)
|
||||
return false;
|
||||
|
||||
std::vector<IndexedEdge> remappedEdges;
|
||||
remappedEdges.reserve(graph.edges.size());
|
||||
for (auto [start, end, weight] : graph.edges) {
|
||||
size_t newStart = oldToNewNode[static_cast<size_t>(start)];
|
||||
size_t newEnd = oldToNewNode[static_cast<size_t>(end)];
|
||||
if (newStart == newEnd)
|
||||
continue;
|
||||
if (newNodeRank[newStart] >= newNodeRank[newEnd])
|
||||
continue;
|
||||
remappedEdges.push_back({static_cast<int64_t>(newStart), static_cast<int64_t>(newEnd), weight});
|
||||
}
|
||||
coarsenedGraph.edges = aggregateEdges(remappedEdges);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
size_t getDcpCoarseningWindowSize(size_t nodeCount, const DcpScheduleOptions &options) {
|
||||
size_t windowSize = std::min(options.criticalWindowSize, nodeCount);
|
||||
CPU maxCpuCount = std::max<CPU>(1, static_cast<CPU>(getSchedulingCpuBudget(options)));
|
||||
if (nodeCount > static_cast<size_t>(maxCpuCount))
|
||||
windowSize = std::max(windowSize, std::min(nodeCount, static_cast<size_t>(maxCpuCount) + 1));
|
||||
return windowSize;
|
||||
}
|
||||
|
||||
void assignFeasibleAest(const ComputeGraph &graph, MergeScheduleResult &result) {
|
||||
llvm::DenseMap<ComputeInstance, size_t> nodeIndexByInstance;
|
||||
nodeIndexByInstance.reserve(graph.nodes.size());
|
||||
for (auto [nodeIndex, node] : llvm::enumerate(graph.nodes))
|
||||
nodeIndexByInstance[node.instance] = nodeIndex;
|
||||
|
||||
struct ScheduledEdge {
|
||||
size_t target = 0;
|
||||
Time delay = 0;
|
||||
};
|
||||
|
||||
std::vector<std::vector<ScheduledEdge>> scheduledChildren(graph.nodes.size());
|
||||
std::vector<size_t> incomingEdgeCount(graph.nodes.size(), 0);
|
||||
for (const ComputeGraphEdge &edge : graph.edges) {
|
||||
const ComputeInstance sourceInstance = graph.nodes[edge.source].instance;
|
||||
const ComputeInstance targetInstance = graph.nodes[edge.target].instance;
|
||||
const size_t sourceCpu = result.computeToCpuMap.lookup(sourceInstance);
|
||||
const size_t targetCpu = result.computeToCpuMap.lookup(targetInstance);
|
||||
|
||||
Time delay = graph.nodes[edge.source].weight;
|
||||
if (sourceCpu != targetCpu)
|
||||
delay = addOrMax(delay, edge.transferCost);
|
||||
|
||||
scheduledChildren[edge.source].push_back({edge.target, delay});
|
||||
incomingEdgeCount[edge.target]++;
|
||||
}
|
||||
|
||||
llvm::DenseMap<size_t, std::vector<std::pair<size_t, size_t>>> tasksByCpu;
|
||||
for (const ComputeGraphNode &node : graph.nodes) {
|
||||
size_t cpu = result.computeToCpuMap.lookup(node.instance);
|
||||
size_t slot = result.computeToCpuSlotMap.lookup(node.instance);
|
||||
tasksByCpu[cpu].push_back({slot, nodeIndexByInstance.lookup(node.instance)});
|
||||
}
|
||||
|
||||
for (auto &entry : tasksByCpu) {
|
||||
auto &scheduledTasks = entry.second;
|
||||
llvm::sort(scheduledTasks, [](const auto &lhs, const auto &rhs) {
|
||||
if (lhs.first != rhs.first)
|
||||
return lhs.first < rhs.first;
|
||||
return lhs.second < rhs.second;
|
||||
});
|
||||
|
||||
for (size_t i = 1; i < scheduledTasks.size(); ++i) {
|
||||
size_t sourceIndex = scheduledTasks[i - 1].second;
|
||||
size_t targetIndex = scheduledTasks[i].second;
|
||||
scheduledChildren[sourceIndex].push_back({targetIndex, graph.nodes[sourceIndex].weight});
|
||||
incomingEdgeCount[targetIndex]++;
|
||||
}
|
||||
}
|
||||
|
||||
auto readyNodeGreater = [&](size_t lhs, size_t rhs) {
|
||||
if (graph.nodes[lhs].originalOrder != graph.nodes[rhs].originalOrder)
|
||||
return graph.nodes[lhs].originalOrder > graph.nodes[rhs].originalOrder;
|
||||
return lhs > rhs;
|
||||
};
|
||||
std::priority_queue<size_t, std::vector<size_t>, decltype(readyNodeGreater)> readyNodes(readyNodeGreater);
|
||||
for (size_t nodeIndex = 0; nodeIndex < graph.nodes.size(); ++nodeIndex)
|
||||
if (incomingEdgeCount[nodeIndex] == 0)
|
||||
readyNodes.push(nodeIndex);
|
||||
|
||||
std::vector<Time> startTimes(graph.nodes.size(), 0);
|
||||
size_t processedNodeCount = 0;
|
||||
while (!readyNodes.empty()) {
|
||||
size_t sourceIndex = readyNodes.top();
|
||||
readyNodes.pop();
|
||||
processedNodeCount++;
|
||||
|
||||
for (const ScheduledEdge &edge : scheduledChildren[sourceIndex]) {
|
||||
startTimes[edge.target] = std::max(startTimes[edge.target], addOrMax(startTimes[sourceIndex], edge.delay));
|
||||
assert(incomingEdgeCount[edge.target] > 0 && "scheduled incoming edge count underflow");
|
||||
incomingEdgeCount[edge.target]--;
|
||||
if (incomingEdgeCount[edge.target] == 0)
|
||||
readyNodes.push(edge.target);
|
||||
}
|
||||
}
|
||||
|
||||
if (processedNodeCount != graph.nodes.size())
|
||||
llvm::report_fatal_error("merge scheduling: coarsened DCP schedule is cyclic");
|
||||
|
||||
for (auto [nodeIndex, node] : llvm::enumerate(graph.nodes))
|
||||
result.computeToAestMap[node.instance] = startTimes[nodeIndex];
|
||||
}
|
||||
|
||||
MergeScheduleResult buildResultFromVirtualGraph(const VirtualGraph &graph, const ComputeGraph &originalGraph) {
|
||||
MergeScheduleResult result;
|
||||
|
||||
TimingInfo timing = computeTiming(graph);
|
||||
std::vector<size_t> virtualNodeOrder;
|
||||
if (timing.valid)
|
||||
virtualNodeOrder = std::move(timing.topologicalOrder);
|
||||
else {
|
||||
virtualNodeOrder.resize(graph.nodes.size());
|
||||
std::iota(virtualNodeOrder.begin(), virtualNodeOrder.end(), 0);
|
||||
}
|
||||
|
||||
std::vector<size_t> originalNodeToCpu(originalGraph.nodes.size(), 0);
|
||||
for (auto [cpu, virtualNodeIndex] : llvm::enumerate(virtualNodeOrder)) {
|
||||
const VirtualNode &virtualNode = graph.nodes[virtualNodeIndex];
|
||||
for (size_t originalIndex : virtualNode.originalNodeIndices)
|
||||
originalNodeToCpu[originalIndex] = cpu;
|
||||
}
|
||||
|
||||
result.dominanceOrderCompute.reserve(originalGraph.nodes.size());
|
||||
llvm::DenseMap<size_t, size_t> nextCpuSlot;
|
||||
for (auto [originalIndex, node] : llvm::enumerate(originalGraph.nodes)) {
|
||||
size_t cpu = originalNodeToCpu[originalIndex];
|
||||
result.dominanceOrderCompute.push_back(node.instance);
|
||||
result.computeToCpuMap[node.instance] = cpu;
|
||||
result.computeToCpuSlotMap[node.instance] = nextCpuSlot[cpu]++;
|
||||
result.cpuToLastComputeMap[cpu] = node.instance;
|
||||
}
|
||||
for (const auto &[cpu, lastCompute] : result.cpuToLastComputeMap)
|
||||
result.isLastComputeOfCpu.insert(lastCompute);
|
||||
assignFeasibleAest(originalGraph, result);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
MergeScheduleResult buildResultFromScheduledGraph(GraphDCP &graphDCP, const ComputeGraph &graph) {
|
||||
MergeScheduleResult result;
|
||||
result.dominanceOrderCompute.reserve(graph.nodes.size());
|
||||
for (const ComputeGraphNode &node : graph.nodes)
|
||||
result.dominanceOrderCompute.push_back(node.instance);
|
||||
|
||||
for (CPU cpu = 0; cpu < graphDCP.cpuCount(); ++cpu) {
|
||||
auto scheduledTasks = graphDCP.getScheduledTasks(cpu);
|
||||
if (scheduledTasks.empty())
|
||||
continue;
|
||||
|
||||
for (auto [slot, task] : llvm::enumerate(scheduledTasks)) {
|
||||
const ComputeInstance instance = graph.nodes[task.nodeIndex].instance;
|
||||
result.computeToCpuMap[instance] = cpu;
|
||||
result.computeToCpuSlotMap[instance] = slot;
|
||||
result.computeToAestMap[instance] = static_cast<uint64_t>(task.aest);
|
||||
}
|
||||
|
||||
const ComputeInstance lastInstance = graph.nodes[scheduledTasks.back().nodeIndex].instance;
|
||||
result.cpuToLastComputeMap[cpu] = lastInstance;
|
||||
result.isLastComputeOfCpu.insert(lastInstance);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
MergeScheduleResult runLegacyDcp(const ComputeGraph &graph, const DcpScheduleOptions &options, mlir::MLIRContext *context) {
|
||||
llvm::SmallVector<Weight> nodeWeights;
|
||||
llvm::SmallVector<CrossbarUsage> nodeCrossbarUsage;
|
||||
llvm::SmallVector<int64_t> nodeOrderKeys;
|
||||
llvm::SmallVector<IndexedEdge> edges;
|
||||
nodeWeights.reserve(graph.nodes.size());
|
||||
nodeCrossbarUsage.reserve(graph.nodes.size());
|
||||
nodeOrderKeys.reserve(graph.nodes.size());
|
||||
edges.reserve(graph.edges.size());
|
||||
|
||||
for (const ComputeGraphNode &node : graph.nodes) {
|
||||
nodeWeights.push_back(node.weight);
|
||||
nodeCrossbarUsage.push_back(node.crossbarUsage);
|
||||
nodeOrderKeys.push_back(static_cast<int64_t>(node.originalOrder));
|
||||
}
|
||||
for (const ComputeGraphEdge &edge : graph.edges) {
|
||||
edges.push_back(
|
||||
{static_cast<int64_t>(edge.source), static_cast<int64_t>(edge.target), static_cast<int64_t>(edge.transferCost)});
|
||||
}
|
||||
|
||||
GraphDCP graphDCP(nodeWeights, edges, nodeOrderKeys, nodeCrossbarUsage);
|
||||
if (options.processorCount > 0)
|
||||
graphDCP.setMaxCpuCount(static_cast<int>(options.processorCount));
|
||||
graphDCP.setContext(context);
|
||||
graphDCP.runDcp();
|
||||
return buildResultFromScheduledGraph(graphDCP, graph);
|
||||
}
|
||||
|
||||
bool needsExactScheduledBatches(const ComputeGraph &graph, const DcpScheduleOptions &options) {
|
||||
if (options.processorCount == 0 || !options.allowFallbackForAutoCoreCount)
|
||||
return false;
|
||||
size_t schedulingCpuBudget = getSchedulingCpuBudget(options);
|
||||
return llvm::any_of(graph.nodes, [&](const ComputeGraphNode &node) {
|
||||
auto batch = dyn_cast<SpatComputeBatch>(node.instance.op);
|
||||
return batch && static_cast<size_t>(batch.getLaneCount()) > schedulingCpuBudget;
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
MergeScheduleResult
|
||||
runDcpScheduler(const ComputeGraph &graph, const DcpScheduleOptions &options, mlir::MLIRContext *context) {
|
||||
if (needsExactScheduledBatches(graph, options))
|
||||
return runLegacyDcp(graph, options, context);
|
||||
|
||||
if (options.criticalWindowSize == 0)
|
||||
return runLegacyDcp(graph, options, context);
|
||||
|
||||
VirtualGraph virtualGraph = buildInitialVirtualGraph(graph);
|
||||
size_t iteration = 0;
|
||||
bool debugCoarsening = isDcpCoarsenDebugEnabled();
|
||||
auto tryCoarsenSelectedNodes = [&](llvm::ArrayRef<size_t> selectedNodes) {
|
||||
size_t oldNodeCount = virtualGraph.nodes.size();
|
||||
WindowScheduleResult windowSchedule = scheduleWindow(virtualGraph, selectedNodes, options, context);
|
||||
if (windowSchedule.mergeGroups.empty()) {
|
||||
if (debugCoarsening && oldNodeCount >= 200)
|
||||
llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} windowCpus={3} "
|
||||
"groups=0 mergedNodes=0 maxGroup=0 new={1} changed=0\n",
|
||||
iteration,
|
||||
oldNodeCount,
|
||||
selectedNodes.size(),
|
||||
windowSchedule.cpuCount);
|
||||
return false;
|
||||
}
|
||||
|
||||
VirtualGraph coarsenedGraph;
|
||||
std::vector<size_t> oldToNewNode;
|
||||
if (!coarsenGraph(virtualGraph, windowSchedule.mergeGroups, coarsenedGraph, oldToNewNode))
|
||||
return false;
|
||||
if (debugCoarsening && (oldNodeCount >= 200 || coarsenedGraph.nodes.size() >= 200))
|
||||
llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} windowCpus={3} "
|
||||
"groups={4} mergedNodes={5} maxGroup={6} new={7} changed={8}\n",
|
||||
iteration,
|
||||
oldNodeCount,
|
||||
selectedNodes.size(),
|
||||
windowSchedule.cpuCount,
|
||||
windowSchedule.mergeGroups.size(),
|
||||
windowSchedule.mergedNodeCount,
|
||||
windowSchedule.maxMergeGroupSize,
|
||||
coarsenedGraph.nodes.size(),
|
||||
oldNodeCount - coarsenedGraph.nodes.size());
|
||||
virtualGraph = std::move(coarsenedGraph);
|
||||
return true;
|
||||
};
|
||||
|
||||
while (virtualGraph.nodes.size() > 1) {
|
||||
if (virtualGraph.nodes.size() <= getSchedulingCpuBudget(options)) {
|
||||
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
|
||||
llvm::errs() << llvm::formatv(
|
||||
"[DCP-COARSEN] iter={0} old={1} stop=cpu-budget\n", iteration, virtualGraph.nodes.size());
|
||||
break;
|
||||
}
|
||||
|
||||
iteration++;
|
||||
TimingInfo timing = computeTiming(virtualGraph);
|
||||
if (!timing.valid) {
|
||||
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
|
||||
llvm::errs() << llvm::formatv(
|
||||
"[DCP-COARSEN] iter={0} old={1} invalid-timing\n", iteration, virtualGraph.nodes.size());
|
||||
break;
|
||||
}
|
||||
|
||||
llvm::SmallVector<size_t> selectedNodes;
|
||||
auto criticalWindow =
|
||||
selectCriticalWindow(virtualGraph, timing, getDcpCoarseningWindowSize(virtualGraph.nodes.size(), options));
|
||||
selectedNodes.append(criticalWindow.begin(), criticalWindow.end());
|
||||
|
||||
if (selectedNodes.size() < 2) {
|
||||
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
|
||||
llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} stop=small-window\n",
|
||||
iteration,
|
||||
virtualGraph.nodes.size(),
|
||||
selectedNodes.size());
|
||||
break;
|
||||
}
|
||||
|
||||
if (tryCoarsenSelectedNodes(selectedNodes))
|
||||
continue;
|
||||
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
|
||||
llvm::errs() << llvm::formatv(
|
||||
"[DCP-COARSEN] iter={0} old={1} stop=no-merge\n", iteration, virtualGraph.nodes.size());
|
||||
break;
|
||||
}
|
||||
|
||||
return buildResultFromVirtualGraph(virtualGraph, graph);
|
||||
}
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,21 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
|
||||
#include "ComputeGraph.hpp"
|
||||
#include "MergeSchedule.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
struct DcpScheduleOptions {
|
||||
size_t processorCount = 0;
|
||||
size_t criticalWindowSize = 0;
|
||||
bool allowFallbackForAutoCoreCount = true;
|
||||
};
|
||||
|
||||
MergeScheduleResult
|
||||
runDcpScheduler(const ComputeGraph &graph, const DcpScheduleOptions &options, mlir::MLIRContext *context);
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,25 @@
|
||||
#pragma once
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
|
||||
#include "ComputeInstance.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
struct MergeScheduleResult {
|
||||
std::vector<ComputeInstance> dominanceOrderCompute;
|
||||
llvm::DenseMap<ComputeInstance, size_t> computeToCpuMap;
|
||||
llvm::DenseMap<ComputeInstance, size_t> computeToCpuSlotMap;
|
||||
llvm::DenseMap<ComputeInstance, uint64_t> computeToAestMap;
|
||||
llvm::DenseSet<ComputeInstance> isLastComputeOfCpu;
|
||||
llvm::DenseMap<size_t, ComputeInstance> cpuToLastComputeMap;
|
||||
};
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
+139
@@ -0,0 +1,139 @@
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
|
||||
#include "ComputeGraph.hpp"
|
||||
#include "../DCPGraph/DCPAnalysis.hpp"
|
||||
#include "DcpScheduler.hpp"
|
||||
#include "MergeSchedulingAnalysis.hpp"
|
||||
#include "PeftScheduler.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
namespace {
|
||||
|
||||
MergeSchedulerKind getSchedulerKind() {
|
||||
switch (pimMergeScheduler.getValue()) {
|
||||
case MergeSchedulerPeft:
|
||||
return MergeSchedulerKind::Peft;
|
||||
case MergeSchedulerDcp:
|
||||
return MergeSchedulerKind::Dcp;
|
||||
}
|
||||
llvm_unreachable("unknown merge scheduler kind");
|
||||
}
|
||||
|
||||
void verifySchedule(const ComputeGraph &graph, const MergeScheduleResult &result, CrossbarUsage crossbarCapacity) {
|
||||
llvm::DenseMap<size_t, std::vector<std::pair<size_t, size_t>>> tasksByCpu;
|
||||
tasksByCpu.reserve(result.cpuToLastComputeMap.size());
|
||||
|
||||
for (size_t nodeIndex = 0; nodeIndex < graph.nodes.size(); ++nodeIndex) {
|
||||
const ComputeInstance instance = graph.nodes[nodeIndex].instance;
|
||||
if (!result.computeToCpuMap.count(instance))
|
||||
llvm::report_fatal_error("merge scheduling: missing CPU assignment");
|
||||
if (!result.computeToCpuSlotMap.count(instance))
|
||||
llvm::report_fatal_error("merge scheduling: missing CPU slot assignment");
|
||||
if (!result.computeToAestMap.count(instance))
|
||||
llvm::report_fatal_error("merge scheduling: missing start time");
|
||||
|
||||
tasksByCpu[result.computeToCpuMap.lookup(instance)].push_back(
|
||||
{result.computeToCpuSlotMap.lookup(instance), nodeIndex});
|
||||
}
|
||||
|
||||
for (auto &entry : tasksByCpu) {
|
||||
auto &scheduledTasks = entry.second;
|
||||
llvm::sort(scheduledTasks, [](const auto &lhs, const auto &rhs) {
|
||||
if (lhs.first != rhs.first)
|
||||
return lhs.first < rhs.first;
|
||||
return lhs.second < rhs.second;
|
||||
});
|
||||
|
||||
CrossbarUsage usedCrossbars = 0;
|
||||
for (size_t slot = 0; slot < scheduledTasks.size(); ++slot) {
|
||||
if (scheduledTasks[slot].first != slot)
|
||||
llvm::report_fatal_error("merge scheduling: CPU slots are not contiguous");
|
||||
usedCrossbars = addOrMax(usedCrossbars, graph.nodes[scheduledTasks[slot].second].crossbarUsage);
|
||||
if (usedCrossbars > crossbarCapacity)
|
||||
llvm::report_fatal_error("merge scheduling: CPU crossbar capacity exceeded");
|
||||
}
|
||||
|
||||
const ComputeInstance expectedLast = graph.nodes[scheduledTasks.back().second].instance;
|
||||
auto lastIt = result.cpuToLastComputeMap.find(entry.first);
|
||||
if (lastIt == result.cpuToLastComputeMap.end() || !(lastIt->second == expectedLast))
|
||||
llvm::report_fatal_error("merge scheduling: cpuToLastComputeMap does not match slot order");
|
||||
if (!result.isLastComputeOfCpu.count(expectedLast))
|
||||
llvm::report_fatal_error("merge scheduling: missing last-compute marker");
|
||||
}
|
||||
|
||||
for (const ComputeGraphEdge &edge : graph.edges) {
|
||||
const ComputeInstance source = graph.nodes[edge.source].instance;
|
||||
const ComputeInstance target = graph.nodes[edge.target].instance;
|
||||
const size_t sourceCpu = result.computeToCpuMap.lookup(source);
|
||||
const size_t targetCpu = result.computeToCpuMap.lookup(target);
|
||||
const size_t sourceSlot = result.computeToCpuSlotMap.lookup(source);
|
||||
const size_t targetSlot = result.computeToCpuSlotMap.lookup(target);
|
||||
const Time sourceStart = static_cast<Time>(result.computeToAestMap.lookup(source));
|
||||
const Time targetStart = static_cast<Time>(result.computeToAestMap.lookup(target));
|
||||
if (sourceCpu == targetCpu && sourceSlot >= targetSlot)
|
||||
llvm::report_fatal_error("merge scheduling: same-CPU dependency order is invalid");
|
||||
|
||||
Time earliestTargetStart = addOrMax(sourceStart, graph.nodes[edge.source].weight);
|
||||
if (sourceCpu != targetCpu)
|
||||
earliestTargetStart = addOrMax(earliestTargetStart, edge.transferCost);
|
||||
if (targetStart < earliestTargetStart) {
|
||||
std::string message = llvm::formatv("merge scheduling: dependency legality failed between tasks {0} and {1}",
|
||||
graph.nodes[edge.source].originalOrder,
|
||||
graph.nodes[edge.target].originalOrder)
|
||||
.str();
|
||||
llvm::report_fatal_error(llvm::StringRef(message));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
MergeSchedulingAnalysis::MergeSchedulingAnalysis(mlir::Operation *op)
|
||||
: entryOp(op) {
|
||||
result = run();
|
||||
}
|
||||
|
||||
MergeScheduleResult MergeSchedulingAnalysis::run() {
|
||||
verifyExplicitPimCoreCount();
|
||||
ComputeGraph graph = buildComputeGraph(entryOp);
|
||||
if (!verifyAcyclic(graph))
|
||||
llvm::report_fatal_error("merge scheduling: compute graph is cyclic");
|
||||
|
||||
MergeSchedulingOptions options;
|
||||
options.kind = getSchedulerKind();
|
||||
if (coresCount.getValue() > 0)
|
||||
options.processorCount = static_cast<size_t>(coresCount.getValue());
|
||||
|
||||
MergeScheduleResult schedule;
|
||||
if (options.kind == MergeSchedulerKind::Peft) {
|
||||
schedule = runPeftScheduler(
|
||||
graph,
|
||||
PeftScheduleOptions {options.processorCount, static_cast<CrossbarUsage>(crossbarCountInCore.getValue()),
|
||||
entryOp->getContext()});
|
||||
}
|
||||
else {
|
||||
schedule = runDcpScheduler(
|
||||
graph,
|
||||
DcpScheduleOptions {
|
||||
options.processorCount,
|
||||
dcpCriticalWindowSize.getValue(),
|
||||
options.allowDcpFallbackForAutoCoreCount
|
||||
},
|
||||
entryOp->getContext());
|
||||
}
|
||||
|
||||
verifySchedule(graph, schedule, static_cast<CrossbarUsage>(crossbarCountInCore.getValue()));
|
||||
return schedule;
|
||||
}
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
+36
@@ -0,0 +1,36 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/Operation.h"
|
||||
|
||||
#include <cstddef>
|
||||
|
||||
#include "MergeSchedule.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
enum class MergeSchedulerKind {
|
||||
Dcp,
|
||||
Peft,
|
||||
};
|
||||
|
||||
struct MergeSchedulingOptions {
|
||||
MergeSchedulerKind kind = MergeSchedulerKind::Peft;
|
||||
size_t processorCount = 0;
|
||||
bool allowDcpFallbackForAutoCoreCount = true;
|
||||
};
|
||||
|
||||
class MergeSchedulingAnalysis {
|
||||
public:
|
||||
explicit MergeSchedulingAnalysis(mlir::Operation *op);
|
||||
MergeScheduleResult &getResult() { return result; }
|
||||
|
||||
private:
|
||||
mlir::Operation *entryOp = nullptr;
|
||||
MergeScheduleResult result;
|
||||
|
||||
MergeScheduleResult run();
|
||||
};
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,303 @@
|
||||
#include "mlir/IR/Threading.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
|
||||
#include <limits>
|
||||
#include <queue>
|
||||
#include <vector>
|
||||
|
||||
#include "PeftScheduler.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
namespace {
|
||||
|
||||
struct ScheduledTask {
|
||||
size_t processor = std::numeric_limits<size_t>::max();
|
||||
Time startTime = 0;
|
||||
Time endTime = 0;
|
||||
size_t slot = 0;
|
||||
};
|
||||
|
||||
std::vector<std::vector<size_t>> buildReverseLevels(const ComputeGraph& graph) {
|
||||
std::vector<size_t> remainingSuccessors(graph.nodes.size(), 0);
|
||||
std::queue<size_t> readySinks;
|
||||
std::vector<std::vector<size_t>> reverseLevels;
|
||||
|
||||
for (size_t node = 0; node < graph.nodes.size(); ++node) {
|
||||
remainingSuccessors[node] = graph.successors[node].size();
|
||||
if (remainingSuccessors[node] == 0)
|
||||
readySinks.push(node);
|
||||
}
|
||||
|
||||
size_t levelizedCount = 0;
|
||||
while (!readySinks.empty()) {
|
||||
size_t levelSize = readySinks.size();
|
||||
std::vector<size_t> levelNodes;
|
||||
levelNodes.reserve(levelSize);
|
||||
for (size_t i = 0; i < levelSize; ++i) {
|
||||
size_t node = readySinks.front();
|
||||
readySinks.pop();
|
||||
levelNodes.push_back(node);
|
||||
++levelizedCount;
|
||||
for (const auto& [pred, weight] : graph.predecessors[node]) {
|
||||
assert(remainingSuccessors[pred] > 0 && "remaining successor count underflow");
|
||||
if (--remainingSuccessors[pred] == 0)
|
||||
readySinks.push(pred);
|
||||
}
|
||||
}
|
||||
reverseLevels.push_back(std::move(levelNodes));
|
||||
}
|
||||
|
||||
if (levelizedCount != graph.nodes.size())
|
||||
llvm::report_fatal_error("PEFT scheduler: compute graph is cyclic or malformed");
|
||||
|
||||
return reverseLevels;
|
||||
}
|
||||
|
||||
void verifyOctTableSize(size_t nodeCount, size_t processorCount) {
|
||||
constexpr size_t kMaxOctTableBytes = 1ull << 30;
|
||||
if (nodeCount == 0 || processorCount == 0)
|
||||
return;
|
||||
if (processorCount > std::numeric_limits<size_t>::max() / sizeof(Time))
|
||||
llvm::report_fatal_error("PEFT scheduler: OCT table size overflow");
|
||||
size_t rowBytes = processorCount * sizeof(Time);
|
||||
if (nodeCount > std::numeric_limits<size_t>::max() / rowBytes)
|
||||
llvm::report_fatal_error("PEFT scheduler: OCT table size overflow");
|
||||
size_t totalBytes = nodeCount * rowBytes;
|
||||
if (totalBytes > kMaxOctTableBytes) {
|
||||
std::string message = llvm::formatv("PEFT scheduler: OCT table would require {0} MiB, exceeding the 1024 MiB guard",
|
||||
totalBytes / (1024 * 1024))
|
||||
.str();
|
||||
llvm::report_fatal_error(llvm::StringRef(message));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftScheduleOptions& options) {
|
||||
const size_t nodeCount = graph.nodes.size();
|
||||
const size_t processorCount = options.processorCount;
|
||||
if (processorCount == 0)
|
||||
llvm::report_fatal_error("PEFT scheduler: processor count must be positive");
|
||||
|
||||
verifyOctTableSize(nodeCount, processorCount);
|
||||
std::vector<std::vector<size_t>> reverseLevels = buildReverseLevels(graph);
|
||||
|
||||
// MOCK: Replace this with your actual heterogeneous cost lookup.
|
||||
// If graph.nodes[task] is modified to hold a vector of weights per processor, access it here.
|
||||
auto getComputeCost = [&](size_t task, size_t processor) -> Time { return graph.nodes[task].weight; };
|
||||
|
||||
std::vector<Time> oct(nodeCount * processorCount, 0);
|
||||
std::vector<Time> minOctPlusComp(nodeCount, 0);
|
||||
|
||||
// 1. O(P(E+V)) Heterogeneous OCT Calculation
|
||||
for (const std::vector<size_t>& levelNodes : reverseLevels) {
|
||||
auto computeNodeOct = [&](size_t levelIndex) {
|
||||
size_t task = levelNodes[levelIndex];
|
||||
std::vector<Time> maxVals(processorCount, 0);
|
||||
|
||||
for (const auto& [succ, comm] : graph.successors[task]) {
|
||||
Time valDifferentCpu = addOrMax(minOctPlusComp[succ], comm);
|
||||
for (size_t processor = 0; processor < processorCount; ++processor) {
|
||||
Time valSameCpu = addOrMax(oct[succ * processorCount + processor], getComputeCost(succ, processor));
|
||||
Time bestSucc = std::min(valSameCpu, valDifferentCpu);
|
||||
maxVals[processor] = std::max(maxVals[processor], bestSucc);
|
||||
}
|
||||
}
|
||||
|
||||
Time minForPreds = std::numeric_limits<Time>::max();
|
||||
for (size_t processor = 0; processor < processorCount; ++processor) {
|
||||
oct[task * processorCount + processor] = maxVals[processor];
|
||||
minForPreds = std::min(minForPreds, addOrMax(maxVals[processor], getComputeCost(task, processor)));
|
||||
}
|
||||
minOctPlusComp[task] = minForPreds == std::numeric_limits<Time>::max() ? 0 : minForPreds;
|
||||
};
|
||||
|
||||
if (options.context != nullptr)
|
||||
mlir::parallelFor(options.context, 0, levelNodes.size(), computeNodeOct);
|
||||
else
|
||||
for (size_t i = 0; i < levelNodes.size(); ++i)
|
||||
computeNodeOct(i);
|
||||
}
|
||||
|
||||
struct RankEntry {
|
||||
long double rank = 0.0L;
|
||||
size_t node = 0;
|
||||
size_t originalOrder = 0;
|
||||
};
|
||||
std::vector<RankEntry> ranks(nodeCount);
|
||||
auto computeRank = [&](size_t node) {
|
||||
long double rank = 0.0L;
|
||||
for (size_t processor = 0; processor < processorCount; ++processor)
|
||||
rank += static_cast<long double>(oct[node * processorCount + processor]);
|
||||
ranks[node] = {rank, node, graph.nodes[node].originalOrder};
|
||||
};
|
||||
|
||||
if (options.context != nullptr)
|
||||
mlir::parallelFor(options.context, 0, nodeCount, computeRank);
|
||||
else
|
||||
for (size_t node = 0; node < nodeCount; ++node)
|
||||
computeRank(node);
|
||||
|
||||
auto readyCompare = [&](size_t lhs, size_t rhs) {
|
||||
const RankEntry& lhsRank = ranks[lhs];
|
||||
const RankEntry& rhsRank = ranks[rhs];
|
||||
if (lhsRank.rank != rhsRank.rank)
|
||||
return lhsRank.rank < rhsRank.rank;
|
||||
if (lhsRank.originalOrder != rhsRank.originalOrder)
|
||||
return lhsRank.originalOrder > rhsRank.originalOrder;
|
||||
return lhs > rhs;
|
||||
};
|
||||
|
||||
std::vector<int> remainingParents(nodeCount, 0);
|
||||
std::priority_queue<size_t, std::vector<size_t>, decltype(readyCompare)> readyQueue(readyCompare);
|
||||
for (size_t node = 0; node < nodeCount; ++node) {
|
||||
remainingParents[node] = graph.predecessors[node].size();
|
||||
if (remainingParents[node] == 0)
|
||||
readyQueue.push(node);
|
||||
}
|
||||
|
||||
std::vector<char> scheduled(nodeCount, false);
|
||||
std::vector<CrossbarUsage> processorCrossbars(processorCount, 0);
|
||||
std::vector<ScheduledTask> schedules(nodeCount);
|
||||
std::vector<std::vector<size_t>> tasksByProcessor(processorCount);
|
||||
|
||||
size_t scheduledCount = 0;
|
||||
while (!readyQueue.empty()) {
|
||||
size_t task = readyQueue.top();
|
||||
readyQueue.pop();
|
||||
if (scheduled[task])
|
||||
continue;
|
||||
|
||||
size_t bestProcessor = std::numeric_limits<size_t>::max();
|
||||
Time bestEst = 0;
|
||||
Time bestEft = 0;
|
||||
Time bestOeft = std::numeric_limits<Time>::max();
|
||||
bool crossbarRejected = false;
|
||||
|
||||
for (size_t processor = 0; processor < processorCount; ++processor) {
|
||||
if (graph.nodes[task].crossbarUsage != 0
|
||||
&& addOrMax(processorCrossbars[processor], graph.nodes[task].crossbarUsage) > options.crossbarCapacity) {
|
||||
crossbarRejected = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
Time dataReady = 0;
|
||||
for (const auto& [pred, comm] : graph.predecessors[task]) {
|
||||
const ScheduledTask& predSchedule = schedules[pred];
|
||||
Time commPenalty = predSchedule.processor == processor ? 0 : comm;
|
||||
dataReady = std::max(dataReady, addOrMax(predSchedule.endTime, commPenalty));
|
||||
}
|
||||
|
||||
// 2. PEFT Gap-Filling EST Calculation (Maintains optimal scheduling math)
|
||||
Time compWeight = getComputeCost(task, processor);
|
||||
Time est = dataReady;
|
||||
Time currentEnd = 0;
|
||||
bool foundGap = false;
|
||||
|
||||
for (size_t schedTaskIndex : tasksByProcessor[processor]) {
|
||||
const ScheduledTask& schedTask = schedules[schedTaskIndex];
|
||||
Time gapStart = std::max(currentEnd, dataReady);
|
||||
|
||||
if (addOrMax(gapStart, compWeight) <= schedTask.startTime) {
|
||||
est = gapStart;
|
||||
foundGap = true;
|
||||
break;
|
||||
}
|
||||
currentEnd = schedTask.endTime;
|
||||
}
|
||||
|
||||
if (!foundGap)
|
||||
est = std::max(currentEnd, dataReady);
|
||||
|
||||
Time eft = addOrMax(est, compWeight);
|
||||
Time oeft = addOrMax(eft, oct[task * processorCount + processor]);
|
||||
|
||||
if (oeft < bestOeft || (oeft == bestOeft && eft < bestEft)
|
||||
|| (oeft == bestOeft && eft == bestEft && est < bestEst)
|
||||
|| (oeft == bestOeft && eft == bestEft && est == bestEst && processor < bestProcessor)) {
|
||||
bestProcessor = processor;
|
||||
bestEst = est;
|
||||
bestEft = eft;
|
||||
bestOeft = oeft;
|
||||
}
|
||||
}
|
||||
|
||||
if (bestProcessor == std::numeric_limits<size_t>::max()) {
|
||||
if (crossbarRejected) {
|
||||
std::string message =
|
||||
llvm::formatv("PEFT scheduler: no valid processor for task {0}; crossbar capacity {1} is exhausted",
|
||||
graph.nodes[task].originalOrder,
|
||||
options.crossbarCapacity)
|
||||
.str();
|
||||
llvm::report_fatal_error(llvm::StringRef(message));
|
||||
}
|
||||
std::string message = llvm::formatv("PEFT scheduler: no valid processor for task {0} with {1} processors",
|
||||
graph.nodes[task].originalOrder,
|
||||
processorCount)
|
||||
.str();
|
||||
llvm::report_fatal_error(llvm::StringRef(message));
|
||||
}
|
||||
|
||||
schedules[task] = {bestProcessor, bestEst, bestEft, 0};
|
||||
scheduled[task] = true;
|
||||
++scheduledCount;
|
||||
processorCrossbars[bestProcessor] = addOrMax(processorCrossbars[bestProcessor], graph.nodes[task].crossbarUsage);
|
||||
|
||||
// 3. CRITICAL FIX: Topological Append
|
||||
// Because the readyQueue pops in strict topological order, simply pushing to the
|
||||
// back guarantees the Monoliths will be physically generated cycle-free.
|
||||
// The hardware will still benefit from the processor assignment chosen by PEFT.
|
||||
tasksByProcessor[bestProcessor].push_back(task);
|
||||
|
||||
for (const auto& [child, weight] : graph.successors[task]) {
|
||||
(void) weight;
|
||||
assert(remainingParents[child] > 0 && "remaining parent count underflow");
|
||||
if (--remainingParents[child] == 0)
|
||||
readyQueue.push(child);
|
||||
}
|
||||
}
|
||||
|
||||
if (scheduledCount != nodeCount)
|
||||
llvm::report_fatal_error("PEFT scheduler: failed to schedule every compute node");
|
||||
|
||||
// 4. Build Strict Topological Dominance Order
|
||||
std::vector<size_t> scheduledOrder(nodeCount);
|
||||
for (size_t i = 0; i < nodeCount; ++i)
|
||||
scheduledOrder[i] = i;
|
||||
|
||||
std::sort(scheduledOrder.begin(), scheduledOrder.end(), [&](size_t a, size_t b) {
|
||||
return graph.nodes[a].originalOrder < graph.nodes[b].originalOrder;
|
||||
});
|
||||
|
||||
// 5. Populate Final Result
|
||||
MergeScheduleResult result;
|
||||
result.dominanceOrderCompute.reserve(nodeCount);
|
||||
|
||||
for (size_t task : scheduledOrder)
|
||||
result.dominanceOrderCompute.push_back(graph.nodes[task].instance);
|
||||
|
||||
for (size_t processor = 0; processor < processorCount; ++processor) {
|
||||
size_t currentSlot = 0;
|
||||
for (size_t task : tasksByProcessor[processor]) {
|
||||
const ComputeInstance instance = graph.nodes[task].instance;
|
||||
result.computeToCpuMap[instance] = processor;
|
||||
result.computeToCpuSlotMap[instance] = currentSlot++;
|
||||
result.computeToAestMap[instance] = schedules[task].startTime;
|
||||
}
|
||||
if (!tasksByProcessor[processor].empty()) {
|
||||
const ComputeInstance lastInstance = graph.nodes[tasksByProcessor[processor].back()].instance;
|
||||
result.cpuToLastComputeMap[processor] = lastInstance;
|
||||
result.isLastComputeOfCpu.insert(lastInstance);
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
|
||||
#include "ComputeGraph.hpp"
|
||||
#include "MergeSchedule.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
struct PeftScheduleOptions {
|
||||
size_t processorCount = 0;
|
||||
CrossbarUsage crossbarCapacity = 0;
|
||||
mlir::MLIRContext *context = nullptr;
|
||||
};
|
||||
|
||||
MergeScheduleResult runPeftScheduler(const ComputeGraph &graph, const PeftScheduleOptions &options);
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
@@ -24,8 +24,11 @@ struct EmitPimCodePass : PassWrapper<EmitPimCodePass, OperationPass<ModuleOp>> {
|
||||
createDirectory(pimDir);
|
||||
|
||||
int compiler_error_code = compileToPimCode(moduleOp, pimDir);
|
||||
if (compiler_error_code != CompilerSuccess)
|
||||
if (compiler_error_code != CompilerSuccess) {
|
||||
moduleOp.emitError() << "failed to emit PIM simulator code artifacts; compiler error code "
|
||||
<< compiler_error_code;
|
||||
signalPassFailure();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -32,14 +32,16 @@ struct HostConstantFoldingPass : PassWrapper<HostConstantFoldingPass, OperationP
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
ModuleOp moduleOp = getOperation();
|
||||
GreedyRewriteConfig config;
|
||||
config.enableFolding();
|
||||
if (failed(applyPatternsGreedily(getOperation(), *patterns, config))) {
|
||||
if (failed(applyPatternsGreedily(moduleOp, *patterns, config))) {
|
||||
moduleOp.emitError("PIM host constant folding failed in the greedy rewrite driver");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
dumpModule(getOperation(), "pim3_folded");
|
||||
dumpModule(moduleOp, "pim3_folded");
|
||||
}
|
||||
|
||||
std::shared_ptr<const FrozenRewritePatternSet> patterns;
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
|
||||
#include "../Common.hpp"
|
||||
#include "../Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
@@ -66,6 +66,83 @@ static Value buildSubviewChunk(const StaticSubviewInfo& info,
|
||||
return memref::SubViewOp::create(rewriter, loc, info.source, chunkOffsets, chunkSizes, chunkStrides);
|
||||
}
|
||||
|
||||
static SmallVector<Value> delinearizeIndexValue(Value linearIndex,
|
||||
ArrayRef<int64_t> shape,
|
||||
ArrayRef<int64_t> strides,
|
||||
PatternRewriter& rewriter) {
|
||||
SmallVector<Value> indices;
|
||||
indices.reserve(shape.size());
|
||||
|
||||
Value remaining = linearIndex;
|
||||
for (auto [_dim, stride] : llvm::enumerate(strides)) {
|
||||
auto cStride = arith::ConstantIndexOp::create(rewriter, linearIndex.getLoc(), stride);
|
||||
Value index = arith::DivUIOp::create(rewriter, linearIndex.getLoc(), remaining, cStride);
|
||||
indices.push_back(index);
|
||||
remaining = arith::RemUIOp::create(rewriter, linearIndex.getLoc(), remaining, cStride);
|
||||
}
|
||||
|
||||
return indices;
|
||||
}
|
||||
|
||||
static OpFoldResult addDynamicOffset(OpFoldResult baseOffset, Value extraOffset, PatternRewriter& rewriter) {
|
||||
if (auto attr = dyn_cast<Attribute>(baseOffset)) {
|
||||
auto integerAttr = cast<IntegerAttr>(attr);
|
||||
if (integerAttr.getInt() == 0)
|
||||
return extraOffset;
|
||||
|
||||
auto cst = arith::ConstantIndexOp::create(rewriter, extraOffset.getLoc(), integerAttr.getInt());
|
||||
return arith::AddIOp::create(rewriter, extraOffset.getLoc(), cst, extraOffset).getResult();
|
||||
}
|
||||
|
||||
auto value = cast<Value>(baseOffset);
|
||||
return arith::AddIOp::create(rewriter, value.getLoc(), value, extraOffset).getResult();
|
||||
}
|
||||
|
||||
static Value buildDynamicSubviewChunk(const StaticSubviewInfo& info,
|
||||
ArrayRef<Value> outerIndices,
|
||||
Location loc,
|
||||
PatternRewriter& rewriter) {
|
||||
SmallVector<OpFoldResult> chunkOffsets;
|
||||
SmallVector<OpFoldResult> chunkSizes;
|
||||
SmallVector<OpFoldResult> chunkStrides;
|
||||
chunkOffsets.reserve(info.offsets.size());
|
||||
chunkSizes.reserve(info.sizes.size());
|
||||
chunkStrides.reserve(info.strides.size());
|
||||
|
||||
for (size_t dim = 0; dim < info.sizes.size(); ++dim) {
|
||||
if (dim + 1 < info.sizes.size()) {
|
||||
assert(info.strides[dim] == 1 && "loop-based subview rewrite requires unit strides");
|
||||
chunkOffsets.push_back(addDynamicOffset(info.offsets[dim], outerIndices[dim], rewriter));
|
||||
chunkSizes.push_back(rewriter.getIndexAttr(1));
|
||||
}
|
||||
else {
|
||||
chunkOffsets.push_back(info.offsets[dim]);
|
||||
chunkSizes.push_back(rewriter.getIndexAttr(info.sizes.back()));
|
||||
}
|
||||
chunkStrides.push_back(rewriter.getIndexAttr(info.strides[dim]));
|
||||
}
|
||||
|
||||
return memref::SubViewOp::create(rewriter, loc, info.source, chunkOffsets, chunkSizes, chunkStrides);
|
||||
}
|
||||
|
||||
static Value buildContiguousChunk(
|
||||
Value source, ArrayRef<int64_t> copyShape, ArrayRef<Value> outerIndices, Location loc, PatternRewriter& rewriter) {
|
||||
SmallVector<OpFoldResult> chunkOffsets;
|
||||
SmallVector<OpFoldResult> chunkSizes;
|
||||
SmallVector<OpFoldResult> chunkStrides;
|
||||
chunkOffsets.reserve(copyShape.size());
|
||||
chunkSizes.reserve(copyShape.size());
|
||||
chunkStrides.reserve(copyShape.size());
|
||||
|
||||
for (size_t dim = 0; dim < copyShape.size(); ++dim) {
|
||||
chunkOffsets.push_back(dim + 1 < copyShape.size() ? OpFoldResult(outerIndices[dim]) : rewriter.getIndexAttr(0));
|
||||
chunkSizes.push_back(rewriter.getIndexAttr(dim + 1 < copyShape.size() ? 1 : copyShape.back()));
|
||||
chunkStrides.push_back(rewriter.getIndexAttr(1));
|
||||
}
|
||||
|
||||
return memref::SubViewOp::create(rewriter, loc, source, chunkOffsets, chunkSizes, chunkStrides);
|
||||
}
|
||||
|
||||
template <typename CopyOp, typename CreateCopyOp>
|
||||
static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp,
|
||||
Value dst,
|
||||
@@ -73,6 +150,7 @@ static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp,
|
||||
int64_t dstOffset,
|
||||
int64_t srcOffset,
|
||||
int64_t size,
|
||||
bool allowLoopRewrite,
|
||||
PatternRewriter& rewriter,
|
||||
CreateCopyOp createCopyOp) {
|
||||
auto srcSubview = getStaticSubviewInfo(src);
|
||||
@@ -114,6 +192,27 @@ static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp,
|
||||
auto outerStrides = computeRowMajorStrides(outerShape);
|
||||
const int64_t numSlices = outerShape.empty() ? 1 : getNumElements(outerShape);
|
||||
|
||||
if (allowLoopRewrite && numSlices > 1 && srcOffset == 0 && dstOffset == 0
|
||||
&& sourceType.getRank() == static_cast<int64_t>(copyShape.size())
|
||||
&& dstType.getRank() == static_cast<int64_t>(copyShape.size())) {
|
||||
auto c0 = arith::ConstantIndexOp::create(rewriter, copyOp.getLoc(), 0);
|
||||
auto cUpper = arith::ConstantIndexOp::create(rewriter, copyOp.getLoc(), numSlices);
|
||||
auto cStep = arith::ConstantIndexOp::create(rewriter, copyOp.getLoc(), 1);
|
||||
|
||||
auto loop = scf::ForOp::create(rewriter, copyOp.getLoc(), c0, cUpper, cStep, ValueRange {});
|
||||
rewriter.setInsertionPointToStart(loop.getBody());
|
||||
|
||||
SmallVector<Value> outerIndices =
|
||||
outerShape.empty() ? SmallVector<Value> {}
|
||||
: delinearizeIndexValue(loop.getInductionVar(), outerShape, outerStrides, rewriter);
|
||||
Value chunkDst = splitDst ? buildDynamicSubviewChunk(*dstSubview, outerIndices, copyOp.getLoc(), rewriter)
|
||||
: buildContiguousChunk(dst, copyShape, outerIndices, copyOp.getLoc(), rewriter);
|
||||
Value chunkSrc = splitSrc ? buildDynamicSubviewChunk(*srcSubview, outerIndices, copyOp.getLoc(), rewriter)
|
||||
: buildContiguousChunk(src, copyShape, outerIndices, copyOp.getLoc(), rewriter);
|
||||
createCopyOp(cast<MemRefType>(chunkDst.getType()), chunkDst, chunkSrc, 0, 0, sliceBytes);
|
||||
return success();
|
||||
}
|
||||
|
||||
rewriter.setInsertionPoint(copyOp);
|
||||
for (int64_t linearIndex = 0; linearIndex < numSlices; ++linearIndex) {
|
||||
SmallVector<int64_t> outerIndices =
|
||||
@@ -143,6 +242,7 @@ struct RewriteCoreSubviewCopyPattern final : OpRewritePattern<pim::PimMemCopyOp>
|
||||
copyOp.getTargetOffset(),
|
||||
copyOp.getSourceOffset(),
|
||||
copyOp.getSize(),
|
||||
/*allowLoopRewrite=*/true,
|
||||
rewriter,
|
||||
[&](
|
||||
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
|
||||
@@ -175,6 +275,7 @@ struct RewriteHostSubviewLoadPattern final : OpRewritePattern<pim::PimMemCopyHos
|
||||
copyOp.getDeviceTargetOffset(),
|
||||
copyOp.getHostSourceOffset(),
|
||||
copyOp.getSize(),
|
||||
/*allowLoopRewrite=*/true,
|
||||
rewriter,
|
||||
[&](
|
||||
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
|
||||
@@ -207,6 +308,7 @@ struct RewriteHostSubviewStorePattern final : OpRewritePattern<pim::PimMemCopyDe
|
||||
copyOp.getHostTargetOffset(),
|
||||
copyOp.getDeviceSourceOffset(),
|
||||
copyOp.getSize(),
|
||||
/*allowLoopRewrite=*/false,
|
||||
rewriter,
|
||||
[&](
|
||||
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
|
||||
|
||||
@@ -160,6 +160,7 @@ struct MaterializeHostConstantsPass : PassWrapper<MaterializeHostConstantsPass,
|
||||
}
|
||||
|
||||
if (hasFailure) {
|
||||
moduleOp.emitError("PIM host-constant materialization failed; see diagnostics above");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -6,10 +6,11 @@
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
@@ -67,6 +68,14 @@ static bool isCodegenAddressableValue(Value value) {
|
||||
|| isa<memref::AllocOp, memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
|
||||
}
|
||||
|
||||
static bool isCodegenAddressableValue(Value value, const StaticValueKnowledge& knowledge) {
|
||||
auto resolvedAddress = resolveContiguousAddress(value, knowledge);
|
||||
if (failed(resolvedAddress))
|
||||
return false;
|
||||
return isa<BlockArgument>(resolvedAddress->base)
|
||||
|| isa<memref::AllocOp, memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
|
||||
}
|
||||
|
||||
static bool isConstantGlobalView(Value value) {
|
||||
while (true) {
|
||||
Operation* defOp = value.getDefiningOp();
|
||||
@@ -88,6 +97,22 @@ static bool isConstantGlobalView(Value value) {
|
||||
value = cast.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto collapse = dyn_cast<memref::CollapseShapeOp>(defOp)) {
|
||||
auto srcType = dyn_cast<MemRefType>(collapse.getSrc().getType());
|
||||
auto resultType = dyn_cast<MemRefType>(collapse.getResult().getType());
|
||||
if (!srcType || !resultType || !srcType.hasStaticShape() || !resultType.hasStaticShape())
|
||||
return false;
|
||||
value = collapse.getSrc();
|
||||
continue;
|
||||
}
|
||||
if (auto expand = dyn_cast<memref::ExpandShapeOp>(defOp)) {
|
||||
auto srcType = dyn_cast<MemRefType>(expand.getSrc().getType());
|
||||
auto resultType = dyn_cast<MemRefType>(expand.getResult().getType());
|
||||
if (!srcType || !resultType || !srcType.hasStaticShape() || !resultType.hasStaticShape())
|
||||
return false;
|
||||
value = expand.getSrc();
|
||||
continue;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -144,14 +169,15 @@ struct VerificationPass : PassWrapper<VerificationPass, OperationPass<ModuleOp>>
|
||||
|
||||
void runOnOperation() override {
|
||||
ModuleOp moduleOp = getOperation();
|
||||
bool hasFailure = false;
|
||||
pim::CappedDiagnosticReporter diagnostics;
|
||||
|
||||
moduleOp.walk([&](Operation* op) {
|
||||
if (op->getDialect()->getNamespace() != "spat")
|
||||
return;
|
||||
|
||||
op->emitError("illegal Spatial operation reached PIM codegen verification");
|
||||
hasFailure = true;
|
||||
diagnostics.report(op, [](Operation* illegalOp) {
|
||||
illegalOp->emitError("illegal Spatial operation reached PIM codegen verification");
|
||||
});
|
||||
});
|
||||
|
||||
for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
|
||||
@@ -160,49 +186,56 @@ struct VerificationPass : PassWrapper<VerificationPass, OperationPass<ModuleOp>>
|
||||
|
||||
for (Operation& op : funcOp.getBody().front().getOperations()) {
|
||||
if (auto coreOp = dyn_cast<pim::PimCoreOp>(&op)) {
|
||||
if (failed(verifyCoreWeights(moduleOp, coreOp)) || failed(verifyCoreOperands(coreOp)))
|
||||
hasFailure = true;
|
||||
(void) verifyCoreWeights(moduleOp, coreOp, diagnostics);
|
||||
(void) verifyCoreOperands(coreOp, diagnostics);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto coreBatchOp = dyn_cast<pim::PimCoreBatchOp>(&op)) {
|
||||
if (failed(verifyCoreWeights(moduleOp, coreBatchOp)) || failed(verifyCoreOperands(coreBatchOp)))
|
||||
hasFailure = true;
|
||||
(void) verifyCoreWeights(moduleOp, coreBatchOp, diagnostics);
|
||||
(void) verifyCoreOperands(coreBatchOp, diagnostics);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto returnOp = dyn_cast<func::ReturnOp>(&op)) {
|
||||
if (failed(verifyReturnOp(returnOp)))
|
||||
hasFailure = true;
|
||||
(void) verifyReturnOp(returnOp, diagnostics);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!isAddressOnlyHostOp(&op)) {
|
||||
op.emitOpError("illegal host-side runtime op remains after PIM bufferization; "
|
||||
"fold it to constants or lower it into pim.core");
|
||||
hasFailure = true;
|
||||
diagnostics.report(&op, [](Operation* illegalOp) {
|
||||
illegalOp->emitOpError("illegal host-side runtime op remains after PIM bufferization; "
|
||||
"fold it to constants or lower it into pim.core");
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
if (failed(verifyAddressOnlyHostOp(&op)))
|
||||
hasFailure = true;
|
||||
(void) verifyAddressOnlyHostOp(&op, diagnostics);
|
||||
}
|
||||
}
|
||||
|
||||
if (hasFailure)
|
||||
if (diagnostics.hasFailure()) {
|
||||
diagnostics.emitSuppressedSummary(moduleOp, "verification failures");
|
||||
moduleOp.emitError("PIM codegen verification failed; see diagnostics above");
|
||||
signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename CoreOpTy>
|
||||
static LogicalResult verifyCoreWeights(ModuleOp moduleOp, CoreOpTy coreOp) {
|
||||
static LogicalResult
|
||||
verifyCoreWeights(ModuleOp moduleOp, CoreOpTy coreOp, pim::CappedDiagnosticReporter& diagnostics) {
|
||||
bool hasFailure = false;
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(coreOp.getWeights())) {
|
||||
for (auto it : llvm::enumerate(coreOp.getWeights())) {
|
||||
size_t weightIndex = it.index();
|
||||
Value weight = it.value();
|
||||
auto getGlobalOp = weight.template getDefiningOp<memref::GetGlobalOp>();
|
||||
if (!getGlobalOp && !isConstantGlobalView(weight)) {
|
||||
coreOp.emitOpError() << "weight #" << weightIndex
|
||||
<< " must be materialized as a constant memref.global or a static view of one before JSON "
|
||||
"codegen";
|
||||
diagnostics.report(coreOp.getOperation(), [&](Operation*) {
|
||||
coreOp.emitOpError() << "weight #" << weightIndex
|
||||
<< " must be materialized as a constant memref.global or a static view of one before "
|
||||
"JSON codegen";
|
||||
});
|
||||
hasFailure = true;
|
||||
continue;
|
||||
}
|
||||
@@ -212,14 +245,18 @@ private:
|
||||
|
||||
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||
if (!globalOp) {
|
||||
coreOp.emitOpError() << "weight #" << weightIndex << " references an unknown memref.global";
|
||||
diagnostics.report(coreOp.getOperation(), [&](Operation*) {
|
||||
coreOp.emitOpError() << "weight #" << weightIndex << " references an unknown memref.global";
|
||||
});
|
||||
hasFailure = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!globalOp.getConstant() || !globalOp.getInitialValue()) {
|
||||
coreOp.emitOpError() << "weight #" << weightIndex
|
||||
<< " must come from a constant memref.global with an initial value";
|
||||
diagnostics.report(coreOp.getOperation(), [&](Operation*) {
|
||||
coreOp.emitOpError() << "weight #" << weightIndex
|
||||
<< " must come from a constant memref.global with an initial value";
|
||||
});
|
||||
hasFailure = true;
|
||||
}
|
||||
}
|
||||
@@ -227,11 +264,15 @@ private:
|
||||
return success(!hasFailure);
|
||||
}
|
||||
|
||||
static LogicalResult verifyReturnOp(func::ReturnOp returnOp) {
|
||||
static LogicalResult verifyReturnOp(func::ReturnOp returnOp, pim::CappedDiagnosticReporter& diagnostics) {
|
||||
bool hasFailure = false;
|
||||
for (auto [resultIndex, operand] : llvm::enumerate(returnOp.getOperands())) {
|
||||
for (auto it : llvm::enumerate(returnOp.getOperands())) {
|
||||
size_t resultIndex = it.index();
|
||||
Value operand = it.value();
|
||||
if (!isCodegenAddressableValue(operand)) {
|
||||
returnOp.emitOpError() << "result #" << resultIndex << " is not backed by contiguous addressable storage";
|
||||
diagnostics.report(returnOp.getOperation(), [&](Operation*) {
|
||||
returnOp.emitOpError() << "result #" << resultIndex << " is not backed by contiguous addressable storage";
|
||||
});
|
||||
hasFailure = true;
|
||||
}
|
||||
}
|
||||
@@ -239,38 +280,50 @@ private:
|
||||
}
|
||||
|
||||
template <typename CoreOpTy>
|
||||
static LogicalResult verifyCoreOperands(CoreOpTy coreOp) {
|
||||
static LogicalResult verifyCoreOperands(CoreOpTy coreOp, pim::CappedDiagnosticReporter& diagnostics) {
|
||||
return walkPimCoreBlock(
|
||||
coreOp.getBody().front(), StaticValueKnowledge {}, [](Operation& op, const StaticValueKnowledge& knowledge) {
|
||||
coreOp.getBody().front(), StaticValueKnowledge {}, [&](Operation& op, const StaticValueKnowledge& knowledge) {
|
||||
bool hasFailure = false;
|
||||
if (!isSupportedCoreInstructionOp(&op)) {
|
||||
op.emitOpError("unsupported executable op reached PIM codegen verification");
|
||||
diagnostics.report(&op, [](Operation* illegalOp) {
|
||||
illegalOp->emitOpError("unsupported executable op reached PIM codegen verification");
|
||||
});
|
||||
hasFailure = true;
|
||||
}
|
||||
|
||||
for (auto [operandIndex, operand] : llvm::enumerate(op.getOperands())) {
|
||||
for (auto it : llvm::enumerate(op.getOperands())) {
|
||||
size_t operandIndex = it.index();
|
||||
Value operand = it.value();
|
||||
if (!isa<BaseMemRefType>(operand.getType()))
|
||||
continue;
|
||||
|
||||
auto resolvedAddress = resolveContiguousAddress(operand, knowledge);
|
||||
if (failed(resolvedAddress)) {
|
||||
op.emitOpError() << "operand #" << operandIndex << " is not backed by contiguous addressable storage";
|
||||
diagnostics.report(&op, [&](Operation* illegalOp) {
|
||||
illegalOp->emitOpError() << "operand #" << operandIndex
|
||||
<< " is not backed by contiguous addressable storage";
|
||||
});
|
||||
hasFailure = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (isExplicitHostOperand(&op, operandIndex)) {
|
||||
if (!isCodegenAddressableValue(operand)) {
|
||||
op.emitOpError() << "host operand #" << operandIndex
|
||||
<< " is not backed by contiguous addressable storage";
|
||||
if (!isCodegenAddressableValue(operand, knowledge)) {
|
||||
diagnostics.report(&op, [&](Operation* illegalOp) {
|
||||
illegalOp->emitOpError() << "host operand #" << operandIndex
|
||||
<< " is not backed by contiguous addressable storage";
|
||||
});
|
||||
hasFailure = true;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!isa<memref::AllocOp>(resolvedAddress->base.getDefiningOp())) {
|
||||
op.emitOpError() << "operand #" << operandIndex
|
||||
<< " must be backed by device-local memory; materialize host values with pim.memcp_hd";
|
||||
diagnostics.report(&op, [&](Operation* illegalOp) {
|
||||
illegalOp->emitOpError() << "operand #" << operandIndex
|
||||
<< " must be backed by device-local memory; materialize host values with "
|
||||
"pim.memcp_hd";
|
||||
});
|
||||
hasFailure = true;
|
||||
}
|
||||
}
|
||||
@@ -278,18 +331,20 @@ private:
|
||||
});
|
||||
}
|
||||
|
||||
static LogicalResult verifyAddressOnlyHostOp(Operation* op) {
|
||||
static LogicalResult verifyAddressOnlyHostOp(Operation* op, pim::CappedDiagnosticReporter& diagnostics) {
|
||||
if (auto subviewOp = dyn_cast<memref::SubViewOp>(op))
|
||||
return verifyAddressOnlyBase(op, subviewOp.getSource());
|
||||
return verifyAddressOnlyBase(op, subviewOp.getSource(), diagnostics);
|
||||
if (auto castOp = dyn_cast<memref::CastOp>(op))
|
||||
return verifyAddressOnlySource(op, castOp.getSource());
|
||||
return verifyAddressOnlySource(op, castOp.getSource(), diagnostics);
|
||||
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(op))
|
||||
return verifyAddressOnlySource(op, collapseOp.getSrc());
|
||||
return verifyAddressOnlySource(op, collapseOp.getSrc(), diagnostics);
|
||||
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(op))
|
||||
return verifyAddressOnlySource(op, expandOp.getSrc());
|
||||
return verifyAddressOnlySource(op, expandOp.getSrc(), diagnostics);
|
||||
if (auto copyOp = dyn_cast<memref::CopyOp>(op)) {
|
||||
if (!isBaseAddressableValue(copyOp.getSource()) || !isBaseAddressableValue(copyOp.getTarget())) {
|
||||
op->emitOpError("depends on a value that is not backed by addressable storage");
|
||||
diagnostics.report(op, [](Operation* illegalOp) {
|
||||
illegalOp->emitOpError("depends on a value that is not backed by addressable storage");
|
||||
});
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
@@ -297,19 +352,24 @@ private:
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verifyAddressOnlySource(Operation* op, Value source) {
|
||||
static LogicalResult
|
||||
verifyAddressOnlySource(Operation* op, Value source, pim::CappedDiagnosticReporter& diagnostics) {
|
||||
if (isCodegenAddressableValue(source))
|
||||
return success();
|
||||
|
||||
op->emitOpError("depends on a value that is not backed by contiguous addressable storage");
|
||||
diagnostics.report(op, [](Operation* illegalOp) {
|
||||
illegalOp->emitOpError("depends on a value that is not backed by contiguous addressable storage");
|
||||
});
|
||||
return failure();
|
||||
}
|
||||
|
||||
static LogicalResult verifyAddressOnlyBase(Operation* op, Value source) {
|
||||
static LogicalResult verifyAddressOnlyBase(Operation* op, Value source, pim::CappedDiagnosticReporter& diagnostics) {
|
||||
if (isBaseAddressableValue(source))
|
||||
return success();
|
||||
|
||||
op->emitOpError("depends on a value that is not backed by addressable storage");
|
||||
diagnostics.report(op, [](Operation* illegalOp) {
|
||||
illegalOp->emitOpError("depends on a value that is not backed by addressable storage");
|
||||
});
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user