Compare commits
25 Commits
c15aba5d96
..
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 5637c861b4 | |||
| 94157a8404 | |||
| 68a3521978 | |||
| e263e05f56 | |||
| 34c29fdec4 | |||
| aa088e2ba5 | |||
| 2836e759ab | |||
| 8071ebab0b | |||
| f1602c0550 | |||
| de0a2f4561 | |||
| 1c4a5bde76 | |||
| 78242e2887 | |||
| fe244d5aa1 | |||
| d09e76c8f9 | |||
| c5e608fa5b | |||
| 43f3ccdd21 | |||
| 8d95c604a6 | |||
| 55eda487dc | |||
| 061139aefb | |||
| ea61540e08 | |||
| 324178cba8 | |||
| e71ba07cd5 | |||
| 64a3805619 | |||
| 9f9e7c0892 | |||
| 03eab42971 |
@@ -114,7 +114,9 @@ Pass these on the `onnx-mlir` command line when compiling for PIM:
|
|||||||
run only the codegen tail.
|
run only the codegen tail.
|
||||||
- `--crossbar-size=<N>` / `--crossbar-count=<N>` — crossbar dimensions and
|
- `--crossbar-size=<N>` / `--crossbar-count=<N>` — crossbar dimensions and
|
||||||
per-core count.
|
per-core count.
|
||||||
- `--core-count=<N>` — number of cores (`-1` picks the minimum).
|
- `--core-count=<N>` — number of cores. Required for PIM compilation.
|
||||||
|
- `--pim-merge-scheduler={peft,dcp}` — scheduler used by the Spatial
|
||||||
|
merge-compute-nodes pass (default: `peft`).
|
||||||
- `--dcp-critical-window-size=<N>` — DCP coarsening window (0 = legacy).
|
- `--dcp-critical-window-size=<N>` — DCP coarsening window (0 = legacy).
|
||||||
- `--use-experimental-conv-impl` — alternative convolution lowering.
|
- `--use-experimental-conv-impl` — alternative convolution lowering.
|
||||||
- `--ignore-concat-error` — soft-fail corner case in `ConcatOp`.
|
- `--ignore-concat-error` — soft-fail corner case in `ConcatOp`.
|
||||||
@@ -129,7 +131,8 @@ Per-operation validation (from `validation/`):
|
|||||||
```
|
```
|
||||||
validate.py \
|
validate.py \
|
||||||
--raptor-path ../cmake-build-release/Release/bin/onnx-mlir \
|
--raptor-path ../cmake-build-release/Release/bin/onnx-mlir \
|
||||||
--onnx-include-dir ../onnx-mlir/include
|
--onnx-include-dir ../onnx-mlir/include \
|
||||||
|
--core-count 1000
|
||||||
```
|
```
|
||||||
|
|
||||||
End-to-end network validation (example: first 4 layers of YOLOv11n):
|
End-to-end network validation (example: first 4 layers of YOLOv11n):
|
||||||
|
|||||||
@@ -67,7 +67,7 @@ fn main() -> Result<()> {
|
|||||||
.lock()
|
.lock()
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.init(executor.cpu().num_core(), args.output.clone());
|
.init(executor.cpu().num_core(), args.output.clone());
|
||||||
executor.execute();
|
executor.execute()?;
|
||||||
dump_memory(executor, &args)?;
|
dump_memory(executor, &args)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@@ -77,7 +77,7 @@ fn map_crossbars_to_cores<'c>(
|
|||||||
args: &Args,
|
args: &Args,
|
||||||
global_crossbars: &'c HashMap<String, Crossbar>,
|
global_crossbars: &'c HashMap<String, Crossbar>,
|
||||||
) -> Vec<Vec<&'c Crossbar>> {
|
) -> Vec<Vec<&'c Crossbar>> {
|
||||||
let mut res = Vec::new();
|
let mut res = vec![Vec::new()];
|
||||||
let num_cores = config.get("core_cnt").unwrap().as_i64().unwrap() as i32;
|
let num_cores = config.get("core_cnt").unwrap().as_i64().unwrap() as i32;
|
||||||
|
|
||||||
if let Some(folder) = args.folder.as_ref() {
|
if let Some(folder) = args.folder.as_ref() {
|
||||||
|
|||||||
@@ -312,7 +312,7 @@ fn append_record(
|
|||||||
29 => {
|
29 => {
|
||||||
inst_data_builder
|
inst_data_builder
|
||||||
.set_rd_u8(rd)
|
.set_rd_u8(rd)
|
||||||
.set_imm_core(r2_or_imm)
|
.set_imm_core(r2_or_imm + 1)
|
||||||
.set_imm_len(generic3)
|
.set_imm_len(generic3)
|
||||||
.set_offset_select_value(generic1, generic2);
|
.set_offset_select_value(generic1, generic2);
|
||||||
inst_builder.make_inst(send, inst_data_builder.build());
|
inst_builder.make_inst(send, inst_data_builder.build());
|
||||||
@@ -320,7 +320,7 @@ fn append_record(
|
|||||||
30 => {
|
30 => {
|
||||||
inst_data_builder
|
inst_data_builder
|
||||||
.set_rd_u8(rd)
|
.set_rd_u8(rd)
|
||||||
.set_imm_core(r2_or_imm)
|
.set_imm_core(r2_or_imm + 1)
|
||||||
.set_imm_len(generic3)
|
.set_imm_len(generic3)
|
||||||
.set_offset_select_value(generic1, generic2);
|
.set_offset_select_value(generic1, generic2);
|
||||||
inst_builder.make_inst(recv, inst_data_builder.build());
|
inst_builder.make_inst(recv, inst_data_builder.build());
|
||||||
@@ -366,23 +366,19 @@ fn binary_to_instructions(
|
|||||||
|
|
||||||
pub fn binary_to_executor<'a, 'b>(
|
pub fn binary_to_executor<'a, 'b>(
|
||||||
config: Value,
|
config: Value,
|
||||||
mut cores: impl Iterator<Item = &'b Vec<u8>>,
|
cores: impl Iterator<Item = &'b Vec<u8>>,
|
||||||
crossbars: Vec<Vec<&'a Crossbar>>,
|
crossbars: Vec<Vec<&'a Crossbar>>,
|
||||||
) -> Result<Executable<'a>> {
|
) -> Result<Executable<'a>> {
|
||||||
let core_cnt = config
|
let core_cnt = config
|
||||||
.get("core_cnt")
|
.get("core_cnt")
|
||||||
.context("missing core_cnt in config")?
|
.context("missing core_cnt in config")?
|
||||||
.as_i64()
|
.as_i64()
|
||||||
.context("core_cnt is not an integer")? as i32
|
.context("core_cnt is not an integer")? as i32;
|
||||||
- 1;
|
|
||||||
|
|
||||||
let cpu = CPU::new(core_cnt, crossbars);
|
let cpu = CPU::new(core_cnt, crossbars);
|
||||||
let mut core_insts_builder = CoreInstructionsBuilder::new(core_cnt as usize);
|
let mut core_insts_builder = CoreInstructionsBuilder::new(core_cnt as usize);
|
||||||
cores.next();
|
for (external_core_indx, core_bytes) in cores.enumerate() {
|
||||||
for core_indx in 1..=core_cnt {
|
let core_indx = external_core_indx as i32 + 1;
|
||||||
let core_bytes = cores
|
|
||||||
.next()
|
|
||||||
.unwrap_or_else(|| panic!("cores files less than {}", core_indx));
|
|
||||||
let instructions = binary_to_instructions(core_bytes, core_indx)?;
|
let instructions = binary_to_instructions(core_bytes, core_indx)?;
|
||||||
core_insts_builder.set_core(core_indx, instructions);
|
core_insts_builder.set_core(core_indx, instructions);
|
||||||
}
|
}
|
||||||
@@ -396,6 +392,7 @@ mod tests {
|
|||||||
HEADER_SIZE, InstructionRecord, MAGIC, RECORD_SIZE, VERSION, binary_to_instructions,
|
HEADER_SIZE, InstructionRecord, MAGIC, RECORD_SIZE, VERSION, binary_to_instructions,
|
||||||
};
|
};
|
||||||
use crate::{
|
use crate::{
|
||||||
|
functor_to_name,
|
||||||
instruction_set::{InstructionsBuilder, instruction_data::InstructionDataBuilder},
|
instruction_set::{InstructionsBuilder, instruction_data::InstructionDataBuilder},
|
||||||
json_to_instruction::json_isa::json_to_instruction,
|
json_to_instruction::json_isa::json_to_instruction,
|
||||||
};
|
};
|
||||||
@@ -490,7 +487,10 @@ mod tests {
|
|||||||
|
|
||||||
assert_eq!(json_instructions.len(), binary_instructions.len());
|
assert_eq!(json_instructions.len(), binary_instructions.len());
|
||||||
for (json_inst, binary_inst) in json_instructions.iter().zip(binary_instructions.iter()) {
|
for (json_inst, binary_inst) in json_instructions.iter().zip(binary_instructions.iter()) {
|
||||||
assert_eq!(json_inst.functor_name(), binary_inst.functor_name());
|
assert_eq!(
|
||||||
|
functor_to_name(json_inst.functor as usize),
|
||||||
|
functor_to_name(binary_inst.functor as usize)
|
||||||
|
);
|
||||||
assert_eq!(json_inst.data, binary_inst.data);
|
assert_eq!(json_inst.data, binary_inst.data);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -567,7 +567,7 @@ fn json_to_send(
|
|||||||
let (offset_select, offset_value) = json_to_offset(json.get("offset").unwrap());
|
let (offset_select, offset_value) = json_to_offset(json.get("offset").unwrap());
|
||||||
inst_data_builder
|
inst_data_builder
|
||||||
.set_rd(rd)
|
.set_rd(rd)
|
||||||
.set_imm_core(core)
|
.set_imm_core(core + 1)
|
||||||
.set_imm_len(size)
|
.set_imm_len(size)
|
||||||
.set_offset_select(offset_select)
|
.set_offset_select(offset_select)
|
||||||
.set_offset_value(offset_value);
|
.set_offset_value(offset_value);
|
||||||
@@ -588,7 +588,7 @@ fn json_to_recv(
|
|||||||
let (offset_select, offset_value) = json_to_offset(json.get("offset").unwrap());
|
let (offset_select, offset_value) = json_to_offset(json.get("offset").unwrap());
|
||||||
inst_data_builder
|
inst_data_builder
|
||||||
.set_rd(rd)
|
.set_rd(rd)
|
||||||
.set_imm_core(core)
|
.set_imm_core(core + 1)
|
||||||
.set_imm_len(size)
|
.set_imm_len(size)
|
||||||
.set_offset_select(offset_select)
|
.set_offset_select(offset_select)
|
||||||
.set_offset_value(offset_value);
|
.set_offset_value(offset_value);
|
||||||
|
|||||||
+17
-32
@@ -1,49 +1,34 @@
|
|||||||
use core::panic;
|
use serde_json::Value;
|
||||||
use std::io::{Read, Write};
|
use std::{fs::File, io::BufReader};
|
||||||
use std::{collections::HashMap, fs::File, io::BufReader};
|
|
||||||
|
|
||||||
use serde_json::{Deserializer, Map, Value};
|
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
CoreInstructionsBuilder, Executable,
|
CoreInstructionsBuilder, Executable,
|
||||||
cpu::{
|
cpu::{CPU, crossbar::Crossbar},
|
||||||
CPU,
|
instruction_set::{InstructionsBuilder, instruction_data::InstructionDataBuilder},
|
||||||
crossbar::{self, Crossbar},
|
json_to_instruction::json_isa,
|
||||||
},
|
|
||||||
instruction_set::{
|
|
||||||
InstructionsBuilder,
|
|
||||||
instruction_data::{self, InstructionData, InstructionDataBuilder},
|
|
||||||
},
|
|
||||||
json_to_instruction::{self, json_isa},
|
|
||||||
memory_manager::type_traits::TryToUsize,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
pub fn json_to_executor<'a, 'b>(
|
pub fn json_to_executor<'a, 'b>(
|
||||||
config: Value,
|
config: Value,
|
||||||
mut cores: &mut Vec<BufReader<File>>,
|
cores: &'b mut Vec<BufReader<File>>,
|
||||||
crossbars: Vec<Vec<&'a Crossbar>>,
|
crossbars: Vec<Vec<&'a Crossbar>>,
|
||||||
) -> Executable<'a> {
|
) -> Executable<'a> {
|
||||||
let cell_precision = config.get("cell_precision").unwrap().as_i64().unwrap() as i32;
|
let core_cnt = config.get("core_cnt").unwrap().as_i64().unwrap() as i32;
|
||||||
let core_cnt = config.get("core_cnt").unwrap().as_i64().unwrap() as i32 - 1;
|
|
||||||
let xbar_count = config.get("xbar_array_count").unwrap().as_i64().unwrap() as i32;
|
|
||||||
let xbar_size = config.get("xbar_size").unwrap().as_array().unwrap();
|
|
||||||
let rows_crossbar = xbar_size[0].as_i64().unwrap() as i32;
|
|
||||||
let column_corssbar = xbar_size[1].as_i64().unwrap() as i32;
|
|
||||||
|
|
||||||
let mut cpu = CPU::new(core_cnt, crossbars);
|
let cpu = CPU::new(core_cnt, crossbars);
|
||||||
let mut core_insts_builder = CoreInstructionsBuilder::new(core_cnt as usize);
|
let mut core_insts_builder = CoreInstructionsBuilder::new(core_cnt as usize);
|
||||||
// Note: cores[0] is intentionally empty and discarded
|
for (external_core_indx, json_core_reader) in cores.iter_mut().enumerate() {
|
||||||
for core_indx in 1..=core_cnt {
|
let core_indx = external_core_indx as i32 + 1;
|
||||||
let mut insts_builder = InstructionsBuilder::new();
|
let mut insts_builder = InstructionsBuilder::new();
|
||||||
let mut inst_data_builder = InstructionDataBuilder::new();
|
let mut inst_data_builder = InstructionDataBuilder::new();
|
||||||
inst_data_builder.set_core_indx(core_indx).fix_core_indx();
|
inst_data_builder.set_core_indx(core_indx).fix_core_indx();
|
||||||
let stream = Deserializer::from_reader(&mut cores[core_indx as usize]).into_iter::<Value>();
|
let json_core: Value = serde_json::from_reader(json_core_reader)
|
||||||
|
.unwrap_or_else(|err| panic!("failed to parse core{}: {}", external_core_indx, err));
|
||||||
for (i, json_inst_result) in stream.enumerate() {
|
let json_core_insts = json_core
|
||||||
let json_inst = json_inst_result.expect("Failed to parse instruction");
|
.as_array()
|
||||||
// Pass the single Value to your parser
|
.unwrap_or_else(|| panic!("core{} has not a list of instruction", external_core_indx));
|
||||||
json_isa::json_to_instruction(&mut insts_builder, &mut inst_data_builder, &json_inst);
|
for json_inst in json_core_insts {
|
||||||
drop(json_inst);
|
json_isa::json_to_instruction(&mut insts_builder, &mut inst_data_builder, json_inst);
|
||||||
}
|
}
|
||||||
core_insts_builder.set_core(core_indx, insts_builder.build());
|
core_insts_builder.set_core(core_indx, insts_builder.build());
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,2 +1,2 @@
|
|||||||
mod json_isa;
|
pub(crate) mod json_isa;
|
||||||
pub mod json_to_executor;
|
pub mod json_to_executor;
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
#![allow(unused)]
|
#![allow(unused)]
|
||||||
|
|
||||||
|
use anyhow::{Result, bail};
|
||||||
use std::{
|
use std::{
|
||||||
collections::{HashMap, HashSet},
|
collections::{HashMap, HashSet},
|
||||||
time::{Duration, SystemTime},
|
time::{Duration, SystemTime},
|
||||||
@@ -87,6 +88,11 @@ pub struct Executable<'a> {
|
|||||||
send_recv: SendRecv,
|
send_recv: SendRecv,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct DeadlockInfo {
|
||||||
|
cycle: String,
|
||||||
|
states: String,
|
||||||
|
}
|
||||||
|
|
||||||
fn print_status(core_instructions: &[CoreInstructions]) {
|
fn print_status(core_instructions: &[CoreInstructions]) {
|
||||||
let mut tot_instructions = 0;
|
let mut tot_instructions = 0;
|
||||||
let mut progress = 0;
|
let mut progress = 0;
|
||||||
@@ -118,7 +124,7 @@ impl<'a> Executable<'a> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn execute<'b>(&'b mut self)
|
pub fn execute<'b>(&'b mut self) -> Result<()>
|
||||||
where
|
where
|
||||||
'a: 'b,
|
'a: 'b,
|
||||||
{
|
{
|
||||||
@@ -153,7 +159,13 @@ impl<'a> Executable<'a> {
|
|||||||
}
|
}
|
||||||
if (now.elapsed().unwrap() > Duration::from_secs(5)) {
|
if (now.elapsed().unwrap() > Duration::from_secs(5)) {
|
||||||
print_status(cores_instructions);
|
print_status(cores_instructions);
|
||||||
check_cycle(cpu, cores_instructions, send_recv);
|
if let Some(deadlock) = detect_deadlock(cores_instructions) {
|
||||||
|
bail!(
|
||||||
|
"Deadlock cycle detected: {} [{}]",
|
||||||
|
deadlock.cycle,
|
||||||
|
deadlock.states
|
||||||
|
);
|
||||||
|
}
|
||||||
now = SystemTime::now();
|
now = SystemTime::now();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -178,8 +190,23 @@ impl<'a> Executable<'a> {
|
|||||||
}
|
}
|
||||||
print_status(cores_instructions);
|
print_status(cores_instructions);
|
||||||
|
|
||||||
|
if let Some(deadlock) = detect_deadlock(cores_instructions) {
|
||||||
|
bail!(
|
||||||
|
"Deadlock cycle detected: {} [{}]",
|
||||||
|
deadlock.cycle,
|
||||||
|
deadlock.states
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if cores_instructions
|
||||||
|
.iter()
|
||||||
|
.any(|core_inst| core_inst.program_counter < core_inst.instructions.len())
|
||||||
|
{
|
||||||
|
bail!("Execution stalled with unfinished instructions");
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(feature = "profile_time")]
|
#[cfg(feature = "profile_time")]
|
||||||
TRACER.lock().unwrap().report();
|
TRACER.lock().unwrap().report();
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn cpu(&self) -> &CPU<'a> {
|
pub fn cpu(&self) -> &CPU<'a> {
|
||||||
@@ -201,11 +228,11 @@ impl<'a> Executable<'a> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn check_cycle(cpu: &mut CPU, cores_instructions: &[CoreInstructions], send_recv: &mut SendRecv) {
|
fn detect_deadlock(cores_instructions: &[CoreInstructions]) -> Option<DeadlockInfo> {
|
||||||
#[derive(Debug, PartialEq, Eq)]
|
#[derive(Debug, PartialEq, Eq)]
|
||||||
enum CoreState {
|
enum CoreState {
|
||||||
SendingTo(i32),
|
SendingTo(i32, i32),
|
||||||
ReceivingFrom(i32),
|
ReceivingFrom(i32, i32),
|
||||||
Working,
|
Working,
|
||||||
Halted,
|
Halted,
|
||||||
}
|
}
|
||||||
@@ -223,9 +250,9 @@ fn check_cycle(cpu: &mut CPU, cores_instructions: &[CoreInstructions], send_recv
|
|||||||
let (this_core, target_core) = data.get_core_immcore();
|
let (this_core, target_core) = data.get_core_immcore();
|
||||||
|
|
||||||
if isa_recv(functor_address) {
|
if isa_recv(functor_address) {
|
||||||
states.insert(this_core, CoreState::ReceivingFrom(target_core));
|
states.insert(this_core, CoreState::ReceivingFrom(target_core, data.imm_len()));
|
||||||
} else if isa_send(functor_address) {
|
} else if isa_send(functor_address) {
|
||||||
states.insert(this_core, CoreState::SendingTo(target_core));
|
states.insert(this_core, CoreState::SendingTo(target_core, data.imm_len()));
|
||||||
} else {
|
} else {
|
||||||
states.insert(this_core, CoreState::Working);
|
states.insert(this_core, CoreState::Working);
|
||||||
}
|
}
|
||||||
@@ -235,15 +262,15 @@ fn check_cycle(cpu: &mut CPU, cores_instructions: &[CoreInstructions], send_recv
|
|||||||
|
|
||||||
for (&core_id, state) in states.iter() {
|
for (&core_id, state) in states.iter() {
|
||||||
match state {
|
match state {
|
||||||
CoreState::SendingTo(target_core) => {
|
CoreState::SendingTo(target_core, size) => {
|
||||||
let target_state = states.get(target_core).unwrap_or(&CoreState::Halted);
|
let target_state = states.get(target_core).unwrap_or(&CoreState::Halted);
|
||||||
if target_state != &CoreState::ReceivingFrom(core_id) {
|
if target_state != &CoreState::ReceivingFrom(core_id, *size) {
|
||||||
wait_for.insert(core_id, *target_core);
|
wait_for.insert(core_id, *target_core);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
CoreState::ReceivingFrom(target_core) => {
|
CoreState::ReceivingFrom(target_core, size) => {
|
||||||
let target_state = states.get(target_core).unwrap_or(&CoreState::Halted);
|
let target_state = states.get(target_core).unwrap_or(&CoreState::Halted);
|
||||||
if target_state != &CoreState::SendingTo(core_id) {
|
if target_state != &CoreState::SendingTo(core_id, *size) {
|
||||||
wait_for.insert(core_id, *target_core);
|
wait_for.insert(core_id, *target_core);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -279,11 +306,33 @@ fn check_cycle(cpu: &mut CPU, cores_instructions: &[CoreInstructions], send_recv
|
|||||||
.collect::<Vec<_>>()
|
.collect::<Vec<_>>()
|
||||||
.join(" -> ");
|
.join(" -> ");
|
||||||
|
|
||||||
|
let cycle = cycle
|
||||||
|
.iter()
|
||||||
|
.copied()
|
||||||
|
.chain(std::iter::once(waiting_for))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
let cycle_msg = format!("{} -> {}", cycle_str, waiting_for);
|
let cycle_msg = format!("{} -> {}", cycle_str, waiting_for);
|
||||||
|
let states_msg = cycle
|
||||||
|
.iter()
|
||||||
|
.filter_map(|core| {
|
||||||
|
states.get(core).map(|state| match state {
|
||||||
|
CoreState::SendingTo(target, size) => {
|
||||||
|
format!("core {} send {}B -> {}", core, size, target)
|
||||||
|
}
|
||||||
|
CoreState::ReceivingFrom(source, size) => {
|
||||||
|
format!("core {} recv {}B <- {}", core, size, source)
|
||||||
|
}
|
||||||
|
CoreState::Working => format!("core {} working", core),
|
||||||
|
CoreState::Halted => format!("core {} halted", core),
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join(", ");
|
||||||
|
|
||||||
println!("Fatal: Deadlock cycle detected: {}", cycle_msg);
|
return Some(DeadlockInfo {
|
||||||
// bail!("Deadlock detected: {}", cycle_msg);
|
cycle: cycle_msg,
|
||||||
break; // Stop tracing
|
states: states_msg,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Hit a known branch that didn't result in a cycle
|
// Hit a known branch that didn't result in a cycle
|
||||||
@@ -294,6 +343,7 @@ fn check_cycle(cpu: &mut CPU, cores_instructions: &[CoreInstructions], send_recv
|
|||||||
current_core = waiting_for;
|
current_core = waiting_for;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
None
|
||||||
}
|
}
|
||||||
|
|
||||||
fn handle_wait_sync<'a, 'b, 'c>(
|
fn handle_wait_sync<'a, 'b, 'c>(
|
||||||
|
|||||||
@@ -1,6 +1,11 @@
|
|||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
|
|
||||||
use pimcore::{Executable, cpu::CPU, instruction_set::{InstructionsBuilder, instruction_data::InstructionDataBuilder, isa::*}};
|
use pimcore::{
|
||||||
|
Executable,
|
||||||
|
cpu::crossbar::Crossbar,
|
||||||
|
instruction_set::{InstructionsBuilder, instruction_data::InstructionDataBuilder, isa::*},
|
||||||
|
memory_manager::CoreMemory,
|
||||||
|
};
|
||||||
|
|
||||||
fn simple_read(path: &Path) -> Vec<f32> {
|
fn simple_read(path: &Path) -> Vec<f32> {
|
||||||
if !path.exists() {
|
if !path.exists() {
|
||||||
@@ -17,14 +22,12 @@ fn simple_read(path: &Path) -> Vec<f32> {
|
|||||||
fn mvmul_f32(err: &str)
|
fn mvmul_f32(err: &str)
|
||||||
where
|
where
|
||||||
{
|
{
|
||||||
let mut cpu = CPU::new(0);
|
let matrix = simple_read(Path::new("tests/B.txt"));
|
||||||
cpu.reserve_crossbar(1, 1024 * size_of::<f32>(), 1024);
|
let mut crossbar = Crossbar::new(1024 * size_of::<f32>(), 1024, CoreMemory::new());
|
||||||
let (memory, crossbars) = cpu.host().get_memory_crossbar();
|
crossbar.execute_store(&matrix).unwrap();
|
||||||
let matrix = simple_read(Path::new("B.txt")) ;
|
let mut cpu = pimcore::cpu::CPU::new(0, vec![vec![&crossbar]]);
|
||||||
|
let (memory, _) = cpu.host().get_memory_crossbar();
|
||||||
|
let vector = simple_read(Path::new("tests/A.txt"));
|
||||||
crossbars.get_mut(0).unwrap().execute_store( &matrix).unwrap();
|
|
||||||
let vector = simple_read(Path::new("A.txt"));
|
|
||||||
memory.execute_store(0, &vector).unwrap();
|
memory.execute_store(0, &vector).unwrap();
|
||||||
|
|
||||||
let mut inst_builder = InstructionsBuilder::new();
|
let mut inst_builder = InstructionsBuilder::new();
|
||||||
@@ -57,7 +60,7 @@ where
|
|||||||
.cpu_mut()
|
.cpu_mut()
|
||||||
.host()
|
.host()
|
||||||
.load::<f32>(1024 * size_of::<f32>(), 1024*size_of::<f32>()).unwrap()[0].iter().zip(
|
.load::<f32>(1024 * size_of::<f32>(), 1024*size_of::<f32>()).unwrap()[0].iter().zip(
|
||||||
simple_read(Path::new("X.txt")) ).all(|(&a,b) : (&f32, f32)| {a-b < 0.001}),
|
simple_read(Path::new("tests/X.txt")) ).all(|(&a,b) : (&f32, f32)| {a-b < 0.001}),
|
||||||
"Wrong result for {}",
|
"Wrong result for {}",
|
||||||
err
|
err
|
||||||
);
|
);
|
||||||
@@ -69,5 +72,3 @@ fn mvmul_big_test() {
|
|||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,5 @@
|
|||||||
|
use pimcore::cpu::CPU;
|
||||||
|
|
||||||
|
pub fn empty_cpu(num_cores: usize) -> CPU<'static> {
|
||||||
|
CPU::new(num_cores, vec![Vec::new(); num_cores + 1])
|
||||||
|
}
|
||||||
@@ -1,51 +1,103 @@
|
|||||||
use std::{fs, io::BufReader, path::Path};
|
use std::{
|
||||||
|
fs::{self, File},
|
||||||
|
io::BufReader,
|
||||||
|
path::{Path, PathBuf},
|
||||||
|
};
|
||||||
|
|
||||||
use anyhow::{Context, Result};
|
use anyhow::{Context, Result};
|
||||||
use pimcore::json_to_instruction::json_to_executor;
|
use pimcore::{
|
||||||
|
cpu::crossbar::Crossbar,
|
||||||
|
json_to_instruction::json_to_executor,
|
||||||
|
memory_manager::CoreMemory,
|
||||||
|
};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
|
||||||
fn collect_json_from_subfolders<P: AsRef<Path>>(root: P) -> Result<Vec<(Value, Vec<Value>)>> {
|
fn collect_examples<P: AsRef<Path>>(root: P) -> Result<Vec<PathBuf>> {
|
||||||
let mut result = Vec::new();
|
let mut result = Vec::new();
|
||||||
for entry in fs::read_dir(root)? {
|
for entry in fs::read_dir(root)? {
|
||||||
let entry = entry.context("Root not found")?;
|
let entry = entry.context("Root not found")?;
|
||||||
let path = entry.path();
|
let path = entry.path();
|
||||||
if path.is_dir() {
|
if path.is_dir() {
|
||||||
let mut cores = Vec::new();
|
result.push(path);
|
||||||
let mut config: Option<Value> = None;
|
|
||||||
for sub_entry in fs::read_dir(&path)
|
|
||||||
.with_context(|| format!("File {} not readable", path.display()))?
|
|
||||||
{
|
|
||||||
let sub_entry =
|
|
||||||
sub_entry.with_context(|| format!("File {} not readable", path.display()))?;
|
|
||||||
let sub_path = sub_entry.path();
|
|
||||||
if sub_path.is_file()
|
|
||||||
&& sub_path.extension().and_then(|s| s.to_str()) == Some("json")
|
|
||||||
{
|
|
||||||
let file = fs::File::open(&sub_path)
|
|
||||||
.with_context(|| format!("Subpath {} not opened", sub_path.display()))?;
|
|
||||||
let reader = BufReader::new(file);
|
|
||||||
let val: Value = serde_json::from_reader(reader).with_context(|| format!(
|
|
||||||
"Serde reader fail for subpath {}",
|
|
||||||
sub_path.display()
|
|
||||||
))?;
|
|
||||||
if sub_path.file_name().unwrap() == "config.json" {
|
|
||||||
config = Some(val);
|
|
||||||
} else {
|
|
||||||
cores.push(val);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
result.push((config.unwrap(), cores));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(result)
|
Ok(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn core_sort_key(path: &Path) -> i32 {
|
||||||
|
let stem = path.file_stem().unwrap().to_str().unwrap();
|
||||||
|
stem[5..].parse::<i32>().unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn crossbar_sort_key(path: &Path) -> i32 {
|
||||||
|
let stem = path.file_stem().unwrap().to_str().unwrap();
|
||||||
|
stem[9..].parse::<i32>().unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_crossbars(folder: &Path, config: &Value) -> Result<Vec<Vec<Crossbar>>> {
|
||||||
|
let xbar_size = config.get("xbar_size").unwrap().as_array().unwrap();
|
||||||
|
let rows = xbar_size[0].as_i64().unwrap() as usize;
|
||||||
|
let cols = xbar_size[1].as_i64().unwrap() as usize;
|
||||||
|
let core_cnt = config.get("core_cnt").unwrap().as_i64().unwrap() as usize;
|
||||||
|
let mut owned_crossbars = Vec::with_capacity(core_cnt + 1);
|
||||||
|
owned_crossbars.push(Vec::new());
|
||||||
|
|
||||||
|
for core_idx in 0..core_cnt {
|
||||||
|
let core_folder = folder.join(format!("core_{core_idx}"));
|
||||||
|
let mut core_crossbars = Vec::new();
|
||||||
|
if core_folder.is_dir() {
|
||||||
|
let mut paths: Vec<_> = fs::read_dir(&core_folder)?
|
||||||
|
.map(|entry| entry.map(|entry| entry.path()))
|
||||||
|
.collect::<std::io::Result<Vec<_>>>()?;
|
||||||
|
paths.sort_by_cached_key(|path| crossbar_sort_key(path));
|
||||||
|
for path in paths {
|
||||||
|
if path.extension().and_then(|ext| ext.to_str()) != Some("bin") {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let bytes = fs::read(&path)
|
||||||
|
.with_context(|| format!("failed to read crossbar {}", path.display()))?;
|
||||||
|
let mut crossbar = Crossbar::new(cols * 4, rows, CoreMemory::new());
|
||||||
|
crossbar.execute_store(&bytes)?;
|
||||||
|
core_crossbars.push(crossbar);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
owned_crossbars.push(core_crossbars);
|
||||||
|
}
|
||||||
|
Ok(owned_crossbars)
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn json_folder_tester() {
|
fn json_folder_tester() {
|
||||||
let examples = collect_json_from_subfolders("data").unwrap();
|
let examples = collect_examples("tests/data").unwrap();
|
||||||
for example in examples {
|
for folder in examples {
|
||||||
let (config, cores) = example;
|
let config_path = folder.join("config.json");
|
||||||
json_to_executor::json_to_executor(config, cores.iter()).execute();
|
let config_file = File::open(&config_path).unwrap();
|
||||||
|
let config: Value = serde_json::from_reader(BufReader::new(config_file)).unwrap();
|
||||||
|
|
||||||
|
let core_cnt = config.get("core_cnt").unwrap().as_i64().unwrap() as usize;
|
||||||
|
let mut core_paths: Vec<_> = fs::read_dir(&folder)
|
||||||
|
.unwrap()
|
||||||
|
.map(|entry| entry.unwrap().path())
|
||||||
|
.filter(|path| path.extension().and_then(|ext| ext.to_str()) == Some("json"))
|
||||||
|
.filter(|path| path.file_name().unwrap() != "config.json")
|
||||||
|
.collect();
|
||||||
|
core_paths.sort_by_cached_key(|path| core_sort_key(path));
|
||||||
|
assert_eq!(core_paths.len(), core_cnt);
|
||||||
|
|
||||||
|
let mut core_readers: Vec<_> = core_paths
|
||||||
|
.into_iter()
|
||||||
|
.map(|path| BufReader::new(File::open(path).unwrap()))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let owned_crossbars = load_crossbars(&folder, &config).unwrap();
|
||||||
|
let crossbars = owned_crossbars
|
||||||
|
.iter()
|
||||||
|
.map(|core_crossbars| core_crossbars.iter().collect())
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let mut executable = json_to_executor::json_to_executor(config, &mut core_readers, crossbars);
|
||||||
|
let memory = fs::read(folder.join("memory.bin")).unwrap();
|
||||||
|
executable.cpu_mut().host().execute_store(0, &memory).unwrap();
|
||||||
|
executable.execute();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,11 +1,17 @@
|
|||||||
|
mod common;
|
||||||
|
|
||||||
use pimcore::{Executable, cpu::CPU, instruction_set::{InstructionType, InstructionsBuilder, instruction_data::InstructionDataBuilder, isa::*}};
|
use pimcore::{
|
||||||
|
Executable,
|
||||||
|
instruction_set::{
|
||||||
|
InstructionType, InstructionsBuilder, instruction_data::InstructionDataBuilder, isa::*,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
#[should_panic(expected = "Function not found for the requested size") ]
|
#[should_panic(expected = "Function not found for the requested size") ]
|
||||||
fn wrong_size_place_holder() {
|
fn wrong_size_place_holder() {
|
||||||
let cpu = CPU::new(0);
|
let cpu = common::empty_cpu(0);
|
||||||
let mut inst_builder = InstructionsBuilder::new();
|
let mut inst_builder = InstructionsBuilder::new();
|
||||||
let mut idata_build = InstructionDataBuilder::new();
|
let mut idata_build = InstructionDataBuilder::new();
|
||||||
idata_build.set_core_indx(0).fix_core_indx();
|
idata_build.set_core_indx(0).fix_core_indx();
|
||||||
@@ -30,7 +36,7 @@ fn wrong_size_place_holder() {
|
|||||||
|
|
||||||
|
|
||||||
fn place_holder(inst : InstructionType) {
|
fn place_holder(inst : InstructionType) {
|
||||||
let mut cpu = CPU::new(0);
|
let mut cpu = common::empty_cpu(0);
|
||||||
let mut idata_build = InstructionDataBuilder::new();
|
let mut idata_build = InstructionDataBuilder::new();
|
||||||
idata_build.set_core_indx(0).fix_core_indx();
|
idata_build.set_core_indx(0).fix_core_indx();
|
||||||
inst(&mut cpu, idata_build.build()).unwrap();
|
inst(&mut cpu, idata_build.build()).unwrap();
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
|
mod common;
|
||||||
|
|
||||||
use pimcore::{
|
use pimcore::{
|
||||||
Executable,
|
Executable,
|
||||||
cpu::CPU,
|
cpu::crossbar::Crossbar,
|
||||||
instruction_set::{InstructionsBuilder, instruction_data::InstructionDataBuilder, isa::*},
|
instruction_set::{InstructionsBuilder, instruction_data::InstructionDataBuilder, isa::*},
|
||||||
memory_manager::{MemoryStorable, type_traits::UpcastDestTraits},
|
memory_manager::{CoreMemory, MemoryStorable, type_traits::UpcastDestTraits},
|
||||||
};
|
};
|
||||||
|
|
||||||
/// VVADD Test
|
/// VVADD Test
|
||||||
@@ -11,7 +13,7 @@ where
|
|||||||
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
||||||
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
|
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
|
||||||
{
|
{
|
||||||
let mut cpu = CPU::new(0);
|
let mut cpu = common::empty_cpu(0);
|
||||||
let buff: [F; _] = [
|
let buff: [F; _] = [
|
||||||
1.0.into(),
|
1.0.into(),
|
||||||
2.0.into(),
|
2.0.into(),
|
||||||
@@ -115,7 +117,7 @@ where
|
|||||||
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
||||||
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
|
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
|
||||||
{
|
{
|
||||||
let mut cpu = CPU::new(0);
|
let mut cpu = common::empty_cpu(0);
|
||||||
let buff: [F; _] = [
|
let buff: [F; _] = [
|
||||||
1.0.into(),
|
1.0.into(),
|
||||||
2.0.into(),
|
2.0.into(),
|
||||||
@@ -219,7 +221,7 @@ where
|
|||||||
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
||||||
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
|
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
|
||||||
{
|
{
|
||||||
let mut cpu = CPU::new(0);
|
let mut cpu = common::empty_cpu(0);
|
||||||
let buff: [F; _] = [
|
let buff: [F; _] = [
|
||||||
1.0.into(),
|
1.0.into(),
|
||||||
2.0.into(),
|
2.0.into(),
|
||||||
@@ -323,7 +325,7 @@ where
|
|||||||
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
||||||
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
|
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
|
||||||
{
|
{
|
||||||
let mut cpu = CPU::new(0);
|
let mut cpu = common::empty_cpu(0);
|
||||||
let buff: [F; _] = [
|
let buff: [F; _] = [
|
||||||
1.0.into(),
|
1.0.into(),
|
||||||
2.0.into(),
|
2.0.into(),
|
||||||
@@ -420,7 +422,7 @@ where
|
|||||||
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
||||||
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
|
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
|
||||||
{
|
{
|
||||||
let mut cpu = CPU::new(0);
|
let mut cpu = common::empty_cpu(0);
|
||||||
let buff: [F; _] = [
|
let buff: [F; _] = [
|
||||||
9.0.into(),
|
9.0.into(),
|
||||||
2.0.into(),
|
2.0.into(),
|
||||||
@@ -524,7 +526,7 @@ where
|
|||||||
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
||||||
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
|
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
|
||||||
{
|
{
|
||||||
let mut cpu = CPU::new(0);
|
let mut cpu = common::empty_cpu(0);
|
||||||
let buff: [F; _] = [
|
let buff: [F; _] = [
|
||||||
9.0.into(),
|
9.0.into(),
|
||||||
2.0.into(),
|
2.0.into(),
|
||||||
@@ -562,6 +564,7 @@ where
|
|||||||
vavg,
|
vavg,
|
||||||
idata_build
|
idata_build
|
||||||
.set_rdr1r2(3, 1, 1)
|
.set_rdr1r2(3, 1, 1)
|
||||||
|
.set_offset_select(1)
|
||||||
.set_imm_len(8 * size_of::<F>() as i32)
|
.set_imm_len(8 * size_of::<F>() as i32)
|
||||||
.build(),
|
.build(),
|
||||||
);
|
);
|
||||||
@@ -617,7 +620,7 @@ where
|
|||||||
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
||||||
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
|
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
|
||||||
{
|
{
|
||||||
let mut cpu = CPU::new(0);
|
let mut cpu = common::empty_cpu(0);
|
||||||
let buff: [F; _] = [
|
let buff: [F; _] = [
|
||||||
(-9.0).into(),
|
(-9.0).into(),
|
||||||
2.0.into(),
|
2.0.into(),
|
||||||
@@ -717,7 +720,7 @@ where
|
|||||||
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
||||||
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable + UpcastDestTraits<T>,
|
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable + UpcastDestTraits<T>,
|
||||||
{
|
{
|
||||||
let mut cpu = CPU::new(0);
|
let mut cpu = common::empty_cpu(0);
|
||||||
let buff: [F; _] = [
|
let buff: [F; _] = [
|
||||||
0.1.into(),
|
0.1.into(),
|
||||||
0.2.into(),
|
0.2.into(),
|
||||||
@@ -819,7 +822,7 @@ where
|
|||||||
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
||||||
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable + UpcastDestTraits<T>,
|
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable + UpcastDestTraits<T>,
|
||||||
{
|
{
|
||||||
let mut cpu = CPU::new(0);
|
let mut cpu = common::empty_cpu(0);
|
||||||
let buff: [F; _] = [
|
let buff: [F; _] = [
|
||||||
0.1.into(),
|
0.1.into(),
|
||||||
0.2.into(),
|
0.2.into(),
|
||||||
@@ -923,9 +926,6 @@ where
|
|||||||
M: From<f32> + std::fmt::Debug + PartialEq<M> + MemoryStorable,
|
M: From<f32> + std::fmt::Debug + PartialEq<M> + MemoryStorable,
|
||||||
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable + UpcastDestTraits<T>,
|
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable + UpcastDestTraits<T>,
|
||||||
{
|
{
|
||||||
let mut cpu = CPU::new(0);
|
|
||||||
cpu.reserve_crossbar(1, 4 * size_of::<M>(), 4);
|
|
||||||
let (memory, crossbars) = cpu.host().get_memory_crossbar();
|
|
||||||
let matrix: [M; _] = [
|
let matrix: [M; _] = [
|
||||||
1.0.into(),
|
1.0.into(),
|
||||||
2.0.into(),
|
2.0.into(),
|
||||||
@@ -944,7 +944,10 @@ where
|
|||||||
15.0.into(),
|
15.0.into(),
|
||||||
16.0.into(),
|
16.0.into(),
|
||||||
];
|
];
|
||||||
crossbars.get_mut(0).unwrap().execute_store( &matrix).unwrap();
|
let mut crossbar = Crossbar::new(4 * size_of::<M>(), 4, CoreMemory::new());
|
||||||
|
crossbar.execute_store(&matrix).unwrap();
|
||||||
|
let mut cpu = pimcore::cpu::CPU::new(0, vec![vec![&crossbar]]);
|
||||||
|
let (memory, _) = cpu.host().get_memory_crossbar();
|
||||||
let vector: [F; _] = [
|
let vector: [F; _] = [
|
||||||
1.0.into(),
|
1.0.into(),
|
||||||
2.0.into(),
|
2.0.into(),
|
||||||
@@ -1054,5 +1057,3 @@ fn mvmul_test() {
|
|||||||
mvmul_test_generic::<f64,f64,f64>("mvmul<f64,f64,f64>",1);
|
mvmul_test_generic::<f64,f64,f64>("mvmul<f64,f64,f64>",1);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,12 +1,13 @@
|
|||||||
|
mod common;
|
||||||
|
|
||||||
use pimcore::{
|
use pimcore::{
|
||||||
Executable, CoreInstructionsBuilder,
|
Executable, CoreInstructionsBuilder,
|
||||||
cpu::CPU,
|
|
||||||
instruction_set::{InstructionsBuilder, instruction_data::InstructionDataBuilder, isa::*},
|
instruction_set::{InstructionsBuilder, instruction_data::InstructionDataBuilder, isa::*},
|
||||||
};
|
};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn ld_test() {
|
fn ld_test() {
|
||||||
let mut cpu = CPU::new(1);
|
let mut cpu = common::empty_cpu(1);
|
||||||
let mut core_instruction_builder = CoreInstructionsBuilder::new(1);
|
let mut core_instruction_builder = CoreInstructionsBuilder::new(1);
|
||||||
let buff: [f32; _] = [
|
let buff: [f32; _] = [
|
||||||
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
|
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
|
||||||
@@ -41,7 +42,7 @@ fn ld_test() {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn st_test() {
|
fn st_test() {
|
||||||
let mut cpu = CPU::new(1);
|
let mut cpu = common::empty_cpu(1);
|
||||||
let mut core_instruction_builder = CoreInstructionsBuilder::new(1);
|
let mut core_instruction_builder = CoreInstructionsBuilder::new(1);
|
||||||
let buff: [f32; _] = [
|
let buff: [f32; _] = [
|
||||||
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
|
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
|
||||||
@@ -76,7 +77,7 @@ fn st_test() {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn lldi_test() {
|
fn lldi_test() {
|
||||||
let cpu = CPU::new(1);
|
let cpu = common::empty_cpu(1);
|
||||||
let mut core_instruction_builder = CoreInstructionsBuilder::new(1);
|
let mut core_instruction_builder = CoreInstructionsBuilder::new(1);
|
||||||
let mut inst_builder = InstructionsBuilder::new();
|
let mut inst_builder = InstructionsBuilder::new();
|
||||||
let mut idata_build = InstructionDataBuilder::new();
|
let mut idata_build = InstructionDataBuilder::new();
|
||||||
@@ -106,7 +107,7 @@ fn lldi_test() {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn lmv_test() {
|
fn lmv_test() {
|
||||||
let mut cpu = CPU::new(1);
|
let mut cpu = common::empty_cpu(1);
|
||||||
let mut core_instruction_builder = CoreInstructionsBuilder::new(1);
|
let mut core_instruction_builder = CoreInstructionsBuilder::new(1);
|
||||||
let buff: [f32; _] = [
|
let buff: [f32; _] = [
|
||||||
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
|
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
|
||||||
@@ -148,7 +149,7 @@ fn lmv_test() {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn simple_send_recv_test() {
|
fn simple_send_recv_test() {
|
||||||
let mut cpu = CPU::new(2);
|
let mut cpu = common::empty_cpu(2);
|
||||||
let mut core_instruction_builder = CoreInstructionsBuilder::new(2);
|
let mut core_instruction_builder = CoreInstructionsBuilder::new(2);
|
||||||
let buff: [f32; _] = [
|
let buff: [f32; _] = [
|
||||||
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
|
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
|
||||||
@@ -207,7 +208,7 @@ fn simple_send_recv_test() {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn multiple_send_recv_test() {
|
fn multiple_send_recv_test() {
|
||||||
let mut cpu = CPU::new(4);
|
let mut cpu = common::empty_cpu(4);
|
||||||
let mut core_instruction_builder = CoreInstructionsBuilder::new(4);
|
let mut core_instruction_builder = CoreInstructionsBuilder::new(4);
|
||||||
let buff: [f32; _] = [
|
let buff: [f32; _] = [
|
||||||
1.0, 1.0, 1.0, 1.0, 1.0
|
1.0, 1.0, 1.0, 1.0, 1.0
|
||||||
@@ -226,7 +227,7 @@ fn multiple_send_recv_test() {
|
|||||||
];
|
];
|
||||||
cpu.core(4).execute_store(0, &buff).unwrap();
|
cpu.core(4).execute_store(0, &buff).unwrap();
|
||||||
|
|
||||||
let send_inst = |cpu :&mut CPU, core_instruction_builder: &mut CoreInstructionsBuilder, inst_builder: &mut InstructionsBuilder, from : i32, to : i32| {
|
let send_inst = |inst_builder: &mut InstructionsBuilder, from: i32, to: i32| {
|
||||||
let mut idata_build = InstructionDataBuilder::new();
|
let mut idata_build = InstructionDataBuilder::new();
|
||||||
idata_build.set_core_indx(from).fix_core_indx();
|
idata_build.set_core_indx(from).fix_core_indx();
|
||||||
inst_builder.make_inst(sldi, idata_build.set_rdimm(1, from*size_of::<f32>() as i32).build());
|
inst_builder.make_inst(sldi, idata_build.set_rdimm(1, from*size_of::<f32>() as i32).build());
|
||||||
@@ -240,7 +241,7 @@ fn multiple_send_recv_test() {
|
|||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
let recv_inst = |cpu :&mut CPU, core_instruction_builder: &mut CoreInstructionsBuilder, mut inst_builder: &mut InstructionsBuilder, to : i32, from : i32| {
|
let recv_inst = |inst_builder: &mut InstructionsBuilder, to: i32, from: i32| {
|
||||||
let mut idata_build = InstructionDataBuilder::new();
|
let mut idata_build = InstructionDataBuilder::new();
|
||||||
idata_build.set_core_indx(to).fix_core_indx();
|
idata_build.set_core_indx(to).fix_core_indx();
|
||||||
inst_builder.make_inst(sldi, idata_build.set_rdimm(1, from*size_of::<f32>() as i32).build());
|
inst_builder.make_inst(sldi, idata_build.set_rdimm(1, from*size_of::<f32>() as i32).build());
|
||||||
@@ -257,26 +258,26 @@ fn multiple_send_recv_test() {
|
|||||||
|
|
||||||
|
|
||||||
// 1 -> 3
|
// 1 -> 3
|
||||||
send_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,1, 3);
|
send_inst(&mut inst_builder, 1, 3);
|
||||||
core_instruction_builder.set_core(1, inst_builder.build());
|
core_instruction_builder.set_core(1, inst_builder.build());
|
||||||
|
|
||||||
// 2 -> 3
|
// 2 -> 3
|
||||||
// 2 <- 4
|
// 2 <- 4
|
||||||
send_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,2, 3);
|
send_inst(&mut inst_builder, 2, 3);
|
||||||
recv_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,2, 4);
|
recv_inst(&mut inst_builder, 2, 4);
|
||||||
core_instruction_builder.set_core(2, inst_builder.build());
|
core_instruction_builder.set_core(2, inst_builder.build());
|
||||||
|
|
||||||
// 3 <- 2
|
// 3 <- 2
|
||||||
// 3 <- 4
|
// 3 <- 4
|
||||||
// 3 <- 1
|
// 3 <- 1
|
||||||
recv_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,3, 2);
|
recv_inst(&mut inst_builder, 3, 2);
|
||||||
recv_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,3, 4);
|
recv_inst(&mut inst_builder, 3, 4);
|
||||||
recv_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,3, 1);
|
recv_inst(&mut inst_builder, 3, 1);
|
||||||
core_instruction_builder.set_core(3, inst_builder.build());
|
core_instruction_builder.set_core(3, inst_builder.build());
|
||||||
// 4 -> 2
|
// 4 -> 2
|
||||||
// 4 -> 3
|
// 4 -> 3
|
||||||
send_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,4, 2);
|
send_inst(&mut inst_builder, 4, 2);
|
||||||
send_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,4, 3);
|
send_inst(&mut inst_builder, 4, 3);
|
||||||
core_instruction_builder.set_core(4, inst_builder.build());
|
core_instruction_builder.set_core(4, inst_builder.build());
|
||||||
|
|
||||||
let mut executable = Executable::new(cpu, core_instruction_builder.build());
|
let mut executable = Executable::new(cpu, core_instruction_builder.build());
|
||||||
|
|||||||
Submodule backend-simulators/pim/pimsim-nn updated: 895e9892b0...6d3b898e6b
@@ -110,6 +110,14 @@ llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticVa
|
|||||||
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) / static_cast<uint64_t>(*rhs));
|
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) / static_cast<uint64_t>(*rhs));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (auto minOp = mlir::dyn_cast<mlir::arith::MinUIOp>(definingOp)) {
|
||||||
|
auto lhs = resolveIndexValueImpl(minOp.getLhs(), knowledge);
|
||||||
|
auto rhs = resolveIndexValueImpl(minOp.getRhs(), knowledge);
|
||||||
|
if (failed(lhs) || failed(rhs))
|
||||||
|
return mlir::failure();
|
||||||
|
return static_cast<int64_t>(std::min(static_cast<uint64_t>(*lhs), static_cast<uint64_t>(*rhs)));
|
||||||
|
}
|
||||||
|
|
||||||
if (auto remOp = mlir::dyn_cast<mlir::arith::RemUIOp>(definingOp)) {
|
if (auto remOp = mlir::dyn_cast<mlir::arith::RemUIOp>(definingOp)) {
|
||||||
auto lhs = resolveIndexValueImpl(remOp.getLhs(), knowledge);
|
auto lhs = resolveIndexValueImpl(remOp.getLhs(), knowledge);
|
||||||
auto rhs = resolveIndexValueImpl(remOp.getRhs(), knowledge);
|
auto rhs = resolveIndexValueImpl(remOp.getRhs(), knowledge);
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ bool isCoreStaticAddressOp(mlir::Operation* op) {
|
|||||||
mlir::arith::SubIOp,
|
mlir::arith::SubIOp,
|
||||||
mlir::arith::MulIOp,
|
mlir::arith::MulIOp,
|
||||||
mlir::arith::DivUIOp,
|
mlir::arith::DivUIOp,
|
||||||
|
mlir::arith::MinUIOp,
|
||||||
mlir::arith::RemUIOp,
|
mlir::arith::RemUIOp,
|
||||||
mlir::arith::IndexCastOp,
|
mlir::arith::IndexCastOp,
|
||||||
mlir::memref::AllocOp,
|
mlir::memref::AllocOp,
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
|
|
||||||
|
|
||||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|||||||
@@ -7,10 +7,34 @@
|
|||||||
#include "llvm/ADT/ArrayRef.h"
|
#include "llvm/ADT/ArrayRef.h"
|
||||||
#include "llvm/ADT/StringRef.h"
|
#include "llvm/ADT/StringRef.h"
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
#include <system_error>
|
#include <system_error>
|
||||||
|
|
||||||
namespace onnx_mlir::pim {
|
namespace onnx_mlir::pim {
|
||||||
|
|
||||||
|
struct CappedDiagnosticReporter {
|
||||||
|
explicit CappedDiagnosticReporter(int64_t maxReportedFailures = 8) : maxReportedFailures(maxReportedFailures) {}
|
||||||
|
|
||||||
|
template <typename EmitFn>
|
||||||
|
void report(mlir::Operation* op, EmitFn&& emit) {
|
||||||
|
numFailures++;
|
||||||
|
if (numFailures <= maxReportedFailures)
|
||||||
|
emit(op);
|
||||||
|
}
|
||||||
|
|
||||||
|
void emitSuppressedSummary(mlir::Operation* op, llvm::StringRef failureDescription) const {
|
||||||
|
if (numFailures > maxReportedFailures)
|
||||||
|
op->emitError() << "suppressed " << (numFailures - maxReportedFailures) << " additional "
|
||||||
|
<< failureDescription;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool hasFailure() const { return numFailures != 0; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
int64_t maxReportedFailures;
|
||||||
|
int64_t numFailures = 0;
|
||||||
|
};
|
||||||
|
|
||||||
/// Emits a consistent diagnostic for target paths that require static shapes.
|
/// Emits a consistent diagnostic for target paths that require static shapes.
|
||||||
mlir::InFlightDiagnostic emitUnsupportedStaticShapeDiagnostic(mlir::Operation* op, llvm::StringRef valueDescription);
|
mlir::InFlightDiagnostic emitUnsupportedStaticShapeDiagnostic(mlir::Operation* op, llvm::StringRef valueDescription);
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
#include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp"
|
|
||||||
|
|
||||||
#include "llvm/Support/Format.h"
|
#include "llvm/Support/Format.h"
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/Support/FileSystemUtils.hpp"
|
#include "src/Accelerators/PIM/Common/Support/FileSystemUtils.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp"
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,9 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "llvm/ADT/STLExtras.h"
|
|
||||||
#include "llvm/ADT/ArrayRef.h"
|
#include "llvm/ADT/ArrayRef.h"
|
||||||
|
#include "llvm/ADT/STLExtras.h"
|
||||||
#include "llvm/Support/raw_ostream.h"
|
#include "llvm/Support/raw_ostream.h"
|
||||||
|
|
||||||
#include <cstdint>
|
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|||||||
@@ -20,38 +20,6 @@ using namespace mlir;
|
|||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
OnnxMlirCompilerErrorCodes writeHostCoreArtifacts(StringRef outputDirPath) {
|
|
||||||
std::error_code errorCode;
|
|
||||||
std::string outputHostCorePath = outputDirPath.str() + "/core_0.pim";
|
|
||||||
raw_fd_ostream hostFileStream(outputHostCorePath, errorCode, sys::fs::OF_None);
|
|
||||||
if (errorCode) {
|
|
||||||
errs() << "Error while opening host core file `" << outputHostCorePath << "`: " << errorCode.message() << '\n';
|
|
||||||
return InvalidOutputFileAccess;
|
|
||||||
}
|
|
||||||
|
|
||||||
pim_binary::writeHeader(hostFileStream);
|
|
||||||
pim_binary::InstructionRecord noop;
|
|
||||||
noop.opcode = pim_binary::Opcode::sldi;
|
|
||||||
pim_binary::writeInstructionRecord(hostFileStream, noop);
|
|
||||||
pim_binary::writeInstructionRecord(hostFileStream, noop);
|
|
||||||
pim_binary::patchInstructionCount(hostFileStream, 2);
|
|
||||||
hostFileStream.close();
|
|
||||||
|
|
||||||
if (pimEmitJson.getValue()) {
|
|
||||||
std::string outputHostJsonPath = outputDirPath.str() + "/core_0.json";
|
|
||||||
raw_fd_ostream hostJsonStream(outputHostJsonPath, errorCode);
|
|
||||||
if (errorCode) {
|
|
||||||
errs() << "Error while opening host core json file `" << outputHostJsonPath << "`: " << errorCode.message()
|
|
||||||
<< '\n';
|
|
||||||
return InvalidOutputFileAccess;
|
|
||||||
}
|
|
||||||
// The host core json contains two no-op-like instructions to satisfy pimsim-nn
|
|
||||||
hostJsonStream << "[{\"imm\":0,\"op\":\"sldi\",\"rd\":0},{\"imm\":0,\"op\":\"sldi\",\"rd\":0}]";
|
|
||||||
hostJsonStream.close();
|
|
||||||
}
|
|
||||||
return CompilerSuccess;
|
|
||||||
}
|
|
||||||
|
|
||||||
OnnxMlirCompilerErrorCodes
|
OnnxMlirCompilerErrorCodes
|
||||||
writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory& memory, StringRef outputDirPath) {
|
writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory& memory, StringRef outputDirPath) {
|
||||||
auto memoryFilePath = (outputDirPath + "/memory.bin").str();
|
auto memoryFilePath = (outputDirPath + "/memory.bin").str();
|
||||||
@@ -109,9 +77,6 @@ OnnxMlirCompilerErrorCodes writeConfigJson(func::FuncOp funcOp,
|
|||||||
json::Object configJson;
|
json::Object configJson;
|
||||||
|
|
||||||
configJson["core_cnt"] = maxCoreId + 1;
|
configJson["core_cnt"] = maxCoreId + 1;
|
||||||
configJson["adc_count"] = 16;
|
|
||||||
configJson["cell_precision"] = 2;
|
|
||||||
configJson["xbar_array_count"] = crossbarCountInCore.getValue();
|
|
||||||
configJson["xbar_size"] = {crossbarSize.getValue(), crossbarSize.getValue()};
|
configJson["xbar_size"] = {crossbarSize.getValue(), crossbarSize.getValue()};
|
||||||
configJson["array_group_map"] = std::move(xbarsPerArrayGroup);
|
configJson["array_group_map"] = std::move(xbarsPerArrayGroup);
|
||||||
|
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ namespace onnx_mlir {
|
|||||||
|
|
||||||
class PimAcceleratorMemory;
|
class PimAcceleratorMemory;
|
||||||
|
|
||||||
OnnxMlirCompilerErrorCodes writeHostCoreArtifacts(llvm::StringRef outputDirPath);
|
|
||||||
OnnxMlirCompilerErrorCodes writeMemoryBinary(mlir::ModuleOp moduleOp,
|
OnnxMlirCompilerErrorCodes writeMemoryBinary(mlir::ModuleOp moduleOp,
|
||||||
mlir::func::FuncOp funcOp,
|
mlir::func::FuncOp funcOp,
|
||||||
PimAcceleratorMemory& memory,
|
PimAcceleratorMemory& memory,
|
||||||
|
|||||||
@@ -70,9 +70,7 @@ inline void writeUint32LE(llvm::raw_ostream& os, uint32_t value) {
|
|||||||
os.write(bytes.data(), bytes.size());
|
os.write(bytes.data(), bytes.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void writeInt32LE(llvm::raw_ostream& os, int32_t value) {
|
inline void writeInt32LE(llvm::raw_ostream& os, int32_t value) { writeUint32LE(os, static_cast<uint32_t>(value)); }
|
||||||
writeUint32LE(os, static_cast<uint32_t>(value));
|
|
||||||
}
|
|
||||||
|
|
||||||
inline void writeHeader(llvm::raw_ostream& os) {
|
inline void writeHeader(llvm::raw_ostream& os) {
|
||||||
os.write(kMagic, sizeof(kMagic));
|
os.write(kMagic, sizeof(kMagic));
|
||||||
@@ -235,9 +233,7 @@ inline InstructionRecord makeInstructionRecord(const llvm::json::Object& instruc
|
|||||||
case Opcode::sldi:
|
case Opcode::sldi:
|
||||||
case Opcode::saddi:
|
case Opcode::saddi:
|
||||||
case Opcode::smuli:
|
case Opcode::smuli:
|
||||||
case Opcode::lldi:
|
case Opcode::lldi: record.r2OrImm = getOptionalInt(instruction, "imm"); break;
|
||||||
record.r2OrImm = getOptionalInt(instruction, "imm");
|
|
||||||
break;
|
|
||||||
case Opcode::mvmul:
|
case Opcode::mvmul:
|
||||||
record.r2OrImm = getOptionalInt(instruction, "mbiw");
|
record.r2OrImm = getOptionalInt(instruction, "mbiw");
|
||||||
record.generic1 = getOptionalInt(instruction, "relu");
|
record.generic1 = getOptionalInt(instruction, "relu");
|
||||||
@@ -252,9 +248,7 @@ inline InstructionRecord makeInstructionRecord(const llvm::json::Object& instruc
|
|||||||
record.r2OrImm = getOptionalInt(instruction, "core");
|
record.r2OrImm = getOptionalInt(instruction, "core");
|
||||||
record.generic3 = getOptionalInt(instruction, "size");
|
record.generic3 = getOptionalInt(instruction, "size");
|
||||||
break;
|
break;
|
||||||
default:
|
default: record.r2OrImm = getOptionalInt(instruction, "rs2"); break;
|
||||||
record.r2OrImm = getOptionalInt(instruction, "rs2");
|
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (record.opcode != Opcode::mvmul && record.opcode != Opcode::setbw) {
|
if (record.opcode != Opcode::mvmul && record.opcode != Opcode::setbw) {
|
||||||
@@ -371,8 +365,7 @@ inline llvm::json::Object makeInstructionJson(const InstructionRecord& record) {
|
|||||||
break;
|
break;
|
||||||
case Opcode::wait:
|
case Opcode::wait:
|
||||||
case Opcode::sync:
|
case Opcode::sync:
|
||||||
case Opcode::nop:
|
case Opcode::nop: break;
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return instruction;
|
return instruction;
|
||||||
|
|||||||
@@ -367,7 +367,7 @@ void PimCodeGen::emitMemCopyOp(StringRef opName,
|
|||||||
instruction.generic1 = 0;
|
instruction.generic1 = 0;
|
||||||
instruction.generic2 = 0;
|
instruction.generic2 = 0;
|
||||||
instruction.generic3 = static_cast<int32_t>(size);
|
instruction.generic3 = static_cast<int32_t>(size);
|
||||||
(void)sizeFieldName;
|
(void) sizeFieldName;
|
||||||
emitInstruction(instruction);
|
emitInstruction(instruction);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -875,11 +875,6 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
|
|||||||
if (auto err = writeMemoryBinary(moduleOp, funcOp, memory, outputDirPath))
|
if (auto err = writeMemoryBinary(moduleOp, funcOp, memory, outputDirPath))
|
||||||
return err;
|
return err;
|
||||||
|
|
||||||
if (auto err = writeHostCoreArtifacts(outputDirPath))
|
|
||||||
return err;
|
|
||||||
|
|
||||||
// For each core, specify the number of crossbar per array group.
|
|
||||||
// This implementation always assigns one crossbar per group.
|
|
||||||
json::Object xbarsPerArrayGroup;
|
json::Object xbarsPerArrayGroup;
|
||||||
size_t maxCoreId = 0;
|
size_t maxCoreId = 0;
|
||||||
uint64_t nextBatchReportId = 0;
|
uint64_t nextBatchReportId = 0;
|
||||||
@@ -891,7 +886,7 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
|
|||||||
SmallDenseMap<memref::GlobalOp, MemEntry, 16> materializedHostGlobals =
|
SmallDenseMap<memref::GlobalOp, MemEntry, 16> materializedHostGlobals =
|
||||||
collectMaterializedHostGlobals(moduleOp, funcOp, memory);
|
collectMaterializedHostGlobals(moduleOp, funcOp, memory);
|
||||||
llvm::DenseMap<size_t, size_t> emittedCoreIds;
|
llvm::DenseMap<size_t, size_t> emittedCoreIds;
|
||||||
size_t nextEmittedCoreId = 1;
|
size_t nextEmittedCoreId = 0;
|
||||||
|
|
||||||
for (Operation* op : coreLikeOps) {
|
for (Operation* op : coreLikeOps) {
|
||||||
if (auto coreOp = dyn_cast<pim::PimCoreOp>(op)) {
|
if (auto coreOp = dyn_cast<pim::PimCoreOp>(op)) {
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
|
|
||||||
|
#include "llvm/Support/ErrorHandling.h"
|
||||||
|
|
||||||
#define DEBUG_TYPE "PimCompilerOptions"
|
#define DEBUG_TYPE "PimCompilerOptions"
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
@@ -13,6 +15,14 @@ llvm::cl::opt<PimEmissionTargetType> pimEmissionTarget(
|
|||||||
llvm::cl::init(EmitPimCodegen),
|
llvm::cl::init(EmitPimCodegen),
|
||||||
llvm::cl::cat(OnnxMlirOptions));
|
llvm::cl::cat(OnnxMlirOptions));
|
||||||
|
|
||||||
|
llvm::cl::opt<PimMergeSchedulerType> pimMergeScheduler(
|
||||||
|
"pim-merge-scheduler",
|
||||||
|
llvm::cl::desc("Scheduler used by the Spatial merge-compute-nodes pass"),
|
||||||
|
llvm::cl::values(clEnumValN(MergeSchedulerPeft, "peft", "Use PEFT scheduling")),
|
||||||
|
llvm::cl::values(clEnumValN(MergeSchedulerDcp, "dcp", "Use the legacy DCP-inspired scheduler")),
|
||||||
|
llvm::cl::init(MergeSchedulerPeft),
|
||||||
|
llvm::cl::cat(OnnxMlirOptions));
|
||||||
|
|
||||||
llvm::cl::opt<bool>
|
llvm::cl::opt<bool>
|
||||||
pimOnlyCodegen("pim-only-codegen",
|
pimOnlyCodegen("pim-only-codegen",
|
||||||
llvm::cl::desc("Only generate code for PIM (assume input is already in bufferized PIM IR)"),
|
llvm::cl::desc("Only generate code for PIM (assume input is already in bufferized PIM IR)"),
|
||||||
@@ -30,19 +40,19 @@ llvm::cl::opt<bool> pimEmitJson("pim-emit-json",
|
|||||||
llvm::cl::cat(OnnxMlirOptions));
|
llvm::cl::cat(OnnxMlirOptions));
|
||||||
|
|
||||||
llvm::cl::opt<size_t>
|
llvm::cl::opt<size_t>
|
||||||
crossbarSize("crossbar-size", llvm::cl::desc("Width and heigth of a single crossbar"), llvm::cl::init(2));
|
crossbarSize("crossbar-size", llvm::cl::desc("Width and height of a single crossbar"), llvm::cl::init(2));
|
||||||
|
|
||||||
llvm::cl::opt<size_t>
|
llvm::cl::opt<size_t>
|
||||||
crossbarCountInCore("crossbar-count", llvm::cl::desc("Number of crossbars in each core"), llvm::cl::init(256));
|
crossbarCountInCore("crossbar-count", llvm::cl::desc("Number of crossbars in each core"), llvm::cl::init(256));
|
||||||
|
|
||||||
llvm::cl::opt<long> coresCount("core-count",
|
llvm::cl::opt<long> coresCount("core-count",
|
||||||
llvm::cl::desc("Number of cores in the chip. `-1` to use the minimum amount of cores."),
|
llvm::cl::desc("Number of cores in the chip. Required for PIM compilation."),
|
||||||
llvm::cl::init(-1));
|
llvm::cl::init(-1));
|
||||||
|
|
||||||
llvm::cl::opt<size_t> dcpCriticalWindowSize(
|
llvm::cl::opt<size_t> dcpCriticalWindowSize(
|
||||||
"dcp-critical-window-size",
|
"dcp-critical-window-size",
|
||||||
llvm::cl::desc("Number of lowest-slack virtual nodes considered by each DCP coarsening iteration. "
|
llvm::cl::desc("Number of lowest-slack virtual nodes considered by each DCP coarsening iteration. "
|
||||||
"Use 0 to run the legacy full-graph DCP analysis."),
|
"Use 0 to run the legacy full-graph DCP analysis. Only used by the DCP scheduler."),
|
||||||
llvm::cl::init(4000));
|
llvm::cl::init(4000));
|
||||||
|
|
||||||
llvm::cl::opt<bool>
|
llvm::cl::opt<bool>
|
||||||
@@ -50,4 +60,13 @@ llvm::cl::opt<bool>
|
|||||||
llvm::cl::desc("Ignore ConcatOp corner case: do not assert and do a simplification"),
|
llvm::cl::desc("Ignore ConcatOp corner case: do not assert and do a simplification"),
|
||||||
llvm::cl::init(false));
|
llvm::cl::init(false));
|
||||||
|
|
||||||
|
bool hasExplicitPimCoreCount() { return coresCount.getNumOccurrences() != 0; }
|
||||||
|
|
||||||
|
void verifyExplicitPimCoreCount() {
|
||||||
|
if (!hasExplicitPimCoreCount())
|
||||||
|
llvm::report_fatal_error("PIM compilation requires an explicit --core-count=<positive integer>");
|
||||||
|
if (coresCount.getValue() <= 0)
|
||||||
|
llvm::report_fatal_error("PIM compilation requires --core-count to be a positive integer");
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -20,8 +20,14 @@ typedef enum {
|
|||||||
EmitPimCodegen = 3
|
EmitPimCodegen = 3
|
||||||
} PimEmissionTargetType;
|
} PimEmissionTargetType;
|
||||||
|
|
||||||
|
typedef enum {
|
||||||
|
MergeSchedulerPeft = 0,
|
||||||
|
MergeSchedulerDcp = 1,
|
||||||
|
} PimMergeSchedulerType;
|
||||||
|
|
||||||
extern llvm::cl::OptionCategory OnnxMlirOptions;
|
extern llvm::cl::OptionCategory OnnxMlirOptions;
|
||||||
extern llvm::cl::opt<PimEmissionTargetType> pimEmissionTarget;
|
extern llvm::cl::opt<PimEmissionTargetType> pimEmissionTarget;
|
||||||
|
extern llvm::cl::opt<PimMergeSchedulerType> pimMergeScheduler;
|
||||||
|
|
||||||
extern llvm::cl::opt<bool> pimOnlyCodegen;
|
extern llvm::cl::opt<bool> pimOnlyCodegen;
|
||||||
extern llvm::cl::opt<bool> useExperimentalConvImpl;
|
extern llvm::cl::opt<bool> useExperimentalConvImpl;
|
||||||
@@ -32,6 +38,9 @@ extern llvm::cl::opt<size_t> crossbarCountInCore;
|
|||||||
extern llvm::cl::opt<long> coresCount;
|
extern llvm::cl::opt<long> coresCount;
|
||||||
extern llvm::cl::opt<size_t> dcpCriticalWindowSize;
|
extern llvm::cl::opt<size_t> dcpCriticalWindowSize;
|
||||||
|
|
||||||
|
bool hasExplicitPimCoreCount();
|
||||||
|
void verifyExplicitPimCoreCount();
|
||||||
|
|
||||||
// This option, by default set to false, will ignore an error when resolving a
|
// This option, by default set to false, will ignore an error when resolving a
|
||||||
// specific tiles of the operands of a concat. This specific case is when the
|
// specific tiles of the operands of a concat. This specific case is when the
|
||||||
// wanted tile is generated by two separate operands of the concat. If this is
|
// wanted tile is generated by two separate operands of the concat. If this is
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
|
|||||||
PassManager& pm,
|
PassManager& pm,
|
||||||
EmissionTargetType& emissionTarget,
|
EmissionTargetType& emissionTarget,
|
||||||
std::string outputNameNoExt) {
|
std::string outputNameNoExt) {
|
||||||
|
verifyExplicitPimCoreCount();
|
||||||
|
|
||||||
if (pimOnlyCodegen) {
|
if (pimOnlyCodegen) {
|
||||||
// Skip all the lowering passes and directly generate code for PIM.
|
// Skip all the lowering passes and directly generate code for PIM.
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ struct DenseWeightView {
|
|||||||
};
|
};
|
||||||
|
|
||||||
FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value weight) {
|
FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value weight) {
|
||||||
SmallVector<memref::SubViewOp> subviews;
|
SmallVector<Operation*> viewOps;
|
||||||
mlir::Value current = weight;
|
mlir::Value current = weight;
|
||||||
memref::GetGlobalOp getGlobalOp;
|
memref::GetGlobalOp getGlobalOp;
|
||||||
|
|
||||||
@@ -46,7 +46,7 @@ FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value
|
|||||||
if (auto subview = dyn_cast<memref::SubViewOp>(defOp)) {
|
if (auto subview = dyn_cast<memref::SubViewOp>(defOp)) {
|
||||||
if (!hasAllStaticSubviewParts(subview))
|
if (!hasAllStaticSubviewParts(subview))
|
||||||
return failure();
|
return failure();
|
||||||
subviews.push_back(subview);
|
viewOps.push_back(subview);
|
||||||
current = subview.getSource();
|
current = subview.getSource();
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@@ -54,6 +54,24 @@ FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value
|
|||||||
current = cast.getSource();
|
current = cast.getSource();
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
if (auto collapse = dyn_cast<memref::CollapseShapeOp>(defOp)) {
|
||||||
|
auto srcType = dyn_cast<MemRefType>(collapse.getSrc().getType());
|
||||||
|
auto resultType = dyn_cast<MemRefType>(collapse.getResult().getType());
|
||||||
|
if (!srcType || !resultType || !srcType.hasStaticShape() || !resultType.hasStaticShape())
|
||||||
|
return failure();
|
||||||
|
viewOps.push_back(collapse);
|
||||||
|
current = collapse.getSrc();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto expand = dyn_cast<memref::ExpandShapeOp>(defOp)) {
|
||||||
|
auto srcType = dyn_cast<MemRefType>(expand.getSrc().getType());
|
||||||
|
auto resultType = dyn_cast<MemRefType>(expand.getResult().getType());
|
||||||
|
if (!srcType || !resultType || !srcType.hasStaticShape() || !resultType.hasStaticShape())
|
||||||
|
return failure();
|
||||||
|
viewOps.push_back(expand);
|
||||||
|
current = expand.getSrc();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -70,7 +88,8 @@ FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value
|
|||||||
view.shape.assign(denseAttr.getType().getShape().begin(), denseAttr.getType().getShape().end());
|
view.shape.assign(denseAttr.getType().getShape().begin(), denseAttr.getType().getShape().end());
|
||||||
view.strides = computeRowMajorStrides(view.shape);
|
view.strides = computeRowMajorStrides(view.shape);
|
||||||
|
|
||||||
for (memref::SubViewOp subview : llvm::reverse(subviews)) {
|
for (Operation* viewOp : llvm::reverse(viewOps)) {
|
||||||
|
if (auto subview = dyn_cast<memref::SubViewOp>(viewOp)) {
|
||||||
SmallVector<int64_t> nextStrides;
|
SmallVector<int64_t> nextStrides;
|
||||||
nextStrides.reserve(subview.getStaticStrides().size());
|
nextStrides.reserve(subview.getStaticStrides().size());
|
||||||
for (auto [offset, stride, sourceStride] :
|
for (auto [offset, stride, sourceStride] :
|
||||||
@@ -80,6 +99,28 @@ FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value
|
|||||||
}
|
}
|
||||||
view.shape.assign(subview.getStaticSizes().begin(), subview.getStaticSizes().end());
|
view.shape.assign(subview.getStaticSizes().begin(), subview.getStaticSizes().end());
|
||||||
view.strides = std::move(nextStrides);
|
view.strides = std::move(nextStrides);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Collapse/expand are accepted only as contiguous static reshapes of a
|
||||||
|
// dense global view, so a row-major stride recomputation preserves layout.
|
||||||
|
if (auto collapse = dyn_cast<memref::CollapseShapeOp>(viewOp)) {
|
||||||
|
if (view.strides != computeRowMajorStrides(view.shape))
|
||||||
|
return failure();
|
||||||
|
auto resultType = cast<MemRefType>(collapse.getResult().getType());
|
||||||
|
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
|
||||||
|
view.strides = computeRowMajorStrides(view.shape);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto expand = dyn_cast<memref::ExpandShapeOp>(viewOp)) {
|
||||||
|
if (view.strides != computeRowMajorStrides(view.shape))
|
||||||
|
return failure();
|
||||||
|
auto resultType = cast<MemRefType>(expand.getResult().getType());
|
||||||
|
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
|
||||||
|
view.strides = computeRowMajorStrides(view.shape);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return view;
|
return view;
|
||||||
|
|||||||
@@ -100,18 +100,27 @@ DenseMap<HSliceId, DenseMap<CoreId, SmallVector<Value>>> tileMatrix(
|
|||||||
return tiles;
|
return tiles;
|
||||||
}
|
}
|
||||||
|
|
||||||
tensor::SplatOp
|
Value broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatternRewriter& rewriter, Location loc) {
|
||||||
broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatternRewriter& rewriter, Location loc) {
|
|
||||||
auto oldType = cast<RankedTensorType>(scalarToBroadcast.getType());
|
auto oldType = cast<RankedTensorType>(scalarToBroadcast.getType());
|
||||||
Type elementType = oldType.getElementType();
|
Type elementType = oldType.getElementType();
|
||||||
int64_t shape[2] = {1, length};
|
int64_t shape[2] = {1, length};
|
||||||
Type type = oldType.cloneWith(ArrayRef(shape), elementType);
|
Type type = oldType.cloneWith(ArrayRef(shape), elementType);
|
||||||
|
|
||||||
|
auto buildBroadcast = [&](Value input) -> Value {
|
||||||
auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0).getResult();
|
auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0).getResult();
|
||||||
SmallVector<Value> index(oldType.getRank(), zero);
|
SmallVector<Value> index(oldType.getRank(), zero);
|
||||||
auto elementValue = tensor::ExtractOp::create(rewriter, loc, scalarToBroadcast, index).getResult();
|
auto elementValue = tensor::ExtractOp::create(rewriter, loc, input, index).getResult();
|
||||||
|
|
||||||
return tensor::SplatOp::create(rewriter, loc, type, elementValue);
|
return tensor::SplatOp::create(rewriter, loc, type, elementValue);
|
||||||
|
};
|
||||||
|
|
||||||
|
if (isHostFoldableValue(scalarToBroadcast))
|
||||||
|
return buildBroadcast(scalarToBroadcast);
|
||||||
|
|
||||||
|
auto broadcastCompute =
|
||||||
|
createSpatCompute<1>(rewriter, loc, TypeRange {type}, {}, ValueRange {scalarToBroadcast}, [&](Value input) {
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, buildBroadcast(input));
|
||||||
|
});
|
||||||
|
return broadcastCompute.getResult(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -136,7 +136,7 @@ tileMatrix(mlir::Value& matrixToTile,
|
|||||||
mlir::ConversionPatternRewriter& rewriter,
|
mlir::ConversionPatternRewriter& rewriter,
|
||||||
mlir::Location& loc);
|
mlir::Location& loc);
|
||||||
|
|
||||||
mlir::tensor::SplatOp broadcastToVector(mlir::Value scalarToBroadcast,
|
mlir::Value broadcastToVector(mlir::Value scalarToBroadcast,
|
||||||
int64_t length,
|
int64_t length,
|
||||||
mlir::ConversionPatternRewriter& rewriter,
|
mlir::ConversionPatternRewriter& rewriter,
|
||||||
mlir::Location loc);
|
mlir::Location loc);
|
||||||
|
|||||||
@@ -1,8 +1,12 @@
|
|||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
#include "mlir/IR/BuiltinAttributes.h"
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
#include "llvm/ADT/SmallBitVector.h"
|
||||||
#include "llvm/ADT/SmallPtrSet.h"
|
#include "llvm/ADT/SmallPtrSet.h"
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
@@ -18,6 +22,11 @@ static bool hasStaticUnitStrides(tensor::ExtractSliceOp extractSliceOp) {
|
|||||||
return llvm::all_of(extractSliceOp.getStaticStrides(), [](int64_t stride) { return stride == 1; });
|
return llvm::all_of(extractSliceOp.getStaticStrides(), [](int64_t stride) { return stride == 1; });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool hasConstantIndices(tensor::ExtractOp extractOp) {
|
||||||
|
return llvm::all_of(extractOp.getIndices(),
|
||||||
|
[](Value index) { return isa_and_nonnull<arith::ConstantIndexOp>(index.getDefiningOp()); });
|
||||||
|
}
|
||||||
|
|
||||||
static bool isStaticTensorResult(Operation* op) {
|
static bool isStaticTensorResult(Operation* op) {
|
||||||
return llvm::all_of(op->getResultTypes(), [](Type type) {
|
return llvm::all_of(op->getResultTypes(), [](Type type) {
|
||||||
auto shapedType = dyn_cast<ShapedType>(type);
|
auto shapedType = dyn_cast<ShapedType>(type);
|
||||||
@@ -25,6 +34,167 @@ static bool isStaticTensorResult(Operation* op) {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static SmallVector<int64_t> computeRowMajorStrides(ArrayRef<int64_t> shape) {
|
||||||
|
SmallVector<int64_t> strides(shape.size(), 1);
|
||||||
|
for (int64_t dim = static_cast<int64_t>(shape.size()) - 2; dim >= 0; --dim)
|
||||||
|
strides[dim] = strides[dim + 1] * shape[dim + 1];
|
||||||
|
return strides;
|
||||||
|
}
|
||||||
|
|
||||||
|
static FailureOr<DenseElementsAttr> transposeDenseElements(DenseElementsAttr denseAttr, ArrayRef<int64_t> perms) {
|
||||||
|
auto tensorType = dyn_cast<RankedTensorType>(denseAttr.getType());
|
||||||
|
if (!tensorType)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
int64_t rank = tensorType.getRank();
|
||||||
|
if (static_cast<int64_t>(perms.size()) != rank)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
llvm::SmallBitVector seen(rank);
|
||||||
|
SmallVector<int64_t> transposedShape;
|
||||||
|
transposedShape.reserve(rank);
|
||||||
|
for (int64_t perm : perms) {
|
||||||
|
if (perm < 0 || perm >= rank || seen.test(perm))
|
||||||
|
return failure();
|
||||||
|
seen.set(perm);
|
||||||
|
transposedShape.push_back(tensorType.getShape()[perm]);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto transposedType = RankedTensorType::get(transposedShape, tensorType.getElementType(), tensorType.getEncoding());
|
||||||
|
if (denseAttr.isSplat())
|
||||||
|
return DenseElementsAttr::get(transposedType, denseAttr.getSplatValue<Attribute>());
|
||||||
|
|
||||||
|
SmallVector<Attribute> originalValues(denseAttr.getValues<Attribute>());
|
||||||
|
SmallVector<Attribute> transposedValues(originalValues.size());
|
||||||
|
SmallVector<int64_t> originalStrides = computeRowMajorStrides(tensorType.getShape());
|
||||||
|
SmallVector<int64_t> transposedStrides = computeRowMajorStrides(transposedShape);
|
||||||
|
SmallVector<int64_t> originalIndices(rank);
|
||||||
|
|
||||||
|
for (auto [linearIndex, value] : llvm::enumerate(originalValues)) {
|
||||||
|
int64_t remaining = static_cast<int64_t>(linearIndex);
|
||||||
|
for (int64_t dim = 0; dim < rank; ++dim) {
|
||||||
|
originalIndices[dim] = remaining / originalStrides[dim];
|
||||||
|
remaining %= originalStrides[dim];
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t transposedLinearIndex = 0;
|
||||||
|
for (int64_t dim = 0; dim < rank; ++dim)
|
||||||
|
transposedLinearIndex += originalIndices[perms[dim]] * transposedStrides[dim];
|
||||||
|
|
||||||
|
transposedValues[transposedLinearIndex] = value;
|
||||||
|
}
|
||||||
|
|
||||||
|
return DenseElementsAttr::get(transposedType, transposedValues);
|
||||||
|
}
|
||||||
|
|
||||||
|
static FailureOr<DenseElementsAttr> reshapeDenseElements(DenseElementsAttr denseAttr, RankedTensorType resultType) {
|
||||||
|
auto sourceType = dyn_cast<RankedTensorType>(denseAttr.getType());
|
||||||
|
if (!sourceType || !resultType || sourceType.getNumElements() != resultType.getNumElements())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (denseAttr.isSplat())
|
||||||
|
return DenseElementsAttr::get(resultType, denseAttr.getSplatValue<Attribute>());
|
||||||
|
|
||||||
|
SmallVector<Attribute> values(denseAttr.getValues<Attribute>());
|
||||||
|
return DenseElementsAttr::get(resultType, values);
|
||||||
|
}
|
||||||
|
|
||||||
|
static FailureOr<DenseElementsAttr> extractSliceDenseElements(DenseElementsAttr denseAttr,
|
||||||
|
tensor::ExtractSliceOp extractSliceOp) {
|
||||||
|
auto sourceType = dyn_cast<RankedTensorType>(denseAttr.getType());
|
||||||
|
auto resultType = dyn_cast<RankedTensorType>(extractSliceOp.getType());
|
||||||
|
if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
ArrayRef<int64_t> offsets = extractSliceOp.getStaticOffsets();
|
||||||
|
ArrayRef<int64_t> sizes = extractSliceOp.getStaticSizes();
|
||||||
|
ArrayRef<int64_t> strides = extractSliceOp.getStaticStrides();
|
||||||
|
if (llvm::any_of(offsets, [](int64_t value) { return ShapedType::isDynamic(value); })
|
||||||
|
|| llvm::any_of(sizes, [](int64_t value) { return ShapedType::isDynamic(value); })
|
||||||
|
|| llvm::any_of(strides, [](int64_t stride) { return ShapedType::isDynamic(stride) || stride != 1; }))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (denseAttr.isSplat())
|
||||||
|
return DenseElementsAttr::get(resultType, denseAttr.getSplatValue<Attribute>());
|
||||||
|
|
||||||
|
SmallVector<Attribute> sourceValues(denseAttr.getValues<Attribute>());
|
||||||
|
SmallVector<int64_t> sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
||||||
|
SmallVector<int64_t> resultStrides = computeRowMajorStrides(resultType.getShape());
|
||||||
|
SmallVector<Attribute> resultValues;
|
||||||
|
resultValues.reserve(resultType.getNumElements());
|
||||||
|
|
||||||
|
for (int64_t linearIndex = 0; linearIndex < resultType.getNumElements(); ++linearIndex) {
|
||||||
|
int64_t remaining = linearIndex;
|
||||||
|
int64_t sourceLinearIndex = 0;
|
||||||
|
for (int64_t dim = 0; dim < resultType.getRank(); ++dim) {
|
||||||
|
const int64_t resultIndex = resultStrides.empty() ? 0 : remaining / resultStrides[dim];
|
||||||
|
remaining = resultStrides.empty() ? 0 : remaining % resultStrides[dim];
|
||||||
|
sourceLinearIndex += (offsets[dim] + resultIndex) * sourceStrides[dim];
|
||||||
|
}
|
||||||
|
resultValues.push_back(sourceValues[sourceLinearIndex]);
|
||||||
|
}
|
||||||
|
|
||||||
|
return DenseElementsAttr::get(resultType, resultValues);
|
||||||
|
}
|
||||||
|
|
||||||
|
static DenseElementsAttr getDirectDenseConstantAttr(Value value) {
|
||||||
|
if (auto constantOp = value.getDefiningOp<arith::ConstantOp>())
|
||||||
|
return dyn_cast<DenseElementsAttr>(constantOp.getValue());
|
||||||
|
if (auto constantOp = value.getDefiningOp<ONNXConstantOp>())
|
||||||
|
return dyn_cast_or_null<DenseElementsAttr>(constantOp.getValueAttr());
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
static DenseElementsAttr getHostFoldableDenseElementsAttrImpl(Value value, llvm::SmallPtrSetImpl<Operation*>& visited) {
|
||||||
|
auto* definingOp = value.getDefiningOp();
|
||||||
|
if (!definingOp || !visited.insert(definingOp).second)
|
||||||
|
return nullptr;
|
||||||
|
|
||||||
|
// Rebuild dense attributes through view-only host-foldable chains so later
|
||||||
|
// lowering stages can still recognize grouped/sliced constants.
|
||||||
|
if (auto denseAttr = getDirectDenseConstantAttr(value))
|
||||||
|
return denseAttr;
|
||||||
|
|
||||||
|
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(definingOp)) {
|
||||||
|
auto inputAttr = getHostFoldableDenseElementsAttrImpl(transposeOp.getData(), visited);
|
||||||
|
if (!inputAttr)
|
||||||
|
return nullptr;
|
||||||
|
|
||||||
|
SmallVector<int64_t> perm;
|
||||||
|
perm.reserve(transposeOp.getPermAttr().size());
|
||||||
|
for (IntegerAttr attr : transposeOp.getPermAttr().getAsRange<IntegerAttr>())
|
||||||
|
perm.push_back(attr.getInt());
|
||||||
|
auto transposedAttr = transposeDenseElements(inputAttr, perm);
|
||||||
|
return succeeded(transposedAttr) ? *transposedAttr : nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(definingOp)) {
|
||||||
|
auto inputAttr = getHostFoldableDenseElementsAttrImpl(collapseShapeOp.getSrc(), visited);
|
||||||
|
if (!inputAttr)
|
||||||
|
return nullptr;
|
||||||
|
auto reshapedAttr = reshapeDenseElements(inputAttr, cast<RankedTensorType>(collapseShapeOp.getType()));
|
||||||
|
return succeeded(reshapedAttr) ? *reshapedAttr : nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(definingOp)) {
|
||||||
|
auto inputAttr = getHostFoldableDenseElementsAttrImpl(expandShapeOp.getSrc(), visited);
|
||||||
|
if (!inputAttr)
|
||||||
|
return nullptr;
|
||||||
|
auto reshapedAttr = reshapeDenseElements(inputAttr, cast<RankedTensorType>(expandShapeOp.getType()));
|
||||||
|
return succeeded(reshapedAttr) ? *reshapedAttr : nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) {
|
||||||
|
auto inputAttr = getHostFoldableDenseElementsAttrImpl(extractSliceOp.getSource(), visited);
|
||||||
|
if (!inputAttr)
|
||||||
|
return nullptr;
|
||||||
|
auto slicedAttr = extractSliceDenseElements(inputAttr, extractSliceOp);
|
||||||
|
return succeeded(slicedAttr) ? *slicedAttr : nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
static bool isHostFoldableOpImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*>& visited) {
|
static bool isHostFoldableOpImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*>& visited) {
|
||||||
if (!op || !visited.insert(op).second)
|
if (!op || !visited.insert(op).second)
|
||||||
return false;
|
return false;
|
||||||
@@ -32,6 +202,9 @@ static bool isHostFoldableOpImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*
|
|||||||
if (isa<arith::ConstantOp, ONNXConstantOp, ONNXNoneOp>(op))
|
if (isa<arith::ConstantOp, ONNXConstantOp, ONNXNoneOp>(op))
|
||||||
return true;
|
return true;
|
||||||
|
|
||||||
|
if (auto extractOp = dyn_cast<tensor::ExtractOp>(op))
|
||||||
|
return hasConstantIndices(extractOp) && isHostFoldableValue(extractOp.getTensor());
|
||||||
|
|
||||||
if (!isStaticTensorResult(op))
|
if (!isStaticTensorResult(op))
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
@@ -47,6 +220,9 @@ static bool isHostFoldableOpImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*
|
|||||||
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op))
|
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op))
|
||||||
return hasStaticUnitStrides(extractSliceOp) && isHostFoldableValue(extractSliceOp.getSource());
|
return hasStaticUnitStrides(extractSliceOp) && isHostFoldableValue(extractSliceOp.getSource());
|
||||||
|
|
||||||
|
if (auto splatOp = dyn_cast<tensor::SplatOp>(op))
|
||||||
|
return isHostFoldableValue(splatOp.getInput());
|
||||||
|
|
||||||
if (auto extractRowsOp = dyn_cast<spatial::SpatExtractRowsOp>(op))
|
if (auto extractRowsOp = dyn_cast<spatial::SpatExtractRowsOp>(op))
|
||||||
return isHostFoldableValue(extractRowsOp.getInput());
|
return isHostFoldableValue(extractRowsOp.getInput());
|
||||||
|
|
||||||
@@ -72,4 +248,9 @@ bool isHostFoldableOp(Operation* op) {
|
|||||||
return isHostFoldableOpImpl(op, visited);
|
return isHostFoldableOpImpl(op, visited);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
DenseElementsAttr getHostFoldableDenseElementsAttr(Value value) {
|
||||||
|
llvm::SmallPtrSet<Operation*, 8> visited;
|
||||||
|
return getHostFoldableDenseElementsAttrImpl(value, visited);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/IR/BuiltinAttributes.h"
|
||||||
#include "mlir/IR/Operation.h"
|
#include "mlir/IR/Operation.h"
|
||||||
#include "mlir/IR/Value.h"
|
#include "mlir/IR/Value.h"
|
||||||
|
|
||||||
@@ -9,4 +10,6 @@ bool isHostFoldableValue(mlir::Value value);
|
|||||||
|
|
||||||
bool isHostFoldableOp(mlir::Operation* op);
|
bool isHostFoldableOp(mlir::Operation* op);
|
||||||
|
|
||||||
|
mlir::DenseElementsAttr getHostFoldableDenseElementsAttr(mlir::Value value);
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostLegality.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostLegality.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
@@ -11,7 +12,7 @@ using namespace mlir;
|
|||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
LogicalResult verifyONNXToSpatialHostLegality(func::FuncOp funcOp) {
|
LogicalResult verifyONNXToSpatialHostLegality(func::FuncOp funcOp) {
|
||||||
bool hasFailure = false;
|
pim::CappedDiagnosticReporter diagnostics;
|
||||||
|
|
||||||
for (Operation& op : funcOp.getFunctionBody().front()) {
|
for (Operation& op : funcOp.getFunctionBody().front()) {
|
||||||
if (isa<func::ReturnOp, spatial::SpatCompute, spatial::SpatComputeBatch>(&op))
|
if (isa<func::ReturnOp, spatial::SpatCompute, spatial::SpatComputeBatch>(&op))
|
||||||
@@ -19,11 +20,15 @@ LogicalResult verifyONNXToSpatialHostLegality(func::FuncOp funcOp) {
|
|||||||
if (isHostFoldableOp(&op))
|
if (isHostFoldableOp(&op))
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
op.emitOpError("non-foldable top-level runtime op remains after ONNX-to-Spatial; lower it inside spat.compute");
|
diagnostics.report(&op, [](Operation* illegalOp) {
|
||||||
hasFailure = true;
|
illegalOp->emitOpError("non-foldable top-level runtime op remains after ONNX-to-Spatial; lower it inside "
|
||||||
|
"spat.compute");
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
return success(!hasFailure);
|
diagnostics.emitSuppressedSummary(funcOp, "ONNX-to-Spatial host legality failures");
|
||||||
|
|
||||||
|
return success(!diagnostics.hasFailure());
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -5,17 +5,15 @@
|
|||||||
#include "mlir/IR/IRMapping.h"
|
#include "mlir/IR/IRMapping.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
#include "mlir/Pass/PassManager.h"
|
#include "mlir/Pass/PassManager.h"
|
||||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
||||||
#include "mlir/Transforms/Passes.h"
|
#include "mlir/Transforms/Passes.h"
|
||||||
|
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
|
||||||
|
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
#include "llvm/Support/Debug.h"
|
|
||||||
|
|
||||||
#include "Common/Common.hpp"
|
#include "Common/Common.hpp"
|
||||||
#include "Common/PimCommon.hpp"
|
#include "Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostLegality.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostLegality.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp"
|
||||||
@@ -87,17 +85,68 @@ static void populateEmptyFunction(func::FuncOp funcOp) {
|
|||||||
returnOp.setOperand(index, computeResult);
|
returnOp.setOperand(index, computeResult);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void wrapTopLevelRuntimeTransposes(func::FuncOp funcOp) {
|
||||||
|
IRRewriter rewriter(funcOp.getContext());
|
||||||
|
Block& entryBlock = funcOp.getFunctionBody().front();
|
||||||
|
|
||||||
|
for (Operation& op : llvm::make_early_inc_range(entryBlock)) {
|
||||||
|
auto transposeOp = dyn_cast<ONNXTransposeOp>(&op);
|
||||||
|
if (!transposeOp || isHostFoldableOp(transposeOp))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
// Transpose stays globally legal because constant/view-only cases are
|
||||||
|
// allowed on the host. Any residual runtime transpose must be sunk into
|
||||||
|
// spat.compute before the host legality check.
|
||||||
|
auto resultType = transposeOp.getResult().getType();
|
||||||
|
rewriter.setInsertionPoint(transposeOp);
|
||||||
|
auto computeOp = createSpatCompute<1>(
|
||||||
|
rewriter, transposeOp.getLoc(), TypeRange {resultType}, {}, ValueRange {transposeOp.getData()}, [&](Value input) {
|
||||||
|
Value transposed =
|
||||||
|
ONNXTransposeOp::create(rewriter, transposeOp.getLoc(), resultType, input, transposeOp.getPermAttr());
|
||||||
|
spatial::SpatYieldOp::create(rewriter, transposeOp.getLoc(), transposed);
|
||||||
|
});
|
||||||
|
rewriter.replaceOp(transposeOp, computeOp.getResult(0));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void ONNXToSpatialPass::runOnOperation() {
|
void ONNXToSpatialPass::runOnOperation() {
|
||||||
ModuleOp moduleOp = getOperation();
|
ModuleOp moduleOp = getOperation();
|
||||||
MLIRContext* ctx = &getContext();
|
MLIRContext* ctx = &getContext();
|
||||||
|
|
||||||
|
ConversionTarget preTarget(*ctx);
|
||||||
|
preTarget.addLegalDialect<spatial::SpatialDialect,
|
||||||
|
ONNXDialect,
|
||||||
|
tensor::TensorDialect,
|
||||||
|
arith::ArithDialect,
|
||||||
|
scf::SCFDialect>();
|
||||||
|
preTarget.addIllegalOp<ONNXConstantOp, ONNXFlattenOp>();
|
||||||
|
|
||||||
RewritePatternSet prePatterns(ctx);
|
RewritePatternSet prePatterns(ctx);
|
||||||
populatePrePatterns(prePatterns, ctx);
|
populatePrePatterns(prePatterns, ctx);
|
||||||
if (failed(applyPatternsGreedily(moduleOp, std::move(prePatterns))))
|
if (failed(applyPartialConversion(moduleOp, preTarget, std::move(prePatterns)))) {
|
||||||
moduleOp.emitWarning("failed to apply ONNX-to-Spatial pre-patterns; continuing");
|
moduleOp.emitError("failed to apply ONNX-to-Spatial pre-rewrites");
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
auto entryFunc = getPimEntryFunc(moduleOp);
|
auto entryFunc = getPimEntryFunc(moduleOp);
|
||||||
if (failed(entryFunc)) {
|
if (failed(entryFunc)) {
|
||||||
|
moduleOp.emitError("failed to locate the PIM entry function during ONNX-to-Spatial lowering");
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
RewritePatternSet matmulPatterns(ctx);
|
||||||
|
populateMatMulRewritePatterns(matmulPatterns, ctx);
|
||||||
|
walkAndApplyPatterns(moduleOp, std::move(matmulPatterns));
|
||||||
|
|
||||||
|
bool hasUnloweredMatMul = false;
|
||||||
|
moduleOp.walk([&](ONNXMatMulOp matmulOp) {
|
||||||
|
hasUnloweredMatMul = true;
|
||||||
|
matmulOp.emitOpError("remaining ONNX MatMul before the required ONNX-to-Spatial conversion");
|
||||||
|
});
|
||||||
|
if (hasUnloweredMatMul) {
|
||||||
|
moduleOp.emitError("failed to lower all ONNX MatMul ops before ONNX-to-Spatial conversion");
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -130,31 +179,28 @@ void ONNXToSpatialPass::runOnOperation() {
|
|||||||
RewritePatternSet conversionPatterns(ctx);
|
RewritePatternSet conversionPatterns(ctx);
|
||||||
populateConversionPatterns(conversionPatterns, ctx);
|
populateConversionPatterns(conversionPatterns, ctx);
|
||||||
if (failed(applyPartialConversion(moduleOp, target, std::move(conversionPatterns)))) {
|
if (failed(applyPartialConversion(moduleOp, target, std::move(conversionPatterns)))) {
|
||||||
|
moduleOp.emitError("failed to convert required ONNX ops to Spatial ops");
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ConversionTarget earlyPostTarget(*ctx);
|
||||||
|
earlyPostTarget.addLegalDialect<spatial::SpatialDialect,
|
||||||
|
ONNXDialect,
|
||||||
|
tensor::TensorDialect,
|
||||||
|
arith::ArithDialect,
|
||||||
|
scf::SCFDialect>();
|
||||||
|
earlyPostTarget.addDynamicallyLegalOp<spatial::SpatComputeBatch>(
|
||||||
|
[](spatial::SpatComputeBatch batchOp) { return !requiresEarlyPostRewrite(batchOp); });
|
||||||
|
|
||||||
RewritePatternSet earlyPostPatterns(ctx);
|
RewritePatternSet earlyPostPatterns(ctx);
|
||||||
populateEarlyPostPatterns(earlyPostPatterns, ctx);
|
populateEarlyPostPatterns(earlyPostPatterns, ctx);
|
||||||
if (failed(applyPatternsGreedily(*entryFunc, std::move(earlyPostPatterns)))) {
|
if (failed(applyPartialConversion(*entryFunc, earlyPostTarget, std::move(earlyPostPatterns)))) {
|
||||||
|
moduleOp.emitError("failed to normalize single-lane spat.compute_batch ops before core assignment checks");
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (coresCount != -1) {
|
|
||||||
int computeOpsCount = 0;
|
|
||||||
for (Operation& op : entryFunc->getFunctionBody().front().getOperations())
|
|
||||||
if (isa<spatial::SpatCompute>(op))
|
|
||||||
computeOpsCount++;
|
|
||||||
|
|
||||||
if (computeOpsCount > coresCount) {
|
|
||||||
entryFunc->emitError() << "number of compute ops (" << computeOpsCount << ") exceeds the core count ("
|
|
||||||
<< coresCount << ")";
|
|
||||||
signalPassFailure();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
PassManager cleanupPM(ctx);
|
PassManager cleanupPM(ctx);
|
||||||
cleanupPM.addPass(createCanonicalizerPass());
|
cleanupPM.addPass(createCanonicalizerPass());
|
||||||
if (failed(cleanupPM.run(moduleOp)))
|
if (failed(cleanupPM.run(moduleOp)))
|
||||||
@@ -162,14 +208,29 @@ void ONNXToSpatialPass::runOnOperation() {
|
|||||||
|
|
||||||
annotateWeightsConstants(*entryFunc);
|
annotateWeightsConstants(*entryFunc);
|
||||||
|
|
||||||
|
ConversionTarget postTarget(*ctx);
|
||||||
|
postTarget.addLegalDialect<spatial::SpatialDialect,
|
||||||
|
ONNXDialect,
|
||||||
|
tensor::TensorDialect,
|
||||||
|
arith::ArithDialect,
|
||||||
|
scf::SCFDialect>();
|
||||||
|
postTarget.addDynamicallyLegalOp<spatial::SpatCompute>(
|
||||||
|
[](spatial::SpatCompute computeOp) { return !requiresPostRewrite(computeOp); });
|
||||||
|
postTarget.addDynamicallyLegalOp<spatial::SpatComputeBatch>(
|
||||||
|
[](spatial::SpatComputeBatch computeOp) { return !requiresPostRewrite(computeOp); });
|
||||||
|
|
||||||
RewritePatternSet postPatterns(ctx);
|
RewritePatternSet postPatterns(ctx);
|
||||||
populatePostPatterns(postPatterns, ctx);
|
populatePostPatterns(postPatterns, ctx);
|
||||||
if (failed(applyPatternsGreedily(*entryFunc, std::move(postPatterns)))) {
|
if (failed(applyPartialConversion(*entryFunc, postTarget, std::move(postPatterns)))) {
|
||||||
|
moduleOp.emitError("failed to normalize weight-like Spatial compute operands before Spatial-to-PIM lowering");
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
wrapTopLevelRuntimeTransposes(*entryFunc);
|
||||||
|
|
||||||
if (failed(verifyONNXToSpatialHostLegality(*entryFunc))) {
|
if (failed(verifyONNXToSpatialHostLegality(*entryFunc))) {
|
||||||
|
moduleOp.emitError("ONNX-to-Spatial host legality verification failed");
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,6 +11,7 @@
|
|||||||
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
@@ -27,16 +28,6 @@ struct ConvToGemm : OpConversionPattern<ONNXConvOp> {
|
|||||||
ConversionPatternRewriter& rewriter) const override;
|
ConversionPatternRewriter& rewriter) const override;
|
||||||
};
|
};
|
||||||
|
|
||||||
static DenseElementsAttr getDenseConstantAttr(Value value) {
|
|
||||||
if (auto constantOp = value.getDefiningOp<arith::ConstantOp>())
|
|
||||||
return dyn_cast<DenseElementsAttr>(constantOp.getValue());
|
|
||||||
|
|
||||||
if (auto constantOp = value.getDefiningOp<ONNXConstantOp>())
|
|
||||||
return dyn_cast_or_null<DenseElementsAttr>(constantOp.getValueAttr());
|
|
||||||
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
static int64_t getI64FromArrayAttr(ArrayAttr arr, size_t idx) { return cast<IntegerAttr>(arr[idx]).getInt(); }
|
static int64_t getI64FromArrayAttr(ArrayAttr arr, size_t idx) { return cast<IntegerAttr>(arr[idx]).getInt(); }
|
||||||
|
|
||||||
static Value expandBiasIfNeeded(Value bias, ConversionPatternRewriter& rewriter, Location loc) {
|
static Value expandBiasIfNeeded(Value bias, ConversionPatternRewriter& rewriter, Location loc) {
|
||||||
@@ -355,49 +346,22 @@ static Value createCollectedConvOutput(ValueRange gemmRows,
|
|||||||
return collectComputeOp.getResult(0);
|
return collectComputeOp.getResult(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
static Value lowerSingleConvGroup(Value x,
|
||||||
|
Value w,
|
||||||
LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
Value b,
|
||||||
ONNXConvOpAdaptor convOpAdaptor,
|
RankedTensorType xType,
|
||||||
ConversionPatternRewriter& rewriter) const {
|
RankedTensorType wType,
|
||||||
Location loc = convOp.getLoc();
|
RankedTensorType outType,
|
||||||
Value x = convOpAdaptor.getX();
|
int64_t padHeightBegin,
|
||||||
Value w = convOpAdaptor.getW();
|
int64_t padHeightEnd,
|
||||||
Value b = convOpAdaptor.getB();
|
int64_t padWidthBegin,
|
||||||
|
int64_t padWidthEnd,
|
||||||
auto xType = cast<RankedTensorType>(x.getType());
|
int64_t strideHeight,
|
||||||
auto wType = cast<RankedTensorType>(w.getType());
|
int64_t strideWidth,
|
||||||
auto outType = cast<RankedTensorType>(convOp.getY().getType());
|
int64_t dilationHeight,
|
||||||
|
int64_t dilationWidth,
|
||||||
if (!xType.hasStaticShape()) {
|
ConversionPatternRewriter& rewriter,
|
||||||
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv input");
|
Location loc) {
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
if (!wType.hasStaticShape()) {
|
|
||||||
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv weight");
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
if (!outType.hasStaticShape()) {
|
|
||||||
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv result");
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
if (xType.getRank() != 4) {
|
|
||||||
pim::emitUnsupportedRankDiagnostic(convOp, "conv input", xType.getRank(), {4});
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
if (wType.getRank() != 4) {
|
|
||||||
pim::emitUnsupportedRankDiagnostic(convOp, "conv weight", wType.getRank(), {4});
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
if (outType.getRank() != 4) {
|
|
||||||
pim::emitUnsupportedRankDiagnostic(convOp, "conv result", outType.getRank(), {4});
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
if (convOp.getGroup() != 1) {
|
|
||||||
convOp.emitOpError("only group=1 convolution is supported for Spatial lowering");
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
|
|
||||||
const int64_t batchSize = xType.getDimSize(0);
|
const int64_t batchSize = xType.getDimSize(0);
|
||||||
const int64_t numChannelsIn = xType.getDimSize(1);
|
const int64_t numChannelsIn = xType.getDimSize(1);
|
||||||
const int64_t xHeight = xType.getDimSize(2);
|
const int64_t xHeight = xType.getDimSize(2);
|
||||||
@@ -408,71 +372,6 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
|||||||
const int64_t outHeight = outType.getDimSize(2);
|
const int64_t outHeight = outType.getDimSize(2);
|
||||||
const int64_t outWidth = outType.getDimSize(3);
|
const int64_t outWidth = outType.getDimSize(3);
|
||||||
|
|
||||||
// Read optional conv attributes (ONNX defaults: stride=1, dilation=1, pad=0)
|
|
||||||
const auto stridesAttr = convOp.getStrides();
|
|
||||||
const auto dilationsAttr = convOp.getDilations();
|
|
||||||
const auto padsAttr = convOp.getPads();
|
|
||||||
|
|
||||||
if (stridesAttr && stridesAttr->size() != 2) {
|
|
||||||
convOp.emitOpError("requires exactly two stride values for Spatial lowering");
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
if (dilationsAttr && dilationsAttr->size() != 2) {
|
|
||||||
convOp.emitOpError("requires exactly two dilation values for Spatial lowering");
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
if (padsAttr && padsAttr->size() != 4) {
|
|
||||||
convOp.emitOpError("requires exactly four pad values for 2D Spatial lowering");
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
|
|
||||||
const int64_t strideHeight = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 0) : 1;
|
|
||||||
const int64_t strideWidth = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 1) : 1;
|
|
||||||
const int64_t dilationHeight = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 0) : 1;
|
|
||||||
const int64_t dilationWidth = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 1) : 1;
|
|
||||||
|
|
||||||
int64_t padHeightBegin = 0;
|
|
||||||
int64_t padHeightEnd = 0;
|
|
||||||
int64_t padWidthBegin = 0;
|
|
||||||
int64_t padWidthEnd = 0;
|
|
||||||
|
|
||||||
if (padsAttr) {
|
|
||||||
padHeightBegin = getI64FromArrayAttr(*padsAttr, 0);
|
|
||||||
padWidthBegin = getI64FromArrayAttr(*padsAttr, 1);
|
|
||||||
padHeightEnd = getI64FromArrayAttr(*padsAttr, 2);
|
|
||||||
padWidthEnd = getI64FromArrayAttr(*padsAttr, 3);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
// Compute padding from auto_pad attribute
|
|
||||||
const auto autoPad = convOp.getAutoPad();
|
|
||||||
if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
|
|
||||||
const int64_t effectiveKernelH = (wHeight - 1) * dilationHeight + 1;
|
|
||||||
const int64_t effectiveKernelW = (wWidth - 1) * dilationWidth + 1;
|
|
||||||
const int64_t totalPadH =
|
|
||||||
std::max(static_cast<int64_t>(0), (outHeight - 1) * strideHeight + effectiveKernelH - xHeight);
|
|
||||||
const int64_t totalPadW =
|
|
||||||
std::max(static_cast<int64_t>(0), (outWidth - 1) * strideWidth + effectiveKernelW - xWidth);
|
|
||||||
|
|
||||||
if (autoPad == "SAME_UPPER") {
|
|
||||||
padHeightBegin = totalPadH / 2;
|
|
||||||
padHeightEnd = totalPadH - padHeightBegin;
|
|
||||||
padWidthBegin = totalPadW / 2;
|
|
||||||
padWidthEnd = totalPadW - padWidthBegin;
|
|
||||||
}
|
|
||||||
else { // SAME_LOWER
|
|
||||||
padHeightEnd = totalPadH / 2;
|
|
||||||
padHeightBegin = totalPadH - padHeightEnd;
|
|
||||||
padWidthEnd = totalPadW / 2;
|
|
||||||
padWidthBegin = totalPadW - padWidthEnd;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else if (autoPad != "NOTSET" && autoPad != "VALID") {
|
|
||||||
convOp.emitOpError() << "unsupported auto_pad value `" << autoPad << "` for Spatial lowering";
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
// "NOTSET" or "VALID" -> all pads stay 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// im2col layout (flipped with respect to the standard, so filters sit in B = crossbar):
|
// im2col layout (flipped with respect to the standard, so filters sit in B = crossbar):
|
||||||
// A (im2col): [numPatches, patchSize] -- one row per output spatial position
|
// A (im2col): [numPatches, patchSize] -- one row per output spatial position
|
||||||
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
|
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
|
||||||
@@ -492,7 +391,7 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
|||||||
const int64_t xbarSize = static_cast<int64_t>(crossbarSize.getValue());
|
const int64_t xbarSize = static_cast<int64_t>(crossbarSize.getValue());
|
||||||
const int64_t wMaxDim = std::max(patchSize, numChannelsOut);
|
const int64_t wMaxDim = std::max(patchSize, numChannelsOut);
|
||||||
const int64_t maxParallelPixels = std::max<int64_t>(1, xbarSize / wMaxDim);
|
const int64_t maxParallelPixels = std::max<int64_t>(1, xbarSize / wMaxDim);
|
||||||
auto wDenseAttr = getDenseConstantAttr(w);
|
auto wDenseAttr = getHostFoldableDenseElementsAttr(w);
|
||||||
|
|
||||||
// Prepare weight matrix W for crossbar storage:
|
// Prepare weight matrix W for crossbar storage:
|
||||||
// W: [numChannelsOut, numChannelsIn, wHeight, wWidth] -> [numChannelsOut, patchSize] -> [patchSize, numChannelsOut]
|
// W: [numChannelsOut, numChannelsIn, wHeight, wWidth] -> [numChannelsOut, patchSize] -> [patchSize, numChannelsOut]
|
||||||
@@ -513,7 +412,7 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
|||||||
DenseElementsAttr biasDenseAttr;
|
DenseElementsAttr biasDenseAttr;
|
||||||
if (hasB) {
|
if (hasB) {
|
||||||
gemmBias = b;
|
gemmBias = b;
|
||||||
biasDenseAttr = getDenseConstantAttr(b);
|
biasDenseAttr = getHostFoldableDenseElementsAttr(b);
|
||||||
biasMatrix = expandBiasIfNeeded(b, rewriter, loc);
|
biasMatrix = expandBiasIfNeeded(b, rewriter, loc);
|
||||||
}
|
}
|
||||||
const bool canPackWeightsAsConstants = static_cast<bool>(wDenseAttr);
|
const bool canPackWeightsAsConstants = static_cast<bool>(wDenseAttr);
|
||||||
@@ -589,9 +488,8 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
|||||||
rewriter.getBoolAttr(false))
|
rewriter.getBoolAttr(false))
|
||||||
.getY();
|
.getY();
|
||||||
|
|
||||||
rewriter.replaceOp(convOp,
|
return createCollectedConvOutput(ValueRange {gemmRows},
|
||||||
createCollectedConvOutput(ValueRange {gemmRows},
|
outType,
|
||||||
convOp.getType(),
|
|
||||||
gemmOutType,
|
gemmOutType,
|
||||||
nhwcType,
|
nhwcType,
|
||||||
outType,
|
outType,
|
||||||
@@ -599,8 +497,238 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
|||||||
numChannelsOut,
|
numChannelsOut,
|
||||||
effectiveMaxParallelPixels,
|
effectiveMaxParallelPixels,
|
||||||
rewriter,
|
rewriter,
|
||||||
|
loc);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||||
|
ONNXConvOpAdaptor convOpAdaptor,
|
||||||
|
ConversionPatternRewriter& rewriter) const {
|
||||||
|
Location loc = convOp.getLoc();
|
||||||
|
Value x = convOpAdaptor.getX();
|
||||||
|
Value w = convOpAdaptor.getW();
|
||||||
|
Value b = convOpAdaptor.getB();
|
||||||
|
|
||||||
|
auto xType = cast<RankedTensorType>(x.getType());
|
||||||
|
auto wType = cast<RankedTensorType>(w.getType());
|
||||||
|
auto outType = cast<RankedTensorType>(convOp.getY().getType());
|
||||||
|
|
||||||
|
if (!xType.hasStaticShape()) {
|
||||||
|
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv input");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (!wType.hasStaticShape()) {
|
||||||
|
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv weight");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (!outType.hasStaticShape()) {
|
||||||
|
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv result");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (xType.getRank() != 4) {
|
||||||
|
pim::emitUnsupportedRankDiagnostic(convOp, "conv input", xType.getRank(), {4});
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (wType.getRank() != 4) {
|
||||||
|
pim::emitUnsupportedRankDiagnostic(convOp, "conv weight", wType.getRank(), {4});
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (outType.getRank() != 4) {
|
||||||
|
pim::emitUnsupportedRankDiagnostic(convOp, "conv result", outType.getRank(), {4});
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (convOp.getGroup() < 1) {
|
||||||
|
convOp.emitOpError("requires group >= 1 for Spatial lowering");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
const int64_t batchSize = xType.getDimSize(0);
|
||||||
|
const int64_t numChannelsIn = xType.getDimSize(1);
|
||||||
|
const int64_t xHeight = xType.getDimSize(2);
|
||||||
|
const int64_t xWidth = xType.getDimSize(3);
|
||||||
|
const int64_t numChannelsOut = wType.getDimSize(0);
|
||||||
|
const int64_t wHeight = wType.getDimSize(2);
|
||||||
|
const int64_t wWidth = wType.getDimSize(3);
|
||||||
|
const int64_t outHeight = outType.getDimSize(2);
|
||||||
|
const int64_t outWidth = outType.getDimSize(3);
|
||||||
|
const int64_t group = convOp.getGroup();
|
||||||
|
const bool hasB = !isa<ONNXNoneOp>(b.getDefiningOp());
|
||||||
|
|
||||||
|
if (numChannelsIn % group != 0) {
|
||||||
|
convOp.emitOpError() << "requires input channels " << numChannelsIn << " to be divisible by group " << group
|
||||||
|
<< " for Spatial lowering";
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (numChannelsOut % group != 0) {
|
||||||
|
convOp.emitOpError() << "requires output channels " << numChannelsOut << " to be divisible by group " << group
|
||||||
|
<< " for Spatial lowering";
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
const int64_t numChannelsInPerGroup = numChannelsIn / group;
|
||||||
|
const int64_t numChannelsOutPerGroup = numChannelsOut / group;
|
||||||
|
if (wType.getDimSize(1) != numChannelsInPerGroup) {
|
||||||
|
convOp.emitOpError() << "requires grouped conv weight input channels " << wType.getDimSize(1)
|
||||||
|
<< " to match input channels per group " << numChannelsInPerGroup << " for Spatial lowering";
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (wType.getDimSize(0) != numChannelsOut) {
|
||||||
|
convOp.emitOpError() << "requires weight output channels " << wType.getDimSize(0) << " to match result channels "
|
||||||
|
<< numChannelsOut << " for Spatial lowering";
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read optional conv attributes (ONNX defaults: stride=1, dilation=1, pad=0)
|
||||||
|
const auto stridesAttr = convOp.getStrides();
|
||||||
|
const auto dilationsAttr = convOp.getDilations();
|
||||||
|
const auto padsAttr = convOp.getPads();
|
||||||
|
|
||||||
|
if (stridesAttr && stridesAttr->size() != 2) {
|
||||||
|
convOp.emitOpError("requires exactly two stride values for Spatial lowering");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (dilationsAttr && dilationsAttr->size() != 2) {
|
||||||
|
convOp.emitOpError("requires exactly two dilation values for Spatial lowering");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (padsAttr && padsAttr->size() != 4) {
|
||||||
|
convOp.emitOpError("requires exactly four pad values for 2D Spatial lowering");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
const int64_t strideHeight = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 0) : 1;
|
||||||
|
const int64_t strideWidth = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 1) : 1;
|
||||||
|
const int64_t dilationHeight = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 0) : 1;
|
||||||
|
const int64_t dilationWidth = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 1) : 1;
|
||||||
|
|
||||||
|
int64_t padHeightBegin = 0;
|
||||||
|
int64_t padHeightEnd = 0;
|
||||||
|
int64_t padWidthBegin = 0;
|
||||||
|
int64_t padWidthEnd = 0;
|
||||||
|
|
||||||
|
if (padsAttr) {
|
||||||
|
padHeightBegin = getI64FromArrayAttr(*padsAttr, 0);
|
||||||
|
padWidthBegin = getI64FromArrayAttr(*padsAttr, 1);
|
||||||
|
padHeightEnd = getI64FromArrayAttr(*padsAttr, 2);
|
||||||
|
padWidthEnd = getI64FromArrayAttr(*padsAttr, 3);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
// Compute padding from auto_pad attribute
|
||||||
|
const auto autoPad = convOp.getAutoPad();
|
||||||
|
if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
|
||||||
|
const int64_t effectiveKernelH = (wHeight - 1) * dilationHeight + 1;
|
||||||
|
const int64_t effectiveKernelW = (wWidth - 1) * dilationWidth + 1;
|
||||||
|
const int64_t totalPadH =
|
||||||
|
std::max(static_cast<int64_t>(0), (outHeight - 1) * strideHeight + effectiveKernelH - xHeight);
|
||||||
|
const int64_t totalPadW =
|
||||||
|
std::max(static_cast<int64_t>(0), (outWidth - 1) * strideWidth + effectiveKernelW - xWidth);
|
||||||
|
|
||||||
|
if (autoPad == "SAME_UPPER") {
|
||||||
|
padHeightBegin = totalPadH / 2;
|
||||||
|
padHeightEnd = totalPadH - padHeightBegin;
|
||||||
|
padWidthBegin = totalPadW / 2;
|
||||||
|
padWidthEnd = totalPadW - padWidthBegin;
|
||||||
|
}
|
||||||
|
else { // SAME_LOWER
|
||||||
|
padHeightEnd = totalPadH / 2;
|
||||||
|
padHeightBegin = totalPadH - padHeightEnd;
|
||||||
|
padWidthEnd = totalPadW / 2;
|
||||||
|
padWidthBegin = totalPadW - padWidthEnd;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else if (autoPad != "NOTSET" && autoPad != "VALID") {
|
||||||
|
convOp.emitOpError() << "unsupported auto_pad value `" << autoPad << "` for Spatial lowering";
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
// "NOTSET" or "VALID" -> all pads stay 0
|
||||||
|
}
|
||||||
|
|
||||||
|
if (group == 1) {
|
||||||
|
rewriter.replaceOp(convOp,
|
||||||
|
lowerSingleConvGroup(x,
|
||||||
|
w,
|
||||||
|
b,
|
||||||
|
xType,
|
||||||
|
wType,
|
||||||
|
outType,
|
||||||
|
padHeightBegin,
|
||||||
|
padHeightEnd,
|
||||||
|
padWidthBegin,
|
||||||
|
padWidthEnd,
|
||||||
|
strideHeight,
|
||||||
|
strideWidth,
|
||||||
|
dilationHeight,
|
||||||
|
dilationWidth,
|
||||||
|
rewriter,
|
||||||
loc));
|
loc));
|
||||||
return success();
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Value> xSlices = sliceTensor(x, /*axis=*/1, numChannelsInPerGroup, rewriter, loc);
|
||||||
|
SmallVector<Value> wSlices = sliceTensor(w, /*axis=*/0, numChannelsOutPerGroup, rewriter, loc);
|
||||||
|
SmallVector<Value> bSlices;
|
||||||
|
if (hasB) {
|
||||||
|
auto biasType = cast<RankedTensorType>(b.getType());
|
||||||
|
int64_t biasAxis = -1;
|
||||||
|
if (biasType.getRank() == 1)
|
||||||
|
biasAxis = 0;
|
||||||
|
else if (biasType.getRank() == 2)
|
||||||
|
biasAxis = biasType.getDimSize(0) != 1 ? 0 : 1;
|
||||||
|
else {
|
||||||
|
convOp.emitOpError() << "requires rank-1 or rank-2 bias for grouped convolution Spatial lowering, but got rank "
|
||||||
|
<< biasType.getRank();
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
bSlices = sliceTensor(b, biasAxis, numChannelsOutPerGroup, rewriter, loc);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (xSlices.size() != static_cast<size_t>(group) || wSlices.size() != static_cast<size_t>(group)
|
||||||
|
|| (hasB && bSlices.size() != static_cast<size_t>(group))) {
|
||||||
|
convOp.emitOpError("failed to partition grouped convolution operands for Spatial lowering");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Value> groupResults;
|
||||||
|
groupResults.reserve(group);
|
||||||
|
auto groupOutType =
|
||||||
|
RankedTensorType::get({batchSize, numChannelsOutPerGroup, outHeight, outWidth}, outType.getElementType());
|
||||||
|
Value noBias = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
|
||||||
|
for (int64_t groupId = 0; groupId < group; groupId++) {
|
||||||
|
Value groupX = xSlices[groupId];
|
||||||
|
Value groupW = wSlices[groupId];
|
||||||
|
Value groupB = hasB ? bSlices[groupId] : noBias;
|
||||||
|
groupResults.push_back(lowerSingleConvGroup(groupX,
|
||||||
|
groupW,
|
||||||
|
groupB,
|
||||||
|
cast<RankedTensorType>(groupX.getType()),
|
||||||
|
cast<RankedTensorType>(groupW.getType()),
|
||||||
|
groupOutType,
|
||||||
|
padHeightBegin,
|
||||||
|
padHeightEnd,
|
||||||
|
padWidthBegin,
|
||||||
|
padWidthEnd,
|
||||||
|
strideHeight,
|
||||||
|
strideWidth,
|
||||||
|
dilationHeight,
|
||||||
|
dilationWidth,
|
||||||
|
rewriter,
|
||||||
|
loc));
|
||||||
|
}
|
||||||
|
|
||||||
|
Value result;
|
||||||
|
if (llvm::all_of(groupResults, isHostFoldableValue)) {
|
||||||
|
result = createSpatConcat(rewriter, loc, /*axis=*/1, groupResults);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {outType}, {}, groupResults, [&](ValueRange args) {
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, /*axis=*/1, args));
|
||||||
|
});
|
||||||
|
result = concatCompute.getResult(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.replaceOp(convOp, result);
|
||||||
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
void populateConvPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert<ConvToGemm>(ctx); }
|
void populateConvPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert<ConvToGemm>(ctx); }
|
||||||
|
|||||||
@@ -502,9 +502,6 @@ LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp,
|
|||||||
}
|
}
|
||||||
(void) bType;
|
(void) bType;
|
||||||
|
|
||||||
if (!isHostFoldableValue(b))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
Value sharedBias;
|
Value sharedBias;
|
||||||
if (hasC) {
|
if (hasC) {
|
||||||
auto scaledC = materializeScaledConstantTensor(c, gemmOpAdaptor.getBeta().convertToFloat(), rewriter, loc);
|
auto scaledC = materializeScaledConstantTensor(c, gemmOpAdaptor.getBeta().convertToFloat(), rewriter, loc);
|
||||||
|
|||||||
@@ -2,8 +2,12 @@
|
|||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/STLExtras.h"
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
|
#include <numeric>
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||||
@@ -19,6 +23,79 @@ static bool haveStaticPositiveShape(ArrayRef<int64_t> shape) {
|
|||||||
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
|
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static int64_t getStaticShapeElementCount(ArrayRef<int64_t> shape) {
|
||||||
|
return std::accumulate(shape.begin(), shape.end(), int64_t {1}, std::multiplies<int64_t> {});
|
||||||
|
}
|
||||||
|
|
||||||
|
static FailureOr<SmallVector<int64_t>> inferSupportedBatchShape(ArrayRef<int64_t> lhsBatchShape,
|
||||||
|
ArrayRef<int64_t> rhsBatchShape) {
|
||||||
|
if (lhsBatchShape.empty())
|
||||||
|
return SmallVector<int64_t>(rhsBatchShape.begin(), rhsBatchShape.end());
|
||||||
|
if (rhsBatchShape.empty())
|
||||||
|
return SmallVector<int64_t>(lhsBatchShape.begin(), lhsBatchShape.end());
|
||||||
|
if (!llvm::equal(lhsBatchShape, rhsBatchShape))
|
||||||
|
return failure();
|
||||||
|
return SmallVector<int64_t>(lhsBatchShape.begin(), lhsBatchShape.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value collapseBatchDims(Value value,
|
||||||
|
int64_t batchSize,
|
||||||
|
int64_t rows,
|
||||||
|
int64_t cols,
|
||||||
|
PatternRewriter& rewriter,
|
||||||
|
Location loc) {
|
||||||
|
auto type = cast<RankedTensorType>(value.getType());
|
||||||
|
if (type.getRank() == 2 || type.getRank() == 3)
|
||||||
|
return value;
|
||||||
|
|
||||||
|
auto collapsedType =
|
||||||
|
RankedTensorType::get({batchSize, rows, cols}, type.getElementType(), type.getEncoding());
|
||||||
|
SmallVector<ReassociationIndices> reassociation = {
|
||||||
|
ReassociationIndices {},
|
||||||
|
ReassociationIndices {static_cast<int64_t>(type.getRank() - 2)},
|
||||||
|
ReassociationIndices {static_cast<int64_t>(type.getRank() - 1)}
|
||||||
|
};
|
||||||
|
for (int64_t dim = 0; dim < type.getRank() - 2; ++dim)
|
||||||
|
reassociation.front().push_back(dim);
|
||||||
|
|
||||||
|
auto buildCollapsed = [&](Value input) -> Value {
|
||||||
|
return tensor::CollapseShapeOp::create(rewriter, loc, collapsedType, input, reassociation);
|
||||||
|
};
|
||||||
|
|
||||||
|
if (isHostFoldableValue(value))
|
||||||
|
return buildCollapsed(value);
|
||||||
|
|
||||||
|
auto collapseCompute =
|
||||||
|
createSpatCompute<1>(rewriter, loc, TypeRange {collapsedType}, {}, ValueRange {value}, [&](Value input) {
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, buildCollapsed(input));
|
||||||
|
});
|
||||||
|
return collapseCompute.getResult(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value expandBatchDims(Value value,
|
||||||
|
RankedTensorType outputType,
|
||||||
|
size_t batchRank,
|
||||||
|
PatternRewriter& rewriter,
|
||||||
|
Location loc) {
|
||||||
|
if (cast<RankedTensorType>(value.getType()) == outputType)
|
||||||
|
return value;
|
||||||
|
|
||||||
|
SmallVector<ReassociationIndices> reassociation = {
|
||||||
|
ReassociationIndices {},
|
||||||
|
ReassociationIndices {static_cast<int64_t>(batchRank)},
|
||||||
|
ReassociationIndices {static_cast<int64_t>(batchRank + 1)}
|
||||||
|
};
|
||||||
|
for (size_t dim = 0; dim < batchRank; ++dim)
|
||||||
|
reassociation.front().push_back(static_cast<int64_t>(dim));
|
||||||
|
|
||||||
|
auto expandCompute =
|
||||||
|
createSpatCompute<1>(rewriter, loc, TypeRange {outputType}, {}, ValueRange {value}, [&](Value input) {
|
||||||
|
Value expanded = tensor::ExpandShapeOp::create(rewriter, loc, outputType, input, reassociation);
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, expanded);
|
||||||
|
});
|
||||||
|
return expandCompute.getResult(0);
|
||||||
|
}
|
||||||
|
|
||||||
static Value extractBatchMatrix(Value value,
|
static Value extractBatchMatrix(Value value,
|
||||||
int64_t batchIndex,
|
int64_t batchIndex,
|
||||||
int64_t batchSize,
|
int64_t batchSize,
|
||||||
@@ -62,13 +139,29 @@ static Value extractBatchMatrix(Value value,
|
|||||||
static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Location loc) {
|
static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Location loc) {
|
||||||
auto type = cast<RankedTensorType>(value.getType());
|
auto type = cast<RankedTensorType>(value.getType());
|
||||||
auto shape = type.getShape();
|
auto shape = type.getShape();
|
||||||
|
RankedTensorType transposedType;
|
||||||
|
SmallVector<int64_t> perm;
|
||||||
if (type.getRank() == 2) {
|
if (type.getRank() == 2) {
|
||||||
auto transposedType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType());
|
transposedType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType());
|
||||||
return ONNXTransposeOp::create(rewriter, loc, transposedType, value, rewriter.getI64ArrayAttr({1, 0}));
|
perm = {1, 0};
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
transposedType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType());
|
||||||
|
perm = {0, 2, 1};
|
||||||
}
|
}
|
||||||
|
|
||||||
auto transposedType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType());
|
auto buildTranspose = [&](Value input) -> Value {
|
||||||
return ONNXTransposeOp::create(rewriter, loc, transposedType, value, rewriter.getI64ArrayAttr({0, 2, 1}));
|
return ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm));
|
||||||
|
};
|
||||||
|
|
||||||
|
if (isHostFoldableValue(value))
|
||||||
|
return buildTranspose(value);
|
||||||
|
|
||||||
|
auto transposeCompute =
|
||||||
|
createSpatCompute<1>(rewriter, loc, TypeRange {transposedType}, {}, ValueRange {value}, [&](Value input) {
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, buildTranspose(input));
|
||||||
|
});
|
||||||
|
return transposeCompute.getResult(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value transposeLastTwoDimsInCompute(Value value, PatternRewriter& rewriter, Location loc) {
|
static Value transposeLastTwoDimsInCompute(Value value, PatternRewriter& rewriter, Location loc) {
|
||||||
@@ -120,24 +213,25 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
|||||||
if (!lhsType || !rhsType || !outType || !lhsType.hasStaticShape() || !rhsType.hasStaticShape()
|
if (!lhsType || !rhsType || !outType || !lhsType.hasStaticShape() || !rhsType.hasStaticShape()
|
||||||
|| !outType.hasStaticShape())
|
|| !outType.hasStaticShape())
|
||||||
return failure();
|
return failure();
|
||||||
if ((lhsType.getRank() != 2 && lhsType.getRank() != 3) || (rhsType.getRank() != 2 && rhsType.getRank() != 3)
|
if (lhsType.getRank() < 2 || rhsType.getRank() < 2 || outType.getRank() < 2)
|
||||||
|| (outType.getRank() != 2 && outType.getRank() != 3))
|
|
||||||
return failure();
|
return failure();
|
||||||
if (!haveStaticPositiveShape(lhsType.getShape()) || !haveStaticPositiveShape(rhsType.getShape())
|
if (!haveStaticPositiveShape(lhsType.getShape()) || !haveStaticPositiveShape(rhsType.getShape())
|
||||||
|| !haveStaticPositiveShape(outType.getShape()))
|
|| !haveStaticPositiveShape(outType.getShape()))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
const int64_t lhsBatch = lhsType.getRank() == 3 ? lhsType.getDimSize(0) : 1;
|
SmallVector<int64_t> lhsBatchShape(lhsType.getShape().begin(), lhsType.getShape().end() - 2);
|
||||||
const int64_t rhsBatch = rhsType.getRank() == 3 ? rhsType.getDimSize(0) : 1;
|
SmallVector<int64_t> rhsBatchShape(rhsType.getShape().begin(), rhsType.getShape().end() - 2);
|
||||||
const int64_t batch = std::max(lhsBatch, rhsBatch);
|
auto batchShape = inferSupportedBatchShape(lhsBatchShape, rhsBatchShape);
|
||||||
|
if (failed(batchShape))
|
||||||
if ((lhsBatch != 1 && lhsBatch != batch) || (rhsBatch != 1 && rhsBatch != batch))
|
|
||||||
return failure();
|
return failure();
|
||||||
|
const int64_t lhsBatch = lhsBatchShape.empty() ? 1 : getStaticShapeElementCount(lhsBatchShape);
|
||||||
|
const int64_t rhsBatch = rhsBatchShape.empty() ? 1 : getStaticShapeElementCount(rhsBatchShape);
|
||||||
|
const int64_t batch = batchShape->empty() ? 1 : getStaticShapeElementCount(*batchShape);
|
||||||
|
|
||||||
const int64_t m = lhsType.getRank() == 3 ? lhsType.getDimSize(1) : lhsType.getDimSize(0);
|
const int64_t m = lhsType.getDimSize(lhsType.getRank() - 2);
|
||||||
const int64_t k = lhsType.getRank() == 3 ? lhsType.getDimSize(2) : lhsType.getDimSize(1);
|
const int64_t k = lhsType.getDimSize(lhsType.getRank() - 1);
|
||||||
const int64_t rhsK = rhsType.getRank() == 3 ? rhsType.getDimSize(1) : rhsType.getDimSize(0);
|
const int64_t rhsK = rhsType.getDimSize(rhsType.getRank() - 2);
|
||||||
const int64_t n = rhsType.getRank() == 3 ? rhsType.getDimSize(2) : rhsType.getDimSize(1);
|
const int64_t n = rhsType.getDimSize(rhsType.getRank() - 1);
|
||||||
if (k != rhsK)
|
if (k != rhsK)
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
@@ -146,15 +240,17 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
|||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
if (outType.getDimSize(0) != batch || outType.getDimSize(1) != m || outType.getDimSize(2) != n)
|
SmallVector<int64_t> outBatchShape(outType.getShape().begin(), outType.getShape().end() - 2);
|
||||||
|
if (!llvm::equal(outBatchShape, *batchShape) || outType.getDimSize(outType.getRank() - 2) != m
|
||||||
|
|| outType.getDimSize(outType.getRank() - 1) != n)
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
Location loc = matmulOp.getLoc();
|
Location loc = matmulOp.getLoc();
|
||||||
bool useTransposedForm = isHostFoldableValue(matmulOp.getA()) && !isHostFoldableValue(matmulOp.getB());
|
bool useTransposedForm = isHostFoldableValue(matmulOp.getA()) && !isHostFoldableValue(matmulOp.getB());
|
||||||
|
|
||||||
Value lhs = matmulOp.getA();
|
Value lhs = collapseBatchDims(matmulOp.getA(), lhsBatch, m, k, rewriter, loc);
|
||||||
Value rhs = matmulOp.getB();
|
Value rhs = collapseBatchDims(matmulOp.getB(), rhsBatch, k, n, rewriter, loc);
|
||||||
int64_t lhsBatchForGemm = lhsBatch;
|
int64_t lhsBatchForGemm = lhsBatch;
|
||||||
int64_t rhsBatchForGemm = rhsBatch;
|
int64_t rhsBatchForGemm = rhsBatch;
|
||||||
int64_t gemmM = m;
|
int64_t gemmM = m;
|
||||||
@@ -239,6 +335,7 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Value result = concatValues(batchResults, /*axis=*/0, rewriter, loc);
|
Value result = concatValues(batchResults, /*axis=*/0, rewriter, loc);
|
||||||
|
result = expandBatchDims(result, outType, batchShape->size(), rewriter, loc);
|
||||||
rewriter.replaceOp(matmulOp, result);
|
rewriter.replaceOp(matmulOp, result);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
@@ -22,53 +23,83 @@ static SmallVector<int64_t> permuteShape(ArrayRef<int64_t> shape, ArrayRef<int64
|
|||||||
return permutedShape;
|
return permutedShape;
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value createSoftmaxCompute(Value input, ConversionPatternRewriter& rewriter, Location loc) {
|
static Value buildLoopSoftmaxSlice(Value input,
|
||||||
|
Value accumulator,
|
||||||
|
RankedTensorType inputType,
|
||||||
|
ArrayRef<Value> outerIndices,
|
||||||
|
ConversionPatternRewriter& rewriter,
|
||||||
|
Location loc) {
|
||||||
|
int64_t rank = inputType.getRank();
|
||||||
|
SmallVector<int64_t> sliceShape(static_cast<size_t>(rank - 1), 1);
|
||||||
|
sliceShape.push_back(inputType.getDimSize(rank - 1));
|
||||||
|
auto sliceType = RankedTensorType::get(sliceShape, inputType.getElementType(), inputType.getEncoding());
|
||||||
|
|
||||||
|
SmallVector<OpFoldResult> offsets;
|
||||||
|
SmallVector<OpFoldResult> sizes;
|
||||||
|
SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
|
||||||
|
offsets.reserve(rank);
|
||||||
|
sizes.reserve(rank);
|
||||||
|
|
||||||
|
for (Value outerIndex : outerIndices) {
|
||||||
|
offsets.push_back(outerIndex);
|
||||||
|
sizes.push_back(rewriter.getIndexAttr(1));
|
||||||
|
}
|
||||||
|
offsets.push_back(rewriter.getIndexAttr(0));
|
||||||
|
sizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(rank - 1)));
|
||||||
|
|
||||||
|
Value inputSlice = tensor::ExtractSliceOp::create(rewriter, loc, sliceType, input, offsets, sizes, strides);
|
||||||
|
Value softmaxSlice = spatial::SpatSoftmaxOp::create(rewriter, loc, sliceType, inputSlice).getResult();
|
||||||
|
return tensor::InsertSliceOp::create(rewriter, loc, softmaxSlice, accumulator, offsets, sizes, strides);
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value buildLoopSoftmaxNest(Value input,
|
||||||
|
Value accumulator,
|
||||||
|
RankedTensorType inputType,
|
||||||
|
int64_t axis,
|
||||||
|
SmallVectorImpl<Value>& outerIndices,
|
||||||
|
ConversionPatternRewriter& rewriter,
|
||||||
|
Location loc) {
|
||||||
|
if (axis == inputType.getRank() - 1)
|
||||||
|
return buildLoopSoftmaxSlice(input, accumulator, inputType, outerIndices, rewriter, loc);
|
||||||
|
|
||||||
|
Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
|
||||||
|
Value c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
|
||||||
|
Value cUpper = arith::ConstantIndexOp::create(rewriter, loc, inputType.getDimSize(axis));
|
||||||
|
|
||||||
|
auto loop = scf::ForOp::create(rewriter, loc, c0, cUpper, c1, ValueRange {accumulator});
|
||||||
|
rewriter.setInsertionPointToStart(loop.getBody());
|
||||||
|
|
||||||
|
Value loopIndex = loop.getInductionVar();
|
||||||
|
Value loopAccumulator = loop.getRegionIterArgs().front();
|
||||||
|
outerIndices.push_back(loopIndex);
|
||||||
|
Value updatedAccumulator =
|
||||||
|
buildLoopSoftmaxNest(input, loopAccumulator, inputType, axis + 1, outerIndices, rewriter, loc);
|
||||||
|
outerIndices.pop_back();
|
||||||
|
|
||||||
|
scf::YieldOp::create(rewriter, loc, updatedAccumulator);
|
||||||
|
rewriter.setInsertionPointAfter(loop);
|
||||||
|
return loop.getResult(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value createLoopSoftmaxCompute(Value input, ConversionPatternRewriter& rewriter, Location loc) {
|
||||||
auto inputType = cast<RankedTensorType>(input.getType());
|
auto inputType = cast<RankedTensorType>(input.getType());
|
||||||
constexpr size_t numInputs = 1;
|
constexpr size_t numInputs = 1;
|
||||||
auto computeOp =
|
auto computeOp =
|
||||||
createSpatCompute<numInputs>(rewriter, loc, TypeRange {inputType}, {}, ValueRange {input}, [&](Value x) {
|
createSpatCompute<numInputs>(rewriter, loc, TypeRange {inputType}, {}, ValueRange {input}, [&](Value x) {
|
||||||
auto softmaxOp = spatial::SpatSoftmaxOp::create(rewriter, loc, inputType, x);
|
if (inputType.getRank() == 1) {
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, softmaxOp.getResult());
|
Value softmax = spatial::SpatSoftmaxOp::create(rewriter, loc, inputType, x).getResult();
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, softmax);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
Value outputInit = tensor::EmptyOp::create(rewriter, loc, inputType.getShape(), inputType.getElementType());
|
||||||
|
SmallVector<Value> outerIndices;
|
||||||
|
Value result = buildLoopSoftmaxNest(x, outputInit, inputType, /*axis=*/0, outerIndices, rewriter, loc);
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, result);
|
||||||
});
|
});
|
||||||
return computeOp.getResult(0);
|
return computeOp.getResult(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value concatValues(ValueRange inputs, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
|
|
||||||
auto firstType = cast<RankedTensorType>(inputs.front().getType());
|
|
||||||
SmallVector<int64_t> outputShape(firstType.getShape().begin(), firstType.getShape().end());
|
|
||||||
int64_t concatDimSize = 0;
|
|
||||||
for (Value input : inputs)
|
|
||||||
concatDimSize += cast<RankedTensorType>(input.getType()).getDimSize(axis);
|
|
||||||
outputShape[axis] = concatDimSize;
|
|
||||||
auto resultType = RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
|
|
||||||
|
|
||||||
if (llvm::all_of(inputs, isHostFoldableValue))
|
|
||||||
return createSpatConcat(rewriter, loc, axis, inputs);
|
|
||||||
|
|
||||||
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {resultType}, {}, inputs, [&](ValueRange args) {
|
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, axis, args));
|
|
||||||
});
|
|
||||||
return concatCompute.getResult(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
static Value
|
|
||||||
buildSoftmax(Value input, int64_t softmaxAxis, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
|
|
||||||
auto inputType = cast<RankedTensorType>(input.getType());
|
|
||||||
if (axis == inputType.getRank())
|
|
||||||
return createSoftmaxCompute(input, rewriter, loc);
|
|
||||||
|
|
||||||
if (axis == softmaxAxis)
|
|
||||||
return buildSoftmax(input, softmaxAxis, axis + 1, rewriter, loc);
|
|
||||||
|
|
||||||
SmallVector<Value> slices = sliceTensor(input, axis, /*sliceSize=*/1, rewriter, loc);
|
|
||||||
SmallVector<Value> rebuiltSlices;
|
|
||||||
rebuiltSlices.reserve(slices.size());
|
|
||||||
for (Value slice : slices)
|
|
||||||
rebuiltSlices.push_back(buildSoftmax(slice, softmaxAxis, axis + 1, rewriter, loc));
|
|
||||||
|
|
||||||
return concatValues(rebuiltSlices, axis, rewriter, loc);
|
|
||||||
}
|
|
||||||
|
|
||||||
struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
|
struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
|
||||||
using OpConversionPattern::OpConversionPattern;
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
|
||||||
@@ -86,7 +117,7 @@ struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
|
|||||||
Value input = adaptor.getInput();
|
Value input = adaptor.getInput();
|
||||||
Value result;
|
Value result;
|
||||||
if (axis == inputType.getRank() - 1) {
|
if (axis == inputType.getRank() - 1) {
|
||||||
result = buildSoftmax(input, axis, /*axis=*/0, rewriter, softmaxOp.getLoc());
|
result = createLoopSoftmaxCompute(input, rewriter, softmaxOp.getLoc());
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
SmallVector<int64_t> permutation;
|
SmallVector<int64_t> permutation;
|
||||||
@@ -109,8 +140,7 @@ struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
|
|||||||
spatial::SpatYieldOp::create(rewriter, softmaxOp.getLoc(), transposed);
|
spatial::SpatYieldOp::create(rewriter, softmaxOp.getLoc(), transposed);
|
||||||
});
|
});
|
||||||
Value transposedInput = preTransposeCompute.getResult(0);
|
Value transposedInput = preTransposeCompute.getResult(0);
|
||||||
Value transposedResult = buildSoftmax(
|
Value transposedResult = createLoopSoftmaxCompute(transposedInput, rewriter, softmaxOp.getLoc());
|
||||||
transposedInput, /*softmaxAxis=*/inputType.getRank() - 1, /*axis=*/0, rewriter, softmaxOp.getLoc());
|
|
||||||
auto postTransposeCompute =
|
auto postTransposeCompute =
|
||||||
createSpatCompute<1>(rewriter, softmaxOp.getLoc(), TypeRange {inputType}, {}, transposedResult, [&](Value x) {
|
createSpatCompute<1>(rewriter, softmaxOp.getLoc(), TypeRange {inputType}, {}, transposedResult, [&](Value x) {
|
||||||
Value transposed = ONNXTransposeOp::create(
|
Value transposed = ONNXTransposeOp::create(
|
||||||
|
|||||||
@@ -80,6 +80,22 @@ static bool inferExpandReassociation(ArrayRef<int64_t> sourceShape,
|
|||||||
return sourceIdx == sourceShape.size() && resultIdx == resultShape.size();
|
return sourceIdx == sourceShape.size() && resultIdx == resultShape.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static SmallVector<ReassociationIndices> getCollapseTo1DReassociation(size_t rank) {
|
||||||
|
SmallVector<ReassociationIndices> reassociation(1);
|
||||||
|
reassociation.front().reserve(rank);
|
||||||
|
for (size_t dim = 0; dim < rank; ++dim)
|
||||||
|
reassociation.front().push_back(static_cast<int64_t>(dim));
|
||||||
|
return reassociation;
|
||||||
|
}
|
||||||
|
|
||||||
|
static SmallVector<ReassociationIndices> getExpandFrom1DReassociation(size_t rank) {
|
||||||
|
SmallVector<ReassociationIndices> reassociation(1);
|
||||||
|
reassociation.front().reserve(rank);
|
||||||
|
for (size_t dim = 0; dim < rank; ++dim)
|
||||||
|
reassociation.front().push_back(static_cast<int64_t>(dim));
|
||||||
|
return reassociation;
|
||||||
|
}
|
||||||
|
|
||||||
struct Reshape : OpConversionPattern<ONNXReshapeOp> {
|
struct Reshape : OpConversionPattern<ONNXReshapeOp> {
|
||||||
using OpConversionPattern::OpConversionPattern;
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
|
||||||
@@ -126,6 +142,23 @@ struct Reshape : OpConversionPattern<ONNXReshapeOp> {
|
|||||||
return tensor::ExpandShapeOp::create(rewriter, reshapeOp.getLoc(), resultType, data, reassociation);
|
return tensor::ExpandShapeOp::create(rewriter, reshapeOp.getLoc(), resultType, data, reassociation);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
if (sourceType.getNumElements() != resultType.getNumElements())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
return replaceWithReshape([&](Value data) -> Value {
|
||||||
|
Value reshaped = data;
|
||||||
|
if (sourceType.getRank() != 1) {
|
||||||
|
auto flatType = RankedTensorType::get({sourceType.getNumElements()}, sourceType.getElementType());
|
||||||
|
reshaped = tensor::CollapseShapeOp::create(
|
||||||
|
rewriter, reshapeOp.getLoc(), flatType, reshaped, getCollapseTo1DReassociation(sourceType.getRank()));
|
||||||
|
}
|
||||||
|
if (resultType.getRank() == 1)
|
||||||
|
return reshaped;
|
||||||
|
return tensor::ExpandShapeOp::create(
|
||||||
|
rewriter, reshapeOp.getLoc(), resultType, reshaped, getExpandFrom1DReassociation(resultType.getRank()))
|
||||||
|
.getResult();
|
||||||
|
});
|
||||||
|
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
|
||||||
#include <algorithm>
|
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
@@ -15,42 +15,88 @@ using namespace mlir;
|
|||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
static Value
|
static Value buildNearestAsymmetricIndex(
|
||||||
extractSliceAt(Value input, int64_t axis, int64_t offset, ConversionPatternRewriter& rewriter, Location loc) {
|
Value outputIndex, int64_t inputDim, int64_t outputDim, ConversionPatternRewriter& rewriter, Location loc) {
|
||||||
auto inputType = cast<RankedTensorType>(input.getType());
|
Value cInputDim = arith::ConstantIndexOp::create(rewriter, loc, inputDim);
|
||||||
SmallVector<OpFoldResult> offsets(inputType.getRank(), rewriter.getIndexAttr(0));
|
Value cOutputDim = arith::ConstantIndexOp::create(rewriter, loc, outputDim);
|
||||||
SmallVector<OpFoldResult> sizes;
|
Value cInputDimLast = arith::ConstantIndexOp::create(rewriter, loc, inputDim - 1);
|
||||||
SmallVector<OpFoldResult> strides(inputType.getRank(), rewriter.getIndexAttr(1));
|
Value scaledIndex = arith::MulIOp::create(rewriter, loc, outputIndex, cInputDim);
|
||||||
sizes.reserve(inputType.getRank());
|
Value inputIndex = arith::DivUIOp::create(rewriter, loc, scaledIndex, cOutputDim);
|
||||||
for (int64_t dim : inputType.getShape())
|
return arith::MinUIOp::create(rewriter, loc, inputIndex, cInputDimLast);
|
||||||
sizes.push_back(rewriter.getIndexAttr(dim));
|
|
||||||
offsets[axis] = rewriter.getIndexAttr(offset);
|
|
||||||
sizes[axis] = rewriter.getIndexAttr(1);
|
|
||||||
return tensor::ExtractSliceOp::create(rewriter, loc, input, offsets, sizes, strides);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static int64_t nearestAsymmetricIndex(int64_t outputIndex, int64_t inputDim, int64_t outputDim) {
|
static Value buildNearestResizeLoop(Value input,
|
||||||
return std::min<int64_t>((outputIndex * inputDim) / outputDim, inputDim - 1);
|
RankedTensorType inputType,
|
||||||
}
|
RankedTensorType resultType,
|
||||||
|
|
||||||
static Value buildNearestResize(Value input,
|
|
||||||
ArrayRef<int64_t> inputShape,
|
|
||||||
ArrayRef<int64_t> outputShape,
|
|
||||||
int64_t axis,
|
|
||||||
ConversionPatternRewriter& rewriter,
|
ConversionPatternRewriter& rewriter,
|
||||||
Location loc) {
|
Location loc) {
|
||||||
if (axis == static_cast<int64_t>(outputShape.size()))
|
auto elemType = resultType.getElementType();
|
||||||
return input;
|
SmallVector<int64_t> unitShape(resultType.getRank(), 1);
|
||||||
|
auto unitTensorType = RankedTensorType::get(unitShape, elemType);
|
||||||
|
|
||||||
SmallVector<Value> slices;
|
SmallVector<OpFoldResult> unitSizes(resultType.getRank(), rewriter.getIndexAttr(1));
|
||||||
slices.reserve(outputShape[axis]);
|
SmallVector<OpFoldResult> unitStrides(resultType.getRank(), rewriter.getIndexAttr(1));
|
||||||
for (int64_t outputIndex = 0; outputIndex < outputShape[axis]; ++outputIndex) {
|
|
||||||
int64_t inputIndex = nearestAsymmetricIndex(outputIndex, inputShape[axis], outputShape[axis]);
|
|
||||||
Value slice = extractSliceAt(input, axis, inputIndex, rewriter, loc);
|
|
||||||
slices.push_back(buildNearestResize(slice, inputShape, outputShape, axis + 1, rewriter, loc));
|
|
||||||
}
|
|
||||||
|
|
||||||
return createSpatConcat(rewriter, loc, axis, slices);
|
Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
|
||||||
|
Value c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
|
||||||
|
Value cOutputN = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(0));
|
||||||
|
Value cOutputC = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(1));
|
||||||
|
Value cOutputH = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(2));
|
||||||
|
Value cOutputW = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(3));
|
||||||
|
|
||||||
|
Value outputInit = tensor::EmptyOp::create(rewriter, loc, resultType.getShape(), elemType);
|
||||||
|
|
||||||
|
auto batchLoop = scf::ForOp::create(rewriter, loc, c0, cOutputN, c1, ValueRange {outputInit});
|
||||||
|
rewriter.setInsertionPointToStart(batchLoop.getBody());
|
||||||
|
|
||||||
|
Value outputN = batchLoop.getInductionVar();
|
||||||
|
Value outputBatchAcc = batchLoop.getRegionIterArgs().front();
|
||||||
|
Value inputN = buildNearestAsymmetricIndex(outputN, inputType.getDimSize(0), resultType.getDimSize(0), rewriter, loc);
|
||||||
|
|
||||||
|
auto channelLoop = scf::ForOp::create(rewriter, loc, c0, cOutputC, c1, ValueRange {outputBatchAcc});
|
||||||
|
rewriter.setInsertionPointToStart(channelLoop.getBody());
|
||||||
|
|
||||||
|
Value outputC = channelLoop.getInductionVar();
|
||||||
|
Value outputChannelAcc = channelLoop.getRegionIterArgs().front();
|
||||||
|
Value inputC =
|
||||||
|
buildNearestAsymmetricIndex(outputC, inputType.getDimSize(1), resultType.getDimSize(1), rewriter, loc);
|
||||||
|
|
||||||
|
auto heightLoop = scf::ForOp::create(rewriter, loc, c0, cOutputH, c1, ValueRange {outputChannelAcc});
|
||||||
|
rewriter.setInsertionPointToStart(heightLoop.getBody());
|
||||||
|
|
||||||
|
Value outputH = heightLoop.getInductionVar();
|
||||||
|
Value outputHeightAcc = heightLoop.getRegionIterArgs().front();
|
||||||
|
Value inputH =
|
||||||
|
buildNearestAsymmetricIndex(outputH, inputType.getDimSize(2), resultType.getDimSize(2), rewriter, loc);
|
||||||
|
|
||||||
|
auto widthLoop = scf::ForOp::create(rewriter, loc, c0, cOutputW, c1, ValueRange {outputHeightAcc});
|
||||||
|
rewriter.setInsertionPointToStart(widthLoop.getBody());
|
||||||
|
|
||||||
|
Value outputW = widthLoop.getInductionVar();
|
||||||
|
Value outputWidthAcc = widthLoop.getRegionIterArgs().front();
|
||||||
|
Value inputW =
|
||||||
|
buildNearestAsymmetricIndex(outputW, inputType.getDimSize(3), resultType.getDimSize(3), rewriter, loc);
|
||||||
|
|
||||||
|
SmallVector<OpFoldResult> inputOffsets = {inputN, inputC, inputH, inputW};
|
||||||
|
Value inputSlice =
|
||||||
|
tensor::ExtractSliceOp::create(rewriter, loc, unitTensorType, input, inputOffsets, unitSizes, unitStrides);
|
||||||
|
|
||||||
|
SmallVector<OpFoldResult> outputOffsets = {outputN, outputC, outputH, outputW};
|
||||||
|
Value updatedOutput =
|
||||||
|
tensor::InsertSliceOp::create(rewriter, loc, inputSlice, outputWidthAcc, outputOffsets, unitSizes, unitStrides);
|
||||||
|
scf::YieldOp::create(rewriter, loc, updatedOutput);
|
||||||
|
|
||||||
|
rewriter.setInsertionPointAfter(widthLoop);
|
||||||
|
scf::YieldOp::create(rewriter, loc, widthLoop.getResult(0));
|
||||||
|
|
||||||
|
rewriter.setInsertionPointAfter(heightLoop);
|
||||||
|
scf::YieldOp::create(rewriter, loc, heightLoop.getResult(0));
|
||||||
|
|
||||||
|
rewriter.setInsertionPointAfter(channelLoop);
|
||||||
|
scf::YieldOp::create(rewriter, loc, channelLoop.getResult(0));
|
||||||
|
|
||||||
|
rewriter.setInsertionPointAfter(batchLoop);
|
||||||
|
return batchLoop.getResult(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct Resize : OpConversionPattern<ONNXResizeOp> {
|
struct Resize : OpConversionPattern<ONNXResizeOp> {
|
||||||
@@ -62,20 +108,22 @@ struct Resize : OpConversionPattern<ONNXResizeOp> {
|
|||||||
auto inputType = dyn_cast<RankedTensorType>(adaptor.getX().getType());
|
auto inputType = dyn_cast<RankedTensorType>(adaptor.getX().getType());
|
||||||
auto resultType = dyn_cast<RankedTensorType>(resizeOp.getY().getType());
|
auto resultType = dyn_cast<RankedTensorType>(resizeOp.getY().getType());
|
||||||
if (!inputType || !resultType || !inputType.hasStaticShape() || !resultType.hasStaticShape())
|
if (!inputType || !resultType || !inputType.hasStaticShape() || !resultType.hasStaticShape())
|
||||||
return failure();
|
return rewriter.notifyMatchFailure(resizeOp, "resize lowering requires static ranked tensor types.");
|
||||||
|
if (inputType.getRank() != 4 || resultType.getRank() != 4)
|
||||||
|
return rewriter.notifyMatchFailure(resizeOp, "resize lowering currently supports only rank-4 NCHW tensors.");
|
||||||
|
|
||||||
if (resizeOp.getMode() != "nearest" || resizeOp.getCoordinateTransformationMode() != "asymmetric"
|
if (resizeOp.getMode() != "nearest" || resizeOp.getCoordinateTransformationMode() != "asymmetric"
|
||||||
|| resizeOp.getNearestMode() != "floor")
|
|| resizeOp.getNearestMode() != "floor")
|
||||||
return failure();
|
return rewriter.notifyMatchFailure(
|
||||||
|
resizeOp, "resize lowering currently supports only nearest + asymmetric + floor.");
|
||||||
|
|
||||||
if (llvm::any_of(inputType.getShape(), [](int64_t dim) { return dim <= 0; })
|
if (llvm::any_of(inputType.getShape(), [](int64_t dim) { return dim <= 0; })
|
||||||
|| llvm::any_of(resultType.getShape(), [](int64_t dim) { return dim <= 0; }))
|
|| llvm::any_of(resultType.getShape(), [](int64_t dim) { return dim <= 0; }))
|
||||||
return failure();
|
return rewriter.notifyMatchFailure(resizeOp, "resize lowering requires positive static dimensions.");
|
||||||
|
|
||||||
auto computeOp =
|
auto computeOp =
|
||||||
createSpatCompute<1>(rewriter, resizeOp.getLoc(), TypeRange {resultType}, {}, adaptor.getX(), [&](Value x) {
|
createSpatCompute<1>(rewriter, resizeOp.getLoc(), TypeRange {resultType}, {}, adaptor.getX(), [&](Value x) {
|
||||||
Value result =
|
Value result = buildNearestResizeLoop(x, inputType, resultType, rewriter, resizeOp.getLoc());
|
||||||
buildNearestResize(x, inputType.getShape(), resultType.getShape(), /*axis=*/0, rewriter, resizeOp.getLoc());
|
|
||||||
spatial::SpatYieldOp::create(rewriter, resizeOp.getLoc(), result);
|
spatial::SpatYieldOp::create(rewriter, resizeOp.getLoc(), result);
|
||||||
});
|
});
|
||||||
rewriter.replaceOp(resizeOp, computeOp.getResults());
|
rewriter.replaceOp(resizeOp, computeOp.getResults());
|
||||||
|
|||||||
@@ -31,6 +31,21 @@ static bool isDirectConstantValue(Value value) {
|
|||||||
return isa_and_nonnull<arith::ConstantOp, ONNXConstantOp>(value.getDefiningOp());
|
return isa_and_nonnull<arith::ConstantOp, ONNXConstantOp>(value.getDefiningOp());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename ComputeOpTy>
|
||||||
|
static bool hasPromotableWeightLikeInputs(ComputeOpTy compute) {
|
||||||
|
Block& block = compute.getBody().front();
|
||||||
|
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||||
|
if (inputIdx >= block.getNumArguments())
|
||||||
|
continue;
|
||||||
|
if (!isWeightLikeComputeOperand(input))
|
||||||
|
continue;
|
||||||
|
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(block.getArgument(inputIdx)))
|
||||||
|
continue;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
// Collapses one-lane batches so later phases do not carry batch-only structure unnecessarily.
|
// Collapses one-lane batches so later phases do not carry batch-only structure unnecessarily.
|
||||||
struct FoldSingleLaneComputeBatchPattern : OpRewritePattern<spatial::SpatComputeBatch> {
|
struct FoldSingleLaneComputeBatchPattern : OpRewritePattern<spatial::SpatComputeBatch> {
|
||||||
using OpRewritePattern<spatial::SpatComputeBatch>::OpRewritePattern;
|
using OpRewritePattern<spatial::SpatComputeBatch>::OpRewritePattern;
|
||||||
@@ -262,4 +277,10 @@ void annotateWeightsConstants(func::FuncOp funcOp) {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool requiresEarlyPostRewrite(spatial::SpatComputeBatch batchOp) { return batchOp.getLaneCount() == 1; }
|
||||||
|
|
||||||
|
bool requiresPostRewrite(spatial::SpatCompute computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
|
||||||
|
|
||||||
|
bool requiresPostRewrite(spatial::SpatComputeBatch computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -3,8 +3,16 @@
|
|||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/IR/MLIRContext.h"
|
#include "mlir/IR/MLIRContext.h"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
bool requiresEarlyPostRewrite(spatial::SpatComputeBatch batchOp);
|
||||||
|
|
||||||
|
bool requiresPostRewrite(spatial::SpatCompute computeOp);
|
||||||
|
|
||||||
|
bool requiresPostRewrite(spatial::SpatComputeBatch computeOp);
|
||||||
|
|
||||||
void populateEarlyPostPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
void populateEarlyPostPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
|
|
||||||
void populatePostPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
void populatePostPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
|
|||||||
@@ -17,9 +17,7 @@ void populatePrePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* c
|
|||||||
patterns.add<convAddToConvWithBiasLeft>(ctx);
|
patterns.add<convAddToConvWithBiasLeft>(ctx);
|
||||||
patterns.add<convAddToConvWithBiasRight>(ctx);
|
patterns.add<convAddToConvWithBiasRight>(ctx);
|
||||||
patterns.add<matMulAddToGemm>(ctx);
|
patterns.add<matMulAddToGemm>(ctx);
|
||||||
patterns.add<matMulToGemm>(ctx);
|
|
||||||
patterns.add<removeFlattenSameShape>(ctx);
|
patterns.add<removeFlattenSameShape>(ctx);
|
||||||
populateMatMulRewritePatterns(patterns, ctx);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -202,6 +202,7 @@ void SpatialToGraphvizPass::runOnOperation() {
|
|||||||
|
|
||||||
auto entryFunc = getPimEntryFunc(module);
|
auto entryFunc = getPimEntryFunc(module);
|
||||||
if (failed(entryFunc)) {
|
if (failed(entryFunc)) {
|
||||||
|
module.emitError("failed to locate the PIM entry function for Spatial graph visualization");
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -138,12 +138,13 @@ static Value padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, V
|
|||||||
}
|
}
|
||||||
|
|
||||||
void SpatialToPimPass::runOnOperation() {
|
void SpatialToPimPass::runOnOperation() {
|
||||||
coreId = 1;
|
coreId = 0;
|
||||||
ModuleOp moduleOp = getOperation();
|
ModuleOp moduleOp = getOperation();
|
||||||
MLIRContext* ctx = moduleOp.getContext();
|
MLIRContext* ctx = moduleOp.getContext();
|
||||||
|
|
||||||
auto entryFunc = getPimEntryFunc(moduleOp);
|
auto entryFunc = getPimEntryFunc(moduleOp);
|
||||||
if (failed(entryFunc)) {
|
if (failed(entryFunc)) {
|
||||||
|
moduleOp.emitError("failed to locate the PIM entry function during Spatial-to-PIM lowering");
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -169,26 +170,22 @@ void SpatialToPimPass::runOnOperation() {
|
|||||||
spatial::SpatChannelSendTensorBatchOp,
|
spatial::SpatChannelSendTensorBatchOp,
|
||||||
spatial::SpatExtractRowsOp>();
|
spatial::SpatExtractRowsOp>();
|
||||||
|
|
||||||
{
|
RewritePatternSet initialPatterns(ctx);
|
||||||
RewritePatternSet patterns(ctx);
|
populateWithGenerated(initialPatterns);
|
||||||
populateWithGenerated(patterns);
|
if (failed(applyPartialConversion(moduleOp, target, std::move(initialPatterns)))) {
|
||||||
|
moduleOp.emitError("failed to lower required Spatial ops to the initial PIM form");
|
||||||
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
|
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
{
|
RewritePatternSet globalTensorPatterns(ctx);
|
||||||
RewritePatternSet patterns(ctx);
|
populateGlobalTensorMaterializationPatterns(globalTensorPatterns);
|
||||||
populateGlobalTensorMaterializationPatterns(patterns);
|
walkAndApplyPatterns(moduleOp, std::move(globalTensorPatterns));
|
||||||
|
|
||||||
walkAndApplyPatterns(moduleOp, std::move(patterns));
|
|
||||||
}
|
|
||||||
auto returnOp = cast<func::ReturnOp>(funcOp.front().getTerminator());
|
auto returnOp = cast<func::ReturnOp>(funcOp.front().getTerminator());
|
||||||
|
|
||||||
addReturnOutputBuffers(returnOp, rewriter, outputTensors);
|
addReturnOutputBuffers(returnOp, rewriter, outputTensors);
|
||||||
if (failed(allocateAndInitializeCoreLocalVariables(funcOp, rewriter))) {
|
if (failed(allocateAndInitializeCoreLocalVariables(funcOp, rewriter))) {
|
||||||
|
funcOp.emitOpError("failed to allocate or initialize core-local tensors during Spatial-to-PIM lowering");
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -197,6 +194,7 @@ void SpatialToPimPass::runOnOperation() {
|
|||||||
for (auto computeOp : funcOp.getOps<spatial::SpatCompute>()) {
|
for (auto computeOp : funcOp.getOps<spatial::SpatCompute>()) {
|
||||||
markOpToRemove(computeOp);
|
markOpToRemove(computeOp);
|
||||||
if (failed(lowerComputeOp(computeOp, coreLoweringState, rewriter))) {
|
if (failed(lowerComputeOp(computeOp, coreLoweringState, rewriter))) {
|
||||||
|
computeOp.emitOpError("failed to lower spat.compute to pim.core");
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -205,17 +203,16 @@ void SpatialToPimPass::runOnOperation() {
|
|||||||
for (auto computeBatchOp : funcOp.getOps<spatial::SpatComputeBatch>()) {
|
for (auto computeBatchOp : funcOp.getOps<spatial::SpatComputeBatch>()) {
|
||||||
markOpToRemove(computeBatchOp);
|
markOpToRemove(computeBatchOp);
|
||||||
if (failed(lowerComputeBatchOp(computeBatchOp, coreLoweringState, rewriter))) {
|
if (failed(lowerComputeBatchOp(computeBatchOp, coreLoweringState, rewriter))) {
|
||||||
|
computeBatchOp.emitOpError("failed to lower spat.compute_batch to pim.core_batch");
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
RewritePatternSet initialTensorPackingPatterns(ctx);
|
||||||
RewritePatternSet patterns(ctx);
|
populateTensorPackingPatterns(initialTensorPackingPatterns);
|
||||||
populateTensorPackingPatterns(patterns);
|
walkAndApplyPatterns(funcOp, std::move(initialTensorPackingPatterns));
|
||||||
walkAndApplyPatterns(funcOp, std::move(patterns));
|
|
||||||
eraseUnusedTensorPackingOps(funcOp, rewriter);
|
eraseUnusedTensorPackingOps(funcOp, rewriter);
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<spatial::SpatChannelReceiveOp> receiveOps;
|
SmallVector<spatial::SpatChannelReceiveOp> receiveOps;
|
||||||
for (auto op : funcOp.getOps<spatial::SpatChannelReceiveOp>())
|
for (auto op : funcOp.getOps<spatial::SpatChannelReceiveOp>())
|
||||||
@@ -229,7 +226,6 @@ void SpatialToPimPass::runOnOperation() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
|
||||||
RewritePatternSet coreBodyPatterns(ctx);
|
RewritePatternSet coreBodyPatterns(ctx);
|
||||||
populateWithGenerated(coreBodyPatterns);
|
populateWithGenerated(coreBodyPatterns);
|
||||||
FrozenRewritePatternSet frozenCoreBodyPatterns(std::move(coreBodyPatterns));
|
FrozenRewritePatternSet frozenCoreBodyPatterns(std::move(coreBodyPatterns));
|
||||||
@@ -238,6 +234,7 @@ void SpatialToPimPass::runOnOperation() {
|
|||||||
funcOp.walk([&](pim::PimCoreOp coreOp) { coreOps.push_back(coreOp); });
|
funcOp.walk([&](pim::PimCoreOp coreOp) { coreOps.push_back(coreOp); });
|
||||||
for (auto coreOp : coreOps) {
|
for (auto coreOp : coreOps) {
|
||||||
if (failed(applyFullConversion(coreOp.getOperation(), target, frozenCoreBodyPatterns))) {
|
if (failed(applyFullConversion(coreOp.getOperation(), target, frozenCoreBodyPatterns))) {
|
||||||
|
coreOp.emitOpError("failed to convert nested Spatial ops inside pim.core");
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -247,11 +244,11 @@ void SpatialToPimPass::runOnOperation() {
|
|||||||
funcOp.walk([&](pim::PimCoreBatchOp coreBatchOp) { coreBatchOps.push_back(coreBatchOp); });
|
funcOp.walk([&](pim::PimCoreBatchOp coreBatchOp) { coreBatchOps.push_back(coreBatchOp); });
|
||||||
for (auto coreBatchOp : coreBatchOps) {
|
for (auto coreBatchOp : coreBatchOps) {
|
||||||
if (failed(applyFullConversion(coreBatchOp.getOperation(), target, frozenCoreBodyPatterns))) {
|
if (failed(applyFullConversion(coreBatchOp.getOperation(), target, frozenCoreBodyPatterns))) {
|
||||||
|
coreBatchOp.emitOpError("failed to convert nested Spatial ops inside pim.core_batch");
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
enlargeVMMOutTensorsToCrossbarSize(funcOp, rewriter);
|
enlargeVMMOutTensorsToCrossbarSize(funcOp, rewriter);
|
||||||
ReturnPathState returnPathState {outputTensors, operationsToRemove};
|
ReturnPathState returnPathState {outputTensors, operationsToRemove};
|
||||||
@@ -259,18 +256,16 @@ void SpatialToPimPass::runOnOperation() {
|
|||||||
|
|
||||||
SmallVector<Operation*> pendingRemovals(operationsToRemove.begin(), operationsToRemove.end());
|
SmallVector<Operation*> pendingRemovals(operationsToRemove.begin(), operationsToRemove.end());
|
||||||
if (failed(erasePendingOps(pendingRemovals, rewriter))) {
|
if (failed(erasePendingOps(pendingRemovals, rewriter))) {
|
||||||
|
funcOp.emitOpError("failed to erase obsolete Spatial ops after lowering to PIM");
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
RewritePatternSet finalTensorPackingPatterns(ctx);
|
||||||
RewritePatternSet patterns(ctx);
|
populateTensorPackingPatterns(finalTensorPackingPatterns);
|
||||||
populateTensorPackingPatterns(patterns);
|
walkAndApplyPatterns(funcOp, std::move(finalTensorPackingPatterns));
|
||||||
walkAndApplyPatterns(funcOp, std::move(patterns));
|
|
||||||
eraseUnusedTensorPackingOps(funcOp, rewriter);
|
eraseUnusedTensorPackingOps(funcOp, rewriter);
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
ConversionTarget communicationTarget(*ctx);
|
ConversionTarget communicationTarget(*ctx);
|
||||||
communicationTarget.addLegalDialect<PimDialect,
|
communicationTarget.addLegalDialect<PimDialect,
|
||||||
tensor::TensorDialect,
|
tensor::TensorDialect,
|
||||||
@@ -291,12 +286,13 @@ void SpatialToPimPass::runOnOperation() {
|
|||||||
RewritePatternSet communicationPatterns(ctx);
|
RewritePatternSet communicationPatterns(ctx);
|
||||||
populateChannelLoweringPatterns(communicationPatterns);
|
populateChannelLoweringPatterns(communicationPatterns);
|
||||||
if (failed(applyFullConversion(funcOp, communicationTarget, std::move(communicationPatterns)))) {
|
if (failed(applyFullConversion(funcOp, communicationTarget, std::move(communicationPatterns)))) {
|
||||||
|
funcOp.emitOpError("failed to lower Spatial communication ops to PIM communication ops");
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if (failed(verifySpatialToPimBoundary(moduleOp))) {
|
if (failed(verifySpatialToPimBoundary(moduleOp))) {
|
||||||
|
moduleOp.emitError("Spatial-to-PIM boundary verification failed");
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp"
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
@@ -75,15 +74,13 @@ struct PackSpatialConcatInputsPattern final : OpRewritePattern<spatial::SpatConc
|
|||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto outputType = cast<ShapedType>(concatOp.getOutput().getType());
|
auto outputType = cast<ShapedType>(concatOp.getOutput().getType());
|
||||||
auto newConcat = pim::PimConcatOp::create(rewriter,
|
auto newConcat = pim::PimConcatOp::create(
|
||||||
|
rewriter,
|
||||||
concatOp.getLoc(),
|
concatOp.getLoc(),
|
||||||
concatOp.getOutput().getType(),
|
concatOp.getOutput().getType(),
|
||||||
concatOp.getAxisAttr(),
|
concatOp.getAxisAttr(),
|
||||||
ValueRange(packedInputs),
|
ValueRange(packedInputs),
|
||||||
tensor::EmptyOp::create(rewriter,
|
tensor::EmptyOp::create(rewriter, concatOp.getLoc(), outputType.getShape(), outputType.getElementType())
|
||||||
concatOp.getLoc(),
|
|
||||||
outputType.getShape(),
|
|
||||||
outputType.getElementType())
|
|
||||||
.getResult());
|
.getResult());
|
||||||
rewriter.replaceOp(concatOp, newConcat.getOutput());
|
rewriter.replaceOp(concatOp, newConcat.getOutput());
|
||||||
return success();
|
return success();
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|||||||
@@ -79,6 +79,7 @@ void PimBufferizationPass::runOnOperation() {
|
|||||||
return WalkResult::skip();
|
return WalkResult::skip();
|
||||||
});
|
});
|
||||||
if (hasFailed) {
|
if (hasFailed) {
|
||||||
|
moduleOp.emitError("failed to lower memref.copy-like ops inside PIM core bodies during bufferization");
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,15 +1,16 @@
|
|||||||
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescing.hpp"
|
|
||||||
|
|
||||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||||
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
|
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
|
||||||
|
|
||||||
#include "llvm/ADT/DenseMap.h"
|
#include "llvm/ADT/DenseMap.h"
|
||||||
|
#include "llvm/ADT/STLExtras.h"
|
||||||
#include "llvm/ADT/SmallPtrSet.h"
|
#include "llvm/ADT/SmallPtrSet.h"
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
#include "llvm/ADT/STLExtras.h"
|
|
||||||
|
|
||||||
#include <limits>
|
#include <limits>
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescing.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
@@ -29,9 +30,8 @@ static uint64_t getTypeSizeBytes(MemRefType type) {
|
|||||||
return static_cast<uint64_t>(type.getNumElements() * type.getElementTypeBitWidth() / 8);
|
return static_cast<uint64_t>(type.getNumElements() * type.getElementTypeBitWidth() / 8);
|
||||||
}
|
}
|
||||||
|
|
||||||
static FailureOr<uint64_t> getLastUseInstruction(memref::AllocOp allocOp,
|
static FailureOr<uint64_t>
|
||||||
Block& body,
|
getLastUseInstruction(memref::AllocOp allocOp, Block& body, const DenseMap<Operation*, uint64_t>& opOrder) {
|
||||||
const DenseMap<Operation*, uint64_t>& opOrder) {
|
|
||||||
uint64_t endInstruction = opOrder.lookup(allocOp);
|
uint64_t endInstruction = opOrder.lookup(allocOp);
|
||||||
SmallPtrSet<Operation*, 16> visited;
|
SmallPtrSet<Operation*, 16> visited;
|
||||||
SmallVector<Value> pendingValues;
|
SmallVector<Value> pendingValues;
|
||||||
@@ -45,9 +45,15 @@ static FailureOr<uint64_t> getLastUseInstruction(memref::AllocOp allocOp,
|
|||||||
if (!visited.insert(user).second)
|
if (!visited.insert(user).second)
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
if (isSupportedAliasOp(user)) {
|
if (isSupportedAliasOp(user))
|
||||||
for (Value result : user->getResults())
|
for (Value result : user->getResults())
|
||||||
pendingValues.push_back(result);
|
pendingValues.push_back(result);
|
||||||
|
|
||||||
|
if (auto forOp = dyn_cast<scf::ForOp>(user)) {
|
||||||
|
for (auto [index, initArg] : llvm::enumerate(forOp.getInitArgs())) {
|
||||||
|
if (initArg == value)
|
||||||
|
pendingValues.push_back(forOp.getResult(index));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto dpsOp = dyn_cast<DestinationStyleOpInterface>(user)) {
|
if (auto dpsOp = dyn_cast<DestinationStyleOpInterface>(user)) {
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
#include "mlir/IR/Operation.h"
|
|
||||||
|
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
|
|||||||
+20
-19
@@ -45,9 +45,7 @@ struct CoalescingReportEntry {
|
|||||||
CoalescingReportRow row;
|
CoalescingReportRow row;
|
||||||
};
|
};
|
||||||
|
|
||||||
static std::string formatMemory(uint64_t bytes) {
|
static std::string formatMemory(uint64_t bytes) { return formatReportMemory(bytes); }
|
||||||
return formatReportMemory(bytes);
|
|
||||||
}
|
|
||||||
|
|
||||||
static SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) {
|
static SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) {
|
||||||
auto coreIdsAttr = coreBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName);
|
auto coreIdsAttr = coreBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName);
|
||||||
@@ -58,9 +56,10 @@ static SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) {
|
|||||||
static void printReportRow(raw_ostream& os, const CoalescingReportRow& row) {
|
static void printReportRow(raw_ostream& os, const CoalescingReportRow& row) {
|
||||||
llvm::SmallVector<ReportField, 4> fields = {
|
llvm::SmallVector<ReportField, 4> fields = {
|
||||||
{"Number of candidates", std::to_string(row.numCandidates)},
|
{"Number of candidates", std::to_string(row.numCandidates)},
|
||||||
{"Skipped allocations", std::to_string(row.numSkipped)},
|
{"Skipped allocations", std::to_string(row.numSkipped) },
|
||||||
{"Removed allocations", std::to_string(row.numRemoved)},
|
{"Removed allocations", std::to_string(row.numRemoved) },
|
||||||
{"Saved memory", formatMemory(row.savedBytes)}};
|
{"Saved memory", formatMemory(row.savedBytes) }
|
||||||
|
};
|
||||||
printReportFlatFields(os, fields);
|
printReportFlatFields(os, fields);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -87,10 +86,12 @@ static void emitReport(ArrayRef<CoalescingReportEntry> entries) {
|
|||||||
totalRow.savedBytes += entryTotal.savedBytes;
|
totalRow.savedBytes += entryTotal.savedBytes;
|
||||||
}
|
}
|
||||||
|
|
||||||
llvm::SmallVector<ReportField, 4> totalFields = {{"Number of candidates", std::to_string(totalRow.numCandidates)},
|
llvm::SmallVector<ReportField, 4> totalFields = {
|
||||||
{"Skipped allocations", std::to_string(totalRow.numSkipped)},
|
{"Number of candidates", std::to_string(totalRow.numCandidates)},
|
||||||
{"Removed allocations", std::to_string(totalRow.numRemoved)},
|
{"Skipped allocations", std::to_string(totalRow.numSkipped) },
|
||||||
{"Saved memory", formatMemory(totalRow.savedBytes)}};
|
{"Removed allocations", std::to_string(totalRow.numRemoved) },
|
||||||
|
{"Saved memory", formatMemory(totalRow.savedBytes) }
|
||||||
|
};
|
||||||
printReportTotalsBlock(os, totalFields);
|
printReportTotalsBlock(os, totalFields);
|
||||||
if (!entries.empty())
|
if (!entries.empty())
|
||||||
os << "\n";
|
os << "\n";
|
||||||
@@ -127,15 +128,17 @@ static void emitReport(ArrayRef<CoalescingReportEntry> entries) {
|
|||||||
if (sortedEntries[index].kind == CoalescingReportEntry::Kind::Batch) {
|
if (sortedEntries[index].kind == CoalescingReportEntry::Kind::Batch) {
|
||||||
llvm::SmallVector<ReportField, 4> perCoreFields = {
|
llvm::SmallVector<ReportField, 4> perCoreFields = {
|
||||||
{"Number of candidates", std::to_string(sortedEntries[index].row.numCandidates)},
|
{"Number of candidates", std::to_string(sortedEntries[index].row.numCandidates)},
|
||||||
{"Skipped allocations", std::to_string(sortedEntries[index].row.numSkipped)},
|
{"Skipped allocations", std::to_string(sortedEntries[index].row.numSkipped) },
|
||||||
{"Removed allocations", std::to_string(sortedEntries[index].row.numRemoved)},
|
{"Removed allocations", std::to_string(sortedEntries[index].row.numRemoved) },
|
||||||
{"Saved memory", formatMemory(sortedEntries[index].row.savedBytes)}};
|
{"Saved memory", formatMemory(sortedEntries[index].row.savedBytes) }
|
||||||
|
};
|
||||||
CoalescingReportRow totalRow = getTotalRow(sortedEntries[index]);
|
CoalescingReportRow totalRow = getTotalRow(sortedEntries[index]);
|
||||||
llvm::SmallVector<ReportField, 4> totalFields = {
|
llvm::SmallVector<ReportField, 4> totalFields = {
|
||||||
{"Number of candidates", std::to_string(totalRow.numCandidates)},
|
{"Number of candidates", std::to_string(totalRow.numCandidates)},
|
||||||
{"Skipped allocations", std::to_string(totalRow.numSkipped)},
|
{"Skipped allocations", std::to_string(totalRow.numSkipped) },
|
||||||
{"Removed allocations", std::to_string(totalRow.numRemoved)},
|
{"Removed allocations", std::to_string(totalRow.numRemoved) },
|
||||||
{"Saved memory", formatMemory(totalRow.savedBytes)}};
|
{"Saved memory", formatMemory(totalRow.savedBytes) }
|
||||||
|
};
|
||||||
printReportPerCoreAndTotalFields(os, perCoreFields, totalFields);
|
printReportPerCoreAndTotalFields(os, perCoreFields, totalFields);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
@@ -196,8 +199,6 @@ struct StaticMemoryCoalescingPass : PassWrapper<StaticMemoryCoalescingPass, Oper
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<Pass> createPimStaticMemoryCoalescingPass() {
|
std::unique_ptr<Pass> createPimStaticMemoryCoalescingPass() { return std::make_unique<StaticMemoryCoalescingPass>(); }
|
||||||
return std::make_unique<StaticMemoryCoalescingPass>();
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -8,7 +8,14 @@ add_pim_library(SpatialOps
|
|||||||
SpatialOpsVerify.cpp
|
SpatialOpsVerify.cpp
|
||||||
SpatialOpsCanonicalization.cpp
|
SpatialOpsCanonicalization.cpp
|
||||||
Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp
|
Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp
|
||||||
|
Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp
|
||||||
|
Transforms/MergeComputeNodes/PostMergeCompaction.cpp
|
||||||
Transforms/MergeComputeNodes/RegularOpCompaction.cpp
|
Transforms/MergeComputeNodes/RegularOpCompaction.cpp
|
||||||
|
Transforms/MergeComputeNodes/Scheduling/ComputeGraph.cpp
|
||||||
|
Transforms/MergeComputeNodes/Scheduling/ComputeInstanceUtils.cpp
|
||||||
|
Transforms/MergeComputeNodes/Scheduling/DcpScheduler.cpp
|
||||||
|
Transforms/MergeComputeNodes/Scheduling/MergeSchedulingAnalysis.cpp
|
||||||
|
Transforms/MergeComputeNodes/Scheduling/PeftScheduler.cpp
|
||||||
Transforms/MergeComputeNodes/DCPGraph/Graph.cpp
|
Transforms/MergeComputeNodes/DCPGraph/Graph.cpp
|
||||||
Transforms/MergeComputeNodes/DCPGraph/GraphDebug.cpp
|
Transforms/MergeComputeNodes/DCPGraph/GraphDebug.cpp
|
||||||
Transforms/MergeComputeNodes/DCPGraph/GraphSupport.cpp
|
Transforms/MergeComputeNodes/DCPGraph/GraphSupport.cpp
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
#include "mlir/IR/Diagnostics.h"
|
#include "mlir/IR/Diagnostics.h"
|
||||||
#include "mlir/IR/OpDefinition.h"
|
#include "mlir/IR/OpDefinition.h"
|
||||||
#include "mlir/IR/TypeUtilities.h"
|
#include "mlir/IR/TypeUtilities.h"
|
||||||
|
#include "mlir/Support/LLVM.h"
|
||||||
|
|
||||||
#include "llvm/ADT/DenseSet.h"
|
#include "llvm/ADT/DenseSet.h"
|
||||||
#include "llvm/Support/LogicalResult.h"
|
#include "llvm/Support/LogicalResult.h"
|
||||||
@@ -338,6 +339,19 @@ LogicalResult SpatConcatOp::verify() {
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LogicalResult verifyComputeResultsUses(Operation* op) {
|
||||||
|
if (!isa<SpatCompute, SpatComputeBatch>(op))
|
||||||
|
return op->emitError("verifyComputeResultUses: Op is not a SpatCompute/SpatComputeBatch operation");
|
||||||
|
if (!llvm::all_of(op->getResults(), [](Value result) {
|
||||||
|
return llvm::all_of(result.getUsers(), [](Operation* op) {
|
||||||
|
return !(op->getParentOfType<SpatCompute>() || op->getParentOfType<SpatComputeBatch>());
|
||||||
|
});
|
||||||
|
})) {
|
||||||
|
return op->emitError("ComputeResult used directly inside another Compute" );
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
LogicalResult SpatCompute::verify() {
|
LogicalResult SpatCompute::verify() {
|
||||||
auto& block = getBody().front();
|
auto& block = getBody().front();
|
||||||
if (block.mightHaveTerminator()) {
|
if (block.mightHaveTerminator()) {
|
||||||
@@ -375,7 +389,8 @@ LogicalResult SpatCompute::verify() {
|
|||||||
for (auto arg : block.getArguments())
|
for (auto arg : block.getArguments())
|
||||||
if (arg.use_empty())
|
if (arg.use_empty())
|
||||||
return emitError("ComputeOp block argument is not used");
|
return emitError("ComputeOp block argument is not used");
|
||||||
|
if (failed(verifyComputeResultsUses(this->getOperation())))
|
||||||
|
return failure();
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -465,8 +480,8 @@ LogicalResult SpatComputeBatch::verify() {
|
|||||||
return emitError("compute_batch coreIds attribute must be a dense i32 array");
|
return emitError("compute_batch coreIds attribute must be a dense i32 array");
|
||||||
if (coreIdsAttr.size() != static_cast<int64_t>(laneCountSz))
|
if (coreIdsAttr.size() != static_cast<int64_t>(laneCountSz))
|
||||||
return emitError("compute_batch coreIds array length must match laneCount");
|
return emitError("compute_batch coreIds array length must match laneCount");
|
||||||
if (llvm::any_of(coreIdsAttr.asArrayRef(), [](int32_t coreId) { return coreId <= 0; }))
|
if (llvm::any_of(coreIdsAttr.asArrayRef(), [](int32_t coreId) { return coreId < 0; }))
|
||||||
return emitError("compute_batch coreIds values must be positive");
|
return emitError("compute_batch coreIds values must be non-negative");
|
||||||
llvm::SmallDenseSet<int32_t, 8> seenCoreIds;
|
llvm::SmallDenseSet<int32_t, 8> seenCoreIds;
|
||||||
for (int32_t coreId : coreIdsAttr.asArrayRef())
|
for (int32_t coreId : coreIdsAttr.asArrayRef())
|
||||||
if (!seenCoreIds.insert(coreId).second)
|
if (!seenCoreIds.insert(coreId).second)
|
||||||
@@ -485,6 +500,8 @@ LogicalResult SpatComputeBatch::verify() {
|
|||||||
return emitError("body block argument type must match input type");
|
return emitError("body block argument type must match input type");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (failed(verifyComputeResultsUses(this->getOperation())))
|
||||||
|
return failure();
|
||||||
return verifyBatchBody(getOperation(), block, getResultTypes(), weightsPerLane);
|
return verifyBatchBody(getOperation(), block, getResultTypes(), weightsPerLane);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,802 +1,19 @@
|
|||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
||||||
#include "mlir/IR/OpDefinition.h"
|
|
||||||
#include "mlir/IR/Value.h"
|
|
||||||
#include "mlir/IR/ValueRange.h"
|
|
||||||
|
|
||||||
#include "llvm/ADT/DenseMap.h"
|
|
||||||
#include "llvm/ADT/STLExtras.h"
|
|
||||||
#include "llvm/Support/Casting.h"
|
|
||||||
#include "llvm/Support/FormatVariadic.h"
|
|
||||||
#include "llvm/Support/raw_ostream.h"
|
|
||||||
|
|
||||||
#include <algorithm>
|
|
||||||
#include <cstdlib>
|
|
||||||
#include <iterator>
|
|
||||||
#include <numeric>
|
|
||||||
#include <optional>
|
|
||||||
#include <queue>
|
|
||||||
#include <utility>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "DCPAnalysis.hpp"
|
#include "DCPAnalysis.hpp"
|
||||||
#include "Graph.hpp"
|
#include "../Scheduling/ComputeGraph.hpp"
|
||||||
|
#include "../Scheduling/DcpScheduler.hpp"
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
#include "src/Support/TypeUtilities.hpp"
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
namespace spatial {
|
namespace spatial {
|
||||||
|
|
||||||
using namespace mlir;
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
using SpatCompute = onnx_mlir::spatial::SpatCompute;
|
|
||||||
using SpatComputeBatch = onnx_mlir::spatial::SpatComputeBatch;
|
|
||||||
|
|
||||||
bool isDcpCoarsenDebugEnabled() { return std::getenv("DCP_COARSEN_DEBUG") != nullptr; }
|
|
||||||
|
|
||||||
struct VirtualNode {
|
|
||||||
SmallVector<size_t, 4> originalComputeIndices;
|
|
||||||
Weight weight = 0;
|
|
||||||
CrossbarUsage crossbarUsage = 0;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct VirtualGraph {
|
|
||||||
std::vector<VirtualNode> nodes;
|
|
||||||
std::vector<IndexedEdge> edges;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct TimingInfo {
|
|
||||||
std::vector<Time> aest;
|
|
||||||
std::vector<Time> alst;
|
|
||||||
std::vector<size_t> topologicalOrder;
|
|
||||||
bool valid = false;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct WindowScheduleResult {
|
|
||||||
std::vector<std::vector<size_t>> mergeGroups;
|
|
||||||
CPU cpuCount = 0;
|
|
||||||
size_t mergedNodeCount = 0;
|
|
||||||
size_t maxMergeGroupSize = 0;
|
|
||||||
};
|
|
||||||
|
|
||||||
size_t getSchedulingCpuBudget() {
|
|
||||||
if (coresCount.getValue() > 0)
|
|
||||||
return static_cast<size_t>(coresCount.getValue());
|
|
||||||
return std::numeric_limits<size_t>::max();
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t getBatchChunkTargetCount(int32_t laneCount) {
|
|
||||||
assert(laneCount > 0 && "laneCount must be positive");
|
|
||||||
return std::min(static_cast<size_t>(laneCount), std::max<size_t>(1, getSchedulingCpuBudget()));
|
|
||||||
}
|
|
||||||
|
|
||||||
ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex) {
|
|
||||||
size_t totalLanes = static_cast<size_t>(batch.getLaneCount());
|
|
||||||
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount());
|
|
||||||
size_t baseChunkSize = totalLanes / chunkCount;
|
|
||||||
size_t largeChunkCount = totalLanes % chunkCount;
|
|
||||||
|
|
||||||
size_t laneStart = chunkIndex * baseChunkSize + std::min(chunkIndex, largeChunkCount);
|
|
||||||
size_t laneCount = baseChunkSize + (chunkIndex < largeChunkCount ? 1 : 0);
|
|
||||||
return {batch.getOperation(), static_cast<uint32_t>(laneStart), static_cast<uint32_t>(laneCount)};
|
|
||||||
}
|
|
||||||
|
|
||||||
ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane) {
|
|
||||||
size_t totalLanes = static_cast<size_t>(batch.getLaneCount());
|
|
||||||
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount());
|
|
||||||
size_t baseChunkSize = totalLanes / chunkCount;
|
|
||||||
size_t largeChunkCount = totalLanes % chunkCount;
|
|
||||||
size_t largeChunkSpan = largeChunkCount * (baseChunkSize + 1);
|
|
||||||
|
|
||||||
size_t chunkIndex = 0;
|
|
||||||
if (static_cast<size_t>(lane) < largeChunkSpan)
|
|
||||||
chunkIndex = static_cast<size_t>(lane) / (baseChunkSize + 1);
|
|
||||||
else
|
|
||||||
chunkIndex = largeChunkCount + (static_cast<size_t>(lane) - largeChunkSpan) / baseChunkSize;
|
|
||||||
return getBatchChunkForIndex(batch, chunkIndex);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<IndexedEdge> aggregateEdges(ArrayRef<IndexedEdge> edges) {
|
|
||||||
llvm::DenseMap<std::pair<size_t, size_t>, Weight> edgeWeights;
|
|
||||||
for (auto [start, end, weight] : edges) {
|
|
||||||
size_t startIndex = static_cast<size_t>(start);
|
|
||||||
size_t endIndex = static_cast<size_t>(end);
|
|
||||||
if (startIndex == endIndex)
|
|
||||||
continue;
|
|
||||||
auto key = std::make_pair(startIndex, endIndex);
|
|
||||||
Weight edgeWeight = static_cast<Weight>(weight);
|
|
||||||
auto inserted = edgeWeights.try_emplace(key, edgeWeight);
|
|
||||||
if (!inserted.second)
|
|
||||||
inserted.first->second = std::max(inserted.first->second, edgeWeight);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<IndexedEdge> aggregatedEdges;
|
|
||||||
aggregatedEdges.reserve(edgeWeights.size());
|
|
||||||
for (auto [key, weight] : edgeWeights)
|
|
||||||
aggregatedEdges.push_back(
|
|
||||||
{static_cast<int64_t>(key.first), static_cast<int64_t>(key.second), static_cast<int64_t>(weight)});
|
|
||||||
llvm::sort(aggregatedEdges, [](const IndexedEdge& lhs, const IndexedEdge& rhs) {
|
|
||||||
if (std::get<0>(lhs) != std::get<0>(rhs))
|
|
||||||
return std::get<0>(lhs) < std::get<0>(rhs);
|
|
||||||
return std::get<1>(lhs) < std::get<1>(rhs);
|
|
||||||
});
|
|
||||||
return aggregatedEdges;
|
|
||||||
}
|
|
||||||
|
|
||||||
Weight getComputeBodyWeight(Region& body) {
|
|
||||||
constexpr Weight kOperationWeight = 100;
|
|
||||||
Weight numOperations = 0;
|
|
||||||
for (auto& block : body)
|
|
||||||
for ([[maybe_unused]] auto& op : block)
|
|
||||||
numOperations = checkedAdd(numOperations, static_cast<Weight>(1));
|
|
||||||
return checkedMultiply(numOperations, kOperationWeight);
|
|
||||||
}
|
|
||||||
|
|
||||||
CrossbarUsage getComputeBodyCrossbarUsage(Region& body) {
|
|
||||||
CrossbarUsage crossbarUsage = 0;
|
|
||||||
for (auto& block : body)
|
|
||||||
for (auto& op : block)
|
|
||||||
if (isa<SpatVMMOp>(op))
|
|
||||||
crossbarUsage = checkedAdd(crossbarUsage, static_cast<CrossbarUsage>(1));
|
|
||||||
return crossbarUsage;
|
|
||||||
}
|
|
||||||
|
|
||||||
Weight getComputeInstanceWeight(const ComputeInstance& instance) {
|
|
||||||
if (auto spatCompute = dyn_cast<SpatCompute>(instance.op))
|
|
||||||
return getSpatComputeWeight(spatCompute);
|
|
||||||
auto batch = cast<SpatComputeBatch>(instance.op);
|
|
||||||
return checkedMultiply(getComputeBodyWeight(batch.getBody()), static_cast<Weight>(instance.laneCount));
|
|
||||||
}
|
|
||||||
|
|
||||||
CrossbarUsage getComputeInstanceCrossbarUsage(const ComputeInstance& instance) {
|
|
||||||
if (auto spatCompute = dyn_cast<SpatCompute>(instance.op))
|
|
||||||
return getSpatComputeCrossbarUsage(spatCompute);
|
|
||||||
auto batch = cast<SpatComputeBatch>(instance.op);
|
|
||||||
return checkedMultiply(getComputeBodyCrossbarUsage(batch.getBody()), static_cast<CrossbarUsage>(instance.laneCount));
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<Value, 4> getComputeInstanceInputs(const ComputeInstance& instance) {
|
|
||||||
if (auto spatCompute = dyn_cast<SpatCompute>(instance.op))
|
|
||||||
return SmallVector<Value, 4>(spatCompute.getInputs().begin(), spatCompute.getInputs().end());
|
|
||||||
auto batch = cast<SpatComputeBatch>(instance.op);
|
|
||||||
SmallVector<Value, 4> inputs;
|
|
||||||
inputs.reserve(instance.laneCount);
|
|
||||||
for (uint32_t lane = instance.laneStart; lane < instance.laneStart + instance.laneCount; ++lane)
|
|
||||||
inputs.push_back(batch.getInputs()[lane]);
|
|
||||||
return inputs;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::optional<ComputeInstance> getOriginalComputeInstance(Value value) {
|
|
||||||
Operation* op = value.getDefiningOp();
|
|
||||||
if (!op)
|
|
||||||
return std::nullopt;
|
|
||||||
|
|
||||||
while (auto extract = dyn_cast<tensor::ExtractSliceOp>(op)) {
|
|
||||||
value = extract.getSource();
|
|
||||||
op = value.getDefiningOp();
|
|
||||||
if (!op)
|
|
||||||
return std::nullopt;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (auto spatCompute = dyn_cast<SpatCompute>(op))
|
|
||||||
return ComputeInstance {spatCompute.getOperation(), 0, 1};
|
|
||||||
if (auto batch = dyn_cast<SpatComputeBatch>(op))
|
|
||||||
return getBatchChunkForLane(batch, static_cast<uint32_t>(cast<OpResult>(value).getResultNumber()));
|
|
||||||
return std::nullopt;
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<ComputeInstance> collectComputeInstances(Operation* entryOp) {
|
|
||||||
SmallVector<ComputeInstance> instances;
|
|
||||||
auto isUsedAsWeightOnly = [](Operation* producerOp) {
|
|
||||||
if (producerOp->getNumResults() == 0)
|
|
||||||
return false;
|
|
||||||
for (Value result : producerOp->getResults()) {
|
|
||||||
if (result.use_empty())
|
|
||||||
return false;
|
|
||||||
for (Operation* user : result.getUsers()) {
|
|
||||||
if (auto compute = dyn_cast<SpatCompute>(user)) {
|
|
||||||
if (!llvm::is_contained(compute.getWeights(), result))
|
|
||||||
return false;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (auto batch = dyn_cast<SpatComputeBatch>(user)) {
|
|
||||||
if (!llvm::is_contained(batch.getWeights(), result))
|
|
||||||
return false;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
};
|
|
||||||
for (Region& region : entryOp->getRegions()) {
|
|
||||||
for (Block& block : region) {
|
|
||||||
for (Operation& op : block) {
|
|
||||||
if (auto spatCompute = dyn_cast<SpatCompute>(&op)) {
|
|
||||||
if (isUsedAsWeightOnly(spatCompute.getOperation()))
|
|
||||||
continue;
|
|
||||||
instances.push_back({spatCompute.getOperation(), 0, 1});
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (auto batch = dyn_cast<SpatComputeBatch>(&op)) {
|
|
||||||
if (isUsedAsWeightOnly(batch.getOperation()))
|
|
||||||
continue;
|
|
||||||
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount());
|
|
||||||
for (size_t chunkIndex = 0; chunkIndex < chunkCount; ++chunkIndex)
|
|
||||||
instances.push_back(getBatchChunkForIndex(batch, chunkIndex));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return instances;
|
|
||||||
}
|
|
||||||
|
|
||||||
VirtualGraph buildInitialVirtualGraph(ArrayRef<ComputeInstance> computeInstances, ArrayRef<IndexedEdge> edges) {
|
|
||||||
VirtualGraph graph;
|
|
||||||
graph.nodes.reserve(computeInstances.size());
|
|
||||||
for (auto [index, computeInstance] : llvm::enumerate(computeInstances)) {
|
|
||||||
VirtualNode node;
|
|
||||||
node.originalComputeIndices.push_back(index);
|
|
||||||
node.weight = getComputeInstanceWeight(computeInstance);
|
|
||||||
node.crossbarUsage = getComputeInstanceCrossbarUsage(computeInstance);
|
|
||||||
graph.nodes.push_back(std::move(node));
|
|
||||||
}
|
|
||||||
graph.edges = aggregateEdges(edges);
|
|
||||||
return graph;
|
|
||||||
}
|
|
||||||
|
|
||||||
TimingInfo computeTiming(const VirtualGraph& graph) {
|
|
||||||
TimingInfo timing;
|
|
||||||
size_t nodeCount = graph.nodes.size();
|
|
||||||
timing.aest.assign(nodeCount, 0);
|
|
||||||
timing.alst.assign(nodeCount, 0);
|
|
||||||
timing.topologicalOrder.reserve(nodeCount);
|
|
||||||
|
|
||||||
std::vector<std::vector<std::pair<size_t, Weight>>> parents(nodeCount);
|
|
||||||
std::vector<std::vector<std::pair<size_t, Weight>>> children(nodeCount);
|
|
||||||
std::vector<size_t> incomingEdgeCount(nodeCount, 0);
|
|
||||||
|
|
||||||
for (auto [start, end, weight] : graph.edges) {
|
|
||||||
size_t startIndex = static_cast<size_t>(start);
|
|
||||||
size_t endIndex = static_cast<size_t>(end);
|
|
||||||
Weight edgeWeight = static_cast<Weight>(weight);
|
|
||||||
assert(startIndex < nodeCount && endIndex < nodeCount && "virtual edge endpoint out of range");
|
|
||||||
children[startIndex].push_back({endIndex, edgeWeight});
|
|
||||||
parents[endIndex].push_back({startIndex, edgeWeight});
|
|
||||||
incomingEdgeCount[endIndex]++;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto getVirtualNodeOrderKey = [&](size_t nodeIndex) {
|
|
||||||
const VirtualNode& node = graph.nodes[nodeIndex];
|
|
||||||
if (!node.originalComputeIndices.empty())
|
|
||||||
return node.originalComputeIndices.front();
|
|
||||||
return nodeIndex;
|
|
||||||
};
|
|
||||||
auto readyNodeGreater = [&](size_t lhs, size_t rhs) {
|
|
||||||
size_t lhsKey = getVirtualNodeOrderKey(lhs);
|
|
||||||
size_t rhsKey = getVirtualNodeOrderKey(rhs);
|
|
||||||
if (lhsKey != rhsKey)
|
|
||||||
return lhsKey > rhsKey;
|
|
||||||
return lhs > rhs;
|
|
||||||
};
|
|
||||||
std::priority_queue<size_t, std::vector<size_t>, decltype(readyNodeGreater)> readyNodes(readyNodeGreater);
|
|
||||||
for (size_t i = 0; i < nodeCount; ++i)
|
|
||||||
if (incomingEdgeCount[i] == 0)
|
|
||||||
readyNodes.push(i);
|
|
||||||
|
|
||||||
while (!readyNodes.empty()) {
|
|
||||||
size_t current = readyNodes.top();
|
|
||||||
readyNodes.pop();
|
|
||||||
timing.topologicalOrder.push_back(current);
|
|
||||||
for (auto [child, weight] : children[current]) {
|
|
||||||
(void) weight;
|
|
||||||
assert(incomingEdgeCount[child] > 0 && "incoming edge count underflow");
|
|
||||||
incomingEdgeCount[child]--;
|
|
||||||
if (incomingEdgeCount[child] == 0)
|
|
||||||
readyNodes.push(child);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (timing.topologicalOrder.size() != nodeCount)
|
|
||||||
return timing;
|
|
||||||
|
|
||||||
Time dcpl = 0;
|
|
||||||
for (size_t nodeIndex : timing.topologicalOrder) {
|
|
||||||
Time maxParentAest = 0;
|
|
||||||
for (auto [parent, transferCost] : parents[nodeIndex]) {
|
|
||||||
maxParentAest =
|
|
||||||
std::max(maxParentAest, addOrMax(addOrMax(timing.aest[parent], graph.nodes[parent].weight), transferCost));
|
|
||||||
}
|
|
||||||
timing.aest[nodeIndex] = maxParentAest;
|
|
||||||
dcpl = std::max(dcpl, addOrMax(maxParentAest, graph.nodes[nodeIndex].weight));
|
|
||||||
}
|
|
||||||
|
|
||||||
for (size_t nodeIndex : llvm::reverse(timing.topologicalOrder)) {
|
|
||||||
Time minAlst = std::numeric_limits<Time>::max();
|
|
||||||
if (children[nodeIndex].empty())
|
|
||||||
minAlst = subtractOrZero(dcpl, graph.nodes[nodeIndex].weight);
|
|
||||||
for (auto [child, transferCost] : children[nodeIndex]) {
|
|
||||||
minAlst =
|
|
||||||
std::min(minAlst, subtractOrZero(timing.alst[child], addOrMax(graph.nodes[nodeIndex].weight, transferCost)));
|
|
||||||
}
|
|
||||||
timing.alst[nodeIndex] = minAlst;
|
|
||||||
}
|
|
||||||
|
|
||||||
timing.valid = true;
|
|
||||||
return timing;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<std::vector<size_t>> buildUndirectedAdjacency(const VirtualGraph& graph) {
|
|
||||||
std::vector<std::vector<size_t>> adjacency(graph.nodes.size());
|
|
||||||
for (auto [start, end, weight] : graph.edges) {
|
|
||||||
(void) weight;
|
|
||||||
size_t startIndex = static_cast<size_t>(start);
|
|
||||||
size_t endIndex = static_cast<size_t>(end);
|
|
||||||
assert(startIndex < graph.nodes.size() && endIndex < graph.nodes.size() && "virtual edge endpoint out of range");
|
|
||||||
adjacency[startIndex].push_back(endIndex);
|
|
||||||
adjacency[endIndex].push_back(startIndex);
|
|
||||||
}
|
|
||||||
for (auto& neighbours : adjacency) {
|
|
||||||
llvm::sort(neighbours);
|
|
||||||
neighbours.erase(std::unique(neighbours.begin(), neighbours.end()), neighbours.end());
|
|
||||||
}
|
|
||||||
return adjacency;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<size_t> selectCriticalWindow(const VirtualGraph& graph, const TimingInfo& timing, size_t windowSize) {
|
|
||||||
std::vector<size_t> ranked(timing.aest.size());
|
|
||||||
std::iota(ranked.begin(), ranked.end(), 0);
|
|
||||||
auto isHigherPriority = [&](size_t lhs, size_t rhs) {
|
|
||||||
Time lhsSlack = slackOrZero(timing.aest[lhs], timing.alst[lhs]);
|
|
||||||
Time rhsSlack = slackOrZero(timing.aest[rhs], timing.alst[rhs]);
|
|
||||||
if (lhsSlack != rhsSlack)
|
|
||||||
return lhsSlack < rhsSlack;
|
|
||||||
if (timing.aest[lhs] != timing.aest[rhs])
|
|
||||||
return timing.aest[lhs] < timing.aest[rhs];
|
|
||||||
return lhs < rhs;
|
|
||||||
};
|
|
||||||
|
|
||||||
windowSize = std::min(windowSize, ranked.size());
|
|
||||||
if (windowSize == 0)
|
|
||||||
return {};
|
|
||||||
if (windowSize == ranked.size()) {
|
|
||||||
llvm::sort(ranked, isHigherPriority);
|
|
||||||
return ranked;
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t criticalPoolSize = std::min(ranked.size(), std::max(windowSize, windowSize * 2));
|
|
||||||
if (criticalPoolSize < ranked.size())
|
|
||||||
std::nth_element(
|
|
||||||
ranked.begin(), ranked.begin() + static_cast<std::ptrdiff_t>(criticalPoolSize), ranked.end(), isHigherPriority);
|
|
||||||
|
|
||||||
std::vector<char> inCriticalPool(ranked.size(), false);
|
|
||||||
for (size_t i = 0; i < criticalPoolSize; ++i)
|
|
||||||
inCriticalPool[ranked[i]] = true;
|
|
||||||
|
|
||||||
size_t seed = *std::min_element(ranked.begin(), ranked.end(), isHigherPriority);
|
|
||||||
std::vector<std::vector<size_t>> adjacency = buildUndirectedAdjacency(graph);
|
|
||||||
std::vector<size_t> selected;
|
|
||||||
std::vector<char> inWindow(ranked.size(), false);
|
|
||||||
selected.reserve(windowSize);
|
|
||||||
|
|
||||||
struct FrontierEntry {
|
|
||||||
size_t node;
|
|
||||||
};
|
|
||||||
auto frontierCompare = [&](FrontierEntry lhs, FrontierEntry rhs) { return isHigherPriority(rhs.node, lhs.node); };
|
|
||||||
std::priority_queue<FrontierEntry, std::vector<FrontierEntry>, decltype(frontierCompare)> frontier(frontierCompare);
|
|
||||||
|
|
||||||
auto addToWindow = [&](size_t node, const std::vector<char>& eligible) {
|
|
||||||
if (inWindow[node])
|
|
||||||
return;
|
|
||||||
inWindow[node] = true;
|
|
||||||
selected.push_back(node);
|
|
||||||
for (size_t neighbour : adjacency[node])
|
|
||||||
if (!inWindow[neighbour] && eligible[neighbour])
|
|
||||||
frontier.push({neighbour});
|
|
||||||
};
|
|
||||||
|
|
||||||
addToWindow(seed, inCriticalPool);
|
|
||||||
while (!frontier.empty() && selected.size() < windowSize) {
|
|
||||||
size_t node = frontier.top().node;
|
|
||||||
frontier.pop();
|
|
||||||
if (!inWindow[node])
|
|
||||||
addToWindow(node, inCriticalPool);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (selected.size() < windowSize) {
|
|
||||||
std::vector<char> anyNode(ranked.size(), true);
|
|
||||||
for (size_t node : selected)
|
|
||||||
for (size_t neighbour : adjacency[node])
|
|
||||||
if (!inWindow[neighbour])
|
|
||||||
frontier.push({neighbour});
|
|
||||||
while (!frontier.empty() && selected.size() < windowSize) {
|
|
||||||
size_t node = frontier.top().node;
|
|
||||||
frontier.pop();
|
|
||||||
if (!inWindow[node])
|
|
||||||
addToWindow(node, anyNode);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (selected.size() < windowSize) {
|
|
||||||
llvm::sort(ranked, isHigherPriority);
|
|
||||||
for (size_t node : ranked) {
|
|
||||||
if (selected.size() == windowSize)
|
|
||||||
break;
|
|
||||||
if (!inWindow[node]) {
|
|
||||||
inWindow[node] = true;
|
|
||||||
selected.push_back(node);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
llvm::sort(selected, isHigherPriority);
|
|
||||||
return selected;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<IndexedEdge> buildWindowEdges(const VirtualGraph& graph, const std::vector<int64_t>& nodeToWindowIndex) {
|
|
||||||
std::vector<IndexedEdge> windowEdges;
|
|
||||||
windowEdges.reserve(graph.edges.size());
|
|
||||||
for (auto [start, end, weight] : graph.edges) {
|
|
||||||
int64_t mappedStart = nodeToWindowIndex[static_cast<size_t>(start)];
|
|
||||||
int64_t mappedEnd = nodeToWindowIndex[static_cast<size_t>(end)];
|
|
||||||
if (mappedStart == -1 || mappedEnd == -1)
|
|
||||||
continue;
|
|
||||||
windowEdges.push_back({mappedStart, mappedEnd, weight});
|
|
||||||
}
|
|
||||||
return aggregateEdges(windowEdges);
|
|
||||||
}
|
|
||||||
|
|
||||||
WindowScheduleResult scheduleWindow(const VirtualGraph& graph, ArrayRef<size_t> selectedNodes, MLIRContext* context) {
|
|
||||||
std::vector<Weight> windowWeights;
|
|
||||||
std::vector<CrossbarUsage> windowCrossbarUsage;
|
|
||||||
std::vector<int64_t> windowNodeOrderKeys;
|
|
||||||
std::vector<int64_t> nodeToWindowIndex(graph.nodes.size(), -1);
|
|
||||||
windowWeights.reserve(selectedNodes.size());
|
|
||||||
windowCrossbarUsage.reserve(selectedNodes.size());
|
|
||||||
windowNodeOrderKeys.reserve(selectedNodes.size());
|
|
||||||
|
|
||||||
for (auto [windowIndex, nodeIndex] : llvm::enumerate(selectedNodes)) {
|
|
||||||
nodeToWindowIndex[nodeIndex] = static_cast<int64_t>(windowIndex);
|
|
||||||
windowWeights.push_back(graph.nodes[nodeIndex].weight);
|
|
||||||
windowCrossbarUsage.push_back(graph.nodes[nodeIndex].crossbarUsage);
|
|
||||||
windowNodeOrderKeys.push_back(static_cast<int64_t>(nodeIndex));
|
|
||||||
}
|
|
||||||
|
|
||||||
GraphDCP windowGraph(
|
|
||||||
windowWeights, buildWindowEdges(graph, nodeToWindowIndex), windowNodeOrderKeys, windowCrossbarUsage);
|
|
||||||
if (coresCount.getValue() > 0)
|
|
||||||
windowGraph.setMaxCpuCount(static_cast<int>(coresCount.getValue()));
|
|
||||||
windowGraph.setContext(context);
|
|
||||||
windowGraph.runDcp();
|
|
||||||
|
|
||||||
WindowScheduleResult result;
|
|
||||||
result.cpuCount = windowGraph.cpuCount();
|
|
||||||
for (CPU cpu = 0; cpu < windowGraph.cpuCount(); ++cpu) {
|
|
||||||
auto scheduledTasks = windowGraph.getScheduledTasks(cpu);
|
|
||||||
if (scheduledTasks.size() < 2)
|
|
||||||
continue;
|
|
||||||
|
|
||||||
result.mergedNodeCount += scheduledTasks.size();
|
|
||||||
result.maxMergeGroupSize = std::max(result.maxMergeGroupSize, scheduledTasks.size());
|
|
||||||
std::vector<size_t> mergeGroup;
|
|
||||||
mergeGroup.reserve(scheduledTasks.size());
|
|
||||||
for (const auto& task : scheduledTasks)
|
|
||||||
mergeGroup.push_back(selectedNodes[task.nodeIndex]);
|
|
||||||
result.mergeGroups.push_back(std::move(mergeGroup));
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool coarsenGraph(const VirtualGraph& graph,
|
|
||||||
ArrayRef<std::vector<size_t>> mergeGroups,
|
|
||||||
VirtualGraph& coarsenedGraph,
|
|
||||||
std::vector<size_t>& oldToNewNode) {
|
|
||||||
TimingInfo timing = computeTiming(graph);
|
|
||||||
std::vector<size_t> topologicalRank(graph.nodes.size());
|
|
||||||
std::iota(topologicalRank.begin(), topologicalRank.end(), 0);
|
|
||||||
if (timing.valid)
|
|
||||||
for (auto [rank, nodeIndex] : llvm::enumerate(timing.topologicalOrder))
|
|
||||||
topologicalRank[nodeIndex] = rank;
|
|
||||||
|
|
||||||
std::vector<std::vector<size_t>> orderedMergeGroups;
|
|
||||||
orderedMergeGroups.reserve(mergeGroups.size());
|
|
||||||
for (const auto& mergeGroup : mergeGroups) {
|
|
||||||
orderedMergeGroups.emplace_back(mergeGroup.begin(), mergeGroup.end());
|
|
||||||
std::stable_sort(orderedMergeGroups.back().begin(), orderedMergeGroups.back().end(), [&](size_t lhs, size_t rhs) {
|
|
||||||
if (topologicalRank[lhs] != topologicalRank[rhs])
|
|
||||||
return topologicalRank[lhs] < topologicalRank[rhs];
|
|
||||||
return lhs < rhs;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<int64_t> nodeToMergeGroup(graph.nodes.size(), -1);
|
|
||||||
for (auto [groupIndex, mergeGroup] : llvm::enumerate(orderedMergeGroups)) {
|
|
||||||
if (mergeGroup.size() < 2)
|
|
||||||
continue;
|
|
||||||
for (size_t nodeIndex : mergeGroup) {
|
|
||||||
assert(nodeIndex < graph.nodes.size() && "merge group node out of range");
|
|
||||||
nodeToMergeGroup[nodeIndex] = static_cast<int64_t>(groupIndex);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<std::optional<size_t>> mergeGroupToNewNode(orderedMergeGroups.size());
|
|
||||||
std::vector<size_t> newNodeRank;
|
|
||||||
oldToNewNode.assign(graph.nodes.size(), 0);
|
|
||||||
bool mergedAny = false;
|
|
||||||
coarsenedGraph.nodes.clear();
|
|
||||||
coarsenedGraph.edges.clear();
|
|
||||||
coarsenedGraph.nodes.reserve(graph.nodes.size());
|
|
||||||
newNodeRank.reserve(graph.nodes.size());
|
|
||||||
|
|
||||||
for (size_t nodeIndex = 0; nodeIndex < graph.nodes.size(); ++nodeIndex) {
|
|
||||||
int64_t mergeGroupIndex = nodeToMergeGroup[nodeIndex];
|
|
||||||
if (mergeGroupIndex == -1) {
|
|
||||||
oldToNewNode[nodeIndex] = coarsenedGraph.nodes.size();
|
|
||||||
coarsenedGraph.nodes.push_back(graph.nodes[nodeIndex]);
|
|
||||||
newNodeRank.push_back(topologicalRank[nodeIndex]);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto& newNodeIndex = mergeGroupToNewNode[static_cast<size_t>(mergeGroupIndex)];
|
|
||||||
if (newNodeIndex.has_value()) {
|
|
||||||
oldToNewNode[nodeIndex] = *newNodeIndex;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
VirtualNode mergedNode;
|
|
||||||
for (size_t memberIndex : orderedMergeGroups[static_cast<size_t>(mergeGroupIndex)]) {
|
|
||||||
const VirtualNode& memberNode = graph.nodes[memberIndex];
|
|
||||||
mergedNode.originalComputeIndices.append(memberNode.originalComputeIndices.begin(),
|
|
||||||
memberNode.originalComputeIndices.end());
|
|
||||||
mergedNode.weight = addOrMax(mergedNode.weight, memberNode.weight);
|
|
||||||
mergedNode.crossbarUsage = addOrMax(mergedNode.crossbarUsage, memberNode.crossbarUsage);
|
|
||||||
}
|
|
||||||
std::sort(mergedNode.originalComputeIndices.begin(), mergedNode.originalComputeIndices.end());
|
|
||||||
|
|
||||||
mergedAny = true;
|
|
||||||
newNodeIndex = coarsenedGraph.nodes.size();
|
|
||||||
for (size_t memberIndex : orderedMergeGroups[static_cast<size_t>(mergeGroupIndex)])
|
|
||||||
oldToNewNode[memberIndex] = *newNodeIndex;
|
|
||||||
newNodeRank.push_back(topologicalRank[orderedMergeGroups[static_cast<size_t>(mergeGroupIndex)].front()]);
|
|
||||||
coarsenedGraph.nodes.push_back(std::move(mergedNode));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!mergedAny)
|
|
||||||
return false;
|
|
||||||
|
|
||||||
std::vector<IndexedEdge> remappedEdges;
|
|
||||||
remappedEdges.reserve(graph.edges.size());
|
|
||||||
for (auto [start, end, weight] : graph.edges) {
|
|
||||||
size_t newStart = oldToNewNode[static_cast<size_t>(start)];
|
|
||||||
size_t newEnd = oldToNewNode[static_cast<size_t>(end)];
|
|
||||||
if (newStart == newEnd)
|
|
||||||
continue;
|
|
||||||
if (newNodeRank[newStart] >= newNodeRank[newEnd])
|
|
||||||
continue;
|
|
||||||
remappedEdges.push_back({static_cast<int64_t>(newStart), static_cast<int64_t>(newEnd), weight});
|
|
||||||
}
|
|
||||||
coarsenedGraph.edges = aggregateEdges(remappedEdges);
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
CPU getVirtualGraphMaxCpuCount() { return static_cast<CPU>(getSchedulingCpuBudget()); }
|
|
||||||
|
|
||||||
size_t getDcpCoarseningWindowSize(size_t nodeCount) {
|
|
||||||
size_t windowSize = std::min(dcpCriticalWindowSize.getValue(), nodeCount);
|
|
||||||
CPU maxCpuCount = std::max<CPU>(1, getVirtualGraphMaxCpuCount());
|
|
||||||
if (nodeCount > static_cast<size_t>(maxCpuCount))
|
|
||||||
windowSize = std::max(windowSize, std::min(nodeCount, static_cast<size_t>(maxCpuCount) + 1));
|
|
||||||
return windowSize;
|
|
||||||
}
|
|
||||||
|
|
||||||
DCPAnalysisResult buildResultFromVirtualGraph(const VirtualGraph& graph, ArrayRef<ComputeInstance> computeInstances) {
|
|
||||||
DCPAnalysisResult result;
|
|
||||||
|
|
||||||
TimingInfo timing = computeTiming(graph);
|
|
||||||
std::vector<size_t> virtualNodeOrder;
|
|
||||||
if (timing.valid) {
|
|
||||||
virtualNodeOrder = std::move(timing.topologicalOrder);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
virtualNodeOrder.resize(graph.nodes.size());
|
|
||||||
std::iota(virtualNodeOrder.begin(), virtualNodeOrder.end(), 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<size_t> originalComputeToCpu(computeInstances.size(), 0);
|
|
||||||
for (auto [cpu, virtualNodeIndex] : llvm::enumerate(virtualNodeOrder)) {
|
|
||||||
const VirtualNode& virtualNode = graph.nodes[virtualNodeIndex];
|
|
||||||
for (size_t originalIndex : virtualNode.originalComputeIndices)
|
|
||||||
originalComputeToCpu[originalIndex] = cpu;
|
|
||||||
}
|
|
||||||
|
|
||||||
result.dominanceOrderCompute.reserve(computeInstances.size());
|
|
||||||
llvm::DenseMap<size_t, size_t> nextCpuSlot;
|
|
||||||
for (auto [originalIndex, computeInstance] : llvm::enumerate(computeInstances)) {
|
|
||||||
size_t cpu = originalComputeToCpu[originalIndex];
|
|
||||||
result.dominanceOrderCompute.push_back(computeInstance);
|
|
||||||
result.computeToCpuMap[computeInstance] = cpu;
|
|
||||||
result.computeToCpuSlotMap[computeInstance] = nextCpuSlot[cpu]++;
|
|
||||||
result.computeToAestMap[computeInstance] = originalIndex;
|
|
||||||
result.cpuToLastComputeMap[cpu] = computeInstance;
|
|
||||||
}
|
|
||||||
for (const auto& [cpu, lastCompute] : result.cpuToLastComputeMap)
|
|
||||||
result.isLastComputeOfCpu.insert(lastCompute);
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
DCPAnalysisResult buildResultFromScheduledGraph(GraphDCP& graphDCP, ArrayRef<ComputeInstance> computeInstances) {
|
|
||||||
DCPAnalysisResult result;
|
|
||||||
result.dominanceOrderCompute.assign(computeInstances.begin(), computeInstances.end());
|
|
||||||
|
|
||||||
for (CPU cpu = 0; cpu < graphDCP.cpuCount(); ++cpu) {
|
|
||||||
auto scheduledTasks = graphDCP.getScheduledTasks(cpu);
|
|
||||||
if (scheduledTasks.empty())
|
|
||||||
continue;
|
|
||||||
|
|
||||||
for (auto [slot, task] : llvm::enumerate(scheduledTasks)) {
|
|
||||||
ComputeInstance instance = computeInstances[task.nodeIndex];
|
|
||||||
result.computeToCpuMap[instance] = cpu;
|
|
||||||
result.computeToCpuSlotMap[instance] = slot;
|
|
||||||
result.computeToAestMap[instance] = static_cast<uint64_t>(task.aest);
|
|
||||||
}
|
|
||||||
result.cpuToLastComputeMap[cpu] = computeInstances[scheduledTasks.back().nodeIndex];
|
|
||||||
result.isLastComputeOfCpu.insert(computeInstances[scheduledTasks.back().nodeIndex]);
|
|
||||||
}
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
DCPAnalysisResult
|
|
||||||
runLegacyDcp(ArrayRef<ComputeInstance> computeInstances, ArrayRef<IndexedEdge> edges, MLIRContext* context) {
|
|
||||||
SmallVector<Weight> nodeWeights;
|
|
||||||
SmallVector<CrossbarUsage> nodeCrossbarUsage;
|
|
||||||
SmallVector<int64_t> nodeOrderKeys;
|
|
||||||
nodeWeights.reserve(computeInstances.size());
|
|
||||||
nodeCrossbarUsage.reserve(computeInstances.size());
|
|
||||||
nodeOrderKeys.reserve(computeInstances.size());
|
|
||||||
for (auto [index, instance] : llvm::enumerate(computeInstances)) {
|
|
||||||
nodeWeights.push_back(getComputeInstanceWeight(instance));
|
|
||||||
nodeCrossbarUsage.push_back(getComputeInstanceCrossbarUsage(instance));
|
|
||||||
nodeOrderKeys.push_back(static_cast<int64_t>(index));
|
|
||||||
}
|
|
||||||
|
|
||||||
GraphDCP graphDCP(nodeWeights, edges, nodeOrderKeys, nodeCrossbarUsage);
|
|
||||||
if (coresCount.getValue() > 0)
|
|
||||||
graphDCP.setMaxCpuCount(static_cast<int>(coresCount.getValue()));
|
|
||||||
graphDCP.setContext(context);
|
|
||||||
graphDCP.runDcp();
|
|
||||||
return buildResultFromScheduledGraph(graphDCP, computeInstances);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
SpatCompute getOriginalSpatCompute(Operation* op) {
|
|
||||||
if (!op)
|
|
||||||
return {};
|
|
||||||
while (auto extract = dyn_cast<tensor::ExtractSliceOp>(op)) {
|
|
||||||
op = extract.getSource().getDefiningOp();
|
|
||||||
if (!op)
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
if (auto res = dyn_cast<SpatCompute>(op))
|
|
||||||
return res;
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
|
|
||||||
DCPAnalysisResult DCPAnalysis::run() {
|
DCPAnalysisResult DCPAnalysis::run() {
|
||||||
SmallVector<ComputeInstance> computeInstances = collectComputeInstances(entryOp);
|
ComputeGraph graph = buildComputeGraph(entryOp);
|
||||||
SmallVector<IndexedEdge, 10> edges;
|
DcpScheduleOptions options;
|
||||||
|
if (coresCount.getValue() > 0)
|
||||||
llvm::DenseMap<ComputeInstance, size_t> instanceToIndex;
|
options.processorCount = static_cast<size_t>(coresCount.getValue());
|
||||||
instanceToIndex.reserve(computeInstances.size());
|
options.criticalWindowSize = dcpCriticalWindowSize.getValue();
|
||||||
for (auto [index, instance] : llvm::enumerate(computeInstances))
|
options.allowFallbackForAutoCoreCount = true;
|
||||||
instanceToIndex[instance] = index;
|
return runDcpScheduler(graph, options, entryOp->getContext());
|
||||||
|
|
||||||
for (auto [indexEndEdge, computeInstance] : llvm::enumerate(computeInstances)) {
|
|
||||||
for (Value input : getComputeInstanceInputs(computeInstance)) {
|
|
||||||
if (auto producerInstance = getOriginalComputeInstance(input)) {
|
|
||||||
auto producerIt = instanceToIndex.find(*producerInstance);
|
|
||||||
assert(producerIt != instanceToIndex.end());
|
|
||||||
auto indexStartEdge = producerIt->second;
|
|
||||||
edges.push_back({static_cast<int64_t>(indexStartEdge),
|
|
||||||
static_cast<int64_t>(indexEndEdge),
|
|
||||||
static_cast<int64_t>(getSizeInBytes(cast<ShapedType>(input.getType())))});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (coresCount.getValue() > 0) {
|
|
||||||
size_t schedulingCpuBudget = getSchedulingCpuBudget();
|
|
||||||
bool needsExactScheduledBatches = llvm::any_of(computeInstances, [&](const ComputeInstance& instance) {
|
|
||||||
auto batch = dyn_cast<SpatComputeBatch>(instance.op);
|
|
||||||
return batch && static_cast<size_t>(batch.getLaneCount()) > schedulingCpuBudget;
|
|
||||||
});
|
|
||||||
if (needsExactScheduledBatches)
|
|
||||||
return runLegacyDcp(computeInstances, edges, entryOp->getContext());
|
|
||||||
}
|
|
||||||
|
|
||||||
if (dcpCriticalWindowSize.getValue() == 0)
|
|
||||||
return runLegacyDcp(computeInstances, edges, entryOp->getContext());
|
|
||||||
|
|
||||||
VirtualGraph virtualGraph = buildInitialVirtualGraph(computeInstances, edges);
|
|
||||||
size_t iteration = 0;
|
|
||||||
bool debugCoarsening = isDcpCoarsenDebugEnabled();
|
|
||||||
auto tryCoarsenSelectedNodes = [&](ArrayRef<size_t> selectedNodes) {
|
|
||||||
size_t oldNodeCount = virtualGraph.nodes.size();
|
|
||||||
WindowScheduleResult windowSchedule = scheduleWindow(virtualGraph, selectedNodes, entryOp->getContext());
|
|
||||||
if (windowSchedule.mergeGroups.empty()) {
|
|
||||||
if (debugCoarsening && oldNodeCount >= 200)
|
|
||||||
llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} windowCpus={3} "
|
|
||||||
"groups=0 mergedNodes=0 maxGroup=0 new={1} changed=0\n",
|
|
||||||
iteration,
|
|
||||||
oldNodeCount,
|
|
||||||
selectedNodes.size(),
|
|
||||||
windowSchedule.cpuCount);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
VirtualGraph coarsenedGraph;
|
|
||||||
std::vector<size_t> oldToNewNode;
|
|
||||||
if (!coarsenGraph(virtualGraph, windowSchedule.mergeGroups, coarsenedGraph, oldToNewNode))
|
|
||||||
return false;
|
|
||||||
if (debugCoarsening && (oldNodeCount >= 200 || coarsenedGraph.nodes.size() >= 200))
|
|
||||||
llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} windowCpus={3} "
|
|
||||||
"groups={4} mergedNodes={5} maxGroup={6} new={7} changed={8}\n",
|
|
||||||
iteration,
|
|
||||||
oldNodeCount,
|
|
||||||
selectedNodes.size(),
|
|
||||||
windowSchedule.cpuCount,
|
|
||||||
windowSchedule.mergeGroups.size(),
|
|
||||||
windowSchedule.mergedNodeCount,
|
|
||||||
windowSchedule.maxMergeGroupSize,
|
|
||||||
coarsenedGraph.nodes.size(),
|
|
||||||
oldNodeCount - coarsenedGraph.nodes.size());
|
|
||||||
virtualGraph = std::move(coarsenedGraph);
|
|
||||||
return true;
|
|
||||||
};
|
|
||||||
|
|
||||||
while (virtualGraph.nodes.size() > 1) {
|
|
||||||
if (virtualGraph.nodes.size() <= getSchedulingCpuBudget()) {
|
|
||||||
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
|
|
||||||
llvm::errs() << llvm::formatv(
|
|
||||||
"[DCP-COARSEN] iter={0} old={1} stop=cpu-budget\n", iteration, virtualGraph.nodes.size());
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
iteration++;
|
|
||||||
TimingInfo timing = computeTiming(virtualGraph);
|
|
||||||
if (!timing.valid) {
|
|
||||||
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
|
|
||||||
llvm::errs() << llvm::formatv(
|
|
||||||
"[DCP-COARSEN] iter={0} old={1} invalid-timing\n", iteration, virtualGraph.nodes.size());
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<size_t> selectedNodes;
|
|
||||||
auto criticalWindow =
|
|
||||||
selectCriticalWindow(virtualGraph, timing, getDcpCoarseningWindowSize(virtualGraph.nodes.size()));
|
|
||||||
selectedNodes.append(criticalWindow.begin(), criticalWindow.end());
|
|
||||||
|
|
||||||
if (selectedNodes.size() < 2) {
|
|
||||||
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
|
|
||||||
llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} stop=small-window\n",
|
|
||||||
iteration,
|
|
||||||
virtualGraph.nodes.size(),
|
|
||||||
selectedNodes.size());
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (tryCoarsenSelectedNodes(selectedNodes))
|
|
||||||
continue;
|
|
||||||
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
|
|
||||||
llvm::errs() << llvm::formatv(
|
|
||||||
"[DCP-COARSEN] iter={0} old={1} stop=no-merge\n", iteration, virtualGraph.nodes.size());
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
return buildResultFromVirtualGraph(virtualGraph, computeInstances);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace spatial
|
} // namespace spatial
|
||||||
|
|||||||
@@ -2,64 +2,27 @@
|
|||||||
|
|
||||||
#include "mlir/IR/Operation.h"
|
#include "mlir/IR/Operation.h"
|
||||||
|
|
||||||
#include "llvm/ADT/DenseMap.h"
|
#include "../Scheduling/MergeSchedule.hpp"
|
||||||
#include "llvm/ADT/DenseSet.h"
|
|
||||||
|
|
||||||
#include <cstdint>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
||||||
|
|
||||||
// A scheduling identity that covers both spat.compute and scheduled shards of
|
|
||||||
// spat.compute_batch.
|
|
||||||
struct ComputeInstance {
|
|
||||||
mlir::Operation* op = nullptr;
|
|
||||||
uint32_t laneStart = 0;
|
|
||||||
uint32_t laneCount = 1;
|
|
||||||
|
|
||||||
bool operator==(const ComputeInstance& other) const {
|
|
||||||
return op == other.op && laneStart == other.laneStart && laneCount == other.laneCount;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct DCPAnalysisResult {
|
|
||||||
std::vector<ComputeInstance> dominanceOrderCompute;
|
|
||||||
llvm::DenseMap<ComputeInstance, size_t> computeToCpuMap;
|
|
||||||
llvm::DenseMap<ComputeInstance, size_t> computeToCpuSlotMap;
|
|
||||||
llvm::DenseMap<ComputeInstance, uint64_t> computeToAestMap;
|
|
||||||
llvm::DenseSet<ComputeInstance> isLastComputeOfCpu;
|
|
||||||
llvm::DenseMap<size_t, ComputeInstance> cpuToLastComputeMap;
|
|
||||||
};
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
namespace spatial {
|
namespace spatial {
|
||||||
|
using DCPAnalysisResult = MergeScheduleResult;
|
||||||
|
|
||||||
struct DCPAnalysis {
|
struct DCPAnalysis {
|
||||||
private:
|
private:
|
||||||
DCPAnalysisResult result;
|
DCPAnalysisResult result;
|
||||||
mlir::Operation* entryOp;
|
mlir::Operation *entryOp;
|
||||||
DCPAnalysisResult run();
|
DCPAnalysisResult run();
|
||||||
|
|
||||||
public:
|
public:
|
||||||
DCPAnalysis(mlir::Operation* op)
|
DCPAnalysis(mlir::Operation *op)
|
||||||
: entryOp(op) {
|
: entryOp(op) {
|
||||||
result = run();
|
result = run();
|
||||||
}
|
}
|
||||||
DCPAnalysisResult& getResult() { return result; }
|
DCPAnalysisResult &getResult() { return result; }
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace spatial
|
} // namespace spatial
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|
||||||
namespace llvm {
|
using DCPAnalysisResult = onnx_mlir::spatial::DCPAnalysisResult;
|
||||||
template <>
|
|
||||||
struct DenseMapInfo<ComputeInstance> {
|
|
||||||
static ComputeInstance getEmptyKey() {
|
|
||||||
return {DenseMapInfo<mlir::Operation*>::getEmptyKey(), UINT32_MAX, UINT32_MAX};
|
|
||||||
}
|
|
||||||
static ComputeInstance getTombstoneKey() {
|
|
||||||
return {DenseMapInfo<mlir::Operation*>::getTombstoneKey(), UINT32_MAX, UINT32_MAX};
|
|
||||||
}
|
|
||||||
static unsigned getHashValue(const ComputeInstance& v) { return llvm::hash_combine(v.op, v.laneStart, v.laneCount); }
|
|
||||||
static bool isEqual(const ComputeInstance& a, const ComputeInstance& b) { return a == b; }
|
|
||||||
};
|
|
||||||
} // namespace llvm
|
|
||||||
|
|||||||
@@ -0,0 +1,636 @@
|
|||||||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
#include "mlir/IR/IRMapping.h"
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/DenseMap.h"
|
||||||
|
#include "llvm/ADT/DenseSet.h"
|
||||||
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
|
#include <cstddef>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <functional>
|
||||||
|
#include <optional>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "MaterializeMergeSchedule.hpp"
|
||||||
|
#include "Scheduling/ComputeInstanceUtils.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace spatial {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using SpatCompute = spatial::SpatCompute;
|
||||||
|
using ProducerValueRef = spatial::ProducerValueRef;
|
||||||
|
using spatial::getComputeInstanceInputs;
|
||||||
|
using spatial::getComputeInstanceOutputTypes;
|
||||||
|
using spatial::getComputeInstanceOutputValues;
|
||||||
|
using spatial::getComputeInstanceTemplateBlock;
|
||||||
|
using spatial::getComputeInstanceWeights;
|
||||||
|
using spatial::getProducerValueRef;
|
||||||
|
|
||||||
|
class MergeScheduleMaterializerImpl {
|
||||||
|
public:
|
||||||
|
explicit MergeScheduleMaterializerImpl(func::FuncOp funcOp)
|
||||||
|
: func(funcOp), loc(funcOp.getLoc()), returnOp(cast<func::ReturnOp>(funcOp.getBody().front().getTerminator())) {}
|
||||||
|
|
||||||
|
LogicalResult run(const MergeScheduleResult& scheduleResult, int64_t& nextChannelIdRef) {
|
||||||
|
schedule = &scheduleResult;
|
||||||
|
nextChannelId = &nextChannelIdRef;
|
||||||
|
|
||||||
|
collectScheduledTasks();
|
||||||
|
buildTaskIndex();
|
||||||
|
collectExternalInputsAndWeights();
|
||||||
|
planRemoteChannels();
|
||||||
|
planReceiveReordering();
|
||||||
|
createCpuComputeOps();
|
||||||
|
if (failed(cloneTaskBodies()))
|
||||||
|
return failure();
|
||||||
|
replaceExternalUses();
|
||||||
|
if (failed(eraseOldScheduledOps()))
|
||||||
|
return failure();
|
||||||
|
moveExternalUsersBeforeReturn();
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
struct ScheduledTask {
|
||||||
|
ComputeInstance computeInstance;
|
||||||
|
size_t cpu = 0;
|
||||||
|
size_t orderWithinCpu = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ChannelInfo {
|
||||||
|
int64_t channelId = -1;
|
||||||
|
int32_t sourceCoreId = -1;
|
||||||
|
int32_t targetCoreId = -1;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct CpuProgram {
|
||||||
|
SpatCompute op;
|
||||||
|
DenseMap<Value, Value> externalInputMap;
|
||||||
|
DenseMap<Value, size_t> weightToIndex;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct RemoteSendInfo {
|
||||||
|
ChannelInfo channelInfo;
|
||||||
|
ComputeInstance consumer;
|
||||||
|
size_t inputIndex = 0;
|
||||||
|
size_t consumerOrder = 0;
|
||||||
|
size_t sourceOrder = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct RemoteReceiveEntry {
|
||||||
|
ChannelInfo channelInfo;
|
||||||
|
ComputeInstance consumer;
|
||||||
|
size_t inputIndex = 0;
|
||||||
|
size_t sourceOrder = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
static uint64_t getRemoteSendPairKey(const ChannelInfo& channelInfo) {
|
||||||
|
return (static_cast<uint64_t>(static_cast<uint32_t>(channelInfo.sourceCoreId)) << 32)
|
||||||
|
| static_cast<uint32_t>(channelInfo.targetCoreId);
|
||||||
|
}
|
||||||
|
|
||||||
|
void collectExternalUsers(Operation* op) {
|
||||||
|
if (!externalUsersToMove.insert(op).second)
|
||||||
|
return;
|
||||||
|
for (Value result : op->getResults()) {
|
||||||
|
for (Operation* user : result.getUsers()) {
|
||||||
|
if (oldComputeOps.contains(user) || isa<func::ReturnOp>(user))
|
||||||
|
continue;
|
||||||
|
collectExternalUsers(user);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void collectScheduledTasks() {
|
||||||
|
for (ComputeInstance scheduledInstance : schedule->dominanceOrderCompute) {
|
||||||
|
oldComputeOps.insert(scheduledInstance.op);
|
||||||
|
scheduledTasks.push_back({scheduledInstance,
|
||||||
|
schedule->computeToCpuMap.lookup(scheduledInstance),
|
||||||
|
schedule->computeToCpuSlotMap.lookup(scheduledInstance)});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void buildTaskIndex() {
|
||||||
|
auto markCpuSeen = [&](size_t cpu) {
|
||||||
|
if (seenCpus.insert(cpu).second)
|
||||||
|
orderedCpus.push_back(cpu);
|
||||||
|
};
|
||||||
|
|
||||||
|
for (const ScheduledTask& task : scheduledTasks) {
|
||||||
|
taskByComputeInstance[task.computeInstance] = task;
|
||||||
|
tasksByCpu[task.cpu].push_back(task);
|
||||||
|
markCpuSeen(task.cpu);
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::sort(orderedCpus);
|
||||||
|
for (size_t cpu : orderedCpus)
|
||||||
|
llvm::stable_sort(tasksByCpu[cpu], [&](const ScheduledTask& lhs, const ScheduledTask& rhs) {
|
||||||
|
return lhs.orderWithinCpu < rhs.orderWithinCpu;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void collectExternalInputsAndWeights() {
|
||||||
|
for (size_t cpu : orderedCpus) {
|
||||||
|
for (const ScheduledTask& task : tasksByCpu[cpu]) {
|
||||||
|
auto& thisCpuWeights = cpuWeights[cpu];
|
||||||
|
auto& thisSeenWeights = seenWeightsByCpu[cpu];
|
||||||
|
auto taskWeights = getComputeInstanceWeights(task.computeInstance);
|
||||||
|
for (Value weight : taskWeights)
|
||||||
|
if (thisSeenWeights.insert(weight).second)
|
||||||
|
thisCpuWeights.push_back(weight);
|
||||||
|
|
||||||
|
auto taskInputs = getComputeInstanceInputs(task.computeInstance);
|
||||||
|
auto& remoteInputs = remoteInputsByTask[task.computeInstance];
|
||||||
|
remoteInputs.resize(taskInputs.size());
|
||||||
|
for (auto [inputIndex, input] : llvm::enumerate(taskInputs)) {
|
||||||
|
auto producerRef = getProducerValueRef(input);
|
||||||
|
if (producerRef) {
|
||||||
|
auto producerIt = taskByComputeInstance.find(producerRef->instance);
|
||||||
|
if (producerIt != taskByComputeInstance.end()) {
|
||||||
|
if (producerIt->second.cpu != cpu) {
|
||||||
|
ChannelInfo info {
|
||||||
|
(*nextChannelId)++,
|
||||||
|
static_cast<int32_t>(producerIt->second.cpu),
|
||||||
|
static_cast<int32_t>(cpu),
|
||||||
|
};
|
||||||
|
remoteInputs[inputIndex] = info;
|
||||||
|
auto& perResultChannels = remoteSendsByTask[producerRef->instance];
|
||||||
|
if (perResultChannels.empty())
|
||||||
|
perResultChannels.resize(getComputeInstanceOutputTypes(producerIt->second.computeInstance).size());
|
||||||
|
perResultChannels[producerRef->resultIndex].push_back(
|
||||||
|
{info, task.computeInstance, inputIndex, task.orderWithinCpu, 0});
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (seenExternalInputsByCpu[cpu].insert(input).second)
|
||||||
|
cpuExternalInputs[cpu].push_back(input);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto taskOutputs = getComputeInstanceOutputValues(task.computeInstance);
|
||||||
|
for (auto [resultIndex, output] : llvm::enumerate(taskOutputs)) {
|
||||||
|
bool hasExternalUser = false;
|
||||||
|
for (auto& use : output.getUses()) {
|
||||||
|
Operation* useOwner = use.getOwner();
|
||||||
|
if (oldComputeOps.contains(useOwner))
|
||||||
|
continue;
|
||||||
|
hasExternalUser = true;
|
||||||
|
if (!isa<func::ReturnOp>(useOwner))
|
||||||
|
collectExternalUsers(useOwner);
|
||||||
|
}
|
||||||
|
if (hasExternalUser)
|
||||||
|
cpuExternalOutputs[cpu].push_back({task.computeInstance, resultIndex});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void planRemoteChannels() {
|
||||||
|
for (size_t cpu : orderedCpus) {
|
||||||
|
DenseMap<uint64_t, size_t> nextSourceOrderByPair;
|
||||||
|
DenseMap<uint64_t, size_t> lastConsumerOrderByPair;
|
||||||
|
for (const ScheduledTask& task : tasksByCpu[cpu]) {
|
||||||
|
auto sendsIt = remoteSendsByTask.find(task.computeInstance);
|
||||||
|
if (sendsIt == remoteSendsByTask.end())
|
||||||
|
continue;
|
||||||
|
for (auto& sendInfos : sendsIt->second) {
|
||||||
|
for (RemoteSendInfo& sendInfo : sendInfos) {
|
||||||
|
uint64_t pairKey = getRemoteSendPairKey(sendInfo.channelInfo);
|
||||||
|
sendInfo.sourceOrder = nextSourceOrderByPair[pairKey]++;
|
||||||
|
auto [it, inserted] = lastConsumerOrderByPair.try_emplace(pairKey, sendInfo.consumerOrder);
|
||||||
|
if (!inserted) {
|
||||||
|
if (sendInfo.consumerOrder < it->second)
|
||||||
|
pairsNeedingReceiveReorder.insert(pairKey);
|
||||||
|
it->second = sendInfo.consumerOrder;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void planReceiveReordering() {
|
||||||
|
DenseMap<uint64_t, SmallVector<RemoteSendInfo*>> reorderedSendsByPair;
|
||||||
|
for (auto& taskSends : remoteSendsByTask) {
|
||||||
|
for (auto& sendInfos : taskSends.second) {
|
||||||
|
for (RemoteSendInfo& sendInfo : sendInfos) {
|
||||||
|
uint64_t pairKey = getRemoteSendPairKey(sendInfo.channelInfo);
|
||||||
|
if (pairsNeedingReceiveReorder.contains(pairKey))
|
||||||
|
reorderedSendsByPair[pairKey].push_back(&sendInfo);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto& pairSends : reorderedSendsByPair) {
|
||||||
|
llvm::stable_sort(pairSends.second, [](const RemoteSendInfo* lhs, const RemoteSendInfo* rhs) {
|
||||||
|
if (lhs->sourceOrder != rhs->sourceOrder)
|
||||||
|
return lhs->sourceOrder < rhs->sourceOrder;
|
||||||
|
return lhs->channelInfo.channelId < rhs->channelInfo.channelId;
|
||||||
|
});
|
||||||
|
for (RemoteSendInfo* sendInfo : pairSends.second) {
|
||||||
|
int64_t channelId = (*nextChannelId)++;
|
||||||
|
sendInfo->channelInfo.channelId = channelId;
|
||||||
|
auto remoteInputsIt = remoteInputsByTask.find(sendInfo->consumer);
|
||||||
|
assert(remoteInputsIt != remoteInputsByTask.end() && "missing remote input for reordered send");
|
||||||
|
assert(sendInfo->inputIndex < remoteInputsIt->second.size() && "remote input index out of range");
|
||||||
|
assert(remoteInputsIt->second[sendInfo->inputIndex] && "missing reordered remote input channel");
|
||||||
|
remoteInputsIt->second[sendInfo->inputIndex]->channelId = channelId;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const auto& taskSends : remoteSendsByTask) {
|
||||||
|
for (const auto& sendInfos : taskSends.second) {
|
||||||
|
for (const RemoteSendInfo& sendInfo : sendInfos) {
|
||||||
|
auto remoteInputsIt = remoteInputsByTask.find(sendInfo.consumer);
|
||||||
|
assert(remoteInputsIt != remoteInputsByTask.end() && "missing remote input for send");
|
||||||
|
assert(sendInfo.inputIndex < remoteInputsIt->second.size() && "remote input index out of range");
|
||||||
|
assert(remoteInputsIt->second[sendInfo.inputIndex] && "missing remote input channel");
|
||||||
|
remoteInputsIt->second[sendInfo.inputIndex] = sendInfo.channelInfo;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto& taskSends : remoteSendsByTask) {
|
||||||
|
for (const auto& sendInfos : taskSends.second) {
|
||||||
|
for (const RemoteSendInfo& sendInfo : sendInfos) {
|
||||||
|
uint64_t pairKey = getRemoteSendPairKey(sendInfo.channelInfo);
|
||||||
|
if (!pairsNeedingReceiveReorder.contains(pairKey))
|
||||||
|
continue;
|
||||||
|
size_t targetCpu = static_cast<size_t>(sendInfo.channelInfo.targetCoreId);
|
||||||
|
receiveQueuesByCpu[targetCpu][pairKey].push_back(
|
||||||
|
{sendInfo.channelInfo, sendInfo.consumer, sendInfo.inputIndex, sendInfo.sourceOrder});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto& cpuQueues : receiveQueuesByCpu) {
|
||||||
|
for (auto& pairQueue : cpuQueues.second) {
|
||||||
|
llvm::stable_sort(pairQueue.second, [](const RemoteReceiveEntry& lhs, const RemoteReceiveEntry& rhs) {
|
||||||
|
if (lhs.sourceOrder != rhs.sourceOrder)
|
||||||
|
return lhs.sourceOrder < rhs.sourceOrder;
|
||||||
|
return lhs.channelInfo.channelId < rhs.channelInfo.channelId;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void createCpuComputeOps() {
|
||||||
|
IRRewriter rewriter(func.getContext());
|
||||||
|
for (size_t cpu : orderedCpus) {
|
||||||
|
SmallVector<Value> operands;
|
||||||
|
operands.reserve(cpuWeights[cpu].size() + cpuExternalInputs[cpu].size());
|
||||||
|
llvm::append_range(operands, cpuWeights[cpu]);
|
||||||
|
llvm::append_range(operands, cpuExternalInputs[cpu]);
|
||||||
|
|
||||||
|
SmallVector<Type> resultTypes;
|
||||||
|
resultTypes.reserve(cpuExternalOutputs[cpu].size());
|
||||||
|
for (ProducerValueRef outputRef : cpuExternalOutputs[cpu]) {
|
||||||
|
ScheduledTask task = taskByComputeInstance.at(outputRef.instance);
|
||||||
|
resultTypes.push_back(getComputeInstanceOutputTypes(task.computeInstance)[outputRef.resultIndex]);
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.setInsertionPoint(returnOp);
|
||||||
|
auto newCompute = SpatCompute::create(rewriter, loc, TypeRange(resultTypes), ValueRange(operands));
|
||||||
|
newCompute.getProperties().setOperandSegmentSizes(
|
||||||
|
{static_cast<int>(cpuWeights[cpu].size()), static_cast<int>(cpuExternalInputs[cpu].size())});
|
||||||
|
newCompute->setAttr(onnx_mlir::kCoreIdAttrName, rewriter.getI32IntegerAttr(static_cast<int32_t>(cpu)));
|
||||||
|
|
||||||
|
SmallVector<Type> blockArgTypes;
|
||||||
|
SmallVector<Location> blockArgLocs;
|
||||||
|
blockArgTypes.reserve(cpuExternalInputs[cpu].size());
|
||||||
|
blockArgLocs.reserve(cpuExternalInputs[cpu].size());
|
||||||
|
for (Value input : cpuExternalInputs[cpu]) {
|
||||||
|
blockArgTypes.push_back(input.getType());
|
||||||
|
blockArgLocs.push_back(loc);
|
||||||
|
}
|
||||||
|
Block* newBlock =
|
||||||
|
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
||||||
|
|
||||||
|
CpuProgram program;
|
||||||
|
program.op = newCompute;
|
||||||
|
for (auto [weightIndex, weight] : llvm::enumerate(cpuWeights[cpu]))
|
||||||
|
program.weightToIndex[weight] = weightIndex;
|
||||||
|
for (auto [inputIndex, input] : llvm::enumerate(cpuExternalInputs[cpu]))
|
||||||
|
program.externalInputMap[input] = newBlock->getArgument(inputIndex);
|
||||||
|
for (auto [resultIndex, outputRef] : llvm::enumerate(cpuExternalOutputs[cpu])) {
|
||||||
|
ScheduledTask task = taskByComputeInstance.at(outputRef.instance);
|
||||||
|
oldToNewExternalValueMap[getComputeInstanceOutputValues(task.computeInstance)[outputRef.resultIndex]] =
|
||||||
|
newCompute.getResult(resultIndex);
|
||||||
|
}
|
||||||
|
cpuPrograms[cpu] = std::move(program);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
FailureOr<Value> receiveThroughInput(IRRewriter& rewriter,
|
||||||
|
size_t cpu,
|
||||||
|
DenseMap<uint64_t, size_t>& receiveQueueIndices,
|
||||||
|
DenseMap<ComputeInstance, SmallVector<Value>>& preReceivedInputsByTask,
|
||||||
|
const ChannelInfo& requestedChannelInfo,
|
||||||
|
ComputeInstance requestedConsumer,
|
||||||
|
size_t requestedInputIndex) {
|
||||||
|
uint64_t pairKey = getRemoteSendPairKey(requestedChannelInfo);
|
||||||
|
auto cpuQueuesIt = receiveQueuesByCpu.find(cpu);
|
||||||
|
if (cpuQueuesIt == receiveQueuesByCpu.end())
|
||||||
|
return failure();
|
||||||
|
auto queueIt = cpuQueuesIt->second.find(pairKey);
|
||||||
|
if (queueIt == cpuQueuesIt->second.end())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto& queue = queueIt->second;
|
||||||
|
size_t& queueIndex = receiveQueueIndices[pairKey];
|
||||||
|
while (queueIndex < queue.size()) {
|
||||||
|
const RemoteReceiveEntry& entry = queue[queueIndex++];
|
||||||
|
auto consumerTaskIt = taskByComputeInstance.find(entry.consumer);
|
||||||
|
if (consumerTaskIt == taskByComputeInstance.end())
|
||||||
|
return failure();
|
||||||
|
SmallVector<Value> consumerInputs = getComputeInstanceInputs(consumerTaskIt->second.computeInstance);
|
||||||
|
if (consumerInputs.size() <= entry.inputIndex)
|
||||||
|
return failure();
|
||||||
|
Type inputType = consumerInputs[entry.inputIndex].getType();
|
||||||
|
auto receive = spatial::SpatChannelReceiveOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
inputType,
|
||||||
|
rewriter.getI64IntegerAttr(entry.channelInfo.channelId),
|
||||||
|
rewriter.getI32IntegerAttr(entry.channelInfo.sourceCoreId),
|
||||||
|
rewriter.getI32IntegerAttr(entry.channelInfo.targetCoreId));
|
||||||
|
|
||||||
|
auto& receivedInputs = preReceivedInputsByTask[entry.consumer];
|
||||||
|
if (receivedInputs.size() <= entry.inputIndex)
|
||||||
|
receivedInputs.resize(entry.inputIndex + 1);
|
||||||
|
receivedInputs[entry.inputIndex] = receive.getResult();
|
||||||
|
|
||||||
|
if (entry.consumer == requestedConsumer && entry.inputIndex == requestedInputIndex)
|
||||||
|
return receive.getResult();
|
||||||
|
}
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult cloneTaskBodies() {
|
||||||
|
for (size_t cpu : orderedCpus) {
|
||||||
|
CpuProgram& program = cpuPrograms[cpu];
|
||||||
|
IRRewriter rewriter(func.getContext());
|
||||||
|
rewriter.setInsertionPointToEnd(&program.op.getBody().front());
|
||||||
|
DenseMap<uint64_t, size_t> receiveQueueIndices;
|
||||||
|
DenseMap<ComputeInstance, SmallVector<Value>> preReceivedInputsByTask;
|
||||||
|
|
||||||
|
auto lookupPreReceivedInput = [&](ComputeInstance consumer, size_t inputIndex) -> std::optional<Value> {
|
||||||
|
auto inputsIt = preReceivedInputsByTask.find(consumer);
|
||||||
|
if (inputsIt == preReceivedInputsByTask.end() || inputsIt->second.size() <= inputIndex)
|
||||||
|
return std::nullopt;
|
||||||
|
Value value = inputsIt->second[inputIndex];
|
||||||
|
if (!value)
|
||||||
|
return std::nullopt;
|
||||||
|
return value;
|
||||||
|
};
|
||||||
|
|
||||||
|
for (const ScheduledTask& task : tasksByCpu[cpu]) {
|
||||||
|
SmallVector<Value> taskInputs = getComputeInstanceInputs(task.computeInstance);
|
||||||
|
auto taskWeights = getComputeInstanceWeights(task.computeInstance);
|
||||||
|
Block& templateBlock = getComputeInstanceTemplateBlock(task.computeInstance);
|
||||||
|
|
||||||
|
SmallVector<Value> resolvedInputs;
|
||||||
|
resolvedInputs.reserve(taskInputs.size());
|
||||||
|
auto remoteInputsIt = remoteInputsByTask.find(task.computeInstance);
|
||||||
|
for (auto [inputIndex, input] : llvm::enumerate(taskInputs)) {
|
||||||
|
auto producerRef = getProducerValueRef(input);
|
||||||
|
if (producerRef) {
|
||||||
|
auto producerIt = taskByComputeInstance.find(producerRef->instance);
|
||||||
|
if (producerIt != taskByComputeInstance.end()) {
|
||||||
|
if (producerIt->second.cpu == cpu) {
|
||||||
|
auto producedIt = producedValuesByTask.find(producerRef->instance);
|
||||||
|
if (producedIt == producedValuesByTask.end() || producedIt->second.size() <= producerRef->resultIndex) {
|
||||||
|
task.computeInstance.op->emitOpError("missing local producer value during per-cpu merge materialization")
|
||||||
|
<< " consumerCpu=" << cpu << " producerCpu=" << producerIt->second.cpu
|
||||||
|
<< " producerLaneStart=" << producerRef->instance.laneStart
|
||||||
|
<< " producerLaneCount=" << producerRef->instance.laneCount;
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
resolvedInputs.push_back(producedIt->second[producerRef->resultIndex]);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
const ChannelInfo& channelInfo = *remoteInputsIt->second[inputIndex];
|
||||||
|
uint64_t pairKey = getRemoteSendPairKey(channelInfo);
|
||||||
|
if (pairsNeedingReceiveReorder.contains(pairKey)) {
|
||||||
|
if (std::optional<Value> preReceived = lookupPreReceivedInput(task.computeInstance, inputIndex)) {
|
||||||
|
resolvedInputs.push_back(*preReceived);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
FailureOr<Value> received = receiveThroughInput(rewriter,
|
||||||
|
cpu,
|
||||||
|
receiveQueueIndices,
|
||||||
|
preReceivedInputsByTask,
|
||||||
|
channelInfo,
|
||||||
|
task.computeInstance,
|
||||||
|
inputIndex);
|
||||||
|
if (failed(received)) {
|
||||||
|
task.computeInstance.op->emitOpError("failed to materialize reordered remote receive")
|
||||||
|
<< " consumerCpu=" << cpu << " sourceCoreId=" << channelInfo.sourceCoreId
|
||||||
|
<< " targetCoreId=" << channelInfo.targetCoreId << " channelId=" << channelInfo.channelId;
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
resolvedInputs.push_back(*received);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto receive =
|
||||||
|
spatial::SpatChannelReceiveOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
input.getType(),
|
||||||
|
rewriter.getI64IntegerAttr(channelInfo.channelId),
|
||||||
|
rewriter.getI32IntegerAttr(channelInfo.sourceCoreId),
|
||||||
|
rewriter.getI32IntegerAttr(channelInfo.targetCoreId));
|
||||||
|
resolvedInputs.push_back(receive.getResult());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
resolvedInputs.push_back(program.externalInputMap.at(input));
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Value> taskYieldValues;
|
||||||
|
rewriter.setInsertionPointToEnd(&program.op.getBody().front());
|
||||||
|
if (isa<SpatCompute>(task.computeInstance.op)) {
|
||||||
|
IRMapping mapper;
|
||||||
|
for (auto [argIndex, oldArg] : llvm::enumerate(templateBlock.getArguments()))
|
||||||
|
mapper.map(oldArg, resolvedInputs[argIndex]);
|
||||||
|
|
||||||
|
for (Operation& op : templateBlock) {
|
||||||
|
if (auto yield = dyn_cast<spatial::SpatYieldOp>(&op)) {
|
||||||
|
for (Value yieldOperand : yield.getOperands())
|
||||||
|
taskYieldValues.push_back(mapper.lookup(yieldOperand));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
Operation* clonedOp = rewriter.clone(op, mapper);
|
||||||
|
if (auto oldWeightedMvmOp = dyn_cast<spatial::SpatMVMOp>(&op)) {
|
||||||
|
auto newWeightedMvmOp = cast<spatial::SpatMVMOp>(clonedOp);
|
||||||
|
Value weight = taskWeights[oldWeightedMvmOp.getWeightIndex()];
|
||||||
|
newWeightedMvmOp.setWeightIndex(program.weightToIndex.at(weight));
|
||||||
|
}
|
||||||
|
if (auto oldWeightedVmmOp = dyn_cast<spatial::SpatVMMOp>(&op)) {
|
||||||
|
auto newWeightedVmmOp = cast<spatial::SpatVMMOp>(clonedOp);
|
||||||
|
Value weight = taskWeights[oldWeightedVmmOp.getWeightIndex()];
|
||||||
|
newWeightedVmmOp.setWeightIndex(program.weightToIndex.at(weight));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
for (size_t laneOffset = 0; laneOffset < task.computeInstance.laneCount; ++laneOffset) {
|
||||||
|
IRMapping mapper;
|
||||||
|
if (templateBlock.getNumArguments() == 1)
|
||||||
|
mapper.map(templateBlock.getArgument(0), resolvedInputs[laneOffset]);
|
||||||
|
|
||||||
|
for (Operation& op : templateBlock) {
|
||||||
|
if (auto yield = dyn_cast<spatial::SpatYieldOp>(&op)) {
|
||||||
|
for (Value yieldOperand : yield.getOperands())
|
||||||
|
taskYieldValues.push_back(mapper.lookup(yieldOperand));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
Operation* clonedOp = rewriter.clone(op, mapper);
|
||||||
|
if (auto oldWeightedMvmOp = dyn_cast<spatial::SpatMVMOp>(&op)) {
|
||||||
|
if (oldWeightedMvmOp.getWeightIndex() != 0) {
|
||||||
|
task.computeInstance.op->emitOpError(
|
||||||
|
"batched per-cpu merge materialization expects lane-local weight index 0");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
auto newWeightedMvmOp = cast<spatial::SpatMVMOp>(clonedOp);
|
||||||
|
newWeightedMvmOp.setWeightIndex(program.weightToIndex.at(taskWeights[laneOffset]));
|
||||||
|
}
|
||||||
|
if (auto oldWeightedVmmOp = dyn_cast<spatial::SpatVMMOp>(&op)) {
|
||||||
|
if (oldWeightedVmmOp.getWeightIndex() != 0) {
|
||||||
|
task.computeInstance.op->emitOpError(
|
||||||
|
"batched per-cpu merge materialization expects lane-local weight index 0");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
auto newWeightedVmmOp = cast<spatial::SpatVMMOp>(clonedOp);
|
||||||
|
newWeightedVmmOp.setWeightIndex(program.weightToIndex.at(taskWeights[laneOffset]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
producedValuesByTask[task.computeInstance] = taskYieldValues;
|
||||||
|
if (auto sendsIt = remoteSendsByTask.find(task.computeInstance); sendsIt != remoteSendsByTask.end()) {
|
||||||
|
for (auto [resultIndex, sendInfos] : llvm::enumerate(sendsIt->second)) {
|
||||||
|
if (sendInfos.empty())
|
||||||
|
continue;
|
||||||
|
Value producedValue = taskYieldValues[resultIndex];
|
||||||
|
for (const RemoteSendInfo& sendInfo : sendInfos) {
|
||||||
|
spatial::SpatChannelSendOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
rewriter.getI64IntegerAttr(sendInfo.channelInfo.channelId),
|
||||||
|
rewriter.getI32IntegerAttr(sendInfo.channelInfo.sourceCoreId),
|
||||||
|
rewriter.getI32IntegerAttr(sendInfo.channelInfo.targetCoreId),
|
||||||
|
producedValue);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Value> yieldValues;
|
||||||
|
yieldValues.reserve(cpuExternalOutputs[cpu].size());
|
||||||
|
for (ProducerValueRef outputRef : cpuExternalOutputs[cpu]) {
|
||||||
|
auto producedIt = producedValuesByTask.find(outputRef.instance);
|
||||||
|
if (producedIt == producedValuesByTask.end() || producedIt->second.size() <= outputRef.resultIndex) {
|
||||||
|
ScheduledTask task = taskByComputeInstance.at(outputRef.instance);
|
||||||
|
task.computeInstance.op->emitOpError("missing yielded external value during per-cpu merge materialization")
|
||||||
|
<< " cpu=" << cpu << " laneStart=" << outputRef.instance.laneStart;
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
yieldValues.push_back(producedIt->second[outputRef.resultIndex]);
|
||||||
|
}
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, ValueRange(yieldValues));
|
||||||
|
}
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
void replaceExternalUses() {
|
||||||
|
for (auto [oldValue, newValue] : oldToNewExternalValueMap) {
|
||||||
|
for (auto& use : llvm::make_early_inc_range(oldValue.getUses()))
|
||||||
|
if (!oldComputeOps.contains(use.getOwner()))
|
||||||
|
use.assign(newValue);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult eraseOldScheduledOps() {
|
||||||
|
SmallVector<Operation*> orderedOpsToErase;
|
||||||
|
for (Operation& op : func.getBody().front())
|
||||||
|
if (oldComputeOps.contains(&op))
|
||||||
|
orderedOpsToErase.push_back(&op);
|
||||||
|
|
||||||
|
for (Operation* op : llvm::reverse(orderedOpsToErase)) {
|
||||||
|
SmallVector<Operation*> remainingUsers;
|
||||||
|
for (Value result : op->getResults())
|
||||||
|
for (Operation* user : result.getUsers())
|
||||||
|
remainingUsers.push_back(user);
|
||||||
|
if (!remainingUsers.empty()) {
|
||||||
|
InFlightDiagnostic diagnostic = op->emitOpError("still has uses during per-cpu merge cleanup")
|
||||||
|
<< "; erase-set=" << (oldComputeOps.contains(op) ? "yes" : "no");
|
||||||
|
for (Operation* user : remainingUsers) {
|
||||||
|
diagnostic.attachNote(user->getLoc())
|
||||||
|
<< "remaining user " << user->getName() << "; erase-set=" << (oldComputeOps.contains(user) ? "yes" : "no");
|
||||||
|
}
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
op->erase();
|
||||||
|
}
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
void moveExternalUsersBeforeReturn() {
|
||||||
|
SmallVector<Operation*> orderedUsersToMove;
|
||||||
|
for (Operation& op : func.getBody().front()) {
|
||||||
|
if (&op == returnOp.getOperation())
|
||||||
|
break;
|
||||||
|
if (externalUsersToMove.contains(&op))
|
||||||
|
orderedUsersToMove.push_back(&op);
|
||||||
|
}
|
||||||
|
for (Operation* op : orderedUsersToMove)
|
||||||
|
op->moveBefore(returnOp);
|
||||||
|
}
|
||||||
|
|
||||||
|
func::FuncOp func;
|
||||||
|
const MergeScheduleResult* schedule = nullptr;
|
||||||
|
int64_t* nextChannelId = nullptr;
|
||||||
|
Location loc;
|
||||||
|
func::ReturnOp returnOp;
|
||||||
|
|
||||||
|
SmallVector<ScheduledTask> scheduledTasks;
|
||||||
|
DenseSet<Operation*> oldComputeOps;
|
||||||
|
DenseMap<ComputeInstance, ScheduledTask> taskByComputeInstance;
|
||||||
|
DenseMap<size_t, SmallVector<ScheduledTask>> tasksByCpu;
|
||||||
|
SmallVector<size_t> orderedCpus;
|
||||||
|
DenseSet<size_t> seenCpus;
|
||||||
|
DenseSet<Operation*> externalUsersToMove;
|
||||||
|
DenseMap<ComputeInstance, SmallVector<SmallVector<RemoteSendInfo>>> remoteSendsByTask;
|
||||||
|
DenseMap<ComputeInstance, SmallVector<std::optional<ChannelInfo>>> remoteInputsByTask;
|
||||||
|
DenseMap<size_t, SmallVector<Value>> cpuExternalInputs;
|
||||||
|
DenseMap<size_t, SmallVector<Value>> cpuWeights;
|
||||||
|
DenseMap<size_t, SmallVector<ProducerValueRef>> cpuExternalOutputs;
|
||||||
|
DenseMap<size_t, DenseSet<Value>> seenExternalInputsByCpu;
|
||||||
|
DenseMap<size_t, DenseSet<Value>> seenWeightsByCpu;
|
||||||
|
DenseSet<uint64_t> pairsNeedingReceiveReorder;
|
||||||
|
DenseMap<size_t, DenseMap<uint64_t, SmallVector<RemoteReceiveEntry>>> receiveQueuesByCpu;
|
||||||
|
DenseMap<size_t, CpuProgram> cpuPrograms;
|
||||||
|
DenseMap<Value, Value> oldToNewExternalValueMap;
|
||||||
|
DenseMap<ComputeInstance, SmallVector<Value>> producedValuesByTask;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
MergeScheduleMaterializer::run(func::FuncOp func, const MergeScheduleResult& schedule, int64_t& nextChannelId) {
|
||||||
|
return MergeScheduleMaterializerImpl(func).run(schedule, nextChannelId);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace spatial
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,18 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
|
#include "mlir/Support/LogicalResult.h"
|
||||||
|
|
||||||
|
#include "Scheduling/MergeSchedule.hpp"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace spatial {
|
||||||
|
|
||||||
|
class MergeScheduleMaterializer {
|
||||||
|
public:
|
||||||
|
mlir::LogicalResult
|
||||||
|
run(mlir::func::FuncOp func, const MergeScheduleResult &schedule, int64_t &nextChannelId);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace spatial
|
||||||
|
} // namespace onnx_mlir
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,459 @@
|
|||||||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
#include "mlir/IR/IRMapping.h"
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/DenseMap.h"
|
||||||
|
#include "llvm/ADT/DenseSet.h"
|
||||||
|
#include "llvm/ADT/Hashing.h"
|
||||||
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
#include "llvm/Support/FormatVariadic.h"
|
||||||
|
#include "llvm/Support/raw_ostream.h"
|
||||||
|
|
||||||
|
#include <chrono>
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <limits>
|
||||||
|
#include <optional>
|
||||||
|
|
||||||
|
#include "PostMergeCompaction.hpp"
|
||||||
|
#include "RegularOpCompaction.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using SpatCompute = spatial::SpatCompute;
|
||||||
|
using SpatComputeBatch = spatial::SpatComputeBatch;
|
||||||
|
|
||||||
|
bool isMergeProfilingEnabled() { return std::getenv("RAPTOR_PROFILE_MERGE") != nullptr; }
|
||||||
|
|
||||||
|
class ScopedMergePhaseTimer {
|
||||||
|
public:
|
||||||
|
explicit ScopedMergePhaseTimer(StringRef phaseName)
|
||||||
|
: enabled(isMergeProfilingEnabled()), phase(phaseName.str()) {
|
||||||
|
if (enabled)
|
||||||
|
start = std::chrono::steady_clock::now();
|
||||||
|
}
|
||||||
|
|
||||||
|
~ScopedMergePhaseTimer() {
|
||||||
|
if (!enabled)
|
||||||
|
return;
|
||||||
|
auto elapsed = std::chrono::steady_clock::now() - start;
|
||||||
|
double millis = std::chrono::duration<double, std::milli>(elapsed).count();
|
||||||
|
llvm::errs() << "[merge-profile] " << phase << ": " << llvm::formatv("{0:F3}", millis) << " ms\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool enabled = false;
|
||||||
|
std::string phase;
|
||||||
|
std::chrono::steady_clock::time_point start;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::optional<int32_t> getComputeCoreId(SpatCompute compute) {
|
||||||
|
if (auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
|
||||||
|
return static_cast<int32_t>(coreIdAttr.getInt());
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
|
static constexpr StringLiteral kRebatchPhaseAttrName = "_pim_rebatch_phase";
|
||||||
|
|
||||||
|
std::optional<uint64_t> getComputeRebatchPhase(SpatCompute compute) {
|
||||||
|
if (auto phaseAttr = compute->getAttrOfType<IntegerAttr>(kRebatchPhaseAttrName))
|
||||||
|
return static_cast<uint64_t>(phaseAttr.getInt());
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct RebatchKey {
|
||||||
|
unsigned inputCount = 0;
|
||||||
|
unsigned resultCount = 0;
|
||||||
|
unsigned weightCount = 0;
|
||||||
|
uint64_t phase = 0;
|
||||||
|
bool hasPhase = false;
|
||||||
|
uint64_t structureHash = 0;
|
||||||
|
|
||||||
|
bool operator==(const RebatchKey& other) const {
|
||||||
|
return inputCount == other.inputCount && resultCount == other.resultCount && weightCount == other.weightCount
|
||||||
|
&& phase == other.phase && hasPhase == other.hasPhase && structureHash == other.structureHash;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct RebatchKeyInfo {
|
||||||
|
static inline RebatchKey getEmptyKey() { return {std::numeric_limits<unsigned>::max(), 0, 0, 0, false, 0}; }
|
||||||
|
|
||||||
|
static inline RebatchKey getTombstoneKey() { return {std::numeric_limits<unsigned>::max() - 1, 0, 0, 0, false, 0}; }
|
||||||
|
|
||||||
|
static unsigned getHashValue(const RebatchKey& key) {
|
||||||
|
return static_cast<unsigned>(
|
||||||
|
llvm::hash_combine(key.inputCount, key.resultCount, key.weightCount, key.phase, key.hasPhase, key.structureHash));
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool isEqual(const RebatchKey& lhs, const RebatchKey& rhs) { return lhs == rhs; }
|
||||||
|
};
|
||||||
|
|
||||||
|
uint64_t getTypeHash(Type type) { return reinterpret_cast<uintptr_t>(type.getAsOpaquePointer()); }
|
||||||
|
|
||||||
|
uint64_t getValueHash(Value value) { return reinterpret_cast<uintptr_t>(value.getAsOpaquePointer()); }
|
||||||
|
|
||||||
|
uint64_t getAttributeHash(Attribute attr) { return reinterpret_cast<uintptr_t>(attr.getAsOpaquePointer()); }
|
||||||
|
|
||||||
|
RebatchKey computeRebatchKey(SpatCompute compute) {
|
||||||
|
llvm::hash_code structureHash =
|
||||||
|
llvm::hash_combine(compute.getInputs().size(), compute.getResultTypes().size(), compute.getWeights().size());
|
||||||
|
|
||||||
|
for (Value weight : compute.getWeights())
|
||||||
|
structureHash = llvm::hash_combine(structureHash, getValueHash(weight));
|
||||||
|
if (std::optional<uint64_t> phase = getComputeRebatchPhase(compute))
|
||||||
|
structureHash = llvm::hash_combine(structureHash, *phase);
|
||||||
|
|
||||||
|
Block& body = compute.getBody().front();
|
||||||
|
structureHash = llvm::hash_combine(structureHash, body.getNumArguments());
|
||||||
|
for (BlockArgument arg : body.getArguments())
|
||||||
|
structureHash = llvm::hash_combine(structureHash, getTypeHash(arg.getType()));
|
||||||
|
|
||||||
|
for (Operation& op : body) {
|
||||||
|
structureHash = llvm::hash_combine(
|
||||||
|
structureHash, op.getName().getStringRef(), op.getNumOperands(), op.getNumResults(), op.getNumRegions());
|
||||||
|
for (Type type : op.getResultTypes())
|
||||||
|
structureHash = llvm::hash_combine(structureHash, getTypeHash(type));
|
||||||
|
for (NamedAttribute attr : op.getAttrs())
|
||||||
|
structureHash = llvm::hash_combine(structureHash, attr.getName().strref(), getAttributeHash(attr.getValue()));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<uint64_t> phase = getComputeRebatchPhase(compute);
|
||||||
|
return {static_cast<unsigned>(compute.getInputs().size()),
|
||||||
|
static_cast<unsigned>(compute.getResultTypes().size()),
|
||||||
|
static_cast<unsigned>(compute.getWeights().size()),
|
||||||
|
phase.value_or(0),
|
||||||
|
phase.has_value(),
|
||||||
|
static_cast<uint64_t>(structureHash)};
|
||||||
|
}
|
||||||
|
|
||||||
|
bool areEquivalentForRebatch(SpatCompute lhs, SpatCompute rhs) {
|
||||||
|
if (!lhs || !rhs)
|
||||||
|
return false;
|
||||||
|
if (lhs.getInputs().size() != rhs.getInputs().size())
|
||||||
|
return false;
|
||||||
|
if (lhs.getResultTypes() != rhs.getResultTypes())
|
||||||
|
return false;
|
||||||
|
if (lhs.getWeights().size() != rhs.getWeights().size())
|
||||||
|
return false;
|
||||||
|
if (getComputeRebatchPhase(lhs) != getComputeRebatchPhase(rhs))
|
||||||
|
return false;
|
||||||
|
if (!llvm::equal(lhs.getWeights(), rhs.getWeights()))
|
||||||
|
return false;
|
||||||
|
|
||||||
|
auto& lhsBlock = lhs.getBody().front();
|
||||||
|
auto& rhsBlock = rhs.getBody().front();
|
||||||
|
if (lhsBlock.getNumArguments() != rhsBlock.getNumArguments())
|
||||||
|
return false;
|
||||||
|
|
||||||
|
DenseMap<Value, Value> mappedValues;
|
||||||
|
for (auto [lhsArg, rhsArg] : llvm::zip(lhsBlock.getArguments(), rhsBlock.getArguments())) {
|
||||||
|
if (lhsArg.getType() != rhsArg.getType())
|
||||||
|
return false;
|
||||||
|
mappedValues[lhsArg] = rhsArg;
|
||||||
|
}
|
||||||
|
auto lhsIt = lhsBlock.begin();
|
||||||
|
auto rhsIt = rhsBlock.begin();
|
||||||
|
for (; lhsIt != lhsBlock.end() && rhsIt != rhsBlock.end(); ++lhsIt, ++rhsIt) {
|
||||||
|
Operation& lhsOp = *lhsIt;
|
||||||
|
Operation& rhsOp = *rhsIt;
|
||||||
|
|
||||||
|
if (lhsOp.getName() != rhsOp.getName())
|
||||||
|
return false;
|
||||||
|
if (lhsOp.getNumOperands() != rhsOp.getNumOperands())
|
||||||
|
return false;
|
||||||
|
if (lhsOp.getNumResults() != rhsOp.getNumResults())
|
||||||
|
return false;
|
||||||
|
if (lhsOp.getNumRegions() != 0 || rhsOp.getNumRegions() != 0)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
for (auto [lhsOperand, rhsOperand] : llvm::zip(lhsOp.getOperands(), rhsOp.getOperands())) {
|
||||||
|
auto mapped = mappedValues.find(lhsOperand);
|
||||||
|
if (mapped != mappedValues.end()) {
|
||||||
|
if (mapped->second != rhsOperand)
|
||||||
|
return false;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (lhsOperand != rhsOperand)
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto lhsReceive = dyn_cast<spatial::SpatChannelReceiveOp>(lhsOp)) {
|
||||||
|
auto rhsReceive = cast<spatial::SpatChannelReceiveOp>(rhsOp);
|
||||||
|
if (lhsReceive.getOutput().getType() != rhsReceive.getOutput().getType())
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
else if (auto lhsSend = dyn_cast<spatial::SpatChannelSendOp>(lhsOp)) {
|
||||||
|
auto rhsSend = cast<spatial::SpatChannelSendOp>(rhsOp);
|
||||||
|
if (lhsSend.getInput().getType() != rhsSend.getInput().getType())
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
else if (lhsOp.getAttrs() != rhsOp.getAttrs()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (lhsOp.getResultTypes() != rhsOp.getResultTypes())
|
||||||
|
return false;
|
||||||
|
for (auto [lhsResult, rhsResult] : llvm::zip(lhsOp.getResults(), rhsOp.getResults()))
|
||||||
|
mappedValues[lhsResult] = rhsResult;
|
||||||
|
}
|
||||||
|
|
||||||
|
return lhsIt == lhsBlock.end() && rhsIt == rhsBlock.end();
|
||||||
|
}
|
||||||
|
|
||||||
|
void rebatchEquivalentComputes(func::FuncOp funcOp) {
|
||||||
|
IRRewriter rewriter(funcOp.getContext());
|
||||||
|
SmallVector<SpatCompute> computes(funcOp.getOps<SpatCompute>());
|
||||||
|
DenseSet<Operation*> consumed;
|
||||||
|
DenseMap<Operation*, size_t> computeOrder;
|
||||||
|
DenseMap<RebatchKey, SmallVector<SpatCompute>, RebatchKeyInfo> candidatesByKey;
|
||||||
|
|
||||||
|
for (auto [index, compute] : llvm::enumerate(computes)) {
|
||||||
|
computeOrder[compute.getOperation()] = index;
|
||||||
|
if (compute.getInputs().size() <= 1 && compute.getResults().empty())
|
||||||
|
candidatesByKey[computeRebatchKey(compute)].push_back(compute);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t index = 0; index < computes.size(); ++index) {
|
||||||
|
auto anchor = computes[index];
|
||||||
|
if (consumed.contains(anchor))
|
||||||
|
continue;
|
||||||
|
if (anchor.getInputs().size() > 1)
|
||||||
|
continue;
|
||||||
|
if (!anchor.getResults().empty())
|
||||||
|
continue;
|
||||||
|
|
||||||
|
SmallVector<SpatCompute> group {anchor};
|
||||||
|
llvm::SmallDenseSet<int32_t, 8> usedCoreIds;
|
||||||
|
if (auto coreId = getComputeCoreId(anchor))
|
||||||
|
usedCoreIds.insert(*coreId);
|
||||||
|
|
||||||
|
auto bucketIt = candidatesByKey.find(computeRebatchKey(anchor));
|
||||||
|
if (bucketIt == candidatesByKey.end())
|
||||||
|
continue;
|
||||||
|
|
||||||
|
for (auto candidate : bucketIt->second) {
|
||||||
|
if (computeOrder.lookup(candidate.getOperation()) <= index)
|
||||||
|
continue;
|
||||||
|
if (consumed.contains(candidate))
|
||||||
|
continue;
|
||||||
|
if (!areEquivalentForRebatch(anchor, candidate))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
if (auto coreId = getComputeCoreId(candidate))
|
||||||
|
if (!usedCoreIds.insert(*coreId).second)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
group.push_back(candidate);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (group.size() <= 1)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
auto insertionAnchor = group.front();
|
||||||
|
if (llvm::all_of(group, [](SpatCompute compute) { return getComputeCoreId(compute).has_value(); })) {
|
||||||
|
llvm::stable_sort(
|
||||||
|
group, [](SpatCompute lhs, SpatCompute rhs) { return *getComputeCoreId(lhs) < *getComputeCoreId(rhs); });
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Value> weights;
|
||||||
|
weights.reserve(group.size() * anchor.getWeights().size());
|
||||||
|
SmallVector<Value> inputs;
|
||||||
|
inputs.reserve(group.size() * anchor.getInputs().size());
|
||||||
|
SmallVector<int32_t> coreIds;
|
||||||
|
coreIds.reserve(group.size());
|
||||||
|
bool haveAllCoreIds = true;
|
||||||
|
for (auto compute : group) {
|
||||||
|
llvm::append_range(weights, compute.getWeights());
|
||||||
|
llvm::append_range(inputs, compute.getInputs());
|
||||||
|
auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName);
|
||||||
|
if (!coreIdAttr)
|
||||||
|
haveAllCoreIds = false;
|
||||||
|
else if (haveAllCoreIds)
|
||||||
|
coreIds.push_back(static_cast<int32_t>(coreIdAttr.getInt()));
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.setInsertionPoint(insertionAnchor);
|
||||||
|
auto rebatched = SpatComputeBatch::create(rewriter,
|
||||||
|
insertionAnchor.getLoc(),
|
||||||
|
TypeRange {},
|
||||||
|
rewriter.getI32IntegerAttr(static_cast<int32_t>(group.size())),
|
||||||
|
ValueRange(weights),
|
||||||
|
ValueRange(inputs));
|
||||||
|
rebatched.getProperties().setOperandSegmentSizes(
|
||||||
|
{static_cast<int>(weights.size()), static_cast<int>(inputs.size())});
|
||||||
|
if (haveAllCoreIds)
|
||||||
|
rebatched->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds));
|
||||||
|
|
||||||
|
SmallVector<Type> blockArgTypes;
|
||||||
|
SmallVector<Location> blockArgLocs;
|
||||||
|
for (BlockArgument arg : anchor.getBody().front().getArguments()) {
|
||||||
|
blockArgTypes.push_back(arg.getType());
|
||||||
|
blockArgLocs.push_back(arg.getLoc());
|
||||||
|
}
|
||||||
|
auto* newBlock =
|
||||||
|
rewriter.createBlock(&rebatched.getBody(), rebatched.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
||||||
|
rewriter.setInsertionPointToEnd(newBlock);
|
||||||
|
|
||||||
|
IRMapping mapper;
|
||||||
|
auto& anchorBlock = anchor.getBody().front();
|
||||||
|
for (auto [oldArg, newArg] : llvm::zip(anchorBlock.getArguments(), newBlock->getArguments()))
|
||||||
|
mapper.map(oldArg, newArg);
|
||||||
|
auto opIts = llvm::map_to_vector(group, [](SpatCompute compute) { return compute.getBody().front().begin(); });
|
||||||
|
for (Operation& anchorOp : anchorBlock) {
|
||||||
|
if (auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(&anchorOp)) {
|
||||||
|
struct BatchReceiveEntry {
|
||||||
|
uint64_t channelId = 0;
|
||||||
|
uint32_t sourceCoreId = 0;
|
||||||
|
uint32_t targetCoreId = 0;
|
||||||
|
};
|
||||||
|
SmallVector<BatchReceiveEntry> entries;
|
||||||
|
entries.reserve(group.size());
|
||||||
|
for (auto [groupIndex, compute] : llvm::enumerate(group)) {
|
||||||
|
auto groupReceive = cast<spatial::SpatChannelReceiveOp>(&*opIts[groupIndex]);
|
||||||
|
entries.push_back(
|
||||||
|
{groupReceive.getChannelId(), groupReceive.getSourceCoreId(), groupReceive.getTargetCoreId()});
|
||||||
|
++opIts[groupIndex];
|
||||||
|
}
|
||||||
|
SmallVector<int64_t> channelIds;
|
||||||
|
SmallVector<int32_t> sourceCoreIds;
|
||||||
|
SmallVector<int32_t> targetCoreIds;
|
||||||
|
channelIds.reserve(group.size());
|
||||||
|
sourceCoreIds.reserve(group.size());
|
||||||
|
targetCoreIds.reserve(group.size());
|
||||||
|
for (const BatchReceiveEntry& entry : entries) {
|
||||||
|
channelIds.push_back(static_cast<int64_t>(entry.channelId));
|
||||||
|
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
|
||||||
|
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
|
||||||
|
}
|
||||||
|
auto batchReceive = spatial::SpatChannelReceiveBatchOp::create(rewriter,
|
||||||
|
receiveOp.getLoc(),
|
||||||
|
receiveOp.getOutput().getType(),
|
||||||
|
rewriter.getDenseI64ArrayAttr(channelIds),
|
||||||
|
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
||||||
|
rewriter.getDenseI32ArrayAttr(targetCoreIds));
|
||||||
|
mapper.map(receiveOp.getOutput(), batchReceive.getOutput());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto sendOp = dyn_cast<spatial::SpatChannelSendOp>(&anchorOp)) {
|
||||||
|
struct BatchSendEntry {
|
||||||
|
uint64_t channelId = 0;
|
||||||
|
uint32_t sourceCoreId = 0;
|
||||||
|
uint32_t targetCoreId = 0;
|
||||||
|
};
|
||||||
|
SmallVector<BatchSendEntry> entries;
|
||||||
|
entries.reserve(group.size());
|
||||||
|
for (auto [groupIndex, compute] : llvm::enumerate(group)) {
|
||||||
|
auto groupSend = cast<spatial::SpatChannelSendOp>(&*opIts[groupIndex]);
|
||||||
|
entries.push_back({groupSend.getChannelId(), groupSend.getSourceCoreId(), groupSend.getTargetCoreId()});
|
||||||
|
++opIts[groupIndex];
|
||||||
|
}
|
||||||
|
SmallVector<int64_t> channelIds;
|
||||||
|
SmallVector<int32_t> sourceCoreIds;
|
||||||
|
SmallVector<int32_t> targetCoreIds;
|
||||||
|
channelIds.reserve(group.size());
|
||||||
|
sourceCoreIds.reserve(group.size());
|
||||||
|
targetCoreIds.reserve(group.size());
|
||||||
|
for (const BatchSendEntry& entry : entries) {
|
||||||
|
channelIds.push_back(static_cast<int64_t>(entry.channelId));
|
||||||
|
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
|
||||||
|
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
|
||||||
|
}
|
||||||
|
spatial::SpatChannelSendBatchOp::create(rewriter,
|
||||||
|
sendOp.getLoc(),
|
||||||
|
rewriter.getDenseI64ArrayAttr(channelIds),
|
||||||
|
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
||||||
|
rewriter.getDenseI32ArrayAttr(targetCoreIds),
|
||||||
|
mapper.lookup(sendOp.getInput()));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isa<spatial::SpatYieldOp>(anchorOp)) {
|
||||||
|
for (auto& opIt : opIts)
|
||||||
|
++opIt;
|
||||||
|
spatial::SpatYieldOp::create(rewriter, anchorOp.getLoc(), ValueRange {});
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
Operation* cloned = rewriter.clone(anchorOp, mapper);
|
||||||
|
for (auto [originalResult, clonedResult] : llvm::zip(anchorOp.getResults(), cloned->getResults()))
|
||||||
|
mapper.map(originalResult, clonedResult);
|
||||||
|
for (auto& opIt : opIts)
|
||||||
|
++opIt;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto compute : group) {
|
||||||
|
compute->removeAttr(kRebatchPhaseAttrName);
|
||||||
|
consumed.insert(compute);
|
||||||
|
rewriter.eraseOp(compute);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto compute : funcOp.getOps<SpatCompute>())
|
||||||
|
compute->removeAttr(kRebatchPhaseAttrName);
|
||||||
|
}
|
||||||
|
|
||||||
|
void cleanupDeadPackingOps(func::FuncOp funcOp) {
|
||||||
|
auto eraseUnusedOps = [&](auto tag) {
|
||||||
|
using OpTy = decltype(tag);
|
||||||
|
SmallVector<OpTy> ops;
|
||||||
|
funcOp.walk([&](OpTy op) { ops.push_back(op); });
|
||||||
|
for (auto op : llvm::reverse(ops))
|
||||||
|
if (op->use_empty())
|
||||||
|
op.erase();
|
||||||
|
};
|
||||||
|
eraseUnusedOps(tensor::ExtractSliceOp {});
|
||||||
|
eraseUnusedOps(spatial::SpatConcatOp {});
|
||||||
|
eraseUnusedOps(spatial::SpatExtractRowsOp {});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
LogicalResult runPostMergeCompactionPipeline(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||||
|
{
|
||||||
|
ScopedMergePhaseTimer timer("order-bilateral-channel-ops");
|
||||||
|
orderBilateralChannelOps(funcOp);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
ScopedMergePhaseTimer timer("rebatch-equivalent-computes");
|
||||||
|
rebatchEquivalentComputes(funcOp);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
ScopedMergePhaseTimer timer("compact-scalar-channel-runs-1");
|
||||||
|
compactScalarChannelRuns(funcOp, nextChannelId);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
ScopedMergePhaseTimer timer("compact-batch-channel-runs-1");
|
||||||
|
compactBatchChannelRuns(funcOp);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
ScopedMergePhaseTimer timer("compact-regular-op-runs");
|
||||||
|
compactRegularOpRuns(funcOp);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
ScopedMergePhaseTimer timer("compact-row-wise-wvmm-runs");
|
||||||
|
compactRowWiseWvmmRuns(funcOp);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
ScopedMergePhaseTimer timer("compact-scalar-channel-runs-2");
|
||||||
|
compactScalarChannelRuns(funcOp, nextChannelId);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
ScopedMergePhaseTimer timer("compact-batch-channel-runs-2");
|
||||||
|
compactBatchChannelRuns(funcOp);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
ScopedMergePhaseTimer timer("cleanup-dead-packing-ops");
|
||||||
|
cleanupDeadPackingOps(funcOp);
|
||||||
|
}
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,12 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
|
#include "mlir/Support/LogicalResult.h"
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
mlir::LogicalResult runPostMergeCompactionPipeline(mlir::func::FuncOp funcOp, int64_t &nextChannelId);
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -7,12 +7,13 @@
|
|||||||
#include "mlir/IR/Value.h"
|
#include "mlir/IR/Value.h"
|
||||||
#include "mlir/Support/LLVM.h"
|
#include "mlir/Support/LLVM.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/DenseMap.h"
|
||||||
|
#include "llvm/ADT/DenseSet.h"
|
||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/ADT/STLExtras.h"
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
#include <tuple>
|
|
||||||
|
|
||||||
#include "RegularOpCompaction.hpp"
|
#include "RegularOpCompaction.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
@@ -42,6 +43,47 @@ struct RegularChunk {
|
|||||||
Value output;
|
Value output;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct RegularCompactionResult {
|
||||||
|
bool changed = false;
|
||||||
|
Operation* resumeAfter = nullptr;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename OpTy>
|
||||||
|
struct ConsecutiveRun {
|
||||||
|
SmallVector<OpTy> ops;
|
||||||
|
Block::iterator end;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename OpTy, typename Predicate>
|
||||||
|
static ConsecutiveRun<OpTy>
|
||||||
|
collectConsecutiveRun(Block::iterator start, Block::iterator blockEnd, Predicate predicate) {
|
||||||
|
ConsecutiveRun<OpTy> run;
|
||||||
|
run.end = start;
|
||||||
|
while (run.end != blockEnd) {
|
||||||
|
auto current = dyn_cast<OpTy>(&*run.end);
|
||||||
|
if (!current || !predicate(current))
|
||||||
|
break;
|
||||||
|
run.ops.push_back(current);
|
||||||
|
++run.end;
|
||||||
|
}
|
||||||
|
return run;
|
||||||
|
}
|
||||||
|
|
||||||
|
static uint64_t getEndpointKey(uint32_t sourceCoreId, uint32_t targetCoreId) {
|
||||||
|
return (static_cast<uint64_t>(sourceCoreId) << 32) | static_cast<uint64_t>(targetCoreId);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void appendChannelAttrs(SmallVectorImpl<int64_t>& channelIds,
|
||||||
|
SmallVectorImpl<int32_t>& sourceCoreIds,
|
||||||
|
SmallVectorImpl<int32_t>& targetCoreIds,
|
||||||
|
uint64_t channelId,
|
||||||
|
uint32_t sourceCoreId,
|
||||||
|
uint32_t targetCoreId) {
|
||||||
|
channelIds.push_back(static_cast<int64_t>(channelId));
|
||||||
|
sourceCoreIds.push_back(static_cast<int32_t>(sourceCoreId));
|
||||||
|
targetCoreIds.push_back(static_cast<int32_t>(targetCoreId));
|
||||||
|
}
|
||||||
|
|
||||||
static spatial::SpatConcatOp getContiguousConcatUse(ValueRange values, unsigned& startOperandIndex) {
|
static spatial::SpatConcatOp getContiguousConcatUse(ValueRange values, unsigned& startOperandIndex) {
|
||||||
if (values.empty() || !values.front().hasOneUse())
|
if (values.empty() || !values.front().hasOneUse())
|
||||||
return {};
|
return {};
|
||||||
@@ -168,6 +210,17 @@ static bool areEquivalentRegularChunks(const RegularChunk& lhs, const RegularChu
|
|||||||
[](auto pair) { return areEquivalentRegularSteps(std::get<0>(pair), std::get<1>(pair)); });
|
[](auto pair) { return areEquivalentRegularSteps(std::get<0>(pair), std::get<1>(pair)); });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool isForwardedChannelPayload(Value value, Block& block) {
|
||||||
|
Operation* op = value.getDefiningOp();
|
||||||
|
if (!op || op->getBlock() != &block)
|
||||||
|
return true;
|
||||||
|
|
||||||
|
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op))
|
||||||
|
return isForwardedChannelPayload(extractSliceOp.getSource(), block);
|
||||||
|
|
||||||
|
return isa<spatial::SpatChannelReceiveOp, spatial::SpatChannelReceiveTensorOp>(op);
|
||||||
|
}
|
||||||
|
|
||||||
static FailureOr<RegularChunk> analyzeRegularChunk(spatial::SpatVMMOp startOp) {
|
static FailureOr<RegularChunk> analyzeRegularChunk(spatial::SpatVMMOp startOp) {
|
||||||
RegularChunk chunk;
|
RegularChunk chunk;
|
||||||
chunk.startOp = startOp.getOperation();
|
chunk.startOp = startOp.getOperation();
|
||||||
@@ -202,9 +255,10 @@ static FailureOr<RegularChunk> analyzeRegularChunk(spatial::SpatVMMOp startOp) {
|
|||||||
return chunk;
|
return chunk;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void compactRegularChunkRun(IRRewriter& rewriter, ArrayRef<RegularChunk> run) {
|
static RegularCompactionResult compactRegularChunkRun(IRRewriter& rewriter, ArrayRef<RegularChunk> run) {
|
||||||
assert(!run.empty() && "expected a non-empty regular chunk run");
|
assert(!run.empty() && "expected a non-empty regular chunk run");
|
||||||
const RegularChunk& anchorChunk = run.front();
|
const RegularChunk& anchorChunk = run.front();
|
||||||
|
RegularCompactionResult result;
|
||||||
|
|
||||||
SmallVector<Value> inputs;
|
SmallVector<Value> inputs;
|
||||||
inputs.reserve(run.size());
|
inputs.reserve(run.size());
|
||||||
@@ -214,7 +268,7 @@ static void compactRegularChunkRun(IRRewriter& rewriter, ArrayRef<RegularChunk>
|
|||||||
rewriter.setInsertionPoint(anchorChunk.startOp);
|
rewriter.setInsertionPoint(anchorChunk.startOp);
|
||||||
Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, anchorChunk.startOp->getLoc());
|
Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, anchorChunk.startOp->getLoc());
|
||||||
if (!packedInput)
|
if (!packedInput)
|
||||||
return;
|
return result;
|
||||||
|
|
||||||
auto inputType = cast<RankedTensorType>(anchorChunk.input.getType());
|
auto inputType = cast<RankedTensorType>(anchorChunk.input.getType());
|
||||||
auto outputType = cast<RankedTensorType>(anchorChunk.output.getType());
|
auto outputType = cast<RankedTensorType>(anchorChunk.output.getType());
|
||||||
@@ -317,10 +371,79 @@ static void compactRegularChunkRun(IRRewriter& rewriter, ArrayRef<RegularChunk>
|
|||||||
llvm::append_range(opsToErase, chunk.ops);
|
llvm::append_range(opsToErase, chunk.ops);
|
||||||
for (Operation* op : llvm::reverse(opsToErase))
|
for (Operation* op : llvm::reverse(opsToErase))
|
||||||
rewriter.eraseOp(op);
|
rewriter.eraseOp(op);
|
||||||
|
|
||||||
|
result.changed = true;
|
||||||
|
result.resumeAfter = loop.getOperation()->getNextNode();
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
void orderBilateralChannelOps(func::FuncOp funcOp) {
|
||||||
|
for (auto compute : funcOp.getOps<spatial::SpatCompute>()) {
|
||||||
|
auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName);
|
||||||
|
if (!coreIdAttr)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
int32_t coreId = static_cast<int32_t>(coreIdAttr.getInt());
|
||||||
|
Block& block = compute.getBody().front();
|
||||||
|
SmallVector<std::pair<spatial::SpatChannelReceiveOp, Operation*>> moves;
|
||||||
|
DenseMap<uint64_t, Operation*> firstForwardedSendByEndpoint;
|
||||||
|
|
||||||
|
for (Operation& op : block) {
|
||||||
|
if (auto sendOp = dyn_cast<spatial::SpatChannelSendOp>(&op)) {
|
||||||
|
if (sendOp.getSourceCoreId() == static_cast<uint32_t>(coreId)
|
||||||
|
&& isForwardedChannelPayload(sendOp.getInput(), block)) {
|
||||||
|
uint64_t key = getEndpointKey(sendOp.getSourceCoreId(), sendOp.getTargetCoreId());
|
||||||
|
firstForwardedSendByEndpoint.try_emplace(key, sendOp.getOperation());
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(&op);
|
||||||
|
if (!receiveOp || receiveOp.getTargetCoreId() != static_cast<uint32_t>(coreId)
|
||||||
|
|| receiveOp.getSourceCoreId() >= static_cast<uint32_t>(coreId)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t key = getEndpointKey(static_cast<uint32_t>(coreId), receiveOp.getSourceCoreId());
|
||||||
|
auto firstMatchingSend = firstForwardedSendByEndpoint.find(key);
|
||||||
|
if (firstMatchingSend != firstForwardedSendByEndpoint.end())
|
||||||
|
moves.push_back({receiveOp, firstMatchingSend->second});
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto [receiveOp, insertionPoint] : moves)
|
||||||
|
receiveOp->moveBefore(insertionPoint);
|
||||||
|
|
||||||
|
for (auto it = block.begin(); it != block.end();) {
|
||||||
|
auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(&*it);
|
||||||
|
if (!receiveOp || receiveOp.getSourceCoreId() >= static_cast<uint32_t>(coreId)) {
|
||||||
|
++it;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
Type outputType = receiveOp.getOutput().getType();
|
||||||
|
auto run = collectConsecutiveRun<spatial::SpatChannelReceiveOp>(
|
||||||
|
it, block.end(), [&](spatial::SpatChannelReceiveOp current) {
|
||||||
|
return current.getOutput().getType() == outputType
|
||||||
|
&& current.getSourceCoreId() < static_cast<uint32_t>(coreId);
|
||||||
|
});
|
||||||
|
|
||||||
|
if (run.ops.size() > 1) {
|
||||||
|
SmallVector<spatial::SpatChannelReceiveOp> sorted(run.ops);
|
||||||
|
llvm::stable_sort(sorted, [](spatial::SpatChannelReceiveOp lhs, spatial::SpatChannelReceiveOp rhs) {
|
||||||
|
return lhs.getSourceCoreId() > rhs.getSourceCoreId();
|
||||||
|
});
|
||||||
|
Block::iterator insertIt = run.end;
|
||||||
|
for (auto op : sorted)
|
||||||
|
op->moveBefore(&block, insertIt);
|
||||||
|
}
|
||||||
|
|
||||||
|
it = run.end;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||||
IRRewriter rewriter(funcOp.getContext());
|
IRRewriter rewriter(funcOp.getContext());
|
||||||
|
|
||||||
@@ -329,18 +452,23 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
|||||||
for (auto it = block.begin(); it != block.end();) {
|
for (auto it = block.begin(); it != block.end();) {
|
||||||
auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(&*it);
|
auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(&*it);
|
||||||
if (receiveOp) {
|
if (receiveOp) {
|
||||||
SmallVector<spatial::SpatChannelReceiveOp> run;
|
|
||||||
Type outputType = receiveOp.getOutput().getType();
|
Type outputType = receiveOp.getOutput().getType();
|
||||||
auto runIt = it;
|
auto run = collectConsecutiveRun<spatial::SpatChannelReceiveOp>(
|
||||||
while (runIt != block.end()) {
|
it, block.end(), [&](spatial::SpatChannelReceiveOp current) {
|
||||||
auto current = dyn_cast<spatial::SpatChannelReceiveOp>(&*runIt);
|
return current.getOutput().getType() == outputType;
|
||||||
if (!current || current.getOutput().getType() != outputType)
|
});
|
||||||
|
|
||||||
|
bool hasRepeatedEndpoint = false;
|
||||||
|
DenseSet<uint64_t> seenEndpoints;
|
||||||
|
for (auto op : run.ops) {
|
||||||
|
uint64_t endpointKey = getEndpointKey(op.getSourceCoreId(), op.getTargetCoreId());
|
||||||
|
if (!seenEndpoints.insert(endpointKey).second) {
|
||||||
|
hasRepeatedEndpoint = true;
|
||||||
break;
|
break;
|
||||||
run.push_back(current);
|
}
|
||||||
++runIt;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (run.size() > 1) {
|
if (run.ops.size() > 1 && !hasRepeatedEndpoint) {
|
||||||
struct ReceiveEntry {
|
struct ReceiveEntry {
|
||||||
spatial::SpatChannelReceiveOp op;
|
spatial::SpatChannelReceiveOp op;
|
||||||
size_t originalIndex = 0;
|
size_t originalIndex = 0;
|
||||||
@@ -349,13 +477,9 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
|||||||
uint64_t channelId = 0;
|
uint64_t channelId = 0;
|
||||||
};
|
};
|
||||||
SmallVector<ReceiveEntry> sortedEntries;
|
SmallVector<ReceiveEntry> sortedEntries;
|
||||||
sortedEntries.reserve(run.size());
|
sortedEntries.reserve(run.ops.size());
|
||||||
for (auto [originalIndex, op] : llvm::enumerate(run))
|
for (auto [originalIndex, op] : llvm::enumerate(run.ops))
|
||||||
sortedEntries.push_back({op, originalIndex, op.getSourceCoreId(), op.getTargetCoreId(), op.getChannelId()});
|
sortedEntries.push_back({op, originalIndex, op.getSourceCoreId(), op.getTargetCoreId(), op.getChannelId()});
|
||||||
llvm::stable_sort(sortedEntries, [](const ReceiveEntry& lhs, const ReceiveEntry& rhs) {
|
|
||||||
return std::tuple(lhs.sourceCoreId, lhs.targetCoreId, lhs.channelId)
|
|
||||||
< std::tuple(rhs.sourceCoreId, rhs.targetCoreId, rhs.channelId);
|
|
||||||
});
|
|
||||||
|
|
||||||
SmallVector<int64_t> channelIds;
|
SmallVector<int64_t> channelIds;
|
||||||
SmallVector<int32_t> sourceCoreIds;
|
SmallVector<int32_t> sourceCoreIds;
|
||||||
@@ -364,13 +488,11 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
|||||||
sourceCoreIds.reserve(sortedEntries.size());
|
sourceCoreIds.reserve(sortedEntries.size());
|
||||||
targetCoreIds.reserve(sortedEntries.size());
|
targetCoreIds.reserve(sortedEntries.size());
|
||||||
for (ReceiveEntry& entry : sortedEntries) {
|
for (ReceiveEntry& entry : sortedEntries) {
|
||||||
(void) entry;
|
appendChannelAttrs(
|
||||||
channelIds.push_back(nextChannelId++);
|
channelIds, sourceCoreIds, targetCoreIds, entry.channelId, entry.sourceCoreId, entry.targetCoreId);
|
||||||
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
|
|
||||||
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
auto rowType = cast<RankedTensorType>(run.front().getOutput().getType());
|
auto rowType = cast<RankedTensorType>(run.ops.front().getOutput().getType());
|
||||||
auto fallbackPackedType = getPackedTensorType(rowType, static_cast<int64_t>(sortedEntries.size()));
|
auto fallbackPackedType = getPackedTensorType(rowType, static_cast<int64_t>(sortedEntries.size()));
|
||||||
SmallVector<Value> sortedOutputs;
|
SmallVector<Value> sortedOutputs;
|
||||||
sortedOutputs.reserve(sortedEntries.size());
|
sortedOutputs.reserve(sortedEntries.size());
|
||||||
@@ -383,10 +505,10 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
|||||||
concatOp ? getPackedConcatSliceType(concatOp, concatStartIndex, static_cast<unsigned>(sortedOutputs.size()))
|
concatOp ? getPackedConcatSliceType(concatOp, concatStartIndex, static_cast<unsigned>(sortedOutputs.size()))
|
||||||
: RankedTensorType {};
|
: RankedTensorType {};
|
||||||
auto packedType = concatPackedType ? concatPackedType : fallbackPackedType;
|
auto packedType = concatPackedType ? concatPackedType : fallbackPackedType;
|
||||||
rewriter.setInsertionPoint(run.front());
|
rewriter.setInsertionPoint(run.ops.front());
|
||||||
auto compactReceive =
|
auto compactReceive =
|
||||||
spatial::SpatChannelReceiveTensorOp::create(rewriter,
|
spatial::SpatChannelReceiveTensorOp::create(rewriter,
|
||||||
run.front().getLoc(),
|
run.ops.front().getLoc(),
|
||||||
packedType,
|
packedType,
|
||||||
rewriter.getDenseI64ArrayAttr(channelIds),
|
rewriter.getDenseI64ArrayAttr(channelIds),
|
||||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
||||||
@@ -403,7 +525,7 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
|||||||
entry.op.getOutput().replaceAllUsesWith(extractPackedChunk(
|
entry.op.getOutput().replaceAllUsesWith(extractPackedChunk(
|
||||||
compactReceive.getOutput(), rowType, static_cast<unsigned>(sortedIndex), rewriter, entry.op.getLoc()));
|
compactReceive.getOutput(), rowType, static_cast<unsigned>(sortedIndex), rewriter, entry.op.getLoc()));
|
||||||
}
|
}
|
||||||
for (auto op : run)
|
for (auto op : run.ops)
|
||||||
rewriter.eraseOp(op);
|
rewriter.eraseOp(op);
|
||||||
|
|
||||||
it = compactReceive->getIterator();
|
it = compactReceive->getIterator();
|
||||||
@@ -414,18 +536,13 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
|||||||
|
|
||||||
auto sendOp = dyn_cast<spatial::SpatChannelSendOp>(&*it);
|
auto sendOp = dyn_cast<spatial::SpatChannelSendOp>(&*it);
|
||||||
if (sendOp) {
|
if (sendOp) {
|
||||||
SmallVector<spatial::SpatChannelSendOp> run;
|
|
||||||
Type inputType = sendOp.getInput().getType();
|
Type inputType = sendOp.getInput().getType();
|
||||||
auto runIt = it;
|
auto run =
|
||||||
while (runIt != block.end()) {
|
collectConsecutiveRun<spatial::SpatChannelSendOp>(it, block.end(), [&](spatial::SpatChannelSendOp current) {
|
||||||
auto current = dyn_cast<spatial::SpatChannelSendOp>(&*runIt);
|
return current.getInput().getType() == inputType;
|
||||||
if (!current || current.getInput().getType() != inputType)
|
});
|
||||||
break;
|
|
||||||
run.push_back(current);
|
|
||||||
++runIt;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (run.size() > 1) {
|
if (run.ops.size() > 1) {
|
||||||
struct SendEntry {
|
struct SendEntry {
|
||||||
spatial::SpatChannelSendOp op;
|
spatial::SpatChannelSendOp op;
|
||||||
uint32_t sourceCoreId = 0;
|
uint32_t sourceCoreId = 0;
|
||||||
@@ -433,13 +550,9 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
|||||||
uint64_t channelId = 0;
|
uint64_t channelId = 0;
|
||||||
};
|
};
|
||||||
SmallVector<SendEntry> sortedEntries;
|
SmallVector<SendEntry> sortedEntries;
|
||||||
sortedEntries.reserve(run.size());
|
sortedEntries.reserve(run.ops.size());
|
||||||
for (auto op : run)
|
for (auto op : run.ops)
|
||||||
sortedEntries.push_back({op, op.getSourceCoreId(), op.getTargetCoreId(), op.getChannelId()});
|
sortedEntries.push_back({op, op.getSourceCoreId(), op.getTargetCoreId(), op.getChannelId()});
|
||||||
llvm::stable_sort(sortedEntries, [](const SendEntry& lhs, const SendEntry& rhs) {
|
|
||||||
return std::tuple(lhs.sourceCoreId, lhs.targetCoreId, lhs.channelId)
|
|
||||||
< std::tuple(rhs.sourceCoreId, rhs.targetCoreId, rhs.channelId);
|
|
||||||
});
|
|
||||||
|
|
||||||
SmallVector<int64_t> channelIds;
|
SmallVector<int64_t> channelIds;
|
||||||
SmallVector<int32_t> sourceCoreIds;
|
SmallVector<int32_t> sourceCoreIds;
|
||||||
@@ -450,26 +563,24 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
|||||||
targetCoreIds.reserve(sortedEntries.size());
|
targetCoreIds.reserve(sortedEntries.size());
|
||||||
inputs.reserve(sortedEntries.size());
|
inputs.reserve(sortedEntries.size());
|
||||||
for (SendEntry& entry : sortedEntries) {
|
for (SendEntry& entry : sortedEntries) {
|
||||||
(void) entry;
|
appendChannelAttrs(
|
||||||
channelIds.push_back(nextChannelId++);
|
channelIds, sourceCoreIds, targetCoreIds, entry.channelId, entry.sourceCoreId, entry.targetCoreId);
|
||||||
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
|
|
||||||
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
|
|
||||||
inputs.push_back(entry.op.getInput());
|
inputs.push_back(entry.op.getInput());
|
||||||
}
|
}
|
||||||
|
|
||||||
rewriter.setInsertionPoint(run.front());
|
rewriter.setInsertionPoint(run.ops.front());
|
||||||
Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, run.front().getLoc());
|
Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, run.ops.front().getLoc());
|
||||||
if (packedInput) {
|
if (packedInput) {
|
||||||
spatial::SpatChannelSendTensorOp::create(rewriter,
|
spatial::SpatChannelSendTensorOp::create(rewriter,
|
||||||
run.front().getLoc(),
|
run.ops.front().getLoc(),
|
||||||
rewriter.getDenseI64ArrayAttr(channelIds),
|
rewriter.getDenseI64ArrayAttr(channelIds),
|
||||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
||||||
rewriter.getDenseI32ArrayAttr(targetCoreIds),
|
rewriter.getDenseI32ArrayAttr(targetCoreIds),
|
||||||
packedInput);
|
packedInput);
|
||||||
for (auto op : run)
|
for (auto op : run.ops)
|
||||||
rewriter.eraseOp(op);
|
rewriter.eraseOp(op);
|
||||||
|
|
||||||
it = runIt;
|
it = run.end;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -488,32 +599,27 @@ void compactBatchChannelRuns(func::FuncOp funcOp) {
|
|||||||
for (auto it = block.begin(); it != block.end();) {
|
for (auto it = block.begin(); it != block.end();) {
|
||||||
auto receiveOp = dyn_cast<spatial::SpatChannelReceiveBatchOp>(&*it);
|
auto receiveOp = dyn_cast<spatial::SpatChannelReceiveBatchOp>(&*it);
|
||||||
if (receiveOp) {
|
if (receiveOp) {
|
||||||
SmallVector<spatial::SpatChannelReceiveBatchOp> run;
|
|
||||||
Type outputType = receiveOp.getOutput().getType();
|
Type outputType = receiveOp.getOutput().getType();
|
||||||
auto runIt = it;
|
auto run = collectConsecutiveRun<spatial::SpatChannelReceiveBatchOp>(
|
||||||
while (runIt != block.end()) {
|
it, block.end(), [&](spatial::SpatChannelReceiveBatchOp current) {
|
||||||
auto current = dyn_cast<spatial::SpatChannelReceiveBatchOp>(&*runIt);
|
return current.getOutput().getType() == outputType;
|
||||||
if (!current || current.getOutput().getType() != outputType)
|
});
|
||||||
break;
|
|
||||||
run.push_back(current);
|
|
||||||
++runIt;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (run.size() > 1) {
|
if (run.ops.size() > 1) {
|
||||||
SmallVector<int64_t> channelIds;
|
SmallVector<int64_t> channelIds;
|
||||||
SmallVector<int32_t> sourceCoreIds;
|
SmallVector<int32_t> sourceCoreIds;
|
||||||
SmallVector<int32_t> targetCoreIds;
|
SmallVector<int32_t> targetCoreIds;
|
||||||
for (auto op : run) {
|
for (auto op : run.ops) {
|
||||||
llvm::append_range(channelIds, op.getChannelIds());
|
llvm::append_range(channelIds, op.getChannelIds());
|
||||||
llvm::append_range(sourceCoreIds, op.getSourceCoreIds());
|
llvm::append_range(sourceCoreIds, op.getSourceCoreIds());
|
||||||
llvm::append_range(targetCoreIds, op.getTargetCoreIds());
|
llvm::append_range(targetCoreIds, op.getTargetCoreIds());
|
||||||
}
|
}
|
||||||
|
|
||||||
auto rowType = cast<RankedTensorType>(run.front().getOutput().getType());
|
auto rowType = cast<RankedTensorType>(run.ops.front().getOutput().getType());
|
||||||
auto fallbackPackedType = getPackedTensorType(rowType, static_cast<int64_t>(run.size()));
|
auto fallbackPackedType = getPackedTensorType(rowType, static_cast<int64_t>(run.ops.size()));
|
||||||
SmallVector<Value> outputs;
|
SmallVector<Value> outputs;
|
||||||
outputs.reserve(run.size());
|
outputs.reserve(run.ops.size());
|
||||||
for (auto op : run)
|
for (auto op : run.ops)
|
||||||
outputs.push_back(op.getOutput());
|
outputs.push_back(op.getOutput());
|
||||||
|
|
||||||
unsigned concatStartIndex = 0;
|
unsigned concatStartIndex = 0;
|
||||||
@@ -522,10 +628,10 @@ void compactBatchChannelRuns(func::FuncOp funcOp) {
|
|||||||
concatOp ? getPackedConcatSliceType(concatOp, concatStartIndex, static_cast<unsigned>(outputs.size()))
|
concatOp ? getPackedConcatSliceType(concatOp, concatStartIndex, static_cast<unsigned>(outputs.size()))
|
||||||
: RankedTensorType {};
|
: RankedTensorType {};
|
||||||
auto packedType = concatPackedType ? concatPackedType : fallbackPackedType;
|
auto packedType = concatPackedType ? concatPackedType : fallbackPackedType;
|
||||||
rewriter.setInsertionPoint(run.front());
|
rewriter.setInsertionPoint(run.ops.front());
|
||||||
auto compactReceive =
|
auto compactReceive =
|
||||||
spatial::SpatChannelReceiveTensorBatchOp::create(rewriter,
|
spatial::SpatChannelReceiveTensorBatchOp::create(rewriter,
|
||||||
run.front().getLoc(),
|
run.ops.front().getLoc(),
|
||||||
packedType,
|
packedType,
|
||||||
rewriter.getDenseI64ArrayAttr(channelIds),
|
rewriter.getDenseI64ArrayAttr(channelIds),
|
||||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
||||||
@@ -535,11 +641,11 @@ void compactBatchChannelRuns(func::FuncOp funcOp) {
|
|||||||
concatOp, concatStartIndex, static_cast<unsigned>(outputs.size()), compactReceive.getOutput(), rewriter);
|
concatOp, concatStartIndex, static_cast<unsigned>(outputs.size()), compactReceive.getOutput(), rewriter);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
for (auto [index, op] : llvm::enumerate(run))
|
for (auto [index, op] : llvm::enumerate(run.ops))
|
||||||
op.getOutput().replaceAllUsesWith(extractPackedChunk(
|
op.getOutput().replaceAllUsesWith(extractPackedChunk(
|
||||||
compactReceive.getOutput(), rowType, static_cast<unsigned>(index), rewriter, op.getLoc()));
|
compactReceive.getOutput(), rowType, static_cast<unsigned>(index), rewriter, op.getLoc()));
|
||||||
}
|
}
|
||||||
for (auto op : run)
|
for (auto op : run.ops)
|
||||||
rewriter.eraseOp(op);
|
rewriter.eraseOp(op);
|
||||||
|
|
||||||
it = compactReceive->getIterator();
|
it = compactReceive->getIterator();
|
||||||
@@ -550,43 +656,38 @@ void compactBatchChannelRuns(func::FuncOp funcOp) {
|
|||||||
|
|
||||||
auto sendOp = dyn_cast<spatial::SpatChannelSendBatchOp>(&*it);
|
auto sendOp = dyn_cast<spatial::SpatChannelSendBatchOp>(&*it);
|
||||||
if (sendOp) {
|
if (sendOp) {
|
||||||
SmallVector<spatial::SpatChannelSendBatchOp> run;
|
|
||||||
Type inputType = sendOp.getInput().getType();
|
Type inputType = sendOp.getInput().getType();
|
||||||
auto runIt = it;
|
auto run = collectConsecutiveRun<spatial::SpatChannelSendBatchOp>(
|
||||||
while (runIt != block.end()) {
|
it, block.end(), [&](spatial::SpatChannelSendBatchOp current) {
|
||||||
auto current = dyn_cast<spatial::SpatChannelSendBatchOp>(&*runIt);
|
return current.getInput().getType() == inputType;
|
||||||
if (!current || current.getInput().getType() != inputType)
|
});
|
||||||
break;
|
|
||||||
run.push_back(current);
|
|
||||||
++runIt;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (run.size() > 1) {
|
if (run.ops.size() > 1) {
|
||||||
SmallVector<int64_t> channelIds;
|
SmallVector<int64_t> channelIds;
|
||||||
SmallVector<int32_t> sourceCoreIds;
|
SmallVector<int32_t> sourceCoreIds;
|
||||||
SmallVector<int32_t> targetCoreIds;
|
SmallVector<int32_t> targetCoreIds;
|
||||||
SmallVector<Value> inputs;
|
SmallVector<Value> inputs;
|
||||||
inputs.reserve(run.size());
|
inputs.reserve(run.ops.size());
|
||||||
for (auto op : run) {
|
for (auto op : run.ops) {
|
||||||
llvm::append_range(channelIds, op.getChannelIds());
|
llvm::append_range(channelIds, op.getChannelIds());
|
||||||
llvm::append_range(sourceCoreIds, op.getSourceCoreIds());
|
llvm::append_range(sourceCoreIds, op.getSourceCoreIds());
|
||||||
llvm::append_range(targetCoreIds, op.getTargetCoreIds());
|
llvm::append_range(targetCoreIds, op.getTargetCoreIds());
|
||||||
inputs.push_back(op.getInput());
|
inputs.push_back(op.getInput());
|
||||||
}
|
}
|
||||||
|
|
||||||
rewriter.setInsertionPoint(run.front());
|
rewriter.setInsertionPoint(run.ops.front());
|
||||||
Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, run.front().getLoc());
|
Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, run.ops.front().getLoc());
|
||||||
if (packedInput) {
|
if (packedInput) {
|
||||||
spatial::SpatChannelSendTensorBatchOp::create(rewriter,
|
spatial::SpatChannelSendTensorBatchOp::create(rewriter,
|
||||||
run.front().getLoc(),
|
run.ops.front().getLoc(),
|
||||||
rewriter.getDenseI64ArrayAttr(channelIds),
|
rewriter.getDenseI64ArrayAttr(channelIds),
|
||||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
||||||
rewriter.getDenseI32ArrayAttr(targetCoreIds),
|
rewriter.getDenseI32ArrayAttr(targetCoreIds),
|
||||||
packedInput);
|
packedInput);
|
||||||
for (auto op : run)
|
for (auto op : run.ops)
|
||||||
rewriter.eraseOp(op);
|
rewriter.eraseOp(op);
|
||||||
|
|
||||||
it = runIt;
|
it = run.end;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -614,8 +715,9 @@ void compactRegularOpRuns(func::FuncOp funcOp) {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto anchorEndIt = std::next(it, static_cast<std::ptrdiff_t>(anchorChunk->ops.size()));
|
||||||
SmallVector<RegularChunk> run {*anchorChunk};
|
SmallVector<RegularChunk> run {*anchorChunk};
|
||||||
auto runIt = std::next(it, static_cast<std::ptrdiff_t>(anchorChunk->ops.size()));
|
auto runIt = anchorEndIt;
|
||||||
while (runIt != block.end()) {
|
while (runIt != block.end()) {
|
||||||
auto candidateStart = dyn_cast<spatial::SpatVMMOp>(&*runIt);
|
auto candidateStart = dyn_cast<spatial::SpatVMMOp>(&*runIt);
|
||||||
if (!candidateStart)
|
if (!candidateStart)
|
||||||
@@ -630,12 +732,26 @@ void compactRegularOpRuns(func::FuncOp funcOp) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (run.size() <= 1) {
|
if (run.size() <= 1) {
|
||||||
++it;
|
it = anchorEndIt;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
compactRegularChunkRun(rewriter, run);
|
size_t originalOpCount = 0;
|
||||||
it = runIt;
|
for (const RegularChunk& chunk : run)
|
||||||
|
originalOpCount += chunk.ops.size();
|
||||||
|
|
||||||
|
RegularCompactionResult result = compactRegularChunkRun(rewriter, run);
|
||||||
|
if (result.changed) {
|
||||||
|
assert(originalOpCount > anchorChunk->ops.size() && "successful regular compaction must consume the run");
|
||||||
|
if (!result.resumeAfter) {
|
||||||
|
it = block.end();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
it = result.resumeAfter->getIterator();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
it = anchorEndIt;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -666,37 +782,32 @@ void compactRowWiseWvmmRuns(func::FuncOp funcOp) {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<spatial::SpatVMMOp> run;
|
|
||||||
auto runIt = it;
|
|
||||||
int64_t expectedRow = static_cast<int64_t>(rowResult.getResultNumber());
|
int64_t expectedRow = static_cast<int64_t>(rowResult.getResultNumber());
|
||||||
while (runIt != block.end()) {
|
auto run = collectConsecutiveRun<spatial::SpatVMMOp>(it, block.end(), [&](spatial::SpatVMMOp current) {
|
||||||
auto current = dyn_cast<spatial::SpatVMMOp>(&*runIt);
|
if (current.getWeightIndex() != wvmmOp.getWeightIndex()
|
||||||
if (!current || current.getWeightIndex() != wvmmOp.getWeightIndex()
|
|
||||||
|| current.getInput().getDefiningOp<spatial::SpatExtractRowsOp>() != extractRowsOp
|
|| current.getInput().getDefiningOp<spatial::SpatExtractRowsOp>() != extractRowsOp
|
||||||
|| current.getInput().getType() != wvmmOp.getInput().getType()
|
|| current.getInput().getType() != wvmmOp.getInput().getType()
|
||||||
|| current.getOutput().getType() != wvmmOp.getOutput().getType()) {
|
|| current.getOutput().getType() != wvmmOp.getOutput().getType())
|
||||||
break;
|
return false;
|
||||||
}
|
|
||||||
|
|
||||||
auto currentRow = dyn_cast<OpResult>(current.getInput());
|
auto currentRow = dyn_cast<OpResult>(current.getInput());
|
||||||
if (!currentRow || currentRow.getResultNumber() != static_cast<unsigned>(expectedRow))
|
if (!currentRow || currentRow.getResultNumber() != static_cast<unsigned>(expectedRow))
|
||||||
break;
|
return false;
|
||||||
|
|
||||||
run.push_back(current);
|
|
||||||
++expectedRow;
|
++expectedRow;
|
||||||
++runIt;
|
return true;
|
||||||
}
|
});
|
||||||
|
|
||||||
if (run.size() <= 1) {
|
if (run.ops.size() <= 1) {
|
||||||
++it;
|
++it;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!run.front().getOutput().hasOneUse()) {
|
if (!run.ops.front().getOutput().hasOneUse()) {
|
||||||
++it;
|
++it;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
auto concatUse = run.front().getOutput().getUses().begin();
|
auto concatUse = run.ops.front().getOutput().getUses().begin();
|
||||||
auto concatOp = dyn_cast<spatial::SpatConcatOp>(concatUse->getOwner());
|
auto concatOp = dyn_cast<spatial::SpatConcatOp>(concatUse->getOwner());
|
||||||
if (!concatOp) {
|
if (!concatOp) {
|
||||||
++it;
|
++it;
|
||||||
@@ -705,7 +816,7 @@ void compactRowWiseWvmmRuns(func::FuncOp funcOp) {
|
|||||||
|
|
||||||
unsigned concatStartIndex = concatUse->getOperandNumber();
|
unsigned concatStartIndex = concatUse->getOperandNumber();
|
||||||
bool validConcatRun = true;
|
bool validConcatRun = true;
|
||||||
for (auto [index, op] : llvm::enumerate(run)) {
|
for (auto [index, op] : llvm::enumerate(run.ops)) {
|
||||||
if (!op.getOutput().hasOneUse()) {
|
if (!op.getOutput().hasOneUse()) {
|
||||||
validConcatRun = false;
|
validConcatRun = false;
|
||||||
break;
|
break;
|
||||||
@@ -736,17 +847,17 @@ void compactRowWiseWvmmRuns(func::FuncOp funcOp) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
int64_t firstRow = static_cast<int64_t>(rowResult.getResultNumber());
|
int64_t firstRow = static_cast<int64_t>(rowResult.getResultNumber());
|
||||||
int64_t runLength = static_cast<int64_t>(run.size());
|
int64_t runLength = static_cast<int64_t>(run.ops.size());
|
||||||
auto packedType = RankedTensorType::get({runLength, outputCols}, outputType.getElementType());
|
auto packedType = RankedTensorType::get({runLength, outputCols}, outputType.getElementType());
|
||||||
|
|
||||||
rewriter.setInsertionPoint(run.front());
|
rewriter.setInsertionPoint(run.ops.front());
|
||||||
auto zero = arith::ConstantIndexOp::create(rewriter, run.front().getLoc(), 0);
|
auto zero = arith::ConstantIndexOp::create(rewriter, run.ops.front().getLoc(), 0);
|
||||||
auto upper = arith::ConstantIndexOp::create(rewriter, run.front().getLoc(), runLength);
|
auto upper = arith::ConstantIndexOp::create(rewriter, run.ops.front().getLoc(), runLength);
|
||||||
auto step = arith::ConstantIndexOp::create(rewriter, run.front().getLoc(), 1);
|
auto step = arith::ConstantIndexOp::create(rewriter, run.ops.front().getLoc(), 1);
|
||||||
auto packedInit =
|
auto packedInit =
|
||||||
tensor::EmptyOp::create(rewriter, run.front().getLoc(), packedType.getShape(), packedType.getElementType());
|
tensor::EmptyOp::create(rewriter, run.ops.front().getLoc(), packedType.getShape(), packedType.getElementType());
|
||||||
auto loop =
|
auto loop =
|
||||||
scf::ForOp::create(rewriter, run.front().getLoc(), zero, upper, step, ValueRange {packedInit.getResult()});
|
scf::ForOp::create(rewriter, run.ops.front().getLoc(), zero, upper, step, ValueRange {packedInit.getResult()});
|
||||||
|
|
||||||
{
|
{
|
||||||
OpBuilder::InsertionGuard guard(rewriter);
|
OpBuilder::InsertionGuard guard(rewriter);
|
||||||
@@ -757,41 +868,41 @@ void compactRowWiseWvmmRuns(func::FuncOp funcOp) {
|
|||||||
|
|
||||||
Value sourceRow = iv;
|
Value sourceRow = iv;
|
||||||
if (firstRow != 0) {
|
if (firstRow != 0) {
|
||||||
auto firstRowValue = arith::ConstantIndexOp::create(rewriter, run.front().getLoc(), firstRow);
|
auto firstRowValue = arith::ConstantIndexOp::create(rewriter, run.ops.front().getLoc(), firstRow);
|
||||||
sourceRow = arith::AddIOp::create(rewriter, run.front().getLoc(), iv, firstRowValue);
|
sourceRow = arith::AddIOp::create(rewriter, run.ops.front().getLoc(), iv, firstRowValue);
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<OpFoldResult> extractOffsets = {sourceRow, rewriter.getIndexAttr(0)};
|
SmallVector<OpFoldResult> extractOffsets = {sourceRow, rewriter.getIndexAttr(0)};
|
||||||
SmallVector<OpFoldResult> extractSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(inputCols)};
|
SmallVector<OpFoldResult> extractSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(inputCols)};
|
||||||
SmallVector<OpFoldResult> extractStrides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
SmallVector<OpFoldResult> extractStrides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||||
auto extractedRow = tensor::ExtractSliceOp::create(rewriter,
|
auto extractedRow = tensor::ExtractSliceOp::create(rewriter,
|
||||||
run.front().getLoc(),
|
run.ops.front().getLoc(),
|
||||||
inputType,
|
inputType,
|
||||||
extractRowsOp.getInput(),
|
extractRowsOp.getInput(),
|
||||||
extractOffsets,
|
extractOffsets,
|
||||||
extractSizes,
|
extractSizes,
|
||||||
extractStrides);
|
extractStrides);
|
||||||
auto loopWvmm = spatial::SpatVMMOp::create(
|
auto loopWvmm = spatial::SpatVMMOp::create(
|
||||||
rewriter, run.front().getLoc(), outputType, wvmmOp.getWeightIndex(), extractedRow.getResult());
|
rewriter, run.ops.front().getLoc(), outputType, wvmmOp.getWeightIndex(), extractedRow.getResult());
|
||||||
|
|
||||||
SmallVector<OpFoldResult> insertOffsets = {iv, rewriter.getIndexAttr(0)};
|
SmallVector<OpFoldResult> insertOffsets = {iv, rewriter.getIndexAttr(0)};
|
||||||
SmallVector<OpFoldResult> insertSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(outputCols)};
|
SmallVector<OpFoldResult> insertSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(outputCols)};
|
||||||
SmallVector<OpFoldResult> insertStrides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
SmallVector<OpFoldResult> insertStrides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||||
auto inserted = tensor::InsertSliceOp::create(
|
auto inserted = tensor::InsertSliceOp::create(
|
||||||
rewriter, run.front().getLoc(), loopWvmm.getResult(), acc, insertOffsets, insertSizes, insertStrides);
|
rewriter, run.ops.front().getLoc(), loopWvmm.getResult(), acc, insertOffsets, insertSizes, insertStrides);
|
||||||
scf::YieldOp::create(rewriter, run.front().getLoc(), inserted.getResult());
|
scf::YieldOp::create(rewriter, run.ops.front().getLoc(), inserted.getResult());
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<Value> newConcatInputs;
|
SmallVector<Value> newConcatInputs;
|
||||||
newConcatInputs.reserve(concatOp.getInputs().size() - run.size() + 1);
|
newConcatInputs.reserve(concatOp.getInputs().size() - run.ops.size() + 1);
|
||||||
for (auto [operandIndex, operand] : llvm::enumerate(concatOp.getInputs())) {
|
for (auto [operandIndex, operand] : llvm::enumerate(concatOp.getInputs())) {
|
||||||
if (operandIndex == concatStartIndex)
|
if (operandIndex == concatStartIndex)
|
||||||
newConcatInputs.push_back(loop.getResult(0));
|
newConcatInputs.push_back(loop.getResult(0));
|
||||||
if (operandIndex < concatStartIndex || operandIndex >= concatStartIndex + run.size())
|
if (operandIndex < concatStartIndex || operandIndex >= concatStartIndex + run.ops.size())
|
||||||
newConcatInputs.push_back(operand);
|
newConcatInputs.push_back(operand);
|
||||||
}
|
}
|
||||||
rewriter.modifyOpInPlace(concatOp, [&] { concatOp->setOperands(newConcatInputs); });
|
rewriter.modifyOpInPlace(concatOp, [&] { concatOp->setOperands(newConcatInputs); });
|
||||||
for (auto op : run)
|
for (auto op : run.ops)
|
||||||
rewriter.eraseOp(op);
|
rewriter.eraseOp(op);
|
||||||
|
|
||||||
it = loop->getIterator();
|
it = loop->getIterator();
|
||||||
|
|||||||
@@ -6,6 +6,7 @@
|
|||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
void orderBilateralChannelOps(mlir::func::FuncOp funcOp);
|
||||||
void compactScalarChannelRuns(mlir::func::FuncOp funcOp, int64_t& nextChannelId);
|
void compactScalarChannelRuns(mlir::func::FuncOp funcOp, int64_t& nextChannelId);
|
||||||
void compactBatchChannelRuns(mlir::func::FuncOp funcOp);
|
void compactBatchChannelRuns(mlir::func::FuncOp funcOp);
|
||||||
void compactRegularOpRuns(mlir::func::FuncOp funcOp);
|
void compactRegularOpRuns(mlir::func::FuncOp funcOp);
|
||||||
|
|||||||
@@ -0,0 +1,189 @@
|
|||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
#include "mlir/IR/Operation.h"
|
||||||
|
#include "mlir/IR/Value.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
#include "llvm/Support/Casting.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <limits>
|
||||||
|
#include <optional>
|
||||||
|
#include <queue>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "ComputeGraph.hpp"
|
||||||
|
#include "src/Support/TypeUtilities.hpp"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace spatial {
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
Weight getComputeBodyWeight(Region &body) {
|
||||||
|
constexpr Weight kOperationWeight = 100;
|
||||||
|
Weight numOperations = 0;
|
||||||
|
for (auto &block : body)
|
||||||
|
for ([[maybe_unused]] auto &op : block)
|
||||||
|
numOperations = checkedAdd(numOperations, static_cast<Weight>(1));
|
||||||
|
return checkedMultiply(numOperations, kOperationWeight);
|
||||||
|
}
|
||||||
|
|
||||||
|
CrossbarUsage getComputeBodyCrossbarUsage(Region &body) {
|
||||||
|
CrossbarUsage crossbarUsage = 0;
|
||||||
|
for (auto &block : body)
|
||||||
|
for (auto &op : block)
|
||||||
|
if (isa<SpatVMMOp>(op))
|
||||||
|
crossbarUsage = checkedAdd(crossbarUsage, static_cast<CrossbarUsage>(1));
|
||||||
|
return crossbarUsage;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool isUsedAsWeightOnly(Operation *producerOp) {
|
||||||
|
if (producerOp->getNumResults() == 0)
|
||||||
|
return false;
|
||||||
|
for (Value result : producerOp->getResults()) {
|
||||||
|
if (result.use_empty())
|
||||||
|
return false;
|
||||||
|
for (Operation *user : result.getUsers()) {
|
||||||
|
if (auto compute = dyn_cast<SpatCompute>(user)) {
|
||||||
|
if (!llvm::is_contained(compute.getWeights(), result))
|
||||||
|
return false;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto batch = dyn_cast<SpatComputeBatch>(user)) {
|
||||||
|
if (!llvm::is_contained(batch.getWeights(), result))
|
||||||
|
return false;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<ComputeGraphEdge> aggregateEdges(llvm::ArrayRef<ComputeGraphEdge> edges) {
|
||||||
|
llvm::DenseMap<std::pair<size_t, size_t>, Weight> edgeWeights;
|
||||||
|
for (const ComputeGraphEdge &edge : edges) {
|
||||||
|
if (edge.source == edge.target)
|
||||||
|
continue;
|
||||||
|
auto inserted = edgeWeights.try_emplace({edge.source, edge.target}, edge.transferCost);
|
||||||
|
if (!inserted.second)
|
||||||
|
inserted.first->second = std::max(inserted.first->second, edge.transferCost);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<ComputeGraphEdge> aggregatedEdges;
|
||||||
|
aggregatedEdges.reserve(edgeWeights.size());
|
||||||
|
for (const auto &[key, weight] : edgeWeights)
|
||||||
|
aggregatedEdges.push_back({key.first, key.second, weight});
|
||||||
|
llvm::sort(aggregatedEdges, [](const ComputeGraphEdge &lhs, const ComputeGraphEdge &rhs) {
|
||||||
|
if (lhs.source != rhs.source)
|
||||||
|
return lhs.source < rhs.source;
|
||||||
|
return lhs.target < rhs.target;
|
||||||
|
});
|
||||||
|
return aggregatedEdges;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
Weight getComputeInstanceWeight(const ComputeInstance &instance) {
|
||||||
|
if (auto spatCompute = dyn_cast<SpatCompute>(instance.op))
|
||||||
|
return getSpatComputeWeight(spatCompute);
|
||||||
|
auto batch = cast<SpatComputeBatch>(instance.op);
|
||||||
|
return checkedMultiply(getComputeBodyWeight(batch.getBody()), static_cast<Weight>(instance.laneCount));
|
||||||
|
}
|
||||||
|
|
||||||
|
CrossbarUsage getComputeInstanceCrossbarUsage(const ComputeInstance &instance) {
|
||||||
|
if (auto spatCompute = dyn_cast<SpatCompute>(instance.op))
|
||||||
|
return getSpatComputeCrossbarUsage(spatCompute);
|
||||||
|
auto batch = cast<SpatComputeBatch>(instance.op);
|
||||||
|
return checkedMultiply(getComputeBodyCrossbarUsage(batch.getBody()),
|
||||||
|
static_cast<CrossbarUsage>(instance.laneCount));
|
||||||
|
}
|
||||||
|
|
||||||
|
ComputeGraph buildComputeGraph(Operation *entryOp) {
|
||||||
|
ComputeGraph graph;
|
||||||
|
|
||||||
|
for (Region ®ion : entryOp->getRegions()) {
|
||||||
|
for (Block &block : region) {
|
||||||
|
for (Operation &op : block) {
|
||||||
|
if (auto spatCompute = dyn_cast<SpatCompute>(&op)) {
|
||||||
|
if (isUsedAsWeightOnly(spatCompute.getOperation()))
|
||||||
|
continue;
|
||||||
|
ComputeInstance instance {spatCompute.getOperation(), 0, 1};
|
||||||
|
size_t index = graph.nodes.size();
|
||||||
|
graph.nodes.push_back({instance, getComputeInstanceWeight(instance), getComputeInstanceCrossbarUsage(instance), index});
|
||||||
|
graph.instanceToIndex[instance] = index;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto batch = dyn_cast<SpatComputeBatch>(&op)) {
|
||||||
|
if (isUsedAsWeightOnly(batch.getOperation()))
|
||||||
|
continue;
|
||||||
|
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount());
|
||||||
|
for (size_t chunkIndex = 0; chunkIndex < chunkCount; ++chunkIndex) {
|
||||||
|
ComputeInstance instance = getBatchChunkForIndex(batch, chunkIndex);
|
||||||
|
size_t index = graph.nodes.size();
|
||||||
|
graph.nodes.push_back(
|
||||||
|
{instance, getComputeInstanceWeight(instance), getComputeInstanceCrossbarUsage(instance), index});
|
||||||
|
graph.instanceToIndex[instance] = index;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::SmallVector<ComputeGraphEdge, 16> rawEdges;
|
||||||
|
for (const auto &[targetIndex, node] : llvm::enumerate(graph.nodes)) {
|
||||||
|
for (Value input : getComputeInstanceInputs(node.instance)) {
|
||||||
|
auto producerInstance = getComputeProducerInstance(input);
|
||||||
|
if (!producerInstance)
|
||||||
|
continue;
|
||||||
|
auto producerIt = graph.instanceToIndex.find(*producerInstance);
|
||||||
|
if (producerIt == graph.instanceToIndex.end())
|
||||||
|
continue;
|
||||||
|
rawEdges.push_back(
|
||||||
|
{producerIt->second, targetIndex, static_cast<Weight>(getSizeInBytes(cast<ShapedType>(input.getType())))});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<ComputeGraphEdge> aggregatedEdges = aggregateEdges(rawEdges);
|
||||||
|
graph.edges.append(aggregatedEdges.begin(), aggregatedEdges.end());
|
||||||
|
graph.successors.assign(graph.nodes.size(), {});
|
||||||
|
graph.predecessors.assign(graph.nodes.size(), {});
|
||||||
|
for (const ComputeGraphEdge &edge : graph.edges) {
|
||||||
|
graph.successors[edge.source].push_back({edge.target, edge.transferCost});
|
||||||
|
graph.predecessors[edge.target].push_back({edge.source, edge.transferCost});
|
||||||
|
}
|
||||||
|
|
||||||
|
return graph;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool verifyAcyclic(const ComputeGraph &graph) {
|
||||||
|
std::vector<size_t> remainingParents(graph.nodes.size(), 0);
|
||||||
|
std::queue<size_t> readyNodes;
|
||||||
|
for (size_t node = 0; node < graph.nodes.size(); ++node) {
|
||||||
|
remainingParents[node] = graph.predecessors[node].size();
|
||||||
|
if (remainingParents[node] == 0)
|
||||||
|
readyNodes.push(node);
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t visited = 0;
|
||||||
|
while (!readyNodes.empty()) {
|
||||||
|
size_t node = readyNodes.front();
|
||||||
|
readyNodes.pop();
|
||||||
|
++visited;
|
||||||
|
for (const auto &[child, weight] : graph.successors[node]) {
|
||||||
|
(void) weight;
|
||||||
|
assert(remainingParents[child] > 0 && "remaining parent count underflow");
|
||||||
|
if (--remainingParents[child] == 0)
|
||||||
|
readyNodes.push(child);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return visited == graph.nodes.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace spatial
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,49 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/IR/Operation.h"
|
||||||
|
#include "mlir/IR/Value.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/DenseMap.h"
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
|
#include <cstddef>
|
||||||
|
#include <optional>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "../DCPGraph/Utils.hpp"
|
||||||
|
#include "ComputeInstance.hpp"
|
||||||
|
#include "ComputeInstanceUtils.hpp"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace spatial {
|
||||||
|
|
||||||
|
struct ComputeGraphNode {
|
||||||
|
ComputeInstance instance;
|
||||||
|
Weight weight = 0;
|
||||||
|
CrossbarUsage crossbarUsage = 0;
|
||||||
|
size_t originalOrder = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ComputeGraphEdge {
|
||||||
|
size_t source = 0;
|
||||||
|
size_t target = 0;
|
||||||
|
Weight transferCost = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ComputeGraph {
|
||||||
|
llvm::SmallVector<ComputeGraphNode> nodes;
|
||||||
|
llvm::SmallVector<ComputeGraphEdge> edges;
|
||||||
|
std::vector<std::vector<std::pair<size_t, Weight>>> successors;
|
||||||
|
std::vector<std::vector<std::pair<size_t, Weight>>> predecessors;
|
||||||
|
llvm::DenseMap<ComputeInstance, size_t> instanceToIndex;
|
||||||
|
};
|
||||||
|
|
||||||
|
ComputeGraph buildComputeGraph(mlir::Operation *entryOp);
|
||||||
|
bool verifyAcyclic(const ComputeGraph &graph);
|
||||||
|
|
||||||
|
Weight getComputeInstanceWeight(const ComputeInstance &instance);
|
||||||
|
CrossbarUsage getComputeInstanceCrossbarUsage(const ComputeInstance &instance);
|
||||||
|
|
||||||
|
} // namespace spatial
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,45 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/IR/Operation.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/DenseMapInfo.h"
|
||||||
|
#include "llvm/ADT/Hashing.h"
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace spatial {
|
||||||
|
|
||||||
|
struct ComputeInstance {
|
||||||
|
mlir::Operation *op = nullptr;
|
||||||
|
uint32_t laneStart = 0;
|
||||||
|
uint32_t laneCount = 1;
|
||||||
|
|
||||||
|
bool operator==(const ComputeInstance &other) const {
|
||||||
|
return op == other.op && laneStart == other.laneStart && laneCount == other.laneCount;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace spatial
|
||||||
|
} // namespace onnx_mlir
|
||||||
|
|
||||||
|
using ComputeInstance = onnx_mlir::spatial::ComputeInstance;
|
||||||
|
|
||||||
|
namespace llvm {
|
||||||
|
template <>
|
||||||
|
struct DenseMapInfo<onnx_mlir::spatial::ComputeInstance> {
|
||||||
|
static onnx_mlir::spatial::ComputeInstance getEmptyKey() {
|
||||||
|
return {DenseMapInfo<mlir::Operation *>::getEmptyKey(), UINT32_MAX, UINT32_MAX};
|
||||||
|
}
|
||||||
|
static onnx_mlir::spatial::ComputeInstance getTombstoneKey() {
|
||||||
|
return {DenseMapInfo<mlir::Operation *>::getTombstoneKey(), UINT32_MAX, UINT32_MAX};
|
||||||
|
}
|
||||||
|
static unsigned getHashValue(const onnx_mlir::spatial::ComputeInstance &value) {
|
||||||
|
return llvm::hash_combine(value.op, value.laneStart, value.laneCount);
|
||||||
|
}
|
||||||
|
static bool isEqual(const onnx_mlir::spatial::ComputeInstance &lhs,
|
||||||
|
const onnx_mlir::spatial::ComputeInstance &rhs) {
|
||||||
|
return lhs == rhs;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace llvm
|
||||||
+151
@@ -0,0 +1,151 @@
|
|||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
|
||||||
|
#include <limits>
|
||||||
|
|
||||||
|
#include "ComputeInstanceUtils.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace spatial {
|
||||||
|
|
||||||
|
size_t getSchedulingCpuBudget() {
|
||||||
|
if (coresCount.getValue() > 0)
|
||||||
|
return static_cast<size_t>(coresCount.getValue());
|
||||||
|
return std::numeric_limits<size_t>::max();
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t getBatchChunkTargetCount(int32_t laneCount) {
|
||||||
|
assert(laneCount > 0 && "laneCount must be positive");
|
||||||
|
return std::min(static_cast<size_t>(laneCount), std::max<size_t>(1, getSchedulingCpuBudget()));
|
||||||
|
}
|
||||||
|
|
||||||
|
ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex) {
|
||||||
|
size_t totalLanes = static_cast<size_t>(batch.getLaneCount());
|
||||||
|
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount());
|
||||||
|
size_t baseChunkSize = totalLanes / chunkCount;
|
||||||
|
size_t largeChunkCount = totalLanes % chunkCount;
|
||||||
|
|
||||||
|
size_t laneStart = chunkIndex * baseChunkSize + std::min(chunkIndex, largeChunkCount);
|
||||||
|
size_t laneCount = baseChunkSize + (chunkIndex < largeChunkCount ? 1 : 0);
|
||||||
|
return {batch.getOperation(), static_cast<uint32_t>(laneStart), static_cast<uint32_t>(laneCount)};
|
||||||
|
}
|
||||||
|
|
||||||
|
ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane) {
|
||||||
|
size_t totalLanes = static_cast<size_t>(batch.getLaneCount());
|
||||||
|
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount());
|
||||||
|
size_t baseChunkSize = totalLanes / chunkCount;
|
||||||
|
size_t largeChunkCount = totalLanes % chunkCount;
|
||||||
|
size_t largeChunkSpan = largeChunkCount * (baseChunkSize + 1);
|
||||||
|
|
||||||
|
size_t chunkIndex = 0;
|
||||||
|
if (static_cast<size_t>(lane) < largeChunkSpan)
|
||||||
|
chunkIndex = static_cast<size_t>(lane) / (baseChunkSize + 1);
|
||||||
|
else
|
||||||
|
chunkIndex = largeChunkCount + (static_cast<size_t>(lane) - largeChunkSpan) / baseChunkSize;
|
||||||
|
return getBatchChunkForIndex(batch, chunkIndex);
|
||||||
|
}
|
||||||
|
|
||||||
|
SpatCompute getOriginalSpatCompute(Operation *op) {
|
||||||
|
if (!op)
|
||||||
|
return {};
|
||||||
|
|
||||||
|
while (auto extract = dyn_cast<tensor::ExtractSliceOp>(op)) {
|
||||||
|
op = extract.getSource().getDefiningOp();
|
||||||
|
if (!op)
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
return dyn_cast<SpatCompute>(op);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<ProducerValueRef> getProducerValueRef(Value value) {
|
||||||
|
Operation *op = value.getDefiningOp();
|
||||||
|
if (!op)
|
||||||
|
return std::nullopt;
|
||||||
|
|
||||||
|
//TODO Extract Slice is not the only global non compute operation. There are other legal op
|
||||||
|
while (auto extract = dyn_cast<tensor::ExtractSliceOp>(op)) {
|
||||||
|
value = extract.getSource();
|
||||||
|
op = value.getDefiningOp();
|
||||||
|
if (!op)
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto compute = dyn_cast<SpatCompute>(op)) {
|
||||||
|
return ProducerValueRef {
|
||||||
|
ComputeInstance {compute.getOperation(), 0, 1},
|
||||||
|
static_cast<size_t>(cast<OpResult>(value).getResultNumber())
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto batch = dyn_cast<SpatComputeBatch>(op)) {
|
||||||
|
uint32_t lane = static_cast<uint32_t>(cast<OpResult>(value).getResultNumber());
|
||||||
|
ComputeInstance instance = getBatchChunkForLane(batch, lane);
|
||||||
|
size_t resultIndex = static_cast<size_t>(lane - instance.laneStart);
|
||||||
|
return ProducerValueRef {instance, resultIndex};
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<ComputeInstance> getComputeProducerInstance(Value value) {
|
||||||
|
if (std::optional<ProducerValueRef> producer = getProducerValueRef(value))
|
||||||
|
return producer->instance;
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::SmallVector<Value, 4> getComputeInstanceInputs(const ComputeInstance &instance) {
|
||||||
|
if (auto compute = dyn_cast<SpatCompute>(instance.op))
|
||||||
|
return llvm::SmallVector<Value, 4>(compute.getInputs().begin(), compute.getInputs().end());
|
||||||
|
|
||||||
|
auto batch = cast<SpatComputeBatch>(instance.op);
|
||||||
|
llvm::SmallVector<Value, 4> inputs;
|
||||||
|
inputs.reserve(instance.laneCount);
|
||||||
|
for (uint32_t lane = instance.laneStart; lane < instance.laneStart + instance.laneCount; ++lane)
|
||||||
|
if (!batch.getInputs().empty())
|
||||||
|
inputs.push_back(batch.getInputs()[lane]);
|
||||||
|
return inputs;
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::SmallVector<Value, 4> getComputeInstanceWeights(const ComputeInstance &instance) {
|
||||||
|
if (auto compute = dyn_cast<SpatCompute>(instance.op))
|
||||||
|
return llvm::SmallVector<Value, 4>(compute.getWeights().begin(), compute.getWeights().end());
|
||||||
|
|
||||||
|
auto batch = cast<SpatComputeBatch>(instance.op);
|
||||||
|
llvm::SmallVector<Value, 4> weights;
|
||||||
|
weights.reserve(instance.laneCount);
|
||||||
|
for (uint32_t lane = instance.laneStart; lane < instance.laneStart + instance.laneCount; ++lane)
|
||||||
|
weights.push_back(batch.getWeights()[lane]);
|
||||||
|
return weights;
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::SmallVector<Value, 4> getComputeInstanceOutputValues(const ComputeInstance &instance) {
|
||||||
|
if (auto compute = dyn_cast<SpatCompute>(instance.op))
|
||||||
|
return llvm::SmallVector<Value, 4>(compute.getResults().begin(), compute.getResults().end());
|
||||||
|
|
||||||
|
auto batch = cast<SpatComputeBatch>(instance.op);
|
||||||
|
llvm::SmallVector<Value, 4> outputs;
|
||||||
|
outputs.reserve(instance.laneCount);
|
||||||
|
for (uint32_t lane = instance.laneStart; lane < instance.laneStart + instance.laneCount; ++lane)
|
||||||
|
if (!batch.getOutputs().empty())
|
||||||
|
outputs.push_back(batch.getOutputs()[lane]);
|
||||||
|
return outputs;
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::SmallVector<Type, 4> getComputeInstanceOutputTypes(const ComputeInstance &instance) {
|
||||||
|
llvm::SmallVector<Type, 4> outputTypes;
|
||||||
|
for (Value output : getComputeInstanceOutputValues(instance))
|
||||||
|
outputTypes.push_back(output.getType());
|
||||||
|
return outputTypes;
|
||||||
|
}
|
||||||
|
|
||||||
|
Block &getComputeInstanceTemplateBlock(const ComputeInstance &instance) {
|
||||||
|
if (auto compute = dyn_cast<SpatCompute>(instance.op))
|
||||||
|
return compute.getBody().front();
|
||||||
|
return cast<SpatComputeBatch>(instance.op).getBody().front();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace spatial
|
||||||
|
} // namespace onnx_mlir
|
||||||
+40
@@ -0,0 +1,40 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/IR/Block.h"
|
||||||
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
#include "mlir/IR/Operation.h"
|
||||||
|
#include "mlir/IR/Value.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
|
#include <cstddef>
|
||||||
|
#include <optional>
|
||||||
|
|
||||||
|
#include "ComputeInstance.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace spatial {
|
||||||
|
|
||||||
|
struct ProducerValueRef {
|
||||||
|
ComputeInstance instance;
|
||||||
|
size_t resultIndex = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
size_t getSchedulingCpuBudget();
|
||||||
|
size_t getBatchChunkTargetCount(int32_t laneCount);
|
||||||
|
ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex);
|
||||||
|
ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane);
|
||||||
|
|
||||||
|
SpatCompute getOriginalSpatCompute(mlir::Operation *op);
|
||||||
|
std::optional<ProducerValueRef> getProducerValueRef(mlir::Value value);
|
||||||
|
std::optional<ComputeInstance> getComputeProducerInstance(mlir::Value value);
|
||||||
|
|
||||||
|
llvm::SmallVector<mlir::Value, 4> getComputeInstanceInputs(const ComputeInstance &instance);
|
||||||
|
llvm::SmallVector<mlir::Value, 4> getComputeInstanceWeights(const ComputeInstance &instance);
|
||||||
|
llvm::SmallVector<mlir::Value, 4> getComputeInstanceOutputValues(const ComputeInstance &instance);
|
||||||
|
llvm::SmallVector<mlir::Type, 4> getComputeInstanceOutputTypes(const ComputeInstance &instance);
|
||||||
|
mlir::Block &getComputeInstanceTemplateBlock(const ComputeInstance &instance);
|
||||||
|
|
||||||
|
} // namespace spatial
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,720 @@
|
|||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
#include "llvm/Support/FormatVariadic.h"
|
||||||
|
#include "llvm/Support/raw_ostream.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <limits>
|
||||||
|
#include <numeric>
|
||||||
|
#include <optional>
|
||||||
|
#include <queue>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "DcpScheduler.hpp"
|
||||||
|
#include "../DCPGraph/Graph.hpp"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace spatial {
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
bool isDcpCoarsenDebugEnabled() { return std::getenv("DCP_COARSEN_DEBUG") != nullptr; }
|
||||||
|
|
||||||
|
struct VirtualNode {
|
||||||
|
llvm::SmallVector<size_t, 4> originalNodeIndices;
|
||||||
|
Weight weight = 0;
|
||||||
|
CrossbarUsage crossbarUsage = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct VirtualGraph {
|
||||||
|
std::vector<VirtualNode> nodes;
|
||||||
|
std::vector<IndexedEdge> edges;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct TimingInfo {
|
||||||
|
std::vector<Time> aest;
|
||||||
|
std::vector<Time> alst;
|
||||||
|
std::vector<size_t> topologicalOrder;
|
||||||
|
bool valid = false;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct WindowScheduleResult {
|
||||||
|
std::vector<std::vector<size_t>> mergeGroups;
|
||||||
|
CPU cpuCount = 0;
|
||||||
|
size_t mergedNodeCount = 0;
|
||||||
|
size_t maxMergeGroupSize = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
size_t getSchedulingCpuBudget(const DcpScheduleOptions &options) {
|
||||||
|
if (options.processorCount > 0)
|
||||||
|
return options.processorCount;
|
||||||
|
return std::numeric_limits<size_t>::max();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<IndexedEdge> aggregateEdges(llvm::ArrayRef<IndexedEdge> edges) {
|
||||||
|
llvm::DenseMap<std::pair<size_t, size_t>, Weight> edgeWeights;
|
||||||
|
for (auto [start, end, weight] : edges) {
|
||||||
|
size_t startIndex = static_cast<size_t>(start);
|
||||||
|
size_t endIndex = static_cast<size_t>(end);
|
||||||
|
if (startIndex == endIndex)
|
||||||
|
continue;
|
||||||
|
auto key = std::make_pair(startIndex, endIndex);
|
||||||
|
Weight edgeWeight = static_cast<Weight>(weight);
|
||||||
|
auto inserted = edgeWeights.try_emplace(key, edgeWeight);
|
||||||
|
if (!inserted.second)
|
||||||
|
inserted.first->second = std::max(inserted.first->second, edgeWeight);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<IndexedEdge> aggregatedEdges;
|
||||||
|
aggregatedEdges.reserve(edgeWeights.size());
|
||||||
|
for (auto [key, weight] : edgeWeights)
|
||||||
|
aggregatedEdges.push_back(
|
||||||
|
{static_cast<int64_t>(key.first), static_cast<int64_t>(key.second), static_cast<int64_t>(weight)});
|
||||||
|
llvm::sort(aggregatedEdges, [](const IndexedEdge &lhs, const IndexedEdge &rhs) {
|
||||||
|
if (std::get<0>(lhs) != std::get<0>(rhs))
|
||||||
|
return std::get<0>(lhs) < std::get<0>(rhs);
|
||||||
|
return std::get<1>(lhs) < std::get<1>(rhs);
|
||||||
|
});
|
||||||
|
return aggregatedEdges;
|
||||||
|
}
|
||||||
|
|
||||||
|
VirtualGraph buildInitialVirtualGraph(const ComputeGraph &graph) {
|
||||||
|
VirtualGraph virtualGraph;
|
||||||
|
virtualGraph.nodes.reserve(graph.nodes.size());
|
||||||
|
for (auto [index, node] : llvm::enumerate(graph.nodes)) {
|
||||||
|
VirtualNode virtualNode;
|
||||||
|
virtualNode.originalNodeIndices.push_back(index);
|
||||||
|
virtualNode.weight = node.weight;
|
||||||
|
virtualNode.crossbarUsage = node.crossbarUsage;
|
||||||
|
virtualGraph.nodes.push_back(std::move(virtualNode));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<IndexedEdge> edges;
|
||||||
|
edges.reserve(graph.edges.size());
|
||||||
|
for (const ComputeGraphEdge &edge : graph.edges)
|
||||||
|
edges.push_back(
|
||||||
|
{static_cast<int64_t>(edge.source), static_cast<int64_t>(edge.target), static_cast<int64_t>(edge.transferCost)});
|
||||||
|
virtualGraph.edges = aggregateEdges(edges);
|
||||||
|
return virtualGraph;
|
||||||
|
}
|
||||||
|
|
||||||
|
TimingInfo computeTiming(const VirtualGraph &graph) {
|
||||||
|
TimingInfo timing;
|
||||||
|
size_t nodeCount = graph.nodes.size();
|
||||||
|
timing.aest.assign(nodeCount, 0);
|
||||||
|
timing.alst.assign(nodeCount, 0);
|
||||||
|
timing.topologicalOrder.reserve(nodeCount);
|
||||||
|
|
||||||
|
std::vector<std::vector<std::pair<size_t, Weight>>> parents(nodeCount);
|
||||||
|
std::vector<std::vector<std::pair<size_t, Weight>>> children(nodeCount);
|
||||||
|
std::vector<size_t> incomingEdgeCount(nodeCount, 0);
|
||||||
|
|
||||||
|
for (auto [start, end, weight] : graph.edges) {
|
||||||
|
size_t startIndex = static_cast<size_t>(start);
|
||||||
|
size_t endIndex = static_cast<size_t>(end);
|
||||||
|
Weight edgeWeight = static_cast<Weight>(weight);
|
||||||
|
assert(startIndex < nodeCount && endIndex < nodeCount && "virtual edge endpoint out of range");
|
||||||
|
children[startIndex].push_back({endIndex, edgeWeight});
|
||||||
|
parents[endIndex].push_back({startIndex, edgeWeight});
|
||||||
|
incomingEdgeCount[endIndex]++;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto getVirtualNodeOrderKey = [&](size_t nodeIndex) {
|
||||||
|
const VirtualNode &node = graph.nodes[nodeIndex];
|
||||||
|
if (!node.originalNodeIndices.empty())
|
||||||
|
return node.originalNodeIndices.front();
|
||||||
|
return nodeIndex;
|
||||||
|
};
|
||||||
|
auto readyNodeGreater = [&](size_t lhs, size_t rhs) {
|
||||||
|
size_t lhsKey = getVirtualNodeOrderKey(lhs);
|
||||||
|
size_t rhsKey = getVirtualNodeOrderKey(rhs);
|
||||||
|
if (lhsKey != rhsKey)
|
||||||
|
return lhsKey > rhsKey;
|
||||||
|
return lhs > rhs;
|
||||||
|
};
|
||||||
|
std::priority_queue<size_t, std::vector<size_t>, decltype(readyNodeGreater)> readyNodes(readyNodeGreater);
|
||||||
|
for (size_t i = 0; i < nodeCount; ++i)
|
||||||
|
if (incomingEdgeCount[i] == 0)
|
||||||
|
readyNodes.push(i);
|
||||||
|
|
||||||
|
while (!readyNodes.empty()) {
|
||||||
|
size_t current = readyNodes.top();
|
||||||
|
readyNodes.pop();
|
||||||
|
timing.topologicalOrder.push_back(current);
|
||||||
|
for (auto [child, weight] : children[current]) {
|
||||||
|
(void) weight;
|
||||||
|
assert(incomingEdgeCount[child] > 0 && "incoming edge count underflow");
|
||||||
|
incomingEdgeCount[child]--;
|
||||||
|
if (incomingEdgeCount[child] == 0)
|
||||||
|
readyNodes.push(child);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (timing.topologicalOrder.size() != nodeCount)
|
||||||
|
return timing;
|
||||||
|
|
||||||
|
Time dcpl = 0;
|
||||||
|
for (size_t nodeIndex : timing.topologicalOrder) {
|
||||||
|
Time maxParentAest = 0;
|
||||||
|
for (auto [parent, transferCost] : parents[nodeIndex]) {
|
||||||
|
maxParentAest =
|
||||||
|
std::max(maxParentAest, addOrMax(addOrMax(timing.aest[parent], graph.nodes[parent].weight), transferCost));
|
||||||
|
}
|
||||||
|
timing.aest[nodeIndex] = maxParentAest;
|
||||||
|
dcpl = std::max(dcpl, addOrMax(maxParentAest, graph.nodes[nodeIndex].weight));
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t nodeIndex : llvm::reverse(timing.topologicalOrder)) {
|
||||||
|
Time minAlst = std::numeric_limits<Time>::max();
|
||||||
|
if (children[nodeIndex].empty())
|
||||||
|
minAlst = subtractOrZero(dcpl, graph.nodes[nodeIndex].weight);
|
||||||
|
for (auto [child, transferCost] : children[nodeIndex]) {
|
||||||
|
minAlst =
|
||||||
|
std::min(minAlst, subtractOrZero(timing.alst[child], addOrMax(graph.nodes[nodeIndex].weight, transferCost)));
|
||||||
|
}
|
||||||
|
timing.alst[nodeIndex] = minAlst;
|
||||||
|
}
|
||||||
|
|
||||||
|
timing.valid = true;
|
||||||
|
return timing;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::vector<size_t>> buildUndirectedAdjacency(const VirtualGraph &graph) {
|
||||||
|
std::vector<std::vector<size_t>> adjacency(graph.nodes.size());
|
||||||
|
for (auto [start, end, weight] : graph.edges) {
|
||||||
|
(void) weight;
|
||||||
|
size_t startIndex = static_cast<size_t>(start);
|
||||||
|
size_t endIndex = static_cast<size_t>(end);
|
||||||
|
assert(startIndex < graph.nodes.size() && endIndex < graph.nodes.size() && "virtual edge endpoint out of range");
|
||||||
|
adjacency[startIndex].push_back(endIndex);
|
||||||
|
adjacency[endIndex].push_back(startIndex);
|
||||||
|
}
|
||||||
|
for (auto &neighbours : adjacency) {
|
||||||
|
llvm::sort(neighbours);
|
||||||
|
neighbours.erase(std::unique(neighbours.begin(), neighbours.end()), neighbours.end());
|
||||||
|
}
|
||||||
|
return adjacency;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<size_t> selectCriticalWindow(const VirtualGraph &graph, const TimingInfo &timing, size_t windowSize) {
|
||||||
|
std::vector<size_t> ranked(timing.aest.size());
|
||||||
|
std::iota(ranked.begin(), ranked.end(), 0);
|
||||||
|
auto isHigherPriority = [&](size_t lhs, size_t rhs) {
|
||||||
|
Time lhsSlack = slackOrZero(timing.aest[lhs], timing.alst[lhs]);
|
||||||
|
Time rhsSlack = slackOrZero(timing.aest[rhs], timing.alst[rhs]);
|
||||||
|
if (lhsSlack != rhsSlack)
|
||||||
|
return lhsSlack < rhsSlack;
|
||||||
|
if (timing.aest[lhs] != timing.aest[rhs])
|
||||||
|
return timing.aest[lhs] < timing.aest[rhs];
|
||||||
|
return lhs < rhs;
|
||||||
|
};
|
||||||
|
|
||||||
|
windowSize = std::min(windowSize, ranked.size());
|
||||||
|
if (windowSize == 0)
|
||||||
|
return {};
|
||||||
|
if (windowSize == ranked.size()) {
|
||||||
|
llvm::sort(ranked, isHigherPriority);
|
||||||
|
return ranked;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t criticalPoolSize = std::min(ranked.size(), std::max(windowSize, windowSize * 2));
|
||||||
|
if (criticalPoolSize < ranked.size())
|
||||||
|
std::nth_element(
|
||||||
|
ranked.begin(), ranked.begin() + static_cast<std::ptrdiff_t>(criticalPoolSize), ranked.end(), isHigherPriority);
|
||||||
|
|
||||||
|
std::vector<char> inCriticalPool(ranked.size(), false);
|
||||||
|
for (size_t i = 0; i < criticalPoolSize; ++i)
|
||||||
|
inCriticalPool[ranked[i]] = true;
|
||||||
|
|
||||||
|
size_t seed = *std::min_element(ranked.begin(), ranked.end(), isHigherPriority);
|
||||||
|
std::vector<std::vector<size_t>> adjacency = buildUndirectedAdjacency(graph);
|
||||||
|
std::vector<size_t> selected;
|
||||||
|
std::vector<char> inWindow(ranked.size(), false);
|
||||||
|
selected.reserve(windowSize);
|
||||||
|
|
||||||
|
struct FrontierEntry {
|
||||||
|
size_t node;
|
||||||
|
};
|
||||||
|
auto frontierCompare = [&](FrontierEntry lhs, FrontierEntry rhs) { return isHigherPriority(rhs.node, lhs.node); };
|
||||||
|
std::priority_queue<FrontierEntry, std::vector<FrontierEntry>, decltype(frontierCompare)> frontier(frontierCompare);
|
||||||
|
|
||||||
|
auto addToWindow = [&](size_t node, const std::vector<char> &eligible) {
|
||||||
|
if (inWindow[node])
|
||||||
|
return;
|
||||||
|
inWindow[node] = true;
|
||||||
|
selected.push_back(node);
|
||||||
|
for (size_t neighbour : adjacency[node])
|
||||||
|
if (!inWindow[neighbour] && eligible[neighbour])
|
||||||
|
frontier.push({neighbour});
|
||||||
|
};
|
||||||
|
|
||||||
|
addToWindow(seed, inCriticalPool);
|
||||||
|
while (!frontier.empty() && selected.size() < windowSize) {
|
||||||
|
size_t node = frontier.top().node;
|
||||||
|
frontier.pop();
|
||||||
|
if (!inWindow[node])
|
||||||
|
addToWindow(node, inCriticalPool);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (selected.size() < windowSize) {
|
||||||
|
std::vector<char> anyNode(ranked.size(), true);
|
||||||
|
for (size_t node : selected)
|
||||||
|
for (size_t neighbour : adjacency[node])
|
||||||
|
if (!inWindow[neighbour])
|
||||||
|
frontier.push({neighbour});
|
||||||
|
while (!frontier.empty() && selected.size() < windowSize) {
|
||||||
|
size_t node = frontier.top().node;
|
||||||
|
frontier.pop();
|
||||||
|
if (!inWindow[node])
|
||||||
|
addToWindow(node, anyNode);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (selected.size() < windowSize) {
|
||||||
|
llvm::sort(ranked, isHigherPriority);
|
||||||
|
for (size_t node : ranked) {
|
||||||
|
if (selected.size() == windowSize)
|
||||||
|
break;
|
||||||
|
if (!inWindow[node]) {
|
||||||
|
inWindow[node] = true;
|
||||||
|
selected.push_back(node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::sort(selected, isHigherPriority);
|
||||||
|
return selected;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<IndexedEdge> buildWindowEdges(const VirtualGraph &graph, const std::vector<int64_t> &nodeToWindowIndex) {
|
||||||
|
std::vector<IndexedEdge> windowEdges;
|
||||||
|
windowEdges.reserve(graph.edges.size());
|
||||||
|
for (auto [start, end, weight] : graph.edges) {
|
||||||
|
int64_t mappedStart = nodeToWindowIndex[static_cast<size_t>(start)];
|
||||||
|
int64_t mappedEnd = nodeToWindowIndex[static_cast<size_t>(end)];
|
||||||
|
if (mappedStart == -1 || mappedEnd == -1)
|
||||||
|
continue;
|
||||||
|
windowEdges.push_back({mappedStart, mappedEnd, weight});
|
||||||
|
}
|
||||||
|
return aggregateEdges(windowEdges);
|
||||||
|
}
|
||||||
|
|
||||||
|
WindowScheduleResult scheduleWindow(const VirtualGraph &graph,
|
||||||
|
llvm::ArrayRef<size_t> selectedNodes,
|
||||||
|
const DcpScheduleOptions &options,
|
||||||
|
mlir::MLIRContext *context) {
|
||||||
|
std::vector<Weight> windowWeights;
|
||||||
|
std::vector<CrossbarUsage> windowCrossbarUsage;
|
||||||
|
std::vector<int64_t> windowNodeOrderKeys;
|
||||||
|
std::vector<int64_t> nodeToWindowIndex(graph.nodes.size(), -1);
|
||||||
|
windowWeights.reserve(selectedNodes.size());
|
||||||
|
windowCrossbarUsage.reserve(selectedNodes.size());
|
||||||
|
windowNodeOrderKeys.reserve(selectedNodes.size());
|
||||||
|
|
||||||
|
for (auto [windowIndex, nodeIndex] : llvm::enumerate(selectedNodes)) {
|
||||||
|
nodeToWindowIndex[nodeIndex] = static_cast<int64_t>(windowIndex);
|
||||||
|
windowWeights.push_back(graph.nodes[nodeIndex].weight);
|
||||||
|
windowCrossbarUsage.push_back(graph.nodes[nodeIndex].crossbarUsage);
|
||||||
|
windowNodeOrderKeys.push_back(static_cast<int64_t>(nodeIndex));
|
||||||
|
}
|
||||||
|
|
||||||
|
GraphDCP windowGraph(
|
||||||
|
windowWeights, buildWindowEdges(graph, nodeToWindowIndex), windowNodeOrderKeys, windowCrossbarUsage);
|
||||||
|
if (options.processorCount > 0)
|
||||||
|
windowGraph.setMaxCpuCount(static_cast<int>(options.processorCount));
|
||||||
|
windowGraph.setContext(context);
|
||||||
|
windowGraph.runDcp();
|
||||||
|
|
||||||
|
WindowScheduleResult result;
|
||||||
|
result.cpuCount = windowGraph.cpuCount();
|
||||||
|
for (CPU cpu = 0; cpu < windowGraph.cpuCount(); ++cpu) {
|
||||||
|
auto scheduledTasks = windowGraph.getScheduledTasks(cpu);
|
||||||
|
if (scheduledTasks.size() < 2)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
result.mergedNodeCount += scheduledTasks.size();
|
||||||
|
result.maxMergeGroupSize = std::max(result.maxMergeGroupSize, scheduledTasks.size());
|
||||||
|
std::vector<size_t> mergeGroup;
|
||||||
|
mergeGroup.reserve(scheduledTasks.size());
|
||||||
|
for (const auto &task : scheduledTasks)
|
||||||
|
mergeGroup.push_back(selectedNodes[task.nodeIndex]);
|
||||||
|
result.mergeGroups.push_back(std::move(mergeGroup));
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool coarsenGraph(const VirtualGraph &graph,
|
||||||
|
llvm::ArrayRef<std::vector<size_t>> mergeGroups,
|
||||||
|
VirtualGraph &coarsenedGraph,
|
||||||
|
std::vector<size_t> &oldToNewNode) {
|
||||||
|
TimingInfo timing = computeTiming(graph);
|
||||||
|
std::vector<size_t> topologicalRank(graph.nodes.size());
|
||||||
|
std::iota(topologicalRank.begin(), topologicalRank.end(), 0);
|
||||||
|
if (timing.valid)
|
||||||
|
for (auto [rank, nodeIndex] : llvm::enumerate(timing.topologicalOrder))
|
||||||
|
topologicalRank[nodeIndex] = rank;
|
||||||
|
|
||||||
|
std::vector<std::vector<size_t>> orderedMergeGroups;
|
||||||
|
orderedMergeGroups.reserve(mergeGroups.size());
|
||||||
|
for (const auto &mergeGroup : mergeGroups) {
|
||||||
|
orderedMergeGroups.emplace_back(mergeGroup.begin(), mergeGroup.end());
|
||||||
|
std::stable_sort(orderedMergeGroups.back().begin(), orderedMergeGroups.back().end(), [&](size_t lhs, size_t rhs) {
|
||||||
|
if (topologicalRank[lhs] != topologicalRank[rhs])
|
||||||
|
return topologicalRank[lhs] < topologicalRank[rhs];
|
||||||
|
return lhs < rhs;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int64_t> nodeToMergeGroup(graph.nodes.size(), -1);
|
||||||
|
for (auto [groupIndex, mergeGroup] : llvm::enumerate(orderedMergeGroups)) {
|
||||||
|
if (mergeGroup.size() < 2)
|
||||||
|
continue;
|
||||||
|
for (size_t nodeIndex : mergeGroup) {
|
||||||
|
assert(nodeIndex < graph.nodes.size() && "merge group node out of range");
|
||||||
|
nodeToMergeGroup[nodeIndex] = static_cast<int64_t>(groupIndex);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::optional<size_t>> mergeGroupToNewNode(orderedMergeGroups.size());
|
||||||
|
std::vector<size_t> newNodeRank;
|
||||||
|
oldToNewNode.assign(graph.nodes.size(), 0);
|
||||||
|
bool mergedAny = false;
|
||||||
|
coarsenedGraph.nodes.clear();
|
||||||
|
coarsenedGraph.edges.clear();
|
||||||
|
coarsenedGraph.nodes.reserve(graph.nodes.size());
|
||||||
|
newNodeRank.reserve(graph.nodes.size());
|
||||||
|
|
||||||
|
for (size_t nodeIndex = 0; nodeIndex < graph.nodes.size(); ++nodeIndex) {
|
||||||
|
int64_t mergeGroupIndex = nodeToMergeGroup[nodeIndex];
|
||||||
|
if (mergeGroupIndex == -1) {
|
||||||
|
oldToNewNode[nodeIndex] = coarsenedGraph.nodes.size();
|
||||||
|
coarsenedGraph.nodes.push_back(graph.nodes[nodeIndex]);
|
||||||
|
newNodeRank.push_back(topologicalRank[nodeIndex]);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto &newNodeIndex = mergeGroupToNewNode[static_cast<size_t>(mergeGroupIndex)];
|
||||||
|
if (newNodeIndex.has_value()) {
|
||||||
|
oldToNewNode[nodeIndex] = *newNodeIndex;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
VirtualNode mergedNode;
|
||||||
|
for (size_t memberIndex : orderedMergeGroups[static_cast<size_t>(mergeGroupIndex)]) {
|
||||||
|
const VirtualNode &memberNode = graph.nodes[memberIndex];
|
||||||
|
mergedNode.originalNodeIndices.append(memberNode.originalNodeIndices.begin(), memberNode.originalNodeIndices.end());
|
||||||
|
mergedNode.weight = addOrMax(mergedNode.weight, memberNode.weight);
|
||||||
|
mergedNode.crossbarUsage = addOrMax(mergedNode.crossbarUsage, memberNode.crossbarUsage);
|
||||||
|
}
|
||||||
|
std::sort(mergedNode.originalNodeIndices.begin(), mergedNode.originalNodeIndices.end());
|
||||||
|
|
||||||
|
mergedAny = true;
|
||||||
|
newNodeIndex = coarsenedGraph.nodes.size();
|
||||||
|
for (size_t memberIndex : orderedMergeGroups[static_cast<size_t>(mergeGroupIndex)])
|
||||||
|
oldToNewNode[memberIndex] = *newNodeIndex;
|
||||||
|
newNodeRank.push_back(topologicalRank[orderedMergeGroups[static_cast<size_t>(mergeGroupIndex)].front()]);
|
||||||
|
coarsenedGraph.nodes.push_back(std::move(mergedNode));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!mergedAny)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
std::vector<IndexedEdge> remappedEdges;
|
||||||
|
remappedEdges.reserve(graph.edges.size());
|
||||||
|
for (auto [start, end, weight] : graph.edges) {
|
||||||
|
size_t newStart = oldToNewNode[static_cast<size_t>(start)];
|
||||||
|
size_t newEnd = oldToNewNode[static_cast<size_t>(end)];
|
||||||
|
if (newStart == newEnd)
|
||||||
|
continue;
|
||||||
|
if (newNodeRank[newStart] >= newNodeRank[newEnd])
|
||||||
|
continue;
|
||||||
|
remappedEdges.push_back({static_cast<int64_t>(newStart), static_cast<int64_t>(newEnd), weight});
|
||||||
|
}
|
||||||
|
coarsenedGraph.edges = aggregateEdges(remappedEdges);
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t getDcpCoarseningWindowSize(size_t nodeCount, const DcpScheduleOptions &options) {
|
||||||
|
size_t windowSize = std::min(options.criticalWindowSize, nodeCount);
|
||||||
|
CPU maxCpuCount = std::max<CPU>(1, static_cast<CPU>(getSchedulingCpuBudget(options)));
|
||||||
|
if (nodeCount > static_cast<size_t>(maxCpuCount))
|
||||||
|
windowSize = std::max(windowSize, std::min(nodeCount, static_cast<size_t>(maxCpuCount) + 1));
|
||||||
|
return windowSize;
|
||||||
|
}
|
||||||
|
|
||||||
|
void assignFeasibleAest(const ComputeGraph &graph, MergeScheduleResult &result) {
|
||||||
|
llvm::DenseMap<ComputeInstance, size_t> nodeIndexByInstance;
|
||||||
|
nodeIndexByInstance.reserve(graph.nodes.size());
|
||||||
|
for (auto [nodeIndex, node] : llvm::enumerate(graph.nodes))
|
||||||
|
nodeIndexByInstance[node.instance] = nodeIndex;
|
||||||
|
|
||||||
|
struct ScheduledEdge {
|
||||||
|
size_t target = 0;
|
||||||
|
Time delay = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<std::vector<ScheduledEdge>> scheduledChildren(graph.nodes.size());
|
||||||
|
std::vector<size_t> incomingEdgeCount(graph.nodes.size(), 0);
|
||||||
|
for (const ComputeGraphEdge &edge : graph.edges) {
|
||||||
|
const ComputeInstance sourceInstance = graph.nodes[edge.source].instance;
|
||||||
|
const ComputeInstance targetInstance = graph.nodes[edge.target].instance;
|
||||||
|
const size_t sourceCpu = result.computeToCpuMap.lookup(sourceInstance);
|
||||||
|
const size_t targetCpu = result.computeToCpuMap.lookup(targetInstance);
|
||||||
|
|
||||||
|
Time delay = graph.nodes[edge.source].weight;
|
||||||
|
if (sourceCpu != targetCpu)
|
||||||
|
delay = addOrMax(delay, edge.transferCost);
|
||||||
|
|
||||||
|
scheduledChildren[edge.source].push_back({edge.target, delay});
|
||||||
|
incomingEdgeCount[edge.target]++;
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::DenseMap<size_t, std::vector<std::pair<size_t, size_t>>> tasksByCpu;
|
||||||
|
for (const ComputeGraphNode &node : graph.nodes) {
|
||||||
|
size_t cpu = result.computeToCpuMap.lookup(node.instance);
|
||||||
|
size_t slot = result.computeToCpuSlotMap.lookup(node.instance);
|
||||||
|
tasksByCpu[cpu].push_back({slot, nodeIndexByInstance.lookup(node.instance)});
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto &entry : tasksByCpu) {
|
||||||
|
auto &scheduledTasks = entry.second;
|
||||||
|
llvm::sort(scheduledTasks, [](const auto &lhs, const auto &rhs) {
|
||||||
|
if (lhs.first != rhs.first)
|
||||||
|
return lhs.first < rhs.first;
|
||||||
|
return lhs.second < rhs.second;
|
||||||
|
});
|
||||||
|
|
||||||
|
for (size_t i = 1; i < scheduledTasks.size(); ++i) {
|
||||||
|
size_t sourceIndex = scheduledTasks[i - 1].second;
|
||||||
|
size_t targetIndex = scheduledTasks[i].second;
|
||||||
|
scheduledChildren[sourceIndex].push_back({targetIndex, graph.nodes[sourceIndex].weight});
|
||||||
|
incomingEdgeCount[targetIndex]++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto readyNodeGreater = [&](size_t lhs, size_t rhs) {
|
||||||
|
if (graph.nodes[lhs].originalOrder != graph.nodes[rhs].originalOrder)
|
||||||
|
return graph.nodes[lhs].originalOrder > graph.nodes[rhs].originalOrder;
|
||||||
|
return lhs > rhs;
|
||||||
|
};
|
||||||
|
std::priority_queue<size_t, std::vector<size_t>, decltype(readyNodeGreater)> readyNodes(readyNodeGreater);
|
||||||
|
for (size_t nodeIndex = 0; nodeIndex < graph.nodes.size(); ++nodeIndex)
|
||||||
|
if (incomingEdgeCount[nodeIndex] == 0)
|
||||||
|
readyNodes.push(nodeIndex);
|
||||||
|
|
||||||
|
std::vector<Time> startTimes(graph.nodes.size(), 0);
|
||||||
|
size_t processedNodeCount = 0;
|
||||||
|
while (!readyNodes.empty()) {
|
||||||
|
size_t sourceIndex = readyNodes.top();
|
||||||
|
readyNodes.pop();
|
||||||
|
processedNodeCount++;
|
||||||
|
|
||||||
|
for (const ScheduledEdge &edge : scheduledChildren[sourceIndex]) {
|
||||||
|
startTimes[edge.target] = std::max(startTimes[edge.target], addOrMax(startTimes[sourceIndex], edge.delay));
|
||||||
|
assert(incomingEdgeCount[edge.target] > 0 && "scheduled incoming edge count underflow");
|
||||||
|
incomingEdgeCount[edge.target]--;
|
||||||
|
if (incomingEdgeCount[edge.target] == 0)
|
||||||
|
readyNodes.push(edge.target);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (processedNodeCount != graph.nodes.size())
|
||||||
|
llvm::report_fatal_error("merge scheduling: coarsened DCP schedule is cyclic");
|
||||||
|
|
||||||
|
for (auto [nodeIndex, node] : llvm::enumerate(graph.nodes))
|
||||||
|
result.computeToAestMap[node.instance] = startTimes[nodeIndex];
|
||||||
|
}
|
||||||
|
|
||||||
|
MergeScheduleResult buildResultFromVirtualGraph(const VirtualGraph &graph, const ComputeGraph &originalGraph) {
|
||||||
|
MergeScheduleResult result;
|
||||||
|
|
||||||
|
TimingInfo timing = computeTiming(graph);
|
||||||
|
std::vector<size_t> virtualNodeOrder;
|
||||||
|
if (timing.valid)
|
||||||
|
virtualNodeOrder = std::move(timing.topologicalOrder);
|
||||||
|
else {
|
||||||
|
virtualNodeOrder.resize(graph.nodes.size());
|
||||||
|
std::iota(virtualNodeOrder.begin(), virtualNodeOrder.end(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<size_t> originalNodeToCpu(originalGraph.nodes.size(), 0);
|
||||||
|
for (auto [cpu, virtualNodeIndex] : llvm::enumerate(virtualNodeOrder)) {
|
||||||
|
const VirtualNode &virtualNode = graph.nodes[virtualNodeIndex];
|
||||||
|
for (size_t originalIndex : virtualNode.originalNodeIndices)
|
||||||
|
originalNodeToCpu[originalIndex] = cpu;
|
||||||
|
}
|
||||||
|
|
||||||
|
result.dominanceOrderCompute.reserve(originalGraph.nodes.size());
|
||||||
|
llvm::DenseMap<size_t, size_t> nextCpuSlot;
|
||||||
|
for (auto [originalIndex, node] : llvm::enumerate(originalGraph.nodes)) {
|
||||||
|
size_t cpu = originalNodeToCpu[originalIndex];
|
||||||
|
result.dominanceOrderCompute.push_back(node.instance);
|
||||||
|
result.computeToCpuMap[node.instance] = cpu;
|
||||||
|
result.computeToCpuSlotMap[node.instance] = nextCpuSlot[cpu]++;
|
||||||
|
result.cpuToLastComputeMap[cpu] = node.instance;
|
||||||
|
}
|
||||||
|
for (const auto &[cpu, lastCompute] : result.cpuToLastComputeMap)
|
||||||
|
result.isLastComputeOfCpu.insert(lastCompute);
|
||||||
|
assignFeasibleAest(originalGraph, result);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
MergeScheduleResult buildResultFromScheduledGraph(GraphDCP &graphDCP, const ComputeGraph &graph) {
|
||||||
|
MergeScheduleResult result;
|
||||||
|
result.dominanceOrderCompute.reserve(graph.nodes.size());
|
||||||
|
for (const ComputeGraphNode &node : graph.nodes)
|
||||||
|
result.dominanceOrderCompute.push_back(node.instance);
|
||||||
|
|
||||||
|
for (CPU cpu = 0; cpu < graphDCP.cpuCount(); ++cpu) {
|
||||||
|
auto scheduledTasks = graphDCP.getScheduledTasks(cpu);
|
||||||
|
if (scheduledTasks.empty())
|
||||||
|
continue;
|
||||||
|
|
||||||
|
for (auto [slot, task] : llvm::enumerate(scheduledTasks)) {
|
||||||
|
const ComputeInstance instance = graph.nodes[task.nodeIndex].instance;
|
||||||
|
result.computeToCpuMap[instance] = cpu;
|
||||||
|
result.computeToCpuSlotMap[instance] = slot;
|
||||||
|
result.computeToAestMap[instance] = static_cast<uint64_t>(task.aest);
|
||||||
|
}
|
||||||
|
|
||||||
|
const ComputeInstance lastInstance = graph.nodes[scheduledTasks.back().nodeIndex].instance;
|
||||||
|
result.cpuToLastComputeMap[cpu] = lastInstance;
|
||||||
|
result.isLastComputeOfCpu.insert(lastInstance);
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
MergeScheduleResult runLegacyDcp(const ComputeGraph &graph, const DcpScheduleOptions &options, mlir::MLIRContext *context) {
|
||||||
|
llvm::SmallVector<Weight> nodeWeights;
|
||||||
|
llvm::SmallVector<CrossbarUsage> nodeCrossbarUsage;
|
||||||
|
llvm::SmallVector<int64_t> nodeOrderKeys;
|
||||||
|
llvm::SmallVector<IndexedEdge> edges;
|
||||||
|
nodeWeights.reserve(graph.nodes.size());
|
||||||
|
nodeCrossbarUsage.reserve(graph.nodes.size());
|
||||||
|
nodeOrderKeys.reserve(graph.nodes.size());
|
||||||
|
edges.reserve(graph.edges.size());
|
||||||
|
|
||||||
|
for (const ComputeGraphNode &node : graph.nodes) {
|
||||||
|
nodeWeights.push_back(node.weight);
|
||||||
|
nodeCrossbarUsage.push_back(node.crossbarUsage);
|
||||||
|
nodeOrderKeys.push_back(static_cast<int64_t>(node.originalOrder));
|
||||||
|
}
|
||||||
|
for (const ComputeGraphEdge &edge : graph.edges) {
|
||||||
|
edges.push_back(
|
||||||
|
{static_cast<int64_t>(edge.source), static_cast<int64_t>(edge.target), static_cast<int64_t>(edge.transferCost)});
|
||||||
|
}
|
||||||
|
|
||||||
|
GraphDCP graphDCP(nodeWeights, edges, nodeOrderKeys, nodeCrossbarUsage);
|
||||||
|
if (options.processorCount > 0)
|
||||||
|
graphDCP.setMaxCpuCount(static_cast<int>(options.processorCount));
|
||||||
|
graphDCP.setContext(context);
|
||||||
|
graphDCP.runDcp();
|
||||||
|
return buildResultFromScheduledGraph(graphDCP, graph);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool needsExactScheduledBatches(const ComputeGraph &graph, const DcpScheduleOptions &options) {
|
||||||
|
if (options.processorCount == 0 || !options.allowFallbackForAutoCoreCount)
|
||||||
|
return false;
|
||||||
|
size_t schedulingCpuBudget = getSchedulingCpuBudget(options);
|
||||||
|
return llvm::any_of(graph.nodes, [&](const ComputeGraphNode &node) {
|
||||||
|
auto batch = dyn_cast<SpatComputeBatch>(node.instance.op);
|
||||||
|
return batch && static_cast<size_t>(batch.getLaneCount()) > schedulingCpuBudget;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
MergeScheduleResult
|
||||||
|
runDcpScheduler(const ComputeGraph &graph, const DcpScheduleOptions &options, mlir::MLIRContext *context) {
|
||||||
|
if (needsExactScheduledBatches(graph, options))
|
||||||
|
return runLegacyDcp(graph, options, context);
|
||||||
|
|
||||||
|
if (options.criticalWindowSize == 0)
|
||||||
|
return runLegacyDcp(graph, options, context);
|
||||||
|
|
||||||
|
VirtualGraph virtualGraph = buildInitialVirtualGraph(graph);
|
||||||
|
size_t iteration = 0;
|
||||||
|
bool debugCoarsening = isDcpCoarsenDebugEnabled();
|
||||||
|
auto tryCoarsenSelectedNodes = [&](llvm::ArrayRef<size_t> selectedNodes) {
|
||||||
|
size_t oldNodeCount = virtualGraph.nodes.size();
|
||||||
|
WindowScheduleResult windowSchedule = scheduleWindow(virtualGraph, selectedNodes, options, context);
|
||||||
|
if (windowSchedule.mergeGroups.empty()) {
|
||||||
|
if (debugCoarsening && oldNodeCount >= 200)
|
||||||
|
llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} windowCpus={3} "
|
||||||
|
"groups=0 mergedNodes=0 maxGroup=0 new={1} changed=0\n",
|
||||||
|
iteration,
|
||||||
|
oldNodeCount,
|
||||||
|
selectedNodes.size(),
|
||||||
|
windowSchedule.cpuCount);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
VirtualGraph coarsenedGraph;
|
||||||
|
std::vector<size_t> oldToNewNode;
|
||||||
|
if (!coarsenGraph(virtualGraph, windowSchedule.mergeGroups, coarsenedGraph, oldToNewNode))
|
||||||
|
return false;
|
||||||
|
if (debugCoarsening && (oldNodeCount >= 200 || coarsenedGraph.nodes.size() >= 200))
|
||||||
|
llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} windowCpus={3} "
|
||||||
|
"groups={4} mergedNodes={5} maxGroup={6} new={7} changed={8}\n",
|
||||||
|
iteration,
|
||||||
|
oldNodeCount,
|
||||||
|
selectedNodes.size(),
|
||||||
|
windowSchedule.cpuCount,
|
||||||
|
windowSchedule.mergeGroups.size(),
|
||||||
|
windowSchedule.mergedNodeCount,
|
||||||
|
windowSchedule.maxMergeGroupSize,
|
||||||
|
coarsenedGraph.nodes.size(),
|
||||||
|
oldNodeCount - coarsenedGraph.nodes.size());
|
||||||
|
virtualGraph = std::move(coarsenedGraph);
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
|
||||||
|
while (virtualGraph.nodes.size() > 1) {
|
||||||
|
if (virtualGraph.nodes.size() <= getSchedulingCpuBudget(options)) {
|
||||||
|
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
|
||||||
|
llvm::errs() << llvm::formatv(
|
||||||
|
"[DCP-COARSEN] iter={0} old={1} stop=cpu-budget\n", iteration, virtualGraph.nodes.size());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
iteration++;
|
||||||
|
TimingInfo timing = computeTiming(virtualGraph);
|
||||||
|
if (!timing.valid) {
|
||||||
|
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
|
||||||
|
llvm::errs() << llvm::formatv(
|
||||||
|
"[DCP-COARSEN] iter={0} old={1} invalid-timing\n", iteration, virtualGraph.nodes.size());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::SmallVector<size_t> selectedNodes;
|
||||||
|
auto criticalWindow =
|
||||||
|
selectCriticalWindow(virtualGraph, timing, getDcpCoarseningWindowSize(virtualGraph.nodes.size(), options));
|
||||||
|
selectedNodes.append(criticalWindow.begin(), criticalWindow.end());
|
||||||
|
|
||||||
|
if (selectedNodes.size() < 2) {
|
||||||
|
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
|
||||||
|
llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} stop=small-window\n",
|
||||||
|
iteration,
|
||||||
|
virtualGraph.nodes.size(),
|
||||||
|
selectedNodes.size());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (tryCoarsenSelectedNodes(selectedNodes))
|
||||||
|
continue;
|
||||||
|
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
|
||||||
|
llvm::errs() << llvm::formatv(
|
||||||
|
"[DCP-COARSEN] iter={0} old={1} stop=no-merge\n", iteration, virtualGraph.nodes.size());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
return buildResultFromVirtualGraph(virtualGraph, graph);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace spatial
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,21 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/IR/MLIRContext.h"
|
||||||
|
|
||||||
|
#include "ComputeGraph.hpp"
|
||||||
|
#include "MergeSchedule.hpp"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace spatial {
|
||||||
|
|
||||||
|
struct DcpScheduleOptions {
|
||||||
|
size_t processorCount = 0;
|
||||||
|
size_t criticalWindowSize = 0;
|
||||||
|
bool allowFallbackForAutoCoreCount = true;
|
||||||
|
};
|
||||||
|
|
||||||
|
MergeScheduleResult
|
||||||
|
runDcpScheduler(const ComputeGraph &graph, const DcpScheduleOptions &options, mlir::MLIRContext *context);
|
||||||
|
|
||||||
|
} // namespace spatial
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,25 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "llvm/ADT/DenseMap.h"
|
||||||
|
#include "llvm/ADT/DenseSet.h"
|
||||||
|
|
||||||
|
#include <cstddef>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "ComputeInstance.hpp"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace spatial {
|
||||||
|
|
||||||
|
struct MergeScheduleResult {
|
||||||
|
std::vector<ComputeInstance> dominanceOrderCompute;
|
||||||
|
llvm::DenseMap<ComputeInstance, size_t> computeToCpuMap;
|
||||||
|
llvm::DenseMap<ComputeInstance, size_t> computeToCpuSlotMap;
|
||||||
|
llvm::DenseMap<ComputeInstance, uint64_t> computeToAestMap;
|
||||||
|
llvm::DenseSet<ComputeInstance> isLastComputeOfCpu;
|
||||||
|
llvm::DenseMap<size_t, ComputeInstance> cpuToLastComputeMap;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace spatial
|
||||||
|
} // namespace onnx_mlir
|
||||||
+139
@@ -0,0 +1,139 @@
|
|||||||
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
#include "llvm/ADT/DenseMap.h"
|
||||||
|
#include "llvm/Support/ErrorHandling.h"
|
||||||
|
#include "llvm/Support/FormatVariadic.h"
|
||||||
|
|
||||||
|
#include <limits>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "ComputeGraph.hpp"
|
||||||
|
#include "../DCPGraph/DCPAnalysis.hpp"
|
||||||
|
#include "DcpScheduler.hpp"
|
||||||
|
#include "MergeSchedulingAnalysis.hpp"
|
||||||
|
#include "PeftScheduler.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace spatial {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
MergeSchedulerKind getSchedulerKind() {
|
||||||
|
switch (pimMergeScheduler.getValue()) {
|
||||||
|
case MergeSchedulerPeft:
|
||||||
|
return MergeSchedulerKind::Peft;
|
||||||
|
case MergeSchedulerDcp:
|
||||||
|
return MergeSchedulerKind::Dcp;
|
||||||
|
}
|
||||||
|
llvm_unreachable("unknown merge scheduler kind");
|
||||||
|
}
|
||||||
|
|
||||||
|
void verifySchedule(const ComputeGraph &graph, const MergeScheduleResult &result, CrossbarUsage crossbarCapacity) {
|
||||||
|
llvm::DenseMap<size_t, std::vector<std::pair<size_t, size_t>>> tasksByCpu;
|
||||||
|
tasksByCpu.reserve(result.cpuToLastComputeMap.size());
|
||||||
|
|
||||||
|
for (size_t nodeIndex = 0; nodeIndex < graph.nodes.size(); ++nodeIndex) {
|
||||||
|
const ComputeInstance instance = graph.nodes[nodeIndex].instance;
|
||||||
|
if (!result.computeToCpuMap.count(instance))
|
||||||
|
llvm::report_fatal_error("merge scheduling: missing CPU assignment");
|
||||||
|
if (!result.computeToCpuSlotMap.count(instance))
|
||||||
|
llvm::report_fatal_error("merge scheduling: missing CPU slot assignment");
|
||||||
|
if (!result.computeToAestMap.count(instance))
|
||||||
|
llvm::report_fatal_error("merge scheduling: missing start time");
|
||||||
|
|
||||||
|
tasksByCpu[result.computeToCpuMap.lookup(instance)].push_back(
|
||||||
|
{result.computeToCpuSlotMap.lookup(instance), nodeIndex});
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto &entry : tasksByCpu) {
|
||||||
|
auto &scheduledTasks = entry.second;
|
||||||
|
llvm::sort(scheduledTasks, [](const auto &lhs, const auto &rhs) {
|
||||||
|
if (lhs.first != rhs.first)
|
||||||
|
return lhs.first < rhs.first;
|
||||||
|
return lhs.second < rhs.second;
|
||||||
|
});
|
||||||
|
|
||||||
|
CrossbarUsage usedCrossbars = 0;
|
||||||
|
for (size_t slot = 0; slot < scheduledTasks.size(); ++slot) {
|
||||||
|
if (scheduledTasks[slot].first != slot)
|
||||||
|
llvm::report_fatal_error("merge scheduling: CPU slots are not contiguous");
|
||||||
|
usedCrossbars = addOrMax(usedCrossbars, graph.nodes[scheduledTasks[slot].second].crossbarUsage);
|
||||||
|
if (usedCrossbars > crossbarCapacity)
|
||||||
|
llvm::report_fatal_error("merge scheduling: CPU crossbar capacity exceeded");
|
||||||
|
}
|
||||||
|
|
||||||
|
const ComputeInstance expectedLast = graph.nodes[scheduledTasks.back().second].instance;
|
||||||
|
auto lastIt = result.cpuToLastComputeMap.find(entry.first);
|
||||||
|
if (lastIt == result.cpuToLastComputeMap.end() || !(lastIt->second == expectedLast))
|
||||||
|
llvm::report_fatal_error("merge scheduling: cpuToLastComputeMap does not match slot order");
|
||||||
|
if (!result.isLastComputeOfCpu.count(expectedLast))
|
||||||
|
llvm::report_fatal_error("merge scheduling: missing last-compute marker");
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const ComputeGraphEdge &edge : graph.edges) {
|
||||||
|
const ComputeInstance source = graph.nodes[edge.source].instance;
|
||||||
|
const ComputeInstance target = graph.nodes[edge.target].instance;
|
||||||
|
const size_t sourceCpu = result.computeToCpuMap.lookup(source);
|
||||||
|
const size_t targetCpu = result.computeToCpuMap.lookup(target);
|
||||||
|
const size_t sourceSlot = result.computeToCpuSlotMap.lookup(source);
|
||||||
|
const size_t targetSlot = result.computeToCpuSlotMap.lookup(target);
|
||||||
|
const Time sourceStart = static_cast<Time>(result.computeToAestMap.lookup(source));
|
||||||
|
const Time targetStart = static_cast<Time>(result.computeToAestMap.lookup(target));
|
||||||
|
if (sourceCpu == targetCpu && sourceSlot >= targetSlot)
|
||||||
|
llvm::report_fatal_error("merge scheduling: same-CPU dependency order is invalid");
|
||||||
|
|
||||||
|
Time earliestTargetStart = addOrMax(sourceStart, graph.nodes[edge.source].weight);
|
||||||
|
if (sourceCpu != targetCpu)
|
||||||
|
earliestTargetStart = addOrMax(earliestTargetStart, edge.transferCost);
|
||||||
|
if (targetStart < earliestTargetStart) {
|
||||||
|
std::string message = llvm::formatv("merge scheduling: dependency legality failed between tasks {0} and {1}",
|
||||||
|
graph.nodes[edge.source].originalOrder,
|
||||||
|
graph.nodes[edge.target].originalOrder)
|
||||||
|
.str();
|
||||||
|
llvm::report_fatal_error(llvm::StringRef(message));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
MergeSchedulingAnalysis::MergeSchedulingAnalysis(mlir::Operation *op)
|
||||||
|
: entryOp(op) {
|
||||||
|
result = run();
|
||||||
|
}
|
||||||
|
|
||||||
|
MergeScheduleResult MergeSchedulingAnalysis::run() {
|
||||||
|
verifyExplicitPimCoreCount();
|
||||||
|
ComputeGraph graph = buildComputeGraph(entryOp);
|
||||||
|
if (!verifyAcyclic(graph))
|
||||||
|
llvm::report_fatal_error("merge scheduling: compute graph is cyclic");
|
||||||
|
|
||||||
|
MergeSchedulingOptions options;
|
||||||
|
options.kind = getSchedulerKind();
|
||||||
|
if (coresCount.getValue() > 0)
|
||||||
|
options.processorCount = static_cast<size_t>(coresCount.getValue());
|
||||||
|
|
||||||
|
MergeScheduleResult schedule;
|
||||||
|
if (options.kind == MergeSchedulerKind::Peft) {
|
||||||
|
schedule = runPeftScheduler(
|
||||||
|
graph,
|
||||||
|
PeftScheduleOptions {options.processorCount, static_cast<CrossbarUsage>(crossbarCountInCore.getValue()),
|
||||||
|
entryOp->getContext()});
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
schedule = runDcpScheduler(
|
||||||
|
graph,
|
||||||
|
DcpScheduleOptions {
|
||||||
|
options.processorCount,
|
||||||
|
dcpCriticalWindowSize.getValue(),
|
||||||
|
options.allowDcpFallbackForAutoCoreCount
|
||||||
|
},
|
||||||
|
entryOp->getContext());
|
||||||
|
}
|
||||||
|
|
||||||
|
verifySchedule(graph, schedule, static_cast<CrossbarUsage>(crossbarCountInCore.getValue()));
|
||||||
|
return schedule;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace spatial
|
||||||
|
} // namespace onnx_mlir
|
||||||
+36
@@ -0,0 +1,36 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/IR/Operation.h"
|
||||||
|
|
||||||
|
#include <cstddef>
|
||||||
|
|
||||||
|
#include "MergeSchedule.hpp"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace spatial {
|
||||||
|
|
||||||
|
enum class MergeSchedulerKind {
|
||||||
|
Dcp,
|
||||||
|
Peft,
|
||||||
|
};
|
||||||
|
|
||||||
|
struct MergeSchedulingOptions {
|
||||||
|
MergeSchedulerKind kind = MergeSchedulerKind::Peft;
|
||||||
|
size_t processorCount = 0;
|
||||||
|
bool allowDcpFallbackForAutoCoreCount = true;
|
||||||
|
};
|
||||||
|
|
||||||
|
class MergeSchedulingAnalysis {
|
||||||
|
public:
|
||||||
|
explicit MergeSchedulingAnalysis(mlir::Operation *op);
|
||||||
|
MergeScheduleResult &getResult() { return result; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
mlir::Operation *entryOp = nullptr;
|
||||||
|
MergeScheduleResult result;
|
||||||
|
|
||||||
|
MergeScheduleResult run();
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace spatial
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,303 @@
|
|||||||
|
#include "mlir/IR/Threading.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
#include "llvm/Support/ErrorHandling.h"
|
||||||
|
#include "llvm/Support/FormatVariadic.h"
|
||||||
|
|
||||||
|
#include <limits>
|
||||||
|
#include <queue>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "PeftScheduler.hpp"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace spatial {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
struct ScheduledTask {
|
||||||
|
size_t processor = std::numeric_limits<size_t>::max();
|
||||||
|
Time startTime = 0;
|
||||||
|
Time endTime = 0;
|
||||||
|
size_t slot = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<std::vector<size_t>> buildReverseLevels(const ComputeGraph& graph) {
|
||||||
|
std::vector<size_t> remainingSuccessors(graph.nodes.size(), 0);
|
||||||
|
std::queue<size_t> readySinks;
|
||||||
|
std::vector<std::vector<size_t>> reverseLevels;
|
||||||
|
|
||||||
|
for (size_t node = 0; node < graph.nodes.size(); ++node) {
|
||||||
|
remainingSuccessors[node] = graph.successors[node].size();
|
||||||
|
if (remainingSuccessors[node] == 0)
|
||||||
|
readySinks.push(node);
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t levelizedCount = 0;
|
||||||
|
while (!readySinks.empty()) {
|
||||||
|
size_t levelSize = readySinks.size();
|
||||||
|
std::vector<size_t> levelNodes;
|
||||||
|
levelNodes.reserve(levelSize);
|
||||||
|
for (size_t i = 0; i < levelSize; ++i) {
|
||||||
|
size_t node = readySinks.front();
|
||||||
|
readySinks.pop();
|
||||||
|
levelNodes.push_back(node);
|
||||||
|
++levelizedCount;
|
||||||
|
for (const auto& [pred, weight] : graph.predecessors[node]) {
|
||||||
|
assert(remainingSuccessors[pred] > 0 && "remaining successor count underflow");
|
||||||
|
if (--remainingSuccessors[pred] == 0)
|
||||||
|
readySinks.push(pred);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
reverseLevels.push_back(std::move(levelNodes));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (levelizedCount != graph.nodes.size())
|
||||||
|
llvm::report_fatal_error("PEFT scheduler: compute graph is cyclic or malformed");
|
||||||
|
|
||||||
|
return reverseLevels;
|
||||||
|
}
|
||||||
|
|
||||||
|
void verifyOctTableSize(size_t nodeCount, size_t processorCount) {
|
||||||
|
constexpr size_t kMaxOctTableBytes = 1ull << 30;
|
||||||
|
if (nodeCount == 0 || processorCount == 0)
|
||||||
|
return;
|
||||||
|
if (processorCount > std::numeric_limits<size_t>::max() / sizeof(Time))
|
||||||
|
llvm::report_fatal_error("PEFT scheduler: OCT table size overflow");
|
||||||
|
size_t rowBytes = processorCount * sizeof(Time);
|
||||||
|
if (nodeCount > std::numeric_limits<size_t>::max() / rowBytes)
|
||||||
|
llvm::report_fatal_error("PEFT scheduler: OCT table size overflow");
|
||||||
|
size_t totalBytes = nodeCount * rowBytes;
|
||||||
|
if (totalBytes > kMaxOctTableBytes) {
|
||||||
|
std::string message = llvm::formatv("PEFT scheduler: OCT table would require {0} MiB, exceeding the 1024 MiB guard",
|
||||||
|
totalBytes / (1024 * 1024))
|
||||||
|
.str();
|
||||||
|
llvm::report_fatal_error(llvm::StringRef(message));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftScheduleOptions& options) {
|
||||||
|
const size_t nodeCount = graph.nodes.size();
|
||||||
|
const size_t processorCount = options.processorCount;
|
||||||
|
if (processorCount == 0)
|
||||||
|
llvm::report_fatal_error("PEFT scheduler: processor count must be positive");
|
||||||
|
|
||||||
|
verifyOctTableSize(nodeCount, processorCount);
|
||||||
|
std::vector<std::vector<size_t>> reverseLevels = buildReverseLevels(graph);
|
||||||
|
|
||||||
|
// MOCK: Replace this with your actual heterogeneous cost lookup.
|
||||||
|
// If graph.nodes[task] is modified to hold a vector of weights per processor, access it here.
|
||||||
|
auto getComputeCost = [&](size_t task, size_t processor) -> Time { return graph.nodes[task].weight; };
|
||||||
|
|
||||||
|
std::vector<Time> oct(nodeCount * processorCount, 0);
|
||||||
|
std::vector<Time> minOctPlusComp(nodeCount, 0);
|
||||||
|
|
||||||
|
// 1. O(P(E+V)) Heterogeneous OCT Calculation
|
||||||
|
for (const std::vector<size_t>& levelNodes : reverseLevels) {
|
||||||
|
auto computeNodeOct = [&](size_t levelIndex) {
|
||||||
|
size_t task = levelNodes[levelIndex];
|
||||||
|
std::vector<Time> maxVals(processorCount, 0);
|
||||||
|
|
||||||
|
for (const auto& [succ, comm] : graph.successors[task]) {
|
||||||
|
Time valDifferentCpu = addOrMax(minOctPlusComp[succ], comm);
|
||||||
|
for (size_t processor = 0; processor < processorCount; ++processor) {
|
||||||
|
Time valSameCpu = addOrMax(oct[succ * processorCount + processor], getComputeCost(succ, processor));
|
||||||
|
Time bestSucc = std::min(valSameCpu, valDifferentCpu);
|
||||||
|
maxVals[processor] = std::max(maxVals[processor], bestSucc);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Time minForPreds = std::numeric_limits<Time>::max();
|
||||||
|
for (size_t processor = 0; processor < processorCount; ++processor) {
|
||||||
|
oct[task * processorCount + processor] = maxVals[processor];
|
||||||
|
minForPreds = std::min(minForPreds, addOrMax(maxVals[processor], getComputeCost(task, processor)));
|
||||||
|
}
|
||||||
|
minOctPlusComp[task] = minForPreds == std::numeric_limits<Time>::max() ? 0 : minForPreds;
|
||||||
|
};
|
||||||
|
|
||||||
|
if (options.context != nullptr)
|
||||||
|
mlir::parallelFor(options.context, 0, levelNodes.size(), computeNodeOct);
|
||||||
|
else
|
||||||
|
for (size_t i = 0; i < levelNodes.size(); ++i)
|
||||||
|
computeNodeOct(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct RankEntry {
|
||||||
|
long double rank = 0.0L;
|
||||||
|
size_t node = 0;
|
||||||
|
size_t originalOrder = 0;
|
||||||
|
};
|
||||||
|
std::vector<RankEntry> ranks(nodeCount);
|
||||||
|
auto computeRank = [&](size_t node) {
|
||||||
|
long double rank = 0.0L;
|
||||||
|
for (size_t processor = 0; processor < processorCount; ++processor)
|
||||||
|
rank += static_cast<long double>(oct[node * processorCount + processor]);
|
||||||
|
ranks[node] = {rank, node, graph.nodes[node].originalOrder};
|
||||||
|
};
|
||||||
|
|
||||||
|
if (options.context != nullptr)
|
||||||
|
mlir::parallelFor(options.context, 0, nodeCount, computeRank);
|
||||||
|
else
|
||||||
|
for (size_t node = 0; node < nodeCount; ++node)
|
||||||
|
computeRank(node);
|
||||||
|
|
||||||
|
auto readyCompare = [&](size_t lhs, size_t rhs) {
|
||||||
|
const RankEntry& lhsRank = ranks[lhs];
|
||||||
|
const RankEntry& rhsRank = ranks[rhs];
|
||||||
|
if (lhsRank.rank != rhsRank.rank)
|
||||||
|
return lhsRank.rank < rhsRank.rank;
|
||||||
|
if (lhsRank.originalOrder != rhsRank.originalOrder)
|
||||||
|
return lhsRank.originalOrder > rhsRank.originalOrder;
|
||||||
|
return lhs > rhs;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<int> remainingParents(nodeCount, 0);
|
||||||
|
std::priority_queue<size_t, std::vector<size_t>, decltype(readyCompare)> readyQueue(readyCompare);
|
||||||
|
for (size_t node = 0; node < nodeCount; ++node) {
|
||||||
|
remainingParents[node] = graph.predecessors[node].size();
|
||||||
|
if (remainingParents[node] == 0)
|
||||||
|
readyQueue.push(node);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<char> scheduled(nodeCount, false);
|
||||||
|
std::vector<CrossbarUsage> processorCrossbars(processorCount, 0);
|
||||||
|
std::vector<ScheduledTask> schedules(nodeCount);
|
||||||
|
std::vector<std::vector<size_t>> tasksByProcessor(processorCount);
|
||||||
|
|
||||||
|
size_t scheduledCount = 0;
|
||||||
|
while (!readyQueue.empty()) {
|
||||||
|
size_t task = readyQueue.top();
|
||||||
|
readyQueue.pop();
|
||||||
|
if (scheduled[task])
|
||||||
|
continue;
|
||||||
|
|
||||||
|
size_t bestProcessor = std::numeric_limits<size_t>::max();
|
||||||
|
Time bestEst = 0;
|
||||||
|
Time bestEft = 0;
|
||||||
|
Time bestOeft = std::numeric_limits<Time>::max();
|
||||||
|
bool crossbarRejected = false;
|
||||||
|
|
||||||
|
for (size_t processor = 0; processor < processorCount; ++processor) {
|
||||||
|
if (graph.nodes[task].crossbarUsage != 0
|
||||||
|
&& addOrMax(processorCrossbars[processor], graph.nodes[task].crossbarUsage) > options.crossbarCapacity) {
|
||||||
|
crossbarRejected = true;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
Time dataReady = 0;
|
||||||
|
for (const auto& [pred, comm] : graph.predecessors[task]) {
|
||||||
|
const ScheduledTask& predSchedule = schedules[pred];
|
||||||
|
Time commPenalty = predSchedule.processor == processor ? 0 : comm;
|
||||||
|
dataReady = std::max(dataReady, addOrMax(predSchedule.endTime, commPenalty));
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. PEFT Gap-Filling EST Calculation (Maintains optimal scheduling math)
|
||||||
|
Time compWeight = getComputeCost(task, processor);
|
||||||
|
Time est = dataReady;
|
||||||
|
Time currentEnd = 0;
|
||||||
|
bool foundGap = false;
|
||||||
|
|
||||||
|
for (size_t schedTaskIndex : tasksByProcessor[processor]) {
|
||||||
|
const ScheduledTask& schedTask = schedules[schedTaskIndex];
|
||||||
|
Time gapStart = std::max(currentEnd, dataReady);
|
||||||
|
|
||||||
|
if (addOrMax(gapStart, compWeight) <= schedTask.startTime) {
|
||||||
|
est = gapStart;
|
||||||
|
foundGap = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
currentEnd = schedTask.endTime;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!foundGap)
|
||||||
|
est = std::max(currentEnd, dataReady);
|
||||||
|
|
||||||
|
Time eft = addOrMax(est, compWeight);
|
||||||
|
Time oeft = addOrMax(eft, oct[task * processorCount + processor]);
|
||||||
|
|
||||||
|
if (oeft < bestOeft || (oeft == bestOeft && eft < bestEft)
|
||||||
|
|| (oeft == bestOeft && eft == bestEft && est < bestEst)
|
||||||
|
|| (oeft == bestOeft && eft == bestEft && est == bestEst && processor < bestProcessor)) {
|
||||||
|
bestProcessor = processor;
|
||||||
|
bestEst = est;
|
||||||
|
bestEft = eft;
|
||||||
|
bestOeft = oeft;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (bestProcessor == std::numeric_limits<size_t>::max()) {
|
||||||
|
if (crossbarRejected) {
|
||||||
|
std::string message =
|
||||||
|
llvm::formatv("PEFT scheduler: no valid processor for task {0}; crossbar capacity {1} is exhausted",
|
||||||
|
graph.nodes[task].originalOrder,
|
||||||
|
options.crossbarCapacity)
|
||||||
|
.str();
|
||||||
|
llvm::report_fatal_error(llvm::StringRef(message));
|
||||||
|
}
|
||||||
|
std::string message = llvm::formatv("PEFT scheduler: no valid processor for task {0} with {1} processors",
|
||||||
|
graph.nodes[task].originalOrder,
|
||||||
|
processorCount)
|
||||||
|
.str();
|
||||||
|
llvm::report_fatal_error(llvm::StringRef(message));
|
||||||
|
}
|
||||||
|
|
||||||
|
schedules[task] = {bestProcessor, bestEst, bestEft, 0};
|
||||||
|
scheduled[task] = true;
|
||||||
|
++scheduledCount;
|
||||||
|
processorCrossbars[bestProcessor] = addOrMax(processorCrossbars[bestProcessor], graph.nodes[task].crossbarUsage);
|
||||||
|
|
||||||
|
// 3. CRITICAL FIX: Topological Append
|
||||||
|
// Because the readyQueue pops in strict topological order, simply pushing to the
|
||||||
|
// back guarantees the Monoliths will be physically generated cycle-free.
|
||||||
|
// The hardware will still benefit from the processor assignment chosen by PEFT.
|
||||||
|
tasksByProcessor[bestProcessor].push_back(task);
|
||||||
|
|
||||||
|
for (const auto& [child, weight] : graph.successors[task]) {
|
||||||
|
(void) weight;
|
||||||
|
assert(remainingParents[child] > 0 && "remaining parent count underflow");
|
||||||
|
if (--remainingParents[child] == 0)
|
||||||
|
readyQueue.push(child);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (scheduledCount != nodeCount)
|
||||||
|
llvm::report_fatal_error("PEFT scheduler: failed to schedule every compute node");
|
||||||
|
|
||||||
|
// 4. Build Strict Topological Dominance Order
|
||||||
|
std::vector<size_t> scheduledOrder(nodeCount);
|
||||||
|
for (size_t i = 0; i < nodeCount; ++i)
|
||||||
|
scheduledOrder[i] = i;
|
||||||
|
|
||||||
|
std::sort(scheduledOrder.begin(), scheduledOrder.end(), [&](size_t a, size_t b) {
|
||||||
|
return graph.nodes[a].originalOrder < graph.nodes[b].originalOrder;
|
||||||
|
});
|
||||||
|
|
||||||
|
// 5. Populate Final Result
|
||||||
|
MergeScheduleResult result;
|
||||||
|
result.dominanceOrderCompute.reserve(nodeCount);
|
||||||
|
|
||||||
|
for (size_t task : scheduledOrder)
|
||||||
|
result.dominanceOrderCompute.push_back(graph.nodes[task].instance);
|
||||||
|
|
||||||
|
for (size_t processor = 0; processor < processorCount; ++processor) {
|
||||||
|
size_t currentSlot = 0;
|
||||||
|
for (size_t task : tasksByProcessor[processor]) {
|
||||||
|
const ComputeInstance instance = graph.nodes[task].instance;
|
||||||
|
result.computeToCpuMap[instance] = processor;
|
||||||
|
result.computeToCpuSlotMap[instance] = currentSlot++;
|
||||||
|
result.computeToAestMap[instance] = schedules[task].startTime;
|
||||||
|
}
|
||||||
|
if (!tasksByProcessor[processor].empty()) {
|
||||||
|
const ComputeInstance lastInstance = graph.nodes[tasksByProcessor[processor].back()].instance;
|
||||||
|
result.cpuToLastComputeMap[processor] = lastInstance;
|
||||||
|
result.isLastComputeOfCpu.insert(lastInstance);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
} // namespace spatial
|
||||||
|
} // namespace onnx_mlir
|
||||||
|
|
||||||
@@ -0,0 +1,20 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/IR/MLIRContext.h"
|
||||||
|
|
||||||
|
#include "ComputeGraph.hpp"
|
||||||
|
#include "MergeSchedule.hpp"
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace spatial {
|
||||||
|
|
||||||
|
struct PeftScheduleOptions {
|
||||||
|
size_t processorCount = 0;
|
||||||
|
CrossbarUsage crossbarCapacity = 0;
|
||||||
|
mlir::MLIRContext *context = nullptr;
|
||||||
|
};
|
||||||
|
|
||||||
|
MergeScheduleResult runPeftScheduler(const ComputeGraph &graph, const PeftScheduleOptions &options);
|
||||||
|
|
||||||
|
} // namespace spatial
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -24,9 +24,12 @@ struct EmitPimCodePass : PassWrapper<EmitPimCodePass, OperationPass<ModuleOp>> {
|
|||||||
createDirectory(pimDir);
|
createDirectory(pimDir);
|
||||||
|
|
||||||
int compiler_error_code = compileToPimCode(moduleOp, pimDir);
|
int compiler_error_code = compileToPimCode(moduleOp, pimDir);
|
||||||
if (compiler_error_code != CompilerSuccess)
|
if (compiler_error_code != CompilerSuccess) {
|
||||||
|
moduleOp.emitError() << "failed to emit PIM simulator code artifacts; compiler error code "
|
||||||
|
<< compiler_error_code;
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
}
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|||||||
@@ -32,14 +32,16 @@ struct HostConstantFoldingPass : PassWrapper<HostConstantFoldingPass, OperationP
|
|||||||
}
|
}
|
||||||
|
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
|
ModuleOp moduleOp = getOperation();
|
||||||
GreedyRewriteConfig config;
|
GreedyRewriteConfig config;
|
||||||
config.enableFolding();
|
config.enableFolding();
|
||||||
if (failed(applyPatternsGreedily(getOperation(), *patterns, config))) {
|
if (failed(applyPatternsGreedily(moduleOp, *patterns, config))) {
|
||||||
|
moduleOp.emitError("PIM host constant folding failed in the greedy rewrite driver");
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
dumpModule(getOperation(), "pim3_folded");
|
dumpModule(moduleOp, "pim3_folded");
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<const FrozenRewritePatternSet> patterns;
|
std::shared_ptr<const FrozenRewritePatternSet> patterns;
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||||
|
|
||||||
#include "../Common.hpp"
|
#include "../Common.hpp"
|
||||||
#include "../Patterns.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
|
|
||||||
@@ -66,6 +66,83 @@ static Value buildSubviewChunk(const StaticSubviewInfo& info,
|
|||||||
return memref::SubViewOp::create(rewriter, loc, info.source, chunkOffsets, chunkSizes, chunkStrides);
|
return memref::SubViewOp::create(rewriter, loc, info.source, chunkOffsets, chunkSizes, chunkStrides);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static SmallVector<Value> delinearizeIndexValue(Value linearIndex,
|
||||||
|
ArrayRef<int64_t> shape,
|
||||||
|
ArrayRef<int64_t> strides,
|
||||||
|
PatternRewriter& rewriter) {
|
||||||
|
SmallVector<Value> indices;
|
||||||
|
indices.reserve(shape.size());
|
||||||
|
|
||||||
|
Value remaining = linearIndex;
|
||||||
|
for (auto [_dim, stride] : llvm::enumerate(strides)) {
|
||||||
|
auto cStride = arith::ConstantIndexOp::create(rewriter, linearIndex.getLoc(), stride);
|
||||||
|
Value index = arith::DivUIOp::create(rewriter, linearIndex.getLoc(), remaining, cStride);
|
||||||
|
indices.push_back(index);
|
||||||
|
remaining = arith::RemUIOp::create(rewriter, linearIndex.getLoc(), remaining, cStride);
|
||||||
|
}
|
||||||
|
|
||||||
|
return indices;
|
||||||
|
}
|
||||||
|
|
||||||
|
static OpFoldResult addDynamicOffset(OpFoldResult baseOffset, Value extraOffset, PatternRewriter& rewriter) {
|
||||||
|
if (auto attr = dyn_cast<Attribute>(baseOffset)) {
|
||||||
|
auto integerAttr = cast<IntegerAttr>(attr);
|
||||||
|
if (integerAttr.getInt() == 0)
|
||||||
|
return extraOffset;
|
||||||
|
|
||||||
|
auto cst = arith::ConstantIndexOp::create(rewriter, extraOffset.getLoc(), integerAttr.getInt());
|
||||||
|
return arith::AddIOp::create(rewriter, extraOffset.getLoc(), cst, extraOffset).getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto value = cast<Value>(baseOffset);
|
||||||
|
return arith::AddIOp::create(rewriter, value.getLoc(), value, extraOffset).getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value buildDynamicSubviewChunk(const StaticSubviewInfo& info,
|
||||||
|
ArrayRef<Value> outerIndices,
|
||||||
|
Location loc,
|
||||||
|
PatternRewriter& rewriter) {
|
||||||
|
SmallVector<OpFoldResult> chunkOffsets;
|
||||||
|
SmallVector<OpFoldResult> chunkSizes;
|
||||||
|
SmallVector<OpFoldResult> chunkStrides;
|
||||||
|
chunkOffsets.reserve(info.offsets.size());
|
||||||
|
chunkSizes.reserve(info.sizes.size());
|
||||||
|
chunkStrides.reserve(info.strides.size());
|
||||||
|
|
||||||
|
for (size_t dim = 0; dim < info.sizes.size(); ++dim) {
|
||||||
|
if (dim + 1 < info.sizes.size()) {
|
||||||
|
assert(info.strides[dim] == 1 && "loop-based subview rewrite requires unit strides");
|
||||||
|
chunkOffsets.push_back(addDynamicOffset(info.offsets[dim], outerIndices[dim], rewriter));
|
||||||
|
chunkSizes.push_back(rewriter.getIndexAttr(1));
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
chunkOffsets.push_back(info.offsets[dim]);
|
||||||
|
chunkSizes.push_back(rewriter.getIndexAttr(info.sizes.back()));
|
||||||
|
}
|
||||||
|
chunkStrides.push_back(rewriter.getIndexAttr(info.strides[dim]));
|
||||||
|
}
|
||||||
|
|
||||||
|
return memref::SubViewOp::create(rewriter, loc, info.source, chunkOffsets, chunkSizes, chunkStrides);
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value buildContiguousChunk(
|
||||||
|
Value source, ArrayRef<int64_t> copyShape, ArrayRef<Value> outerIndices, Location loc, PatternRewriter& rewriter) {
|
||||||
|
SmallVector<OpFoldResult> chunkOffsets;
|
||||||
|
SmallVector<OpFoldResult> chunkSizes;
|
||||||
|
SmallVector<OpFoldResult> chunkStrides;
|
||||||
|
chunkOffsets.reserve(copyShape.size());
|
||||||
|
chunkSizes.reserve(copyShape.size());
|
||||||
|
chunkStrides.reserve(copyShape.size());
|
||||||
|
|
||||||
|
for (size_t dim = 0; dim < copyShape.size(); ++dim) {
|
||||||
|
chunkOffsets.push_back(dim + 1 < copyShape.size() ? OpFoldResult(outerIndices[dim]) : rewriter.getIndexAttr(0));
|
||||||
|
chunkSizes.push_back(rewriter.getIndexAttr(dim + 1 < copyShape.size() ? 1 : copyShape.back()));
|
||||||
|
chunkStrides.push_back(rewriter.getIndexAttr(1));
|
||||||
|
}
|
||||||
|
|
||||||
|
return memref::SubViewOp::create(rewriter, loc, source, chunkOffsets, chunkSizes, chunkStrides);
|
||||||
|
}
|
||||||
|
|
||||||
template <typename CopyOp, typename CreateCopyOp>
|
template <typename CopyOp, typename CreateCopyOp>
|
||||||
static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp,
|
static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp,
|
||||||
Value dst,
|
Value dst,
|
||||||
@@ -73,6 +150,7 @@ static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp,
|
|||||||
int64_t dstOffset,
|
int64_t dstOffset,
|
||||||
int64_t srcOffset,
|
int64_t srcOffset,
|
||||||
int64_t size,
|
int64_t size,
|
||||||
|
bool allowLoopRewrite,
|
||||||
PatternRewriter& rewriter,
|
PatternRewriter& rewriter,
|
||||||
CreateCopyOp createCopyOp) {
|
CreateCopyOp createCopyOp) {
|
||||||
auto srcSubview = getStaticSubviewInfo(src);
|
auto srcSubview = getStaticSubviewInfo(src);
|
||||||
@@ -114,6 +192,27 @@ static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp,
|
|||||||
auto outerStrides = computeRowMajorStrides(outerShape);
|
auto outerStrides = computeRowMajorStrides(outerShape);
|
||||||
const int64_t numSlices = outerShape.empty() ? 1 : getNumElements(outerShape);
|
const int64_t numSlices = outerShape.empty() ? 1 : getNumElements(outerShape);
|
||||||
|
|
||||||
|
if (allowLoopRewrite && numSlices > 1 && srcOffset == 0 && dstOffset == 0
|
||||||
|
&& sourceType.getRank() == static_cast<int64_t>(copyShape.size())
|
||||||
|
&& dstType.getRank() == static_cast<int64_t>(copyShape.size())) {
|
||||||
|
auto c0 = arith::ConstantIndexOp::create(rewriter, copyOp.getLoc(), 0);
|
||||||
|
auto cUpper = arith::ConstantIndexOp::create(rewriter, copyOp.getLoc(), numSlices);
|
||||||
|
auto cStep = arith::ConstantIndexOp::create(rewriter, copyOp.getLoc(), 1);
|
||||||
|
|
||||||
|
auto loop = scf::ForOp::create(rewriter, copyOp.getLoc(), c0, cUpper, cStep, ValueRange {});
|
||||||
|
rewriter.setInsertionPointToStart(loop.getBody());
|
||||||
|
|
||||||
|
SmallVector<Value> outerIndices =
|
||||||
|
outerShape.empty() ? SmallVector<Value> {}
|
||||||
|
: delinearizeIndexValue(loop.getInductionVar(), outerShape, outerStrides, rewriter);
|
||||||
|
Value chunkDst = splitDst ? buildDynamicSubviewChunk(*dstSubview, outerIndices, copyOp.getLoc(), rewriter)
|
||||||
|
: buildContiguousChunk(dst, copyShape, outerIndices, copyOp.getLoc(), rewriter);
|
||||||
|
Value chunkSrc = splitSrc ? buildDynamicSubviewChunk(*srcSubview, outerIndices, copyOp.getLoc(), rewriter)
|
||||||
|
: buildContiguousChunk(src, copyShape, outerIndices, copyOp.getLoc(), rewriter);
|
||||||
|
createCopyOp(cast<MemRefType>(chunkDst.getType()), chunkDst, chunkSrc, 0, 0, sliceBytes);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
rewriter.setInsertionPoint(copyOp);
|
rewriter.setInsertionPoint(copyOp);
|
||||||
for (int64_t linearIndex = 0; linearIndex < numSlices; ++linearIndex) {
|
for (int64_t linearIndex = 0; linearIndex < numSlices; ++linearIndex) {
|
||||||
SmallVector<int64_t> outerIndices =
|
SmallVector<int64_t> outerIndices =
|
||||||
@@ -143,6 +242,7 @@ struct RewriteCoreSubviewCopyPattern final : OpRewritePattern<pim::PimMemCopyOp>
|
|||||||
copyOp.getTargetOffset(),
|
copyOp.getTargetOffset(),
|
||||||
copyOp.getSourceOffset(),
|
copyOp.getSourceOffset(),
|
||||||
copyOp.getSize(),
|
copyOp.getSize(),
|
||||||
|
/*allowLoopRewrite=*/true,
|
||||||
rewriter,
|
rewriter,
|
||||||
[&](
|
[&](
|
||||||
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
|
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
|
||||||
@@ -175,6 +275,7 @@ struct RewriteHostSubviewLoadPattern final : OpRewritePattern<pim::PimMemCopyHos
|
|||||||
copyOp.getDeviceTargetOffset(),
|
copyOp.getDeviceTargetOffset(),
|
||||||
copyOp.getHostSourceOffset(),
|
copyOp.getHostSourceOffset(),
|
||||||
copyOp.getSize(),
|
copyOp.getSize(),
|
||||||
|
/*allowLoopRewrite=*/true,
|
||||||
rewriter,
|
rewriter,
|
||||||
[&](
|
[&](
|
||||||
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
|
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
|
||||||
@@ -207,6 +308,7 @@ struct RewriteHostSubviewStorePattern final : OpRewritePattern<pim::PimMemCopyDe
|
|||||||
copyOp.getHostTargetOffset(),
|
copyOp.getHostTargetOffset(),
|
||||||
copyOp.getDeviceSourceOffset(),
|
copyOp.getDeviceSourceOffset(),
|
||||||
copyOp.getSize(),
|
copyOp.getSize(),
|
||||||
|
/*allowLoopRewrite=*/false,
|
||||||
rewriter,
|
rewriter,
|
||||||
[&](
|
[&](
|
||||||
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
|
MemRefType resultType, Value dst, Value src, int64_t dstByteOffset, int64_t srcByteOffset, int64_t sliceBytes) {
|
||||||
|
|||||||
@@ -160,6 +160,7 @@ struct MaterializeHostConstantsPass : PassWrapper<MaterializeHostConstantsPass,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (hasFailure) {
|
if (hasFailure) {
|
||||||
|
moduleOp.emitError("PIM host-constant materialization failed; see diagnostics above");
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,10 +6,11 @@
|
|||||||
|
|
||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
|
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
@@ -67,6 +68,14 @@ static bool isCodegenAddressableValue(Value value) {
|
|||||||
|| isa<memref::AllocOp, memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
|
|| isa<memref::AllocOp, memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool isCodegenAddressableValue(Value value, const StaticValueKnowledge& knowledge) {
|
||||||
|
auto resolvedAddress = resolveContiguousAddress(value, knowledge);
|
||||||
|
if (failed(resolvedAddress))
|
||||||
|
return false;
|
||||||
|
return isa<BlockArgument>(resolvedAddress->base)
|
||||||
|
|| isa<memref::AllocOp, memref::GetGlobalOp>(resolvedAddress->base.getDefiningOp());
|
||||||
|
}
|
||||||
|
|
||||||
static bool isConstantGlobalView(Value value) {
|
static bool isConstantGlobalView(Value value) {
|
||||||
while (true) {
|
while (true) {
|
||||||
Operation* defOp = value.getDefiningOp();
|
Operation* defOp = value.getDefiningOp();
|
||||||
@@ -88,6 +97,22 @@ static bool isConstantGlobalView(Value value) {
|
|||||||
value = cast.getSource();
|
value = cast.getSource();
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
if (auto collapse = dyn_cast<memref::CollapseShapeOp>(defOp)) {
|
||||||
|
auto srcType = dyn_cast<MemRefType>(collapse.getSrc().getType());
|
||||||
|
auto resultType = dyn_cast<MemRefType>(collapse.getResult().getType());
|
||||||
|
if (!srcType || !resultType || !srcType.hasStaticShape() || !resultType.hasStaticShape())
|
||||||
|
return false;
|
||||||
|
value = collapse.getSrc();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto expand = dyn_cast<memref::ExpandShapeOp>(defOp)) {
|
||||||
|
auto srcType = dyn_cast<MemRefType>(expand.getSrc().getType());
|
||||||
|
auto resultType = dyn_cast<MemRefType>(expand.getResult().getType());
|
||||||
|
if (!srcType || !resultType || !srcType.hasStaticShape() || !resultType.hasStaticShape())
|
||||||
|
return false;
|
||||||
|
value = expand.getSrc();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -144,14 +169,15 @@ struct VerificationPass : PassWrapper<VerificationPass, OperationPass<ModuleOp>>
|
|||||||
|
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
ModuleOp moduleOp = getOperation();
|
ModuleOp moduleOp = getOperation();
|
||||||
bool hasFailure = false;
|
pim::CappedDiagnosticReporter diagnostics;
|
||||||
|
|
||||||
moduleOp.walk([&](Operation* op) {
|
moduleOp.walk([&](Operation* op) {
|
||||||
if (op->getDialect()->getNamespace() != "spat")
|
if (op->getDialect()->getNamespace() != "spat")
|
||||||
return;
|
return;
|
||||||
|
|
||||||
op->emitError("illegal Spatial operation reached PIM codegen verification");
|
diagnostics.report(op, [](Operation* illegalOp) {
|
||||||
hasFailure = true;
|
illegalOp->emitError("illegal Spatial operation reached PIM codegen verification");
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
|
for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
|
||||||
@@ -160,49 +186,56 @@ struct VerificationPass : PassWrapper<VerificationPass, OperationPass<ModuleOp>>
|
|||||||
|
|
||||||
for (Operation& op : funcOp.getBody().front().getOperations()) {
|
for (Operation& op : funcOp.getBody().front().getOperations()) {
|
||||||
if (auto coreOp = dyn_cast<pim::PimCoreOp>(&op)) {
|
if (auto coreOp = dyn_cast<pim::PimCoreOp>(&op)) {
|
||||||
if (failed(verifyCoreWeights(moduleOp, coreOp)) || failed(verifyCoreOperands(coreOp)))
|
(void) verifyCoreWeights(moduleOp, coreOp, diagnostics);
|
||||||
hasFailure = true;
|
(void) verifyCoreOperands(coreOp, diagnostics);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto coreBatchOp = dyn_cast<pim::PimCoreBatchOp>(&op)) {
|
if (auto coreBatchOp = dyn_cast<pim::PimCoreBatchOp>(&op)) {
|
||||||
if (failed(verifyCoreWeights(moduleOp, coreBatchOp)) || failed(verifyCoreOperands(coreBatchOp)))
|
(void) verifyCoreWeights(moduleOp, coreBatchOp, diagnostics);
|
||||||
hasFailure = true;
|
(void) verifyCoreOperands(coreBatchOp, diagnostics);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto returnOp = dyn_cast<func::ReturnOp>(&op)) {
|
if (auto returnOp = dyn_cast<func::ReturnOp>(&op)) {
|
||||||
if (failed(verifyReturnOp(returnOp)))
|
(void) verifyReturnOp(returnOp, diagnostics);
|
||||||
hasFailure = true;
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!isAddressOnlyHostOp(&op)) {
|
if (!isAddressOnlyHostOp(&op)) {
|
||||||
op.emitOpError("illegal host-side runtime op remains after PIM bufferization; "
|
diagnostics.report(&op, [](Operation* illegalOp) {
|
||||||
|
illegalOp->emitOpError("illegal host-side runtime op remains after PIM bufferization; "
|
||||||
"fold it to constants or lower it into pim.core");
|
"fold it to constants or lower it into pim.core");
|
||||||
hasFailure = true;
|
});
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (failed(verifyAddressOnlyHostOp(&op)))
|
(void) verifyAddressOnlyHostOp(&op, diagnostics);
|
||||||
hasFailure = true;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (hasFailure)
|
if (diagnostics.hasFailure()) {
|
||||||
|
diagnostics.emitSuppressedSummary(moduleOp, "verification failures");
|
||||||
|
moduleOp.emitError("PIM codegen verification failed; see diagnostics above");
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
template <typename CoreOpTy>
|
template <typename CoreOpTy>
|
||||||
static LogicalResult verifyCoreWeights(ModuleOp moduleOp, CoreOpTy coreOp) {
|
static LogicalResult
|
||||||
|
verifyCoreWeights(ModuleOp moduleOp, CoreOpTy coreOp, pim::CappedDiagnosticReporter& diagnostics) {
|
||||||
bool hasFailure = false;
|
bool hasFailure = false;
|
||||||
for (auto [weightIndex, weight] : llvm::enumerate(coreOp.getWeights())) {
|
for (auto it : llvm::enumerate(coreOp.getWeights())) {
|
||||||
|
size_t weightIndex = it.index();
|
||||||
|
Value weight = it.value();
|
||||||
auto getGlobalOp = weight.template getDefiningOp<memref::GetGlobalOp>();
|
auto getGlobalOp = weight.template getDefiningOp<memref::GetGlobalOp>();
|
||||||
if (!getGlobalOp && !isConstantGlobalView(weight)) {
|
if (!getGlobalOp && !isConstantGlobalView(weight)) {
|
||||||
|
diagnostics.report(coreOp.getOperation(), [&](Operation*) {
|
||||||
coreOp.emitOpError() << "weight #" << weightIndex
|
coreOp.emitOpError() << "weight #" << weightIndex
|
||||||
<< " must be materialized as a constant memref.global or a static view of one before JSON "
|
<< " must be materialized as a constant memref.global or a static view of one before "
|
||||||
"codegen";
|
"JSON codegen";
|
||||||
|
});
|
||||||
hasFailure = true;
|
hasFailure = true;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@@ -212,14 +245,18 @@ private:
|
|||||||
|
|
||||||
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||||
if (!globalOp) {
|
if (!globalOp) {
|
||||||
|
diagnostics.report(coreOp.getOperation(), [&](Operation*) {
|
||||||
coreOp.emitOpError() << "weight #" << weightIndex << " references an unknown memref.global";
|
coreOp.emitOpError() << "weight #" << weightIndex << " references an unknown memref.global";
|
||||||
|
});
|
||||||
hasFailure = true;
|
hasFailure = true;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!globalOp.getConstant() || !globalOp.getInitialValue()) {
|
if (!globalOp.getConstant() || !globalOp.getInitialValue()) {
|
||||||
|
diagnostics.report(coreOp.getOperation(), [&](Operation*) {
|
||||||
coreOp.emitOpError() << "weight #" << weightIndex
|
coreOp.emitOpError() << "weight #" << weightIndex
|
||||||
<< " must come from a constant memref.global with an initial value";
|
<< " must come from a constant memref.global with an initial value";
|
||||||
|
});
|
||||||
hasFailure = true;
|
hasFailure = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -227,11 +264,15 @@ private:
|
|||||||
return success(!hasFailure);
|
return success(!hasFailure);
|
||||||
}
|
}
|
||||||
|
|
||||||
static LogicalResult verifyReturnOp(func::ReturnOp returnOp) {
|
static LogicalResult verifyReturnOp(func::ReturnOp returnOp, pim::CappedDiagnosticReporter& diagnostics) {
|
||||||
bool hasFailure = false;
|
bool hasFailure = false;
|
||||||
for (auto [resultIndex, operand] : llvm::enumerate(returnOp.getOperands())) {
|
for (auto it : llvm::enumerate(returnOp.getOperands())) {
|
||||||
|
size_t resultIndex = it.index();
|
||||||
|
Value operand = it.value();
|
||||||
if (!isCodegenAddressableValue(operand)) {
|
if (!isCodegenAddressableValue(operand)) {
|
||||||
|
diagnostics.report(returnOp.getOperation(), [&](Operation*) {
|
||||||
returnOp.emitOpError() << "result #" << resultIndex << " is not backed by contiguous addressable storage";
|
returnOp.emitOpError() << "result #" << resultIndex << " is not backed by contiguous addressable storage";
|
||||||
|
});
|
||||||
hasFailure = true;
|
hasFailure = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -239,38 +280,50 @@ private:
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename CoreOpTy>
|
template <typename CoreOpTy>
|
||||||
static LogicalResult verifyCoreOperands(CoreOpTy coreOp) {
|
static LogicalResult verifyCoreOperands(CoreOpTy coreOp, pim::CappedDiagnosticReporter& diagnostics) {
|
||||||
return walkPimCoreBlock(
|
return walkPimCoreBlock(
|
||||||
coreOp.getBody().front(), StaticValueKnowledge {}, [](Operation& op, const StaticValueKnowledge& knowledge) {
|
coreOp.getBody().front(), StaticValueKnowledge {}, [&](Operation& op, const StaticValueKnowledge& knowledge) {
|
||||||
bool hasFailure = false;
|
bool hasFailure = false;
|
||||||
if (!isSupportedCoreInstructionOp(&op)) {
|
if (!isSupportedCoreInstructionOp(&op)) {
|
||||||
op.emitOpError("unsupported executable op reached PIM codegen verification");
|
diagnostics.report(&op, [](Operation* illegalOp) {
|
||||||
|
illegalOp->emitOpError("unsupported executable op reached PIM codegen verification");
|
||||||
|
});
|
||||||
hasFailure = true;
|
hasFailure = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto [operandIndex, operand] : llvm::enumerate(op.getOperands())) {
|
for (auto it : llvm::enumerate(op.getOperands())) {
|
||||||
|
size_t operandIndex = it.index();
|
||||||
|
Value operand = it.value();
|
||||||
if (!isa<BaseMemRefType>(operand.getType()))
|
if (!isa<BaseMemRefType>(operand.getType()))
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
auto resolvedAddress = resolveContiguousAddress(operand, knowledge);
|
auto resolvedAddress = resolveContiguousAddress(operand, knowledge);
|
||||||
if (failed(resolvedAddress)) {
|
if (failed(resolvedAddress)) {
|
||||||
op.emitOpError() << "operand #" << operandIndex << " is not backed by contiguous addressable storage";
|
diagnostics.report(&op, [&](Operation* illegalOp) {
|
||||||
|
illegalOp->emitOpError() << "operand #" << operandIndex
|
||||||
|
<< " is not backed by contiguous addressable storage";
|
||||||
|
});
|
||||||
hasFailure = true;
|
hasFailure = true;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (isExplicitHostOperand(&op, operandIndex)) {
|
if (isExplicitHostOperand(&op, operandIndex)) {
|
||||||
if (!isCodegenAddressableValue(operand)) {
|
if (!isCodegenAddressableValue(operand, knowledge)) {
|
||||||
op.emitOpError() << "host operand #" << operandIndex
|
diagnostics.report(&op, [&](Operation* illegalOp) {
|
||||||
|
illegalOp->emitOpError() << "host operand #" << operandIndex
|
||||||
<< " is not backed by contiguous addressable storage";
|
<< " is not backed by contiguous addressable storage";
|
||||||
|
});
|
||||||
hasFailure = true;
|
hasFailure = true;
|
||||||
}
|
}
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!isa<memref::AllocOp>(resolvedAddress->base.getDefiningOp())) {
|
if (!isa<memref::AllocOp>(resolvedAddress->base.getDefiningOp())) {
|
||||||
op.emitOpError() << "operand #" << operandIndex
|
diagnostics.report(&op, [&](Operation* illegalOp) {
|
||||||
<< " must be backed by device-local memory; materialize host values with pim.memcp_hd";
|
illegalOp->emitOpError() << "operand #" << operandIndex
|
||||||
|
<< " must be backed by device-local memory; materialize host values with "
|
||||||
|
"pim.memcp_hd";
|
||||||
|
});
|
||||||
hasFailure = true;
|
hasFailure = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -278,18 +331,20 @@ private:
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
static LogicalResult verifyAddressOnlyHostOp(Operation* op) {
|
static LogicalResult verifyAddressOnlyHostOp(Operation* op, pim::CappedDiagnosticReporter& diagnostics) {
|
||||||
if (auto subviewOp = dyn_cast<memref::SubViewOp>(op))
|
if (auto subviewOp = dyn_cast<memref::SubViewOp>(op))
|
||||||
return verifyAddressOnlyBase(op, subviewOp.getSource());
|
return verifyAddressOnlyBase(op, subviewOp.getSource(), diagnostics);
|
||||||
if (auto castOp = dyn_cast<memref::CastOp>(op))
|
if (auto castOp = dyn_cast<memref::CastOp>(op))
|
||||||
return verifyAddressOnlySource(op, castOp.getSource());
|
return verifyAddressOnlySource(op, castOp.getSource(), diagnostics);
|
||||||
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(op))
|
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(op))
|
||||||
return verifyAddressOnlySource(op, collapseOp.getSrc());
|
return verifyAddressOnlySource(op, collapseOp.getSrc(), diagnostics);
|
||||||
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(op))
|
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(op))
|
||||||
return verifyAddressOnlySource(op, expandOp.getSrc());
|
return verifyAddressOnlySource(op, expandOp.getSrc(), diagnostics);
|
||||||
if (auto copyOp = dyn_cast<memref::CopyOp>(op)) {
|
if (auto copyOp = dyn_cast<memref::CopyOp>(op)) {
|
||||||
if (!isBaseAddressableValue(copyOp.getSource()) || !isBaseAddressableValue(copyOp.getTarget())) {
|
if (!isBaseAddressableValue(copyOp.getSource()) || !isBaseAddressableValue(copyOp.getTarget())) {
|
||||||
op->emitOpError("depends on a value that is not backed by addressable storage");
|
diagnostics.report(op, [](Operation* illegalOp) {
|
||||||
|
illegalOp->emitOpError("depends on a value that is not backed by addressable storage");
|
||||||
|
});
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
return success();
|
return success();
|
||||||
@@ -297,19 +352,24 @@ private:
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
static LogicalResult verifyAddressOnlySource(Operation* op, Value source) {
|
static LogicalResult
|
||||||
|
verifyAddressOnlySource(Operation* op, Value source, pim::CappedDiagnosticReporter& diagnostics) {
|
||||||
if (isCodegenAddressableValue(source))
|
if (isCodegenAddressableValue(source))
|
||||||
return success();
|
return success();
|
||||||
|
|
||||||
op->emitOpError("depends on a value that is not backed by contiguous addressable storage");
|
diagnostics.report(op, [](Operation* illegalOp) {
|
||||||
|
illegalOp->emitOpError("depends on a value that is not backed by contiguous addressable storage");
|
||||||
|
});
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
static LogicalResult verifyAddressOnlyBase(Operation* op, Value source) {
|
static LogicalResult verifyAddressOnlyBase(Operation* op, Value source, pim::CappedDiagnosticReporter& diagnostics) {
|
||||||
if (isBaseAddressableValue(source))
|
if (isBaseAddressableValue(source))
|
||||||
return success();
|
return success();
|
||||||
|
|
||||||
op->emitOpError("depends on a value that is not backed by addressable storage");
|
diagnostics.report(op, [](Operation* illegalOp) {
|
||||||
|
illegalOp->emitOpError("depends on a value that is not backed by addressable storage");
|
||||||
|
});
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user