Compare commits
9 Commits
85e2750d6c
...
multiple-o
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
87922d994f | ||
|
|
0f13269040 | ||
|
|
dafc1d15b7 | ||
|
|
3fa140be25 | ||
|
|
df703f0be9 | ||
|
|
9fa850c140 | ||
|
|
186c88d860 | ||
|
|
0368f96593 | ||
|
|
25ade1bd63 |
@@ -1,10 +1,13 @@
|
||||
use anyhow::{bail, Context, Result};
|
||||
use anyhow::{Context, Result, bail};
|
||||
use clap::Parser;
|
||||
use glob::glob;
|
||||
use pimcore::cpu::crossbar::Crossbar;
|
||||
use pimcore::json_to_instruction::json_to_executor;
|
||||
use pimcore::memory_manager::CoreMemory;
|
||||
use pimcore::tracing::TRACER;
|
||||
use serde_json::Value;
|
||||
use std::fs;
|
||||
use std::collections::HashMap;
|
||||
use std::fs::{self, read_link};
|
||||
use std::io::Write;
|
||||
use std::path::PathBuf;
|
||||
|
||||
@@ -43,8 +46,10 @@ fn main() -> Result<()> {
|
||||
let config_json = retrive_config(&args)?;
|
||||
let core_jsons = retrive_cores(&args)?;
|
||||
let memory = retrive_memory(&args)?;
|
||||
let mut executor = json_to_executor::json_to_executor(config_json, core_jsons.iter());
|
||||
populate_crossbar(&args, &mut executor);
|
||||
let global_crossbars = get_crossbars(&config_json, &args).unwrap();
|
||||
let crossbars = map_crossbars_to_cores(&config_json, &args, &global_crossbars);
|
||||
let mut executor =
|
||||
json_to_executor::json_to_executor(config_json, core_jsons.iter(), crossbars);
|
||||
set_memory(&mut executor, memory);
|
||||
TRACER
|
||||
.lock()
|
||||
@@ -55,46 +60,100 @@ fn main() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn populate_crossbar(args: &Args, executor: &mut pimcore::Executable) {
|
||||
let num_cores = executor.cpu_mut().num_core();
|
||||
fn map_crossbars_to_cores<'c>(
|
||||
config: &Value,
|
||||
args: &Args,
|
||||
global_crossbars: &'c HashMap<String, Crossbar>,
|
||||
) -> Vec<Vec<&'c Crossbar>> {
|
||||
let mut res = Vec::new();
|
||||
let num_cores = config.get("core_cnt").unwrap().as_i64().unwrap() as i32;
|
||||
|
||||
if let Some(folder) = args.folder.as_ref() {
|
||||
for core_idx in 0..num_cores {
|
||||
let core_folder = folder.join(format!("core_{}", core_idx));
|
||||
res.push(Vec::new());
|
||||
|
||||
if !core_folder.is_dir() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut bin_files: Vec<(u32, std::path::PathBuf)> = std::fs::read_dir(&core_folder)
|
||||
.expect("Failed to read core directory")
|
||||
.filter_map(|entry| {
|
||||
let path = entry.ok()?.path();
|
||||
let file_name = path.file_name()?.to_str()?;
|
||||
let mut sym_link_files: Vec<(u32, std::path::PathBuf)> =
|
||||
std::fs::read_dir(&core_folder)
|
||||
.expect("Failed to read core directory")
|
||||
.filter_map(|entry| {
|
||||
let entry = entry.ok()?;
|
||||
assert!(entry.metadata().unwrap().is_symlink());
|
||||
let path = entry.path();
|
||||
let file_name = path.file_name()?.to_str()?;
|
||||
|
||||
if file_name.starts_with("crossbar_") && file_name.ends_with(".bin") {
|
||||
let num_str = &file_name[9..file_name.len() - 4];
|
||||
let num = num_str.parse::<u32>().ok()?;
|
||||
Some((num, path))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
bin_files.sort_by_key(|&(num, _)| num);
|
||||
let core = executor.cpu_mut().core(core_idx);
|
||||
let (_memory, crossbars) = core.get_memory_crossbar();
|
||||
if file_name.starts_with("crossbar_") && file_name.ends_with(".bin") {
|
||||
let num_str = &file_name[9..file_name.len() - 4];
|
||||
let num = num_str.parse::<u32>().ok()?;
|
||||
Some((num, path))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
sym_link_files.sort_by_key(|&(num, _)| num);
|
||||
|
||||
for (i, path) in bin_files {
|
||||
let bytes = std::fs::read(path).expect("Failed to read binary file");
|
||||
crossbars
|
||||
.get_mut(i as usize)
|
||||
for (_, symlink) in sym_link_files {
|
||||
let real_path = read_link(symlink).unwrap();
|
||||
let path_as_str = real_path.to_str().unwrap();
|
||||
assert!(
|
||||
global_crossbars.contains_key(path_as_str),
|
||||
"symlink point to {:?}\n a not stored crossbar",
|
||||
real_path
|
||||
);
|
||||
|
||||
res.iter_mut()
|
||||
.next_back()
|
||||
.unwrap()
|
||||
.execute_store(&bytes)
|
||||
.unwrap();
|
||||
.push(global_crossbars.get(path_as_str).unwrap());
|
||||
}
|
||||
}
|
||||
}
|
||||
res
|
||||
}
|
||||
|
||||
fn get_crossbars(config: &Value, args: &Args) -> anyhow::Result<HashMap<String, Crossbar>> {
|
||||
let xbar_size = config.get("xbar_size").unwrap().as_array().unwrap();
|
||||
let rows_crossbar = xbar_size[0].as_i64().unwrap() as usize;
|
||||
let column_corssbar = xbar_size[1].as_i64().unwrap() as usize;
|
||||
let mut res = HashMap::new();
|
||||
|
||||
if let Some(folder) = args.folder.as_ref() {
|
||||
let weight_folder = folder.join("weights");
|
||||
if !weight_folder.is_dir() {
|
||||
bail!("Not a directory");
|
||||
}
|
||||
for weight_file in
|
||||
std::fs::read_dir(&weight_folder).context("Weight folder not iterable")?
|
||||
{
|
||||
let weight_file = weight_file.context("File not iterable")?;
|
||||
if weight_file
|
||||
.metadata()
|
||||
.context("Doesn't contain metadata")?
|
||||
.is_dir()
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
let bytes = std::fs::read(weight_file.path()).expect("Failed to read binary file");
|
||||
let mut crossbar =
|
||||
Crossbar::new(column_corssbar * 4, rows_crossbar, CoreMemory::new());
|
||||
crossbar.execute_store(&bytes).unwrap();
|
||||
res.insert(
|
||||
weight_file
|
||||
.path()
|
||||
.to_str()
|
||||
.context("file name not utf-8")?
|
||||
.to_string(),
|
||||
crossbar,
|
||||
);
|
||||
}
|
||||
}
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
fn dump_memory(mut executor: pimcore::Executable, args: &Args) -> Result<()> {
|
||||
@@ -170,7 +229,11 @@ fn retrive_cores(args: &Args) -> Result<Vec<Value>, anyhow::Error> {
|
||||
let pattern_str = pattern.to_str().context("Invalid path encoding")?;
|
||||
let mut paths: Vec<_> = glob(pattern_str)?.map(|x| x.unwrap()).collect();
|
||||
paths.sort_by_cached_key(|x| {
|
||||
let mut x = x.file_stem().expect("Extracting the stem").to_str().expect("File not utf-8");
|
||||
let mut x = x
|
||||
.file_stem()
|
||||
.expect("Extracting the stem")
|
||||
.to_str()
|
||||
.expect("File not utf-8");
|
||||
x = &x[5..];
|
||||
x.parse::<i32>().unwrap()
|
||||
});
|
||||
|
||||
@@ -38,14 +38,14 @@ impl Crossbar {
|
||||
self.memory.execute_store(0, element)
|
||||
}
|
||||
|
||||
pub fn load<T>(&mut self, size: usize) -> Result<Vec<&[T]>> where
|
||||
pub fn load<T>(&self, size: usize) -> Result<Vec<&[T]>> where
|
||||
T: MemoryStorable, {
|
||||
if self.memory.get_len() < size
|
||||
//|| self.stored_bytes < size
|
||||
{
|
||||
bail!("Loading outside crossbar boundary [{} {}] < {}", self.stored_bytes, self.memory.get_len() , size);
|
||||
}
|
||||
self.memory.load(0, size)
|
||||
self.memory.load_const(0, size)
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use std::{collections::HashMap, fmt::Debug};
|
||||
use anyhow::{Context, Result};
|
||||
use anyhow::{Context, Result, ensure};
|
||||
|
||||
use crate::{
|
||||
cpu::crossbar::Crossbar,
|
||||
@@ -10,53 +10,44 @@ use crate::{
|
||||
pub mod crossbar;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CPU {
|
||||
cores: Box<[Core]>,
|
||||
pub struct CPU<'a> {
|
||||
cores: Box<[Core<'a>]>,
|
||||
}
|
||||
|
||||
impl CPU {
|
||||
pub fn new(num_cores: impl TryToUsize) -> Self {
|
||||
impl<'a> CPU<'a> {
|
||||
pub fn new(num_cores: impl TryToUsize, crossbars: Vec<Vec<&'a Crossbar>> ) -> Self {
|
||||
let num_cores = num_cores.try_into().expect("num_cores can not be negative");
|
||||
let mut cores: Vec<Core> = std::iter::repeat_with(Core::new)
|
||||
.take(num_cores + 1)
|
||||
.collect();
|
||||
assert!(crossbars.len() == num_cores + 1);
|
||||
let mut cores = Vec::new();
|
||||
for crossbar in crossbars {
|
||||
cores.push(Core::new(crossbar));
|
||||
}
|
||||
Self {
|
||||
cores: cores.into(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn reserve_crossbar(
|
||||
&mut self,
|
||||
num_crossbar: impl TryToUsize,
|
||||
byte_width: impl TryToUsize,
|
||||
height: impl TryToUsize,
|
||||
) {
|
||||
let num_crossbar = num_crossbar
|
||||
.try_into()
|
||||
.expect("num_crossbar can not be negative");
|
||||
let byte_width = byte_width
|
||||
.try_into()
|
||||
.expect("byte_width can not be negative");
|
||||
let height = height.try_into().expect("height can not be negative");
|
||||
for core in &mut self.cores {
|
||||
core.reserve_crossbar(num_crossbar, byte_width, height);
|
||||
}
|
||||
pub fn host<'b>(&'b mut self) -> &'b mut Core<'a>
|
||||
where 'a : 'b
|
||||
{
|
||||
& mut self.cores[0]
|
||||
}
|
||||
|
||||
pub fn host(&mut self) -> &mut Core {
|
||||
&mut self.cores[0]
|
||||
}
|
||||
|
||||
pub fn core(&mut self, index: impl TryToUsize) -> &mut Core {
|
||||
pub fn core<'b >(&'b mut self, index: impl TryToUsize) -> &'b mut Core<'a>
|
||||
where 'a : 'b
|
||||
{
|
||||
let index = index.try_into().expect("can not be negative");
|
||||
&mut self.cores[index]
|
||||
& mut self.cores[index]
|
||||
}
|
||||
|
||||
pub fn num_core(&self) -> usize {
|
||||
self.cores.len()
|
||||
}
|
||||
|
||||
pub(crate) fn host_and_cores(&mut self, core: impl TryToUsize) -> (&mut Core, &mut Core) {
|
||||
pub(crate) fn host_and_cores<'b, 'c >(&'b mut self, core: impl TryToUsize) -> (&'c mut Core<'a>, &'c mut Core<'a>)
|
||||
where 'a: 'b,
|
||||
'b: 'c
|
||||
{
|
||||
let core = core.try_into().expect("core can not be negative");
|
||||
assert_ne!(
|
||||
core, 0,
|
||||
@@ -70,45 +61,29 @@ impl CPU {
|
||||
(host, core)
|
||||
}
|
||||
|
||||
pub fn get_multiple_cores<const N: usize>(&mut self, indices: [usize; N]) -> [&mut Core; N] {
|
||||
pub fn get_multiple_cores<'b, const N: usize>(&'b mut self, indices: [usize; N]) -> [&'b mut Core<'a>; N]
|
||||
where 'a : 'b
|
||||
{
|
||||
self.cores.get_disjoint_mut(indices).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Core {
|
||||
crossbars: Vec<Crossbar>,
|
||||
pub struct Core<'a> {
|
||||
crossbars: Vec<&'a Crossbar>,
|
||||
memory: CoreMemory,
|
||||
registers: [i32; 32],
|
||||
}
|
||||
|
||||
impl Core {
|
||||
fn new() -> Self {
|
||||
impl<'a> Core<'a> {
|
||||
fn new(crossbars : Vec<&'a Crossbar>) -> Self {
|
||||
Self {
|
||||
crossbars: Vec::new(),
|
||||
crossbars,
|
||||
memory: CoreMemory::new(),
|
||||
registers: [0; 32],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn reserve_crossbar(
|
||||
&mut self,
|
||||
num_crossbar: impl TryToUsize,
|
||||
width: impl TryToUsize,
|
||||
height: impl TryToUsize,
|
||||
) {
|
||||
let num_crossbar = num_crossbar
|
||||
.try_into()
|
||||
.expect("num_crossbar can not be negative");
|
||||
let width = width.try_into().expect("width can not be negative");
|
||||
let height = height.try_into().expect("height can not be negative");
|
||||
for _ in 0..num_crossbar {
|
||||
let mut crossbar = CoreMemory::new();
|
||||
crossbar.set_capacity(width * height);
|
||||
self.crossbars.push(Crossbar::new(width, height, crossbar));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn execute_load<T>(&mut self) -> Result<Vec<&[T]>>
|
||||
where
|
||||
T: MemoryStorable,
|
||||
@@ -157,7 +132,7 @@ impl Core {
|
||||
self.memory.load(address, size)
|
||||
}
|
||||
|
||||
pub fn get_memory_crossbar(&mut self) -> (&mut CoreMemory, &mut Vec<Crossbar>) {
|
||||
pub fn get_memory_crossbar(&mut self) -> (&mut CoreMemory, &mut Vec<&'a Crossbar>) {
|
||||
let Self {
|
||||
crossbars,
|
||||
memory,
|
||||
|
||||
@@ -76,7 +76,8 @@ pub fn functor_to_name(functor: usize) -> &'static str {
|
||||
///////////////////////////////////////////////////////////////
|
||||
/////////////////Scalar/register Instructions//////////////////
|
||||
///////////////////////////////////////////////////////////////
|
||||
pub fn sldi(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
pub fn sldi(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||
{
|
||||
TRACER.lock().unwrap().pre_sldi(cores, data);
|
||||
let (core_indx, rd, imm) = data.get_core_rd_imm();
|
||||
let core = cores.core(core_indx);
|
||||
|
||||
@@ -40,7 +40,9 @@ impl Instruction {
|
||||
Self { data, functor }
|
||||
}
|
||||
|
||||
pub fn execute(&self, cpu: &mut CPU) -> InstructionStatus {
|
||||
pub fn execute<'a, 'b>(&'b self, cpu: &mut CPU<'a>) -> InstructionStatus
|
||||
where 'a : 'b
|
||||
{
|
||||
(self.functor)(cpu, self.data)
|
||||
.with_context(|| format!("Instruction: {}", functor_to_name(self.functor as usize)))
|
||||
.with_context(|| format!("Error in core: {}", self.data.core_indx() - 1))
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
use core::panic;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use serde_json::{Map, Value};
|
||||
|
||||
use crate::{
|
||||
CoreInstructionsBuilder, Executable,
|
||||
cpu::{CPU, crossbar},
|
||||
cpu::{CPU, crossbar::{self, Crossbar}},
|
||||
instruction_set::{
|
||||
InstructionsBuilder,
|
||||
instruction_data::{self, InstructionData, InstructionDataBuilder},
|
||||
@@ -13,18 +14,20 @@ use crate::{
|
||||
memory_manager::type_traits::TryToUsize,
|
||||
};
|
||||
|
||||
|
||||
pub fn json_to_executor<'a>(
|
||||
config: Value,
|
||||
mut cores: impl Iterator<Item = &'a Value>,
|
||||
) -> Executable {
|
||||
crossbars : Vec<Vec<&'a Crossbar>>
|
||||
) -> Executable<'a> {
|
||||
let cell_precision = config.get("cell_precision").unwrap().as_i64().unwrap() as i32;
|
||||
let core_cnt = config.get("core_cnt").unwrap().as_i64().unwrap() as i32 - 1;
|
||||
let xbar_count = config.get("xbar_array_count").unwrap().as_i64().unwrap() as i32;
|
||||
let xbar_size = config.get("xbar_size").unwrap().as_array().unwrap();
|
||||
let rows_crossbar = xbar_size[0].as_i64().unwrap() as i32;
|
||||
let column_corssbar = xbar_size[1].as_i64().unwrap() as i32;
|
||||
let mut cpu = CPU::new(core_cnt);
|
||||
cpu.reserve_crossbar(xbar_count, column_corssbar * 4, rows_crossbar);
|
||||
|
||||
let mut cpu = CPU::new(core_cnt, crossbars);
|
||||
let mut core_insts_builder = CoreInstructionsBuilder::new(core_cnt as usize);
|
||||
cores.next();
|
||||
for core_indx in 1..=core_cnt {
|
||||
|
||||
@@ -111,6 +111,24 @@ where {
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
pub fn load_const<T>(&self, address: impl TryToUsize, size: impl TryToUsize) -> Result<Vec<&[T]>>
|
||||
where
|
||||
T: MemoryStorable,
|
||||
{
|
||||
let address = address.try_into().expect("address can not be negative");
|
||||
let size = size.try_into().expect("size can not be negative");
|
||||
let Self {
|
||||
memory,
|
||||
load_requests,
|
||||
} = self;
|
||||
let mut res = Vec::new();
|
||||
let memory_slice = &memory[address..address + size];
|
||||
let memory_slice = unsafe { slice_from_u8(memory_slice) }
|
||||
.with_context(|| format!("Accessing from {} to {}", address, address + size))?;
|
||||
res.push(memory_slice);
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
pub fn load<T>(&mut self, address: impl TryToUsize, size: impl TryToUsize) -> Result<Vec<&[T]>>
|
||||
where
|
||||
T: MemoryStorable,
|
||||
|
||||
@@ -1,50 +1,54 @@
|
||||
#![allow(unused)]
|
||||
|
||||
use std::time::{Duration, SystemTime};
|
||||
|
||||
use crate::{
|
||||
cpu::CPU, instruction_set::{Instruction, InstructionStatus, Instructions, isa::functor_to_name}, memory_manager::type_traits::TryToUsize, send_recv::{SendRecv, handle_send_recv}, tracing::TRACER
|
||||
cpu::CPU,
|
||||
instruction_set::{Instruction, InstructionStatus, Instructions, isa::functor_to_name},
|
||||
memory_manager::type_traits::TryToUsize,
|
||||
send_recv::{SendRecv, handle_send_recv},
|
||||
tracing::TRACER,
|
||||
};
|
||||
pub mod cpu;
|
||||
pub mod instruction_set;
|
||||
pub mod json_to_instruction;
|
||||
pub mod memory_manager;
|
||||
pub mod send_recv;
|
||||
pub mod utility;
|
||||
pub mod json_to_instruction;
|
||||
pub mod tracing;
|
||||
|
||||
|
||||
pub mod utility;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CoreInstructionsBuilder {
|
||||
core_instructions : Vec<CoreInstruction>
|
||||
core_instructions: Vec<CoreInstructions>,
|
||||
}
|
||||
|
||||
impl CoreInstructionsBuilder {
|
||||
pub fn new(size:usize) -> Self {
|
||||
pub fn new(size: usize) -> Self {
|
||||
let mut core_instructions = Vec::with_capacity(size);
|
||||
for _ in 0..=size {
|
||||
core_instructions.push(CoreInstruction::empty());
|
||||
core_instructions.push(CoreInstructions::empty());
|
||||
}
|
||||
Self { core_instructions }
|
||||
}
|
||||
|
||||
pub fn build(self) -> Vec<CoreInstruction> {
|
||||
pub fn build(self) -> Vec<CoreInstructions> {
|
||||
self.core_instructions
|
||||
}
|
||||
|
||||
pub fn set_core(&mut self, core : impl TryToUsize, core_instruction : Instructions) -> &mut Self{
|
||||
self.core_instructions[core.try_into().expect("Set core with not valid size")] = core_instruction.into();
|
||||
self
|
||||
pub fn set_core(&mut self, core: impl TryToUsize, core_instruction: Instructions) -> &mut Self {
|
||||
self.core_instructions[core.try_into().expect("Set core with not valid size")] =
|
||||
core_instruction.into();
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CoreInstruction {
|
||||
pub struct CoreInstructions {
|
||||
instructions: Instructions,
|
||||
program_counter: usize,
|
||||
}
|
||||
|
||||
impl CoreInstruction {
|
||||
impl CoreInstructions {
|
||||
fn new(instructions: Instructions, program_counter: usize) -> Self {
|
||||
Self {
|
||||
instructions,
|
||||
@@ -53,13 +57,16 @@ impl CoreInstruction {
|
||||
}
|
||||
|
||||
fn empty() -> Self {
|
||||
Self { instructions: Vec::new(), program_counter: 0 }
|
||||
Self {
|
||||
instructions: Vec::new(),
|
||||
program_counter: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Instructions> for CoreInstruction {
|
||||
impl From<Instructions> for CoreInstructions {
|
||||
fn from(value: Instructions) -> Self {
|
||||
CoreInstruction {
|
||||
CoreInstructions {
|
||||
instructions: value,
|
||||
program_counter: 0,
|
||||
}
|
||||
@@ -67,39 +74,64 @@ impl From<Instructions> for CoreInstruction {
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Executable {
|
||||
cpu: CPU,
|
||||
core_instructions: Vec<CoreInstruction>,
|
||||
send_recv : SendRecv,
|
||||
pub struct Executable<'a> {
|
||||
cpu: CPU<'a>,
|
||||
core_instructions: Vec<CoreInstructions>,
|
||||
send_recv: SendRecv,
|
||||
}
|
||||
|
||||
impl Executable {
|
||||
pub fn new(cpu: CPU, core_instructions: Vec<CoreInstruction>) -> Self {
|
||||
fn print_status(core_instructions: &[CoreInstructions]) {
|
||||
let mut tot_instructions = 0;
|
||||
let mut progress = 0;
|
||||
for core_instruction in core_instructions.iter() {
|
||||
tot_instructions += core_instruction.instructions.len();
|
||||
progress += core_instruction.program_counter;
|
||||
}
|
||||
println!(
|
||||
"Progress: {}% ({}/{}) ",
|
||||
progress as f32 / tot_instructions as f32 * 100.0,
|
||||
progress,
|
||||
tot_instructions
|
||||
);
|
||||
}
|
||||
|
||||
impl<'a> Executable<'a> {
|
||||
pub fn new(cpu: CPU<'a>, core_instructions: Vec<CoreInstructions>) -> Executable<'a> {
|
||||
let num_core = cpu.num_core();
|
||||
let send_recv = SendRecv::new(num_core);
|
||||
assert_eq!(num_core, core_instructions.len(), "Some core doesn't have is list of istruction (required even if empty)");
|
||||
assert_eq!(
|
||||
num_core,
|
||||
core_instructions.len(),
|
||||
"Some core doesn't have is list of istruction (required even if empty)"
|
||||
);
|
||||
Self {
|
||||
cpu,
|
||||
core_instructions,
|
||||
send_recv
|
||||
send_recv,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn execute(&mut self) {
|
||||
pub fn execute<'b>(&'b mut self)
|
||||
where
|
||||
'a: 'b,
|
||||
{
|
||||
let Self {
|
||||
cpu,
|
||||
core_instructions,
|
||||
send_recv
|
||||
core_instructions: cores_instructions,
|
||||
send_recv,
|
||||
} = self;
|
||||
let mut cpu_progressed = 0;
|
||||
let max_core = cpu.num_core();
|
||||
let mut index_unit = 0;
|
||||
let mut cpu_index = 0;
|
||||
let mut now = SystemTime::now();
|
||||
|
||||
while (cpu_progressed > -2) {
|
||||
let mut core_result = InstructionStatus::Completed;
|
||||
while core_result.is_completed() && let Some(core_instruction) = core_instructions.get_mut(index_unit){
|
||||
while core_result.is_completed()
|
||||
&& let Some(core_instruction) = cores_instructions.get_mut(cpu_index)
|
||||
{
|
||||
core_result = InstructionStatus::NotExecuted;
|
||||
let CoreInstruction {
|
||||
let CoreInstructions {
|
||||
instructions,
|
||||
program_counter,
|
||||
} = core_instruction;
|
||||
@@ -112,29 +144,44 @@ impl Executable {
|
||||
cpu_progressed = 0;
|
||||
*program_counter += 1;
|
||||
}
|
||||
if (now.elapsed().unwrap() > Duration::from_secs(1)) {
|
||||
print_status(&cores_instructions);
|
||||
now = SystemTime::now();
|
||||
}
|
||||
}
|
||||
handle_wait_sync(cpu, cores_instructions, core_result);
|
||||
match handle_send_recv(cpu, cores_instructions, send_recv, core_result) {
|
||||
(true, other_cpu_index) => {
|
||||
cpu_progressed = 0;
|
||||
cpu_index = other_cpu_index;
|
||||
}
|
||||
(false, 0) => {
|
||||
cpu_index = if cpu_index + 1 >= cores_instructions.len() {
|
||||
cpu_progressed -= 1;
|
||||
0
|
||||
} else {
|
||||
cpu_index + 1
|
||||
};
|
||||
}
|
||||
(false, other_cpu_index) => {
|
||||
cpu_index = other_cpu_index;
|
||||
}
|
||||
}
|
||||
if handle_send_recv(cpu, core_instructions, send_recv, core_result) { cpu_progressed = 0; }
|
||||
handle_wait_sync(cpu, core_instructions, core_result);
|
||||
index_unit = if index_unit + 1 >= max_core {
|
||||
cpu_progressed-=1;
|
||||
0
|
||||
} else {
|
||||
index_unit + 1
|
||||
};
|
||||
}
|
||||
print_status(cores_instructions);
|
||||
}
|
||||
|
||||
pub fn cpu(&self) -> &CPU {
|
||||
pub fn cpu(&self) -> &CPU<'a> {
|
||||
&self.cpu
|
||||
}
|
||||
|
||||
pub fn cpu_mut(&mut self) -> &mut CPU {
|
||||
pub fn cpu_mut(&mut self) -> &mut CPU<'a> {
|
||||
&mut self.cpu
|
||||
}
|
||||
|
||||
pub fn dump(&self) {
|
||||
pub fn dump(&self) {
|
||||
let core_instructions = &self.core_instructions;
|
||||
for (i, core_instruction) in core_instructions.iter().enumerate() {
|
||||
for (i, core_instruction) in core_instructions.iter().enumerate() {
|
||||
eprintln!("INST OF CORE {}:", i);
|
||||
for inst in &core_instruction.instructions {
|
||||
inst.dump();
|
||||
@@ -143,64 +190,12 @@ impl Executable {
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_wait_sync(cpu: &mut CPU, core_instructions: &mut [CoreInstruction], core_result: InstructionStatus) {
|
||||
}
|
||||
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
use crate::instruction_set::instruction_data::InstructionDataBuilder;
|
||||
use crate::instruction_set::{InstructionsBuilder, isa::*};
|
||||
|
||||
#[test]
|
||||
fn test_only_host() {
|
||||
let mut cpu = CPU::new(0);
|
||||
cpu.host()
|
||||
.execute_store(0, &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]);
|
||||
let mut inst_builder = InstructionsBuilder::new();
|
||||
let mut idata_build = InstructionDataBuilder::new();
|
||||
idata_build.set_core_indx(0).fix_core_indx();
|
||||
inst_builder.make_inst(sldi, idata_build.set_rdimm(1, 0).build());
|
||||
inst_builder.make_inst(sld, idata_build.set_rdr1(1, 1).build());
|
||||
inst_builder.make_inst(sldi, idata_build.set_rdimm(2, 8).build());
|
||||
inst_builder.make_inst(sld, idata_build.set_rdr1(2, 2).build());
|
||||
inst_builder.make_inst(sadd, idata_build.set_rdr1r2(2, 1, 2).build());
|
||||
let mut core_instruction = vec![inst_builder.build().into()];
|
||||
let mut executable = Executable::new(cpu, core_instruction);
|
||||
executable.execute();
|
||||
assert_eq!(executable.cpu_mut().host().register(2), 4, "Not sum to 4");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_10_core_same_code() {
|
||||
let setup_core = |index: usize, cpu: &mut CPU| -> Instructions {
|
||||
cpu.core(index)
|
||||
.execute_store(0, &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]);
|
||||
let mut inst_builder = InstructionsBuilder::new();
|
||||
let mut idata_build = InstructionDataBuilder::new();
|
||||
idata_build.set_core_indx(index as i32).fix_core_indx();
|
||||
inst_builder.make_inst(sldi, idata_build.set_rdimm(1, 0).build());
|
||||
inst_builder.make_inst(sld, idata_build.set_rdr1(1, 1).build());
|
||||
inst_builder.make_inst(sldi, idata_build.set_rdimm(2, 8).build());
|
||||
inst_builder.make_inst(sld, idata_build.set_rdr1(2, 2).build());
|
||||
inst_builder.make_inst(sadd, idata_build.set_rdr1r2(2, 1, 2).build());
|
||||
inst_builder.build()
|
||||
};
|
||||
|
||||
let mut cpu = CPU::new(10);
|
||||
let mut core_instruction = Vec::new();
|
||||
for i in 0..cpu.num_core() {
|
||||
core_instruction.push(setup_core(i, &mut cpu).into())
|
||||
}
|
||||
|
||||
let mut executable = Executable::new(cpu, core_instruction);
|
||||
executable.execute();
|
||||
for i in 0.. executable.cpu.num_core() {
|
||||
assert_eq!(executable.cpu_mut().core(i).register(2), 4, "Core {} not sum to 4", i);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
fn handle_wait_sync<'a, 'b, 'c>(
|
||||
cpu: &'b mut CPU<'a>,
|
||||
core_instructions: &'c mut [CoreInstructions],
|
||||
core_result: InstructionStatus,
|
||||
) where
|
||||
'a: 'b,
|
||||
'a: 'c,
|
||||
{
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use anyhow::Context;
|
||||
|
||||
use crate::{
|
||||
CoreInstruction, cpu::CPU, instruction_set::InstructionStatus, tracing::TRACER,
|
||||
CoreInstructions, cpu::CPU, instruction_set::InstructionStatus, tracing::TRACER,
|
||||
utility::add_offset_rd,
|
||||
};
|
||||
|
||||
@@ -41,14 +41,16 @@ impl SendRecv {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn handle_send_recv(
|
||||
cpu: &mut CPU,
|
||||
core_instructions: &mut [CoreInstruction],
|
||||
send_recv: &mut SendRecv,
|
||||
pub fn handle_send_recv<'a, 'b >(
|
||||
cpu: &'b mut CPU<'a>,
|
||||
core_instructions: & mut [CoreInstructions],
|
||||
send_recv: & mut SendRecv,
|
||||
core_result: InstructionStatus,
|
||||
) -> bool {
|
||||
let transfer_memory = |cpu: &mut CPU,
|
||||
core_instructions: &mut [CoreInstruction],
|
||||
) -> (bool, usize)
|
||||
where 'a : 'b
|
||||
{
|
||||
let transfer_memory = |cpu: &'b mut CPU<'a>,
|
||||
core_instructions: & mut [CoreInstructions],
|
||||
sender: Option<SendRecvInfo>,
|
||||
receiver: Option<SendRecvInfo>| {
|
||||
if let Some(sender) = sender
|
||||
@@ -117,7 +119,7 @@ pub fn handle_send_recv(
|
||||
send_recv.sending[sender] = None;
|
||||
send_recv.receiving[receiver] = None;
|
||||
}
|
||||
return transfered;
|
||||
(transfered, receiver)
|
||||
}
|
||||
InstructionStatus::Reciving(instruction_data) => {
|
||||
let (core_idx, imm_core) = instruction_data.get_core_immcore();
|
||||
@@ -146,8 +148,8 @@ pub fn handle_send_recv(
|
||||
send_recv.sending[sender] = None;
|
||||
send_recv.receiving[receiver] = None;
|
||||
}
|
||||
return transfered;
|
||||
(transfered, sender)
|
||||
}
|
||||
_ => false,
|
||||
_ => (false, 0),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
mod tracing_isa;
|
||||
mod disable;
|
||||
mod pretty_print;
|
||||
|
||||
#[cfg(feature = "tracing")]
|
||||
use std::{fs::File, path::{ PathBuf}};
|
||||
use std::sync::{LazyLock, Mutex};
|
||||
|
||||
|
||||
use crate::Executable;
|
||||
|
||||
#[cfg(feature = "tracing")]
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
#include "llvm/Support/FileSystem.h"
|
||||
#include "llvm/Support/JSON.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
@@ -33,6 +34,12 @@ MemEntry* PimMemory::gatherMemEntry(mlir::Value value) {
|
||||
return &memEntries.emplace_back(memEntry, value).first;
|
||||
}
|
||||
|
||||
void PimMemory::allocateGatheredMemory() {
|
||||
llvm::sort(memEntries, [](auto a, auto b) -> bool { return a.first.size > b.first.size; });
|
||||
for (auto& [memEntry, value] : memEntries)
|
||||
allocateMemoryForValue(value, memEntry);
|
||||
}
|
||||
|
||||
void PimMemory::allocateMemoryForValue(mlir::Value value, MemEntry& memEntry) {
|
||||
memEntry.address = firstAvailableAddress;
|
||||
firstAvailableAddress += memEntry.size;
|
||||
@@ -44,35 +51,37 @@ void PimMemory::allocateMemoryForValue(mlir::Value value, MemEntry& memEntry) {
|
||||
}
|
||||
|
||||
void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) {
|
||||
// More than one SSA value per single global constant:
|
||||
// Cannot call gatherMemEntry for each of them, otherwise memory will be allocated multiple times
|
||||
// Thus, call gatherMemEntry only for the first SSA value and assign the same memEntry to all others
|
||||
SmallDenseMap<memref::GlobalOp, MemEntry*, 8> globalConstants;
|
||||
SmallDenseMap<memref::GlobalOp, mlir::Value, 8> globalConstants;
|
||||
SmallVector<std::pair<mlir::Value, mlir::Value>, 16> globalAliases;
|
||||
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
|
||||
if (!hasWeightAlways(getGlobalOp)) {
|
||||
auto globalMemrefOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||
auto iter = globalConstants.find(globalMemrefOp);
|
||||
if (iter == globalConstants.end())
|
||||
globalConstants[globalMemrefOp] = gatherMemEntry(getGlobalOp);
|
||||
else {
|
||||
MemEntry memEntry = *iter->second;
|
||||
globalMemEntriesMap[getGlobalOp] = memEntry;
|
||||
}
|
||||
auto [iter, inserted] = globalConstants.try_emplace(globalMemrefOp, getGlobalOp.getResult());
|
||||
if (inserted)
|
||||
gatherMemEntry(getGlobalOp.getResult());
|
||||
else
|
||||
globalAliases.push_back({getGlobalOp.getResult(), iter->second});
|
||||
}
|
||||
});
|
||||
|
||||
for (mlir::Value arg : funcOp.getArguments())
|
||||
gatherMemEntry(arg);
|
||||
|
||||
allocateCore(funcOp);
|
||||
funcOp.walk([&](memref::AllocOp allocOp) {
|
||||
if (!allocOp->getParentOfType<pim::PimCoreOp>())
|
||||
gatherMemEntry(allocOp.getResult());
|
||||
});
|
||||
|
||||
allocateGatheredMemory();
|
||||
|
||||
for (auto [alias, original] : globalAliases)
|
||||
globalMemEntriesMap[alias] = getMemEntry(original);
|
||||
}
|
||||
|
||||
void PimMemory::allocateCore(Operation* op) {
|
||||
op->walk([&](memref::AllocOp allocOp) { gatherMemEntry(allocOp); });
|
||||
|
||||
llvm::sort(memEntries, [](auto a, auto b) -> bool { return a.first.size > b.first.size; });
|
||||
for (auto& [memEntry, value] : memEntries)
|
||||
allocateMemoryForValue(value, memEntry);
|
||||
allocateGatheredMemory();
|
||||
}
|
||||
|
||||
MemEntry PimMemory::getMemEntry(mlir::Value value) const {
|
||||
@@ -465,6 +474,19 @@ std::string getMemorySizeAsString(size_t size) {
|
||||
return std::to_string(size) + " Bytes";
|
||||
}
|
||||
|
||||
static SmallVector<unsigned, 8> getUsedWeightIndices(pim::PimCoreOp coreOp) {
|
||||
SmallVector<unsigned, 8> indices;
|
||||
auto addIndex = [&](unsigned weightIndex) {
|
||||
if (!llvm::is_contained(indices, weightIndex))
|
||||
indices.push_back(weightIndex);
|
||||
};
|
||||
|
||||
coreOp.walk([&](pim::PimMVMOp mvmOp) { addIndex(mvmOp.getWeightIndex()); });
|
||||
coreOp.walk([&](pim::PimVMMOp vmmOp) { addIndex(vmmOp.getWeightIndex()); });
|
||||
llvm::sort(indices);
|
||||
return indices;
|
||||
}
|
||||
|
||||
/// Write global constant data into a binary memory image at their allocated addresses.
|
||||
static OnnxMlirCompilerErrorCodes
|
||||
writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory& memory, StringRef outputDirPath) {
|
||||
@@ -478,12 +500,15 @@ writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory&
|
||||
|
||||
std::vector<char> memoryBuffer(memory.hostMem.getFirstAvailableAddress(), 0);
|
||||
|
||||
SmallPtrSet<Operation*, 16> writtenGlobals;
|
||||
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
|
||||
if (hasWeightAlways(getGlobalOp))
|
||||
return;
|
||||
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||
if (!globalOp)
|
||||
return;
|
||||
if (!writtenGlobals.insert(globalOp.getOperation()).second)
|
||||
return;
|
||||
auto initialValue = globalOp.getInitialValue();
|
||||
if (!initialValue)
|
||||
return;
|
||||
@@ -658,7 +683,12 @@ createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) {
|
||||
llvm::DenseMap<memref::GlobalOp, std::string> mapGlobalOpToFileName;
|
||||
|
||||
for (pim::PimCoreOp coreOp : funcOp.getOps<pim::PimCoreOp>()) {
|
||||
for (auto [index, weight] : llvm::enumerate(coreOp.getWeights())) {
|
||||
for (unsigned index : getUsedWeightIndices(coreOp)) {
|
||||
if (index >= coreOp.getWeights().size()) {
|
||||
coreOp.emitWarning("Weight index " + std::to_string(index) + " is out of range");
|
||||
assert(index < coreOp.getWeights().size() && "Weight index is out of range");
|
||||
}
|
||||
mlir::Value weight = coreOp.getWeights()[index];
|
||||
|
||||
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
|
||||
if (!getGlobalOp) {
|
||||
@@ -855,7 +885,12 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
|
||||
|
||||
auto& mapWeightToFile = mapCoreWeightToFileName[coreOp];
|
||||
json::Array xbarsPerGroup;
|
||||
for (auto [index, weight] : llvm::enumerate(coreOp.getWeights())) {
|
||||
for (unsigned index : getUsedWeightIndices(coreOp)) {
|
||||
if (index >= coreOp.getWeights().size()) {
|
||||
coreOp.emitWarning("Weight index " + std::to_string(index) + " is out of range");
|
||||
assert(index < coreOp.getWeights().size() && "Weight index is out of range");
|
||||
}
|
||||
mlir::Value weight = coreOp.getWeights()[index];
|
||||
xbarsPerGroup.push_back(index);
|
||||
assert(mapWeightToFile.contains(weight) && "Weight was not materialized into a file!!");
|
||||
auto& fileName = mapWeightToFile[weight];
|
||||
|
||||
@@ -24,6 +24,7 @@ class PimMemory {
|
||||
size_t firstAvailableAddress = 0;
|
||||
|
||||
MemEntry* gatherMemEntry(mlir::Value value);
|
||||
void allocateGatheredMemory();
|
||||
void allocateMemoryForValue(mlir::Value value, MemEntry& memEntry);
|
||||
|
||||
public:
|
||||
|
||||
@@ -47,6 +47,12 @@ llvm::cl::opt<long> coresCount("core-count",
|
||||
llvm::cl::desc("Number of cores in the chip. `-1` to use the minimum amount of cores."),
|
||||
llvm::cl::init(-1));
|
||||
|
||||
llvm::cl::opt<size_t> dcpCriticalWindowSize(
|
||||
"dcp-critical-window-size",
|
||||
llvm::cl::desc("Number of lowest-slack virtual nodes considered by each DCP coarsening iteration. "
|
||||
"Use 0 to run the legacy full-graph DCP analysis."),
|
||||
llvm::cl::init(1024));
|
||||
|
||||
llvm::cl::opt<bool>
|
||||
ignoreConcatError("ignore-concat-error",
|
||||
llvm::cl::desc("Ignore ConcatOp corner case: do not assert and do a simplification"),
|
||||
|
||||
@@ -29,6 +29,7 @@ extern llvm::cl::opt<bool> useExperimentalConvImpl;
|
||||
extern llvm::cl::opt<size_t> crossbarSize;
|
||||
extern llvm::cl::opt<size_t> crossbarCountInCore;
|
||||
extern llvm::cl::opt<long> coresCount;
|
||||
extern llvm::cl::opt<size_t> dcpCriticalWindowSize;
|
||||
|
||||
// This option, by default set to false, will ignore an error when resolving a
|
||||
// specific tiles of the operands of a concat. This specific case is when the
|
||||
|
||||
@@ -182,7 +182,7 @@ auto createSpatCompute(RewriterT& rewriter,
|
||||
mlir::ValueRange inputs,
|
||||
BodyFn&& body) {
|
||||
assert(inputs.size() == NumInputs && "NumInputs must match the number of input values");
|
||||
auto computeOp = spatial::SpatWeightedCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
||||
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
||||
|
||||
auto* block = new mlir::Block();
|
||||
for (mlir::Value input : inputs)
|
||||
@@ -198,10 +198,10 @@ auto createSpatCompute(RewriterT& rewriter,
|
||||
if (mlir::failed(bodyResult)) {
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
rewriter.eraseOp(computeOp);
|
||||
return mlir::FailureOr<spatial::SpatWeightedCompute>(mlir::failure());
|
||||
return mlir::FailureOr<spatial::SpatCompute>(mlir::failure());
|
||||
}
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
return mlir::FailureOr<spatial::SpatWeightedCompute>(computeOp);
|
||||
return mlir::FailureOr<spatial::SpatCompute>(computeOp);
|
||||
}
|
||||
else {
|
||||
static_assert(std::is_same_v<BodyResult, void>, "createSpatCompute body must return void or mlir::LogicalResult");
|
||||
@@ -219,7 +219,7 @@ auto createSpatCompute(RewriterT& rewriter,
|
||||
mlir::ValueRange weights,
|
||||
mlir::ValueRange inputs,
|
||||
BodyFn&& body) {
|
||||
auto computeOp = spatial::SpatWeightedCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
||||
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
||||
|
||||
auto* block = new mlir::Block();
|
||||
for (mlir::Value input : inputs)
|
||||
@@ -234,10 +234,10 @@ auto createSpatCompute(RewriterT& rewriter,
|
||||
if (mlir::failed(bodyResult)) {
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
rewriter.eraseOp(computeOp);
|
||||
return mlir::FailureOr<spatial::SpatWeightedCompute>(mlir::failure());
|
||||
return mlir::FailureOr<spatial::SpatCompute>(mlir::failure());
|
||||
}
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
return mlir::FailureOr<spatial::SpatWeightedCompute>(computeOp);
|
||||
return mlir::FailureOr<spatial::SpatCompute>(computeOp);
|
||||
}
|
||||
else {
|
||||
static_assert(std::is_same_v<BodyResult, void>, "createSpatCompute body must return void or mlir::LogicalResult");
|
||||
|
||||
@@ -133,7 +133,7 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
if (coresCount != -1) {
|
||||
int computeOpsCount = 0;
|
||||
for (auto& op : entryFunc->getFunctionBody().front().getOperations())
|
||||
if (isa<spatial::SpatWeightedCompute>(op))
|
||||
if (isa<spatial::SpatCompute>(op))
|
||||
computeOpsCount++;
|
||||
|
||||
if (computeOpsCount > coresCount) {
|
||||
@@ -167,16 +167,16 @@ bool encapsulator(IRRewriter& rewriter, Location loc, Operation* inst, std::func
|
||||
if (T toRemoveOp = llvm::dyn_cast_if_present<T>(inst)) {
|
||||
Value source = funcSource(toRemoveOp);
|
||||
rewriter.setInsertionPointAfter(toRemoveOp);
|
||||
if (isa_and_present<spatial::SpatWeightedCompute>(source.getDefiningOp())) {
|
||||
auto newCompute = spatial::SpatWeightedCompute::create(rewriter, loc, inst->getResultTypes().front(), source);
|
||||
if (isa_and_present<spatial::SpatCompute>(source.getDefiningOp())) {
|
||||
auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), source);
|
||||
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), {source.getType()}, {loc});
|
||||
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) 1});
|
||||
rewriter.setInsertionPointToEnd(BB);
|
||||
IRMapping mapper;
|
||||
mapper.map(source, BB->getArgument(0));
|
||||
auto newInst = rewriter.clone(*inst, mapper);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, newInst->getResult(0));
|
||||
inst->replaceAllUsesWith(newCompute);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, newInst->getResults());
|
||||
inst->replaceAllUsesWith(newCompute->getResults());
|
||||
inst->erase();
|
||||
return true;
|
||||
}
|
||||
@@ -189,8 +189,8 @@ bool encapsulateConcat(IRRewriter& rewriter, Location loc, Operation* inst) {
|
||||
auto sources = toRemoveOp.getInputs();
|
||||
rewriter.setInsertionPointAfter(toRemoveOp);
|
||||
if (llvm::any_of(
|
||||
sources, [](auto source) { return isa_and_present<spatial::SpatWeightedCompute>(source.getDefiningOp()); })) {
|
||||
auto newCompute = spatial::SpatWeightedCompute::create(rewriter, loc, inst->getResultTypes().front(), sources);
|
||||
sources, [](auto source) { return isa_and_present<spatial::SpatCompute>(source.getDefiningOp()); })) {
|
||||
auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), sources);
|
||||
SmallVector<Type> sourceTypes;
|
||||
SmallVector<Location> sourceLoc;
|
||||
for (auto source : sources) {
|
||||
@@ -204,8 +204,8 @@ bool encapsulateConcat(IRRewriter& rewriter, Location loc, Operation* inst) {
|
||||
for (auto [source, bbArg] : llvm::zip(sources, BB->getArguments()))
|
||||
mapper.map(source, bbArg);
|
||||
auto newConcat = rewriter.clone(*inst, mapper);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, newConcat->getResult(0));
|
||||
inst->replaceAllUsesWith(newCompute);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, newConcat->getResults());
|
||||
inst->replaceAllUsesWith(newCompute->getResults());
|
||||
inst->erase();
|
||||
return true;
|
||||
}
|
||||
@@ -298,14 +298,15 @@ void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) {
|
||||
void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
|
||||
Location loc = funcOp.getLoc();
|
||||
IRRewriter rewriter(&getContext());
|
||||
SmallVector<spatial::SpatWeightedCompute> trivialComputes;
|
||||
llvm::SmallSet<spatial::SpatWeightedCompute, 8> toErase;
|
||||
SmallVector<spatial::SpatCompute> trivialComputes;
|
||||
llvm::SmallSet<spatial::SpatCompute, 8> toErase;
|
||||
|
||||
for (auto compute : funcOp.getOps<spatial::SpatWeightedCompute>())
|
||||
for (auto compute : funcOp.getOps<spatial::SpatCompute>())
|
||||
if (compute->hasOneUse()) {
|
||||
auto user = dyn_cast<spatial::SpatWeightedCompute>(*compute->getUsers().begin());
|
||||
auto& use = *compute->getUses().begin();
|
||||
auto user = dyn_cast<spatial::SpatCompute>(use.getOwner());
|
||||
|
||||
if (user && user.getInputs().size() == 1)
|
||||
if (user && user.getInputs().size() == 1 && use.getOperandNumber() >= user.getWeights().size())
|
||||
trivialComputes.push_back(compute);
|
||||
}
|
||||
|
||||
@@ -317,12 +318,15 @@ void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
|
||||
trivialComputes.pop_back();
|
||||
continue;
|
||||
}
|
||||
auto child = cast<spatial::SpatWeightedCompute>(*compute->getUsers().begin());
|
||||
auto& computeUse = *compute->getUses().begin();
|
||||
auto child = cast<spatial::SpatCompute>(computeUse.getOwner());
|
||||
auto usedResult = cast<OpResult>(computeUse.get()).getResultNumber();
|
||||
auto childArgIndex = computeUse.getOperandNumber() - child.getWeights().size();
|
||||
|
||||
rewriter.setInsertionPointAfter(compute.getOperation());
|
||||
|
||||
auto newCompute =
|
||||
spatial::SpatWeightedCompute::create(rewriter, loc, child.getResultTypes(), compute.getOperands());
|
||||
spatial::SpatCompute::create(rewriter, loc, child.getResultTypes(), compute.getOperands());
|
||||
newCompute.getProperties().setOperandSegmentSizes(
|
||||
{static_cast<int>(compute.getWeights().size()), static_cast<int>(compute.getInputs().size())});
|
||||
|
||||
@@ -343,7 +347,7 @@ void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
|
||||
|
||||
compute.getBodyRegion().cloneInto(&newCompute.getBodyRegion(), mapper);
|
||||
auto newTerminator = newCompute.getBody().front().getTerminator();
|
||||
mapper.map(*child.getBody().front().getArguments().begin(), newTerminator->getOperand(0));
|
||||
mapper.map(child.getBody().front().getArgument(childArgIndex), newTerminator->getOperand(usedResult));
|
||||
newTerminator->erase();
|
||||
rewriter.setInsertionPoint(&newCompute.getBody().front(), newCompute.getBody().front().end());
|
||||
for (auto& op : child.getBody().front()) {
|
||||
@@ -371,14 +375,16 @@ void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
|
||||
toErase.insert(compute);
|
||||
|
||||
if (newCompute->hasOneUse()) {
|
||||
auto user = dyn_cast<spatial::SpatWeightedCompute>(*newCompute->getUsers().begin());
|
||||
if (user && user.getInputs().size() == 1)
|
||||
auto& use = *newCompute->getUses().begin();
|
||||
auto user = dyn_cast<spatial::SpatCompute>(use.getOwner());
|
||||
if (user && user.getInputs().size() == 1 && use.getOperandNumber() >= user.getWeights().size())
|
||||
trivialComputes.push_back(newCompute);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto compute : toErase) {
|
||||
compute.getResult(0).dropAllUses();
|
||||
for (Value result : compute->getResults())
|
||||
result.dropAllUses();
|
||||
compute.erase();
|
||||
}
|
||||
}
|
||||
@@ -386,7 +392,7 @@ void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
|
||||
void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const {
|
||||
funcOp.walk([&](arith::ConstantOp constantOp) {
|
||||
bool isAlwaysWeight =
|
||||
llvm::all_of(constantOp->getUsers(), [](auto user) -> bool { return isa<spatial::SpatWeightedCompute>(user); });
|
||||
llvm::all_of(constantOp->getUsers(), [](auto user) -> bool { return isa<spatial::SpatCompute>(user); });
|
||||
if (isAlwaysWeight)
|
||||
markWeightAlways(constantOp);
|
||||
});
|
||||
@@ -394,7 +400,7 @@ void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const {
|
||||
|
||||
LogicalResult ONNXToSpatialPass::promoteConstantInputsToWeights(func::FuncOp funcOp) {
|
||||
IRRewriter rewriter(&getContext());
|
||||
SmallVector<spatial::SpatWeightedCompute> computes(funcOp.getOps<spatial::SpatWeightedCompute>());
|
||||
SmallVector<spatial::SpatCompute> computes(funcOp.getOps<spatial::SpatCompute>());
|
||||
|
||||
for (auto compute : computes) {
|
||||
SmallVector<bool> promoteInput(compute.getInputs().size(), false);
|
||||
@@ -430,7 +436,7 @@ LogicalResult ONNXToSpatialPass::promoteConstantInputsToWeights(func::FuncOp fun
|
||||
}
|
||||
|
||||
auto newCompute =
|
||||
spatial::SpatWeightedCompute::create(rewriter, compute.getLoc(), compute.getResultTypes(), newWeights, newInputs);
|
||||
spatial::SpatCompute::create(rewriter, compute.getLoc(), compute.getResultTypes(), newWeights, newInputs);
|
||||
auto* newBlock =
|
||||
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), newInputTypes, newInputLocs);
|
||||
newCompute.getProperties().setOperandSegmentSizes(
|
||||
|
||||
@@ -147,33 +147,37 @@ static Value buildPackedBias(bool hasBias,
|
||||
return arith::ConstantOp::create(rewriter, loc, packedBiasType, packedBiasAttr).getResult();
|
||||
}
|
||||
|
||||
static Value createIm2colCompute(Value x,
|
||||
RankedTensorType xType,
|
||||
RankedTensorType im2colType,
|
||||
RankedTensorType rowType,
|
||||
int64_t batchSize,
|
||||
int64_t numChannelsIn,
|
||||
int64_t xHeight,
|
||||
int64_t xWidth,
|
||||
int64_t wHeight,
|
||||
int64_t wWidth,
|
||||
int64_t padHeightBegin,
|
||||
int64_t padHeightEnd,
|
||||
int64_t padWidthBegin,
|
||||
int64_t padWidthEnd,
|
||||
int64_t strideHeight,
|
||||
int64_t strideWidth,
|
||||
int64_t dilationHeight,
|
||||
int64_t dilationWidth,
|
||||
int64_t outWidth,
|
||||
int64_t patchSize,
|
||||
int64_t numPatches,
|
||||
int64_t numPatchesPerBatch,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
static SmallVector<Value> createIm2colRowComputes(Value x,
|
||||
RankedTensorType xType,
|
||||
RankedTensorType im2colType,
|
||||
RankedTensorType im2colRowType,
|
||||
RankedTensorType gemmInputRowType,
|
||||
int64_t batchSize,
|
||||
int64_t numChannelsIn,
|
||||
int64_t xHeight,
|
||||
int64_t xWidth,
|
||||
int64_t wHeight,
|
||||
int64_t wWidth,
|
||||
int64_t padHeightBegin,
|
||||
int64_t padHeightEnd,
|
||||
int64_t padWidthBegin,
|
||||
int64_t padWidthEnd,
|
||||
int64_t strideHeight,
|
||||
int64_t strideWidth,
|
||||
int64_t dilationHeight,
|
||||
int64_t dilationWidth,
|
||||
int64_t outWidth,
|
||||
int64_t patchSize,
|
||||
int64_t numPatches,
|
||||
int64_t numPatchesPerBatch,
|
||||
int64_t packFactor,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
auto elemType = xType.getElementType();
|
||||
constexpr size_t numInputs = 1;
|
||||
auto im2colComputeOp = createSpatCompute<numInputs>(rewriter, loc, im2colType, {}, x, [&](Value xArg) {
|
||||
const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor);
|
||||
SmallVector<Type> resultTypes(packedNumRows, gemmInputRowType);
|
||||
auto im2colComputeOp = createSpatCompute<numInputs>(rewriter, loc, resultTypes, {}, x, [&](Value xArg) {
|
||||
Value paddedInput = xArg;
|
||||
|
||||
// Pad input with zeros if needed:
|
||||
@@ -240,7 +244,7 @@ static Value createIm2colCompute(Value x,
|
||||
|
||||
Value row = tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
rowType,
|
||||
im2colRowType,
|
||||
patch,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0},
|
||||
@@ -256,121 +260,117 @@ static Value createIm2colCompute(Value x,
|
||||
|
||||
rewriter.setInsertionPointAfter(im2colLoop);
|
||||
Value im2col = im2colLoop.getResult(0);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, im2col);
|
||||
});
|
||||
return im2colComputeOp.getResult(0);
|
||||
}
|
||||
|
||||
static Value createPackedIm2colRows(Value im2col,
|
||||
RankedTensorType im2colType,
|
||||
Type elemType,
|
||||
int64_t numPatches,
|
||||
int64_t patchSize,
|
||||
int64_t packFactor,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
if (packFactor == 1)
|
||||
return im2col;
|
||||
|
||||
const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor);
|
||||
const int64_t paddedNumPatches = packedNumRows * packFactor;
|
||||
auto groupedType = RankedTensorType::get({packedNumRows, packFactor, patchSize}, elemType);
|
||||
auto packedType = RankedTensorType::get({packedNumRows, packFactor * patchSize}, elemType);
|
||||
auto packedComputeOp = createSpatCompute<1>(rewriter, loc, packedType, {}, im2col, [&](Value im2colArg) {
|
||||
Value paddedIm2col = createPaddedRows(im2colArg, im2colType, paddedNumPatches, rewriter, loc);
|
||||
Value groupedIm2col = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
groupedType,
|
||||
paddedIm2col,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1},
|
||||
{2}
|
||||
});
|
||||
Value packedIm2col = tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
packedType,
|
||||
groupedIm2col,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0},
|
||||
{1, 2}
|
||||
});
|
||||
spatial::SpatYieldOp::create(rewriter, loc, packedIm2col);
|
||||
});
|
||||
return packedComputeOp.getResult(0);
|
||||
}
|
||||
|
||||
static Value createUnpackedOutput(Value packedOutput,
|
||||
RankedTensorType gemmOutType,
|
||||
RankedTensorType outType,
|
||||
int64_t numPatches,
|
||||
int64_t numChannelsOut,
|
||||
int64_t packFactor,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
if (packFactor == 1)
|
||||
return packedOutput;
|
||||
|
||||
const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor);
|
||||
const int64_t paddedNumPatches = packedNumRows * packFactor;
|
||||
auto expandedType = RankedTensorType::get({packedNumRows, packFactor, numChannelsOut}, outType.getElementType());
|
||||
auto paddedType = RankedTensorType::get({paddedNumPatches, numChannelsOut}, outType.getElementType());
|
||||
auto unpackComputeOp = createSpatCompute<1>(rewriter, loc, gemmOutType, {}, packedOutput, [&](Value packedOutputArg) {
|
||||
Value expandedOutput = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
expandedType,
|
||||
packedOutputArg,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0},
|
||||
{1, 2}
|
||||
});
|
||||
Value paddedOutput = tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
paddedType,
|
||||
expandedOutput,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1},
|
||||
{2}
|
||||
});
|
||||
|
||||
Value unpackedOutput = paddedOutput;
|
||||
if (paddedNumPatches != numPatches) {
|
||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(numPatches), rewriter.getIndexAttr(numChannelsOut)};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
unpackedOutput =
|
||||
tensor::ExtractSliceOp::create(rewriter, loc, gemmOutType, paddedOutput, offsets, sizes, strides);
|
||||
Value gemmInputRows = im2col;
|
||||
if (packFactor != 1) {
|
||||
const int64_t paddedNumPatches = packedNumRows * packFactor;
|
||||
auto groupedType = RankedTensorType::get({packedNumRows, packFactor, patchSize}, elemType);
|
||||
auto packedType = RankedTensorType::get({packedNumRows, packFactor * patchSize}, elemType);
|
||||
Value paddedIm2col = createPaddedRows(im2col, im2colType, paddedNumPatches, rewriter, loc);
|
||||
Value groupedIm2col = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
groupedType,
|
||||
paddedIm2col,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1},
|
||||
{2}
|
||||
});
|
||||
gemmInputRows = tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
packedType,
|
||||
groupedIm2col,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0},
|
||||
{1, 2}
|
||||
});
|
||||
}
|
||||
|
||||
spatial::SpatYieldOp::create(rewriter, loc, unpackedOutput);
|
||||
SmallVector<Value> rowResults;
|
||||
rowResults.reserve(packedNumRows);
|
||||
for (int64_t rowIdx = 0; rowIdx < packedNumRows; rowIdx++) {
|
||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(packFactor * patchSize)};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
rowResults.push_back(
|
||||
tensor::ExtractSliceOp::create(rewriter, loc, gemmInputRowType, gemmInputRows, offsets, sizes, strides));
|
||||
}
|
||||
spatial::SpatYieldOp::create(rewriter, loc, rowResults);
|
||||
});
|
||||
return unpackComputeOp.getResult(0);
|
||||
|
||||
SmallVector<Value> rows;
|
||||
rows.reserve(im2colComputeOp.getNumResults());
|
||||
for (Value result : im2colComputeOp.getResults())
|
||||
rows.push_back(result);
|
||||
return rows;
|
||||
}
|
||||
|
||||
static Value createCollectedConvOutput(Value gemmOut,
|
||||
static Value createCollectedConvOutput(ValueRange gemmRows,
|
||||
Type convType,
|
||||
RankedTensorType gemmOutType,
|
||||
RankedTensorType nhwcType,
|
||||
RankedTensorType outType,
|
||||
int64_t numPatches,
|
||||
int64_t numChannelsOut,
|
||||
int64_t packFactor,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
auto collectComputeOp =
|
||||
createSpatCompute(rewriter, loc, convType, {}, ValueRange {gemmOut}, [&](ValueRange gemmOutArgs) {
|
||||
Value gemmOutArg = gemmOutArgs.front();
|
||||
|
||||
// Restore to NCHW layout:
|
||||
// [numPatches, numChannelsOut]
|
||||
// -> [1, outHeight, outWidth, numChannelsOut]
|
||||
// -> [1, numChannelsOut, outHeight, outWidth]
|
||||
Value nhwcOut = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
nhwcType,
|
||||
gemmOutArg,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1, 2},
|
||||
{3}
|
||||
const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor);
|
||||
const int64_t paddedNumPatches = packedNumRows * packFactor;
|
||||
auto collectComputeOp = createSpatCompute(rewriter, loc, convType, {}, gemmRows, [&](ValueRange gemmRowArgs) {
|
||||
Value gemmOut;
|
||||
if (packFactor == 1) {
|
||||
gemmOut = gemmRowArgs.size() == 1 ? gemmRowArgs.front()
|
||||
: tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, gemmRowArgs).getResult();
|
||||
}
|
||||
else {
|
||||
auto expandedType = RankedTensorType::get({packedNumRows, packFactor, numChannelsOut}, outType.getElementType());
|
||||
auto paddedType = RankedTensorType::get({paddedNumPatches, numChannelsOut}, outType.getElementType());
|
||||
Value packedOutput =
|
||||
gemmRowArgs.size() == 1
|
||||
? gemmRowArgs.front()
|
||||
: tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, gemmRowArgs).getResult();
|
||||
Value expandedOutput = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
expandedType,
|
||||
packedOutput,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0},
|
||||
{1, 2}
|
||||
});
|
||||
Value nchwOut = ONNXTransposeOp::create(rewriter, loc, outType, nhwcOut, rewriter.getI64ArrayAttr({0, 3, 1, 2}));
|
||||
spatial::SpatYieldOp::create(rewriter, loc, nchwOut);
|
||||
Value paddedOutput = tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
paddedType,
|
||||
expandedOutput,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1},
|
||||
{2}
|
||||
});
|
||||
|
||||
gemmOut = paddedOutput;
|
||||
if (paddedNumPatches != numPatches) {
|
||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(numPatches), rewriter.getIndexAttr(numChannelsOut)};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
gemmOut = tensor::ExtractSliceOp::create(rewriter, loc, gemmOutType, paddedOutput, offsets, sizes, strides);
|
||||
}
|
||||
}
|
||||
|
||||
// Restore to NCHW layout:
|
||||
// [numPatches, numChannelsOut]
|
||||
// -> [1, outHeight, outWidth, numChannelsOut]
|
||||
// -> [1, numChannelsOut, outHeight, outWidth]
|
||||
Value nhwcOut = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
nhwcType,
|
||||
gemmOut,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1, 2},
|
||||
{3}
|
||||
});
|
||||
Value nchwOut = ONNXTransposeOp::create(rewriter, loc, outType, nhwcOut, rewriter.getI64ArrayAttr({0, 3, 1, 2}));
|
||||
spatial::SpatYieldOp::create(rewriter, loc, nchwOut);
|
||||
});
|
||||
return collectComputeOp.getResult(0);
|
||||
}
|
||||
|
||||
@@ -487,11 +487,11 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
|
||||
// Pass bias through directly; Gemm handles rank-1 C canonicalization.
|
||||
bool hasB = !isa<ONNXNoneOp>(b.getDefiningOp());
|
||||
Value gemmC = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
|
||||
Value gemmBias = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
|
||||
Value biasMatrix;
|
||||
DenseElementsAttr biasDenseAttr;
|
||||
if (hasB) {
|
||||
gemmC = b;
|
||||
gemmBias = b;
|
||||
biasDenseAttr = getDenseConstantAttr(b);
|
||||
biasMatrix = expandBiasIfNeeded(b, rewriter, loc);
|
||||
}
|
||||
@@ -500,94 +500,86 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
const int64_t effectiveMaxParallelPixels =
|
||||
(canPackWeightsAsConstants && canPackBiasAsConstants) ? maxParallelPixels : 1;
|
||||
|
||||
Value im2col = createIm2colCompute(x,
|
||||
xType,
|
||||
im2colType,
|
||||
rowType,
|
||||
batchSize,
|
||||
numChannelsIn,
|
||||
xHeight,
|
||||
xWidth,
|
||||
wHeight,
|
||||
wWidth,
|
||||
padHeightBegin,
|
||||
padHeightEnd,
|
||||
padWidthBegin,
|
||||
padWidthEnd,
|
||||
strideHeight,
|
||||
strideWidth,
|
||||
dilationHeight,
|
||||
dilationWidth,
|
||||
outWidth,
|
||||
patchSize,
|
||||
numPatches,
|
||||
numPatchesPerBatch,
|
||||
rewriter,
|
||||
loc);
|
||||
// Keep the standard im2col view of convolution:
|
||||
// A (im2col): [numPatches, patchSize] -- one row per output spatial position
|
||||
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
|
||||
// and optionally repack several old rows into one GEMM row to use the available crossbar size better.
|
||||
//
|
||||
// The im2col compute yields each GEMM input row as a separate result so every GEMM consumes only
|
||||
// the row it needs instead of receiving a full packed tensor and slicing it locally.
|
||||
auto gemmInputRowType =
|
||||
RankedTensorType::get({1, effectiveMaxParallelPixels * patchSize}, elemType);
|
||||
auto gemmOutputRowType =
|
||||
RankedTensorType::get({1, effectiveMaxParallelPixels * numChannelsOut}, outType.getElementType());
|
||||
SmallVector<Value> gemmInputRows = createIm2colRowComputes(x,
|
||||
xType,
|
||||
im2colType,
|
||||
rowType,
|
||||
gemmInputRowType,
|
||||
batchSize,
|
||||
numChannelsIn,
|
||||
xHeight,
|
||||
xWidth,
|
||||
wHeight,
|
||||
wWidth,
|
||||
padHeightBegin,
|
||||
padHeightEnd,
|
||||
padWidthBegin,
|
||||
padWidthEnd,
|
||||
strideHeight,
|
||||
strideWidth,
|
||||
dilationHeight,
|
||||
dilationWidth,
|
||||
outWidth,
|
||||
patchSize,
|
||||
numPatches,
|
||||
numPatchesPerBatch,
|
||||
effectiveMaxParallelPixels,
|
||||
rewriter,
|
||||
loc);
|
||||
|
||||
Value gemmOut;
|
||||
if (effectiveMaxParallelPixels == 1) {
|
||||
// Fallback to the plain im2col GEMM when a single crossbar cannot fit multiple pixels.
|
||||
gemmOut = ONNXGemmOp::create(rewriter,
|
||||
loc,
|
||||
gemmOutType,
|
||||
im2col,
|
||||
wTrans,
|
||||
gemmC,
|
||||
rewriter.getF32FloatAttr(1.0f),
|
||||
rewriter.getF32FloatAttr(1.0f),
|
||||
rewriter.getBoolAttr(false),
|
||||
rewriter.getBoolAttr(false))
|
||||
.getY();
|
||||
}
|
||||
else {
|
||||
// Keep the standard im2col view of convolution:
|
||||
// A (im2col): [numPatches, patchSize] -- one row per output spatial position
|
||||
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
|
||||
// but repack several old rows into one new row so we use the available crossbar size better.
|
||||
//
|
||||
// We want to process N spatial pixels at the exact same time. Instead of doing N separate
|
||||
// operations of (1 x patchSize) x (patchSize x cOut), we construct a block-diagonal weight matrix
|
||||
// containing N copies of W^T and concatenate N im2col rows into one longer row:
|
||||
// A_packed: [ceil(numPatches / N), N * patchSize]
|
||||
// B_packed: [N * patchSize, N * cOut]
|
||||
// Y_packed: [ceil(numPatches / N), N * cOut]
|
||||
// The downstream GemmToManyGemv pass still splits by row, but now there are fewer, longer rows.
|
||||
const int64_t packedNumRows = ceilIntegerDivide(numPatches, effectiveMaxParallelPixels);
|
||||
auto packedOutType =
|
||||
RankedTensorType::get({packedNumRows, effectiveMaxParallelPixels * numChannelsOut}, outType.getElementType());
|
||||
Value gemmB = buildPackedWeight(wDenseAttr,
|
||||
wTrans,
|
||||
wType,
|
||||
numChannelsIn,
|
||||
numChannelsOut,
|
||||
wHeight,
|
||||
wWidth,
|
||||
patchSize,
|
||||
effectiveMaxParallelPixels,
|
||||
rewriter,
|
||||
loc);
|
||||
Value gemmC = buildPackedBias(
|
||||
hasB, gemmBias, biasMatrix, biasDenseAttr, outType, numChannelsOut, effectiveMaxParallelPixels, rewriter, loc);
|
||||
|
||||
Value packedA = createPackedIm2colRows(
|
||||
im2col, im2colType, elemType, numPatches, patchSize, effectiveMaxParallelPixels, rewriter, loc);
|
||||
Value packedB = buildPackedWeight(wDenseAttr,
|
||||
wTrans,
|
||||
wType,
|
||||
numChannelsIn,
|
||||
numChannelsOut,
|
||||
wHeight,
|
||||
wWidth,
|
||||
patchSize,
|
||||
effectiveMaxParallelPixels,
|
||||
rewriter,
|
||||
loc);
|
||||
Value packedC = buildPackedBias(
|
||||
hasB, gemmC, biasMatrix, biasDenseAttr, outType, numChannelsOut, effectiveMaxParallelPixels, rewriter, loc);
|
||||
Value packedOut = ONNXGemmOp::create(rewriter,
|
||||
loc,
|
||||
packedOutType,
|
||||
packedA,
|
||||
packedB,
|
||||
packedC,
|
||||
rewriter.getF32FloatAttr(1.0f),
|
||||
rewriter.getF32FloatAttr(1.0f),
|
||||
rewriter.getBoolAttr(false),
|
||||
rewriter.getBoolAttr(false))
|
||||
.getY();
|
||||
gemmOut = createUnpackedOutput(
|
||||
packedOut, gemmOutType, outType, numPatches, numChannelsOut, effectiveMaxParallelPixels, rewriter, loc);
|
||||
SmallVector<Value> gemmRows;
|
||||
gemmRows.reserve(gemmInputRows.size());
|
||||
for (Value gemmInputRow : gemmInputRows) {
|
||||
Value gemmRow = ONNXGemmOp::create(rewriter,
|
||||
loc,
|
||||
gemmOutputRowType,
|
||||
gemmInputRow,
|
||||
gemmB,
|
||||
gemmC,
|
||||
rewriter.getF32FloatAttr(1.0f),
|
||||
rewriter.getF32FloatAttr(1.0f),
|
||||
rewriter.getBoolAttr(false),
|
||||
rewriter.getBoolAttr(false))
|
||||
.getY();
|
||||
gemmRows.push_back(gemmRow);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(convOp, createCollectedConvOutput(gemmOut, convOp.getType(), nhwcType, outType, rewriter, loc));
|
||||
rewriter.replaceOp(convOp,
|
||||
createCollectedConvOutput(gemmRows,
|
||||
convOp.getType(),
|
||||
gemmOutType,
|
||||
nhwcType,
|
||||
outType,
|
||||
numPatches,
|
||||
numChannelsOut,
|
||||
effectiveMaxParallelPixels,
|
||||
rewriter,
|
||||
loc));
|
||||
return success();
|
||||
}
|
||||
|
||||
|
||||
@@ -42,15 +42,15 @@ private:
|
||||
raw_ostream& os;
|
||||
|
||||
/**
|
||||
* Draws the subgraph for a given spatial::SpatWeightedCompute, including:
|
||||
* Draws the subgraph for a given spatial::SpatCompute, including:
|
||||
* 1. Input nodes (block arguments)
|
||||
* 2. Operations
|
||||
* 3. Edges between yield (output) and its users
|
||||
*
|
||||
* @param op The spatial::SpatWeightedCompute to draw the subgraph for.
|
||||
* @param op The spatial::SpatCompute to draw the subgraph for.
|
||||
* @param computeNum The number of the compute operation.
|
||||
*/
|
||||
void drawComputeOpSubgraph(spatial::SpatWeightedCompute op, size_t computeNum) {
|
||||
void drawComputeOpSubgraph(spatial::SpatCompute op, size_t computeNum) {
|
||||
os << "\tsubgraph cluster" << computeNum << " {\n\t\tlabel=\"Compute" << computeNum << "\";\n"
|
||||
<< "\t\tstyle=filled;\n"
|
||||
<< "\t\tcolor=lightblue;\n";
|
||||
@@ -217,7 +217,7 @@ void SpatialToGraphvizPass::runOnOperation() {
|
||||
// 1. Print their subgraph
|
||||
// 2. Print the edges from its inputs to its outputs
|
||||
for (Operation& op : func.getOps()) {
|
||||
if (auto computeOp = dyn_cast<spatial::SpatWeightedCompute>(op)) {
|
||||
if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) {
|
||||
drawComputeOpSubgraph(computeOp, computeNum++);
|
||||
}
|
||||
else if (auto concatOp = dyn_cast<tensor::ConcatOp>(op)) {
|
||||
|
||||
@@ -62,7 +62,7 @@ private:
|
||||
void runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter);
|
||||
void
|
||||
addReceiveOps(Value channelSourceOp, spatial::SpatChannelNewOp& channel, bool useBroadcastOp, IRRewriter& rewriter);
|
||||
void replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp,
|
||||
void replaceBlockArgumentWithRecvOp(spatial::SpatCompute& computeOp,
|
||||
unsigned int argIndex,
|
||||
Value channelSourceOp,
|
||||
Value consumerValue,
|
||||
@@ -73,7 +73,7 @@ private:
|
||||
void annotateChannelCoreIds(func::FuncOp funcOp);
|
||||
void lowerBroadcastChannelOps(func::FuncOp funcOp, IRRewriter& rewriter);
|
||||
|
||||
void runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter);
|
||||
void runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter& rewriter);
|
||||
|
||||
void enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter);
|
||||
|
||||
@@ -116,7 +116,7 @@ static size_t countComputeLeafUsers(Value value) {
|
||||
auto walkUses = [&](Value currentValue, auto& self) -> void {
|
||||
for (OpOperand& use : currentValue.getUses()) {
|
||||
Operation* owner = use.getOwner();
|
||||
if (isa<spatial::SpatWeightedCompute>(owner)) {
|
||||
if (isa<spatial::SpatCompute>(owner)) {
|
||||
leafUserCount++;
|
||||
continue;
|
||||
}
|
||||
@@ -174,7 +174,7 @@ void SpatialToPimPass::runOnOperation() {
|
||||
markOpToRemove(receiveOp);
|
||||
runOnReceiveOp(receiveOp, rewriter);
|
||||
}
|
||||
for (auto computeOp : funcOp.getOps<spatial::SpatWeightedCompute>()) {
|
||||
for (auto computeOp : funcOp.getOps<spatial::SpatCompute>()) {
|
||||
markOpToRemove(computeOp);
|
||||
runOnComputeOp(computeOp, rewriter);
|
||||
}
|
||||
@@ -222,7 +222,7 @@ void SpatialToPimPass::runOnOperation() {
|
||||
dumpModule(moduleOp, "pim0");
|
||||
}
|
||||
|
||||
void SpatialToPimPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter) {
|
||||
void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter& rewriter) {
|
||||
Location loc = computeOp->getLoc();
|
||||
|
||||
auto& block = computeOp.getRegion().front();
|
||||
@@ -504,7 +504,7 @@ LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::Fu
|
||||
|
||||
llvm::SmallSet<tensor::ExtractSliceOp, 8> sliceOpsToRemove;
|
||||
for (auto& op : funcOp.getBody().getOps())
|
||||
if (auto computeOp = dyn_cast<spatial::SpatWeightedCompute>(op)) {
|
||||
if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) {
|
||||
unsigned numComputeWeights = computeOp.getWeights().size();
|
||||
for (auto [computeInputIdx, computeOpInput] : llvm::enumerate(computeOp.getInputs())) {
|
||||
TypedValue<TensorType> tensorSource;
|
||||
@@ -513,7 +513,7 @@ LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::Fu
|
||||
if (auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(computeOpInput.getDefiningOp())) {
|
||||
tensorSource = cast<TypedValue<TensorType>>(sliceOp.getSource());
|
||||
|
||||
if (isa<spatial::SpatWeightedCompute>(tensorSource.getDefiningOp()))
|
||||
if (isa<spatial::SpatCompute>(tensorSource.getDefiningOp()))
|
||||
continue;
|
||||
|
||||
ArrayRef<int64_t> sourceShape = tensorSource.getType().getShape();
|
||||
@@ -538,7 +538,7 @@ LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::Fu
|
||||
tensorSource = cast<TypedValue<TensorType>>(computeOpInput);
|
||||
|
||||
// Compute results must be transferred through channels via send/receive
|
||||
if (isa<spatial::SpatWeightedCompute>(tensorSource.getDefiningOp()))
|
||||
if (isa<spatial::SpatCompute>(tensorSource.getDefiningOp()))
|
||||
continue;
|
||||
|
||||
BlockArgument computeBlockArgToReplace = computeOp.getBody().front().getArgument(computeInputIdx);
|
||||
@@ -553,7 +553,7 @@ LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::Fu
|
||||
return success();
|
||||
}
|
||||
|
||||
void SpatialToPimPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp,
|
||||
void SpatialToPimPass::replaceBlockArgumentWithRecvOp(spatial::SpatCompute& computeOp,
|
||||
unsigned int argIndex,
|
||||
Value channelSourceOp,
|
||||
Value consumerValue,
|
||||
@@ -614,7 +614,7 @@ void SpatialToPimPass::addReceiveOps(Value channelSourceOp,
|
||||
auto replayUsesIntoConsumers = [&](Value currentValue, auto& self) -> void {
|
||||
for (OpOperand& use : currentValue.getUses()) {
|
||||
Operation* owner = use.getOwner();
|
||||
if (auto computeUser = dyn_cast<spatial::SpatWeightedCompute>(owner)) {
|
||||
if (auto computeUser = dyn_cast<spatial::SpatCompute>(owner)) {
|
||||
replaceBlockArgumentWithRecvOp(
|
||||
computeUser, use.getOperandNumber(), channelSourceOp, currentValue, channel, useBroadcastOp, rewriter);
|
||||
continue;
|
||||
|
||||
@@ -93,15 +93,22 @@ void PimBufferizationPass::runOnOperation() {
|
||||
}
|
||||
|
||||
void PimBufferizationPass::annotateWeightsMemrefs(ModuleOp moduleOp, func::FuncOp funcOp) const {
|
||||
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
|
||||
bool isAlwaysWeight = !getGlobalOp->getUsers().empty()
|
||||
&& all_of(getGlobalOp->getUsers(), [](auto user) -> bool { return isa<PimCoreOp>(user); });
|
||||
if (isAlwaysWeight) {
|
||||
funcOp.walk([&](PimCoreOp coreOp) {
|
||||
auto annotateWeight = [&](unsigned weightIndex) {
|
||||
if (weightIndex >= coreOp.getWeights().size())
|
||||
return;
|
||||
Value weight = coreOp.getWeights()[weightIndex];
|
||||
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
|
||||
if (!getGlobalOp)
|
||||
return;
|
||||
auto globalMemrefOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||
assert("Weights must be constants" && globalMemrefOp.getConstant());
|
||||
markWeightAlways(getGlobalOp);
|
||||
markWeightAlways(globalMemrefOp);
|
||||
}
|
||||
};
|
||||
|
||||
coreOp.walk([&](PimMVMOp mvmOp) { annotateWeight(mvmOp.getWeightIndex()); });
|
||||
coreOp.walk([&](PimVMMOp vmmOp) { annotateWeight(vmmOp.getWeightIndex()); });
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ def SpatChannelType : SpatType<"SpatChannel", "ch"> {
|
||||
// Execution
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def SpatWeightedCompute : SpatOp<"compute", [SingleBlock, AttrSizedOperandSegments]> {
|
||||
def SpatCompute : SpatOp<"compute", [SingleBlock, AttrSizedOperandSegments]> {
|
||||
let summary = "Compute region with attached constant weights";
|
||||
|
||||
let arguments = (ins
|
||||
|
||||
@@ -119,7 +119,7 @@ inline LogicalResult mvmOpVerifySize4(SpatWeightedMVMOp* emitter,
|
||||
}
|
||||
|
||||
llvm::FailureOr<ArrayRef<int64_t>> getWeightShapeForWeightedOp(Operation* weigthedOp, size_t weightIndex) {
|
||||
auto wcomputeOp = dyn_cast<SpatWeightedCompute>(weigthedOp->getParentOp());
|
||||
auto wcomputeOp = dyn_cast<SpatCompute>(weigthedOp->getParentOp());
|
||||
if (wcomputeOp)
|
||||
return cast<ShapedType>(wcomputeOp.getWeights()[weightIndex].getType()).getShape();
|
||||
|
||||
@@ -134,7 +134,7 @@ llvm::FailureOr<ArrayRef<int64_t>> getWeightShapeForWeightedOp(Operation* weigth
|
||||
LogicalResult SpatWeightedMVMOp::verify() {
|
||||
auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex());
|
||||
if (failed(matrixShapeOpt))
|
||||
return emitError("SpatWeightedMVMOp was not within a SpatWeightedCompute or Core op");
|
||||
return emitError("SpatWeightedMVMOp was not within a SpatCompute or Core op");
|
||||
auto matrixShape = *matrixShapeOpt;
|
||||
auto vectorShape = getInput().getType().getShape();
|
||||
auto outputShape = getOutput().getType().getShape();
|
||||
@@ -155,7 +155,7 @@ LogicalResult SpatWeightedMVMOp::verify() {
|
||||
LogicalResult SpatWeightedVMMOp::verify() {
|
||||
auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex());
|
||||
if (failed(matrixShapeOpt))
|
||||
return emitError("SpatWeightedVMMOp was not within a SpatWeightedCompute or Core op");
|
||||
return emitError("SpatWeightedVMMOp was not within a SpatCompute or Core op");
|
||||
auto matrixShape = *matrixShapeOpt;
|
||||
auto vectorShape = getInput().getType().getShape();
|
||||
auto outputShape = getOutput().getType().getShape();
|
||||
@@ -200,9 +200,8 @@ LogicalResult SpatVMaxOp::verify() {
|
||||
return OpTrait::impl::verifySameOperandsAndResultType(*this);
|
||||
}
|
||||
|
||||
LogicalResult SpatWeightedCompute::verify() {
|
||||
// Check that it has a terminator, it is a yieldOp, and it has a single
|
||||
// operand with the same type as the result
|
||||
LogicalResult SpatCompute::verify() {
|
||||
// Check that the terminator yields the same number and types as the compute results.
|
||||
auto& block = getBody().front();
|
||||
if (block.mightHaveTerminator()) {
|
||||
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
|
||||
@@ -257,7 +256,7 @@ LogicalResult SpatWeightedCompute::verify() {
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult SpatWeightedCompute::fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::mlir::OpFoldResult>& results) {
|
||||
LogicalResult SpatCompute::fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::mlir::OpFoldResult>& results) {
|
||||
Block& block = getBody().front();
|
||||
if (!llvm::hasSingleElement(block))
|
||||
return failure();
|
||||
|
||||
@@ -6,10 +6,18 @@
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <iterator>
|
||||
#include <map>
|
||||
#include <numeric>
|
||||
#include <optional>
|
||||
#include <set>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "DCPAnalysis.hpp"
|
||||
#include "Graph.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Support/TypeUtilities.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
@@ -17,7 +25,362 @@ namespace spatial {
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
SpatWeightedCompute getOriginalSpatWeightedCompute(Operation* op) {
|
||||
namespace {
|
||||
|
||||
struct VirtualNode {
|
||||
llvm::SmallVector<size_t, 4> originalComputeIndices;
|
||||
Weight weight = 0;
|
||||
CrossbarUsage crossbarUsage = 0;
|
||||
};
|
||||
|
||||
struct VirtualGraph {
|
||||
std::vector<VirtualNode> nodes;
|
||||
std::vector<IndexedEdge> edges;
|
||||
};
|
||||
|
||||
struct TimingInfo {
|
||||
std::vector<Time> aest;
|
||||
std::vector<Time> alst;
|
||||
std::vector<size_t> topologicalOrder;
|
||||
bool valid = false;
|
||||
};
|
||||
|
||||
struct WindowScheduleResult {
|
||||
std::vector<std::vector<size_t>> mergeGroups;
|
||||
bool usedAllAvailableCpus = false;
|
||||
};
|
||||
|
||||
std::vector<IndexedEdge> aggregateEdges(llvm::ArrayRef<IndexedEdge> edges) {
|
||||
std::map<std::pair<size_t, size_t>, Weight> edgeWeights;
|
||||
for (auto [start, end, weight] : edges) {
|
||||
size_t startIndex = static_cast<size_t>(start);
|
||||
size_t endIndex = static_cast<size_t>(end);
|
||||
if (startIndex == endIndex)
|
||||
continue;
|
||||
auto key = std::make_pair(startIndex, endIndex);
|
||||
Weight edgeWeight = static_cast<Weight>(weight);
|
||||
auto it = edgeWeights.find(key);
|
||||
if (it == edgeWeights.end())
|
||||
edgeWeights.insert({key, edgeWeight});
|
||||
else
|
||||
it->second = std::max(it->second, edgeWeight);
|
||||
}
|
||||
|
||||
std::vector<IndexedEdge> aggregatedEdges;
|
||||
aggregatedEdges.reserve(edgeWeights.size());
|
||||
for (auto [key, weight] : edgeWeights)
|
||||
aggregatedEdges.push_back(
|
||||
{static_cast<int64_t>(key.first), static_cast<int64_t>(key.second), static_cast<int64_t>(weight)});
|
||||
return aggregatedEdges;
|
||||
}
|
||||
|
||||
VirtualGraph buildInitialVirtualGraph(llvm::ArrayRef<SpatCompute> spatComputes,
|
||||
llvm::ArrayRef<IndexedEdge> edges) {
|
||||
VirtualGraph graph;
|
||||
graph.nodes.reserve(spatComputes.size());
|
||||
for (auto [index, spatCompute] : llvm::enumerate(spatComputes)) {
|
||||
VirtualNode node;
|
||||
node.originalComputeIndices.push_back(index);
|
||||
node.weight = getSpatComputeWeight(spatCompute);
|
||||
node.crossbarUsage = getSpatComputeCrossbarUsage(spatCompute);
|
||||
graph.nodes.push_back(std::move(node));
|
||||
}
|
||||
graph.edges = aggregateEdges(edges);
|
||||
return graph;
|
||||
}
|
||||
|
||||
TimingInfo computeTiming(const VirtualGraph& graph) {
|
||||
TimingInfo timing;
|
||||
size_t nodeCount = graph.nodes.size();
|
||||
timing.aest.assign(nodeCount, 0);
|
||||
timing.alst.assign(nodeCount, 0);
|
||||
timing.topologicalOrder.reserve(nodeCount);
|
||||
|
||||
std::vector<std::vector<std::pair<size_t, Weight>>> parents(nodeCount);
|
||||
std::vector<std::vector<std::pair<size_t, Weight>>> children(nodeCount);
|
||||
std::vector<size_t> incomingEdgeCount(nodeCount, 0);
|
||||
|
||||
for (auto [start, end, weight] : graph.edges) {
|
||||
size_t startIndex = static_cast<size_t>(start);
|
||||
size_t endIndex = static_cast<size_t>(end);
|
||||
Weight edgeWeight = static_cast<Weight>(weight);
|
||||
assert(startIndex < nodeCount && endIndex < nodeCount && "virtual edge endpoint out of range");
|
||||
children[startIndex].push_back({endIndex, edgeWeight});
|
||||
parents[endIndex].push_back({startIndex, edgeWeight});
|
||||
incomingEdgeCount[endIndex]++;
|
||||
}
|
||||
|
||||
std::vector<size_t> readyNodes;
|
||||
readyNodes.reserve(nodeCount);
|
||||
for (size_t i = 0; i < nodeCount; ++i)
|
||||
if (incomingEdgeCount[i] == 0)
|
||||
readyNodes.push_back(i);
|
||||
|
||||
size_t readyIndex = 0;
|
||||
while (readyIndex != readyNodes.size()) {
|
||||
size_t current = readyNodes[readyIndex++];
|
||||
timing.topologicalOrder.push_back(current);
|
||||
for (auto [child, weight] : children[current]) {
|
||||
(void) weight;
|
||||
assert(incomingEdgeCount[child] > 0 && "incoming edge count underflow");
|
||||
incomingEdgeCount[child]--;
|
||||
if (incomingEdgeCount[child] == 0)
|
||||
readyNodes.push_back(child);
|
||||
}
|
||||
}
|
||||
|
||||
if (timing.topologicalOrder.size() != nodeCount)
|
||||
return timing;
|
||||
|
||||
Time dcpl = 0;
|
||||
for (size_t nodeIndex : timing.topologicalOrder) {
|
||||
Time maxParentAest = 0;
|
||||
for (auto [parent, transferCost] : parents[nodeIndex]) {
|
||||
maxParentAest =
|
||||
std::max(maxParentAest, addOrMax(addOrMax(timing.aest[parent], graph.nodes[parent].weight), transferCost));
|
||||
}
|
||||
timing.aest[nodeIndex] = maxParentAest;
|
||||
dcpl = std::max(dcpl, addOrMax(maxParentAest, graph.nodes[nodeIndex].weight));
|
||||
}
|
||||
|
||||
for (size_t nodeIndex : llvm::reverse(timing.topologicalOrder)) {
|
||||
Time minAlst = std::numeric_limits<Time>::max();
|
||||
if (children[nodeIndex].empty())
|
||||
minAlst = subtractOrZero(dcpl, graph.nodes[nodeIndex].weight);
|
||||
for (auto [child, transferCost] : children[nodeIndex]) {
|
||||
minAlst =
|
||||
std::min(minAlst, subtractOrZero(timing.alst[child], addOrMax(graph.nodes[nodeIndex].weight, transferCost)));
|
||||
}
|
||||
timing.alst[nodeIndex] = minAlst;
|
||||
}
|
||||
|
||||
timing.valid = true;
|
||||
return timing;
|
||||
}
|
||||
|
||||
std::vector<size_t> selectCriticalWindow(const TimingInfo& timing, size_t windowSize) {
|
||||
std::vector<size_t> selected(timing.aest.size());
|
||||
std::iota(selected.begin(), selected.end(), 0);
|
||||
std::stable_sort(selected.begin(), selected.end(), [&](size_t lhs, size_t rhs) {
|
||||
Time lhsSlack = slackOrZero(timing.aest[lhs], timing.alst[lhs]);
|
||||
Time rhsSlack = slackOrZero(timing.aest[rhs], timing.alst[rhs]);
|
||||
if (lhsSlack != rhsSlack)
|
||||
return lhsSlack < rhsSlack;
|
||||
if (timing.aest[lhs] != timing.aest[rhs])
|
||||
return timing.aest[lhs] < timing.aest[rhs];
|
||||
return lhs < rhs;
|
||||
});
|
||||
selected.resize(std::min(windowSize, selected.size()));
|
||||
return selected;
|
||||
}
|
||||
|
||||
std::vector<size_t> getOriginalSignature(const VirtualGraph& graph, llvm::ArrayRef<size_t> selectedNodes) {
|
||||
std::vector<size_t> signature;
|
||||
for (size_t nodeIndex : selectedNodes) {
|
||||
const VirtualNode& node = graph.nodes[nodeIndex];
|
||||
signature.insert(signature.end(), node.originalComputeIndices.begin(), node.originalComputeIndices.end());
|
||||
}
|
||||
std::sort(signature.begin(), signature.end());
|
||||
return signature;
|
||||
}
|
||||
|
||||
std::vector<IndexedEdge> buildWindowEdges(const VirtualGraph& graph, const std::vector<int64_t>& nodeToWindowIndex) {
|
||||
std::vector<IndexedEdge> windowEdges;
|
||||
windowEdges.reserve(graph.edges.size());
|
||||
for (auto [start, end, weight] : graph.edges) {
|
||||
int64_t mappedStart = nodeToWindowIndex[static_cast<size_t>(start)];
|
||||
int64_t mappedEnd = nodeToWindowIndex[static_cast<size_t>(end)];
|
||||
if (mappedStart == -1 || mappedEnd == -1)
|
||||
continue;
|
||||
windowEdges.push_back({mappedStart, mappedEnd, weight});
|
||||
}
|
||||
return aggregateEdges(windowEdges);
|
||||
}
|
||||
|
||||
WindowScheduleResult
|
||||
scheduleWindow(const VirtualGraph& graph, llvm::ArrayRef<size_t> selectedNodes, MLIRContext* context) {
|
||||
std::vector<Weight> windowWeights;
|
||||
std::vector<CrossbarUsage> windowCrossbarUsage;
|
||||
std::vector<int64_t> nodeToWindowIndex(graph.nodes.size(), -1);
|
||||
windowWeights.reserve(selectedNodes.size());
|
||||
windowCrossbarUsage.reserve(selectedNodes.size());
|
||||
|
||||
for (auto [windowIndex, nodeIndex] : llvm::enumerate(selectedNodes)) {
|
||||
nodeToWindowIndex[nodeIndex] = static_cast<int64_t>(windowIndex);
|
||||
windowWeights.push_back(graph.nodes[nodeIndex].weight);
|
||||
windowCrossbarUsage.push_back(graph.nodes[nodeIndex].crossbarUsage);
|
||||
}
|
||||
|
||||
GraphDCP windowGraph(windowWeights, buildWindowEdges(graph, nodeToWindowIndex), windowCrossbarUsage);
|
||||
if (coresCount.getValue() > 0)
|
||||
windowGraph.setMaxCpuCount(static_cast<int>(coresCount.getValue()));
|
||||
windowGraph.setContext(context);
|
||||
windowGraph.runDcp();
|
||||
|
||||
WindowScheduleResult result;
|
||||
result.usedAllAvailableCpus = windowGraph.cpuCount() >= windowGraph.getMaxCpuCount();
|
||||
for (CPU cpu = 0; cpu < windowGraph.cpuCount(); ++cpu) {
|
||||
auto scheduledTasks = windowGraph.getScheduledTasks(cpu);
|
||||
if (scheduledTasks.size() < 2)
|
||||
continue;
|
||||
|
||||
std::vector<size_t> mergeGroup;
|
||||
mergeGroup.reserve(scheduledTasks.size());
|
||||
for (const auto& task : scheduledTasks)
|
||||
mergeGroup.push_back(selectedNodes[task.nodeIndex]);
|
||||
std::sort(mergeGroup.begin(), mergeGroup.end());
|
||||
result.mergeGroups.push_back(std::move(mergeGroup));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
bool coarsenGraph(const VirtualGraph& graph,
|
||||
llvm::ArrayRef<std::vector<size_t>> mergeGroups,
|
||||
VirtualGraph& coarsenedGraph) {
|
||||
std::vector<int64_t> nodeToMergeGroup(graph.nodes.size(), -1);
|
||||
for (auto [groupIndex, mergeGroup] : llvm::enumerate(mergeGroups)) {
|
||||
if (mergeGroup.size() < 2)
|
||||
continue;
|
||||
for (size_t nodeIndex : mergeGroup) {
|
||||
assert(nodeIndex < graph.nodes.size() && "merge group node out of range");
|
||||
nodeToMergeGroup[nodeIndex] = static_cast<int64_t>(groupIndex);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::optional<size_t>> mergeGroupToNewNode(mergeGroups.size());
|
||||
std::vector<size_t> oldToNewNode(graph.nodes.size(), 0);
|
||||
bool mergedAny = false;
|
||||
coarsenedGraph.nodes.clear();
|
||||
coarsenedGraph.edges.clear();
|
||||
coarsenedGraph.nodes.reserve(graph.nodes.size());
|
||||
|
||||
for (size_t nodeIndex = 0; nodeIndex < graph.nodes.size(); ++nodeIndex) {
|
||||
int64_t mergeGroupIndex = nodeToMergeGroup[nodeIndex];
|
||||
if (mergeGroupIndex == -1) {
|
||||
oldToNewNode[nodeIndex] = coarsenedGraph.nodes.size();
|
||||
coarsenedGraph.nodes.push_back(graph.nodes[nodeIndex]);
|
||||
continue;
|
||||
}
|
||||
|
||||
auto& newNodeIndex = mergeGroupToNewNode[static_cast<size_t>(mergeGroupIndex)];
|
||||
if (newNodeIndex.has_value()) {
|
||||
oldToNewNode[nodeIndex] = *newNodeIndex;
|
||||
continue;
|
||||
}
|
||||
|
||||
VirtualNode mergedNode;
|
||||
for (size_t memberIndex : mergeGroups[static_cast<size_t>(mergeGroupIndex)]) {
|
||||
const VirtualNode& memberNode = graph.nodes[memberIndex];
|
||||
mergedNode.originalComputeIndices.append(memberNode.originalComputeIndices.begin(),
|
||||
memberNode.originalComputeIndices.end());
|
||||
mergedNode.weight = addOrMax(mergedNode.weight, memberNode.weight);
|
||||
mergedNode.crossbarUsage = addOrMax(mergedNode.crossbarUsage, memberNode.crossbarUsage);
|
||||
}
|
||||
std::sort(mergedNode.originalComputeIndices.begin(), mergedNode.originalComputeIndices.end());
|
||||
|
||||
mergedAny = true;
|
||||
newNodeIndex = coarsenedGraph.nodes.size();
|
||||
for (size_t memberIndex : mergeGroups[static_cast<size_t>(mergeGroupIndex)])
|
||||
oldToNewNode[memberIndex] = *newNodeIndex;
|
||||
coarsenedGraph.nodes.push_back(std::move(mergedNode));
|
||||
}
|
||||
|
||||
if (!mergedAny)
|
||||
return false;
|
||||
|
||||
std::vector<IndexedEdge> remappedEdges;
|
||||
remappedEdges.reserve(graph.edges.size());
|
||||
for (auto [start, end, weight] : graph.edges) {
|
||||
size_t newStart = oldToNewNode[static_cast<size_t>(start)];
|
||||
size_t newEnd = oldToNewNode[static_cast<size_t>(end)];
|
||||
if (newStart == newEnd)
|
||||
continue;
|
||||
remappedEdges.push_back({static_cast<int64_t>(newStart), static_cast<int64_t>(newEnd), weight});
|
||||
}
|
||||
coarsenedGraph.edges = aggregateEdges(remappedEdges);
|
||||
|
||||
return computeTiming(coarsenedGraph).valid;
|
||||
}
|
||||
|
||||
bool coarsenGraphWithFallback(const VirtualGraph& graph,
|
||||
llvm::ArrayRef<std::vector<size_t>> mergeGroups,
|
||||
VirtualGraph& coarsenedGraph) {
|
||||
if (coarsenGraph(graph, mergeGroups, coarsenedGraph))
|
||||
return true;
|
||||
|
||||
std::vector<size_t> orderedGroupIndices(mergeGroups.size());
|
||||
std::iota(orderedGroupIndices.begin(), orderedGroupIndices.end(), 0);
|
||||
std::stable_sort(orderedGroupIndices.begin(), orderedGroupIndices.end(), [&](size_t lhs, size_t rhs) {
|
||||
return mergeGroups[lhs].size() > mergeGroups[rhs].size();
|
||||
});
|
||||
|
||||
std::vector<std::vector<size_t>> acceptedMergeGroups;
|
||||
acceptedMergeGroups.reserve(mergeGroups.size());
|
||||
for (size_t groupIndex : orderedGroupIndices) {
|
||||
std::vector<std::vector<size_t>> candidateMergeGroups = acceptedMergeGroups;
|
||||
candidateMergeGroups.push_back(mergeGroups[groupIndex]);
|
||||
|
||||
VirtualGraph candidateGraph;
|
||||
if (!coarsenGraph(graph, candidateMergeGroups, candidateGraph))
|
||||
continue;
|
||||
|
||||
acceptedMergeGroups = std::move(candidateMergeGroups);
|
||||
coarsenedGraph = std::move(candidateGraph);
|
||||
}
|
||||
return !acceptedMergeGroups.empty();
|
||||
}
|
||||
|
||||
std::vector<size_t> computeOriginalTopologicalOrder(size_t computeCount, llvm::ArrayRef<IndexedEdge> edges) {
|
||||
VirtualGraph graph;
|
||||
graph.nodes.resize(computeCount);
|
||||
graph.edges = aggregateEdges(edges);
|
||||
TimingInfo timing = computeTiming(graph);
|
||||
if (timing.valid)
|
||||
return timing.topologicalOrder;
|
||||
|
||||
std::vector<size_t> fallbackOrder(computeCount);
|
||||
std::iota(fallbackOrder.begin(), fallbackOrder.end(), 0);
|
||||
return fallbackOrder;
|
||||
}
|
||||
|
||||
DCPAnalysisResult buildResultFromVirtualGraph(const VirtualGraph& graph,
|
||||
llvm::ArrayRef<SpatCompute> spatComputes,
|
||||
llvm::ArrayRef<IndexedEdge> originalEdges) {
|
||||
DCPAnalysisResult result;
|
||||
std::vector<size_t> originalToVirtualNode(spatComputes.size(), 0);
|
||||
for (auto [virtualNodeIndex, virtualNode] : llvm::enumerate(graph.nodes))
|
||||
for (size_t originalIndex : virtualNode.originalComputeIndices)
|
||||
originalToVirtualNode[originalIndex] = virtualNodeIndex;
|
||||
|
||||
auto dominanceOrder = computeOriginalTopologicalOrder(spatComputes.size(), originalEdges);
|
||||
result.dominanceOrderCompute.reserve(dominanceOrder.size());
|
||||
for (size_t originalIndex : dominanceOrder) {
|
||||
SpatCompute spatCompute = spatComputes[originalIndex];
|
||||
size_t cpu = originalToVirtualNode[originalIndex];
|
||||
result.dominanceOrderCompute.push_back(spatCompute);
|
||||
result.computeToCpuMap[spatCompute] = cpu;
|
||||
result.cpuToLastComputeMap[cpu] = spatCompute;
|
||||
}
|
||||
|
||||
for (auto [cpu, lastCompute] : result.cpuToLastComputeMap)
|
||||
result.isLastComputeOfCpu.insert(lastCompute);
|
||||
return result;
|
||||
}
|
||||
|
||||
DCPAnalysisResult runLegacyDcp(llvm::ArrayRef<SpatCompute> spatComputes,
|
||||
llvm::ArrayRef<IndexedEdge> edges,
|
||||
MLIRContext* context) {
|
||||
GraphDCP graphDCP(spatComputes, edges);
|
||||
if (coresCount.getValue() > 0)
|
||||
graphDCP.setMaxCpuCount(static_cast<int>(coresCount.getValue()));
|
||||
graphDCP.setContext(context);
|
||||
graphDCP.runDcp();
|
||||
return graphDCP.getResult();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
SpatCompute getOriginalSpatCompute(Operation* op) {
|
||||
if (!op)
|
||||
return {};
|
||||
while (auto extract = llvm::dyn_cast<tensor::ExtractSliceOp>(op)) {
|
||||
@@ -25,38 +388,59 @@ SpatWeightedCompute getOriginalSpatWeightedCompute(Operation* op) {
|
||||
if (!op)
|
||||
return {};
|
||||
}
|
||||
if (auto res = llvm::dyn_cast<SpatWeightedCompute>(op))
|
||||
if (auto res = llvm::dyn_cast<SpatCompute>(op))
|
||||
return res;
|
||||
return {};
|
||||
}
|
||||
|
||||
DCPAnalysisResult DCPAnalysis::run() {
|
||||
llvm::SmallVector<SpatWeightedCompute, 10> spatWeightedComputes;
|
||||
llvm::SmallVector<IndexedEdge, 10> edges;
|
||||
SmallVector<SpatCompute, 10> spatComputes;
|
||||
SmallVector<IndexedEdge, 10> edges;
|
||||
for (auto& region : entryOp->getRegions())
|
||||
for (SpatWeightedCompute spatWeightedCompute : region.getOps<SpatWeightedCompute>())
|
||||
spatWeightedComputes.push_back(spatWeightedCompute);
|
||||
for (SpatCompute spatCompute : region.getOps<SpatCompute>())
|
||||
spatComputes.push_back(spatCompute);
|
||||
|
||||
for (auto [indexEndEdge, spatWeightedCompute] : llvm::enumerate(spatWeightedComputes)) {
|
||||
for (Value input : spatWeightedCompute.getInputs()) {
|
||||
if (auto producerCompute = getOriginalSpatWeightedCompute(input.getDefiningOp())) {
|
||||
auto producerIt = llvm::find(spatWeightedComputes, producerCompute);
|
||||
assert(producerIt != spatWeightedComputes.end());
|
||||
auto indexStartEdge = std::distance(spatWeightedComputes.begin(), producerIt);
|
||||
ResultRange outputs = producerCompute.getResults();
|
||||
int64_t totalSize = 0;
|
||||
for (auto output : outputs) {
|
||||
ShapedType resultType = cast<ShapedType>(output.getType());
|
||||
totalSize += getSizeInBytes(resultType);
|
||||
}
|
||||
edges.push_back({indexStartEdge, indexEndEdge, totalSize});
|
||||
for (auto [indexEndEdge, spatCompute] : llvm::enumerate(spatComputes)) {
|
||||
for (Value input : spatCompute.getInputs()) {
|
||||
if (auto producerCompute = getOriginalSpatCompute(input.getDefiningOp())) {
|
||||
auto producerIt = llvm::find(spatComputes, producerCompute);
|
||||
assert(producerIt != spatComputes.end());
|
||||
auto indexStartEdge = std::distance(spatComputes.begin(), producerIt);
|
||||
edges.push_back({indexStartEdge, indexEndEdge, getSizeInBytes(cast<ShapedType>(input.getType()))});
|
||||
}
|
||||
}
|
||||
}
|
||||
GraphDCP graphDCP(spatWeightedComputes, edges);
|
||||
graphDCP.setContext(entryOp->getContext());
|
||||
graphDCP.runDcp();
|
||||
return graphDCP.getResult();
|
||||
|
||||
if (dcpCriticalWindowSize.getValue() == 0)
|
||||
return runLegacyDcp(spatComputes, edges, entryOp->getContext());
|
||||
|
||||
VirtualGraph virtualGraph = buildInitialVirtualGraph(spatComputes, edges);
|
||||
std::set<std::vector<size_t>> seenCriticalWindows;
|
||||
while (virtualGraph.nodes.size() > 1) {
|
||||
TimingInfo timing = computeTiming(virtualGraph);
|
||||
if (!timing.valid)
|
||||
break;
|
||||
|
||||
auto selectedNodes = selectCriticalWindow(timing, dcpCriticalWindowSize.getValue());
|
||||
if (selectedNodes.size() < 2)
|
||||
break;
|
||||
|
||||
if (!seenCriticalWindows.insert(getOriginalSignature(virtualGraph, selectedNodes)).second)
|
||||
break;
|
||||
|
||||
WindowScheduleResult windowSchedule = scheduleWindow(virtualGraph, selectedNodes, entryOp->getContext());
|
||||
if (windowSchedule.mergeGroups.empty())
|
||||
break;
|
||||
|
||||
VirtualGraph coarsenedGraph;
|
||||
if (!coarsenGraphWithFallback(virtualGraph, windowSchedule.mergeGroups, coarsenedGraph))
|
||||
break;
|
||||
virtualGraph = std::move(coarsenedGraph);
|
||||
if (windowSchedule.usedAllAvailableCpus)
|
||||
break;
|
||||
}
|
||||
|
||||
return buildResultFromVirtualGraph(virtualGraph, spatComputes, edges);
|
||||
}
|
||||
|
||||
} // namespace spatial
|
||||
|
||||
@@ -10,10 +10,10 @@
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
struct DCPAnalysisResult {
|
||||
std::vector<onnx_mlir::spatial::SpatWeightedCompute> dominanceOrderCompute;
|
||||
llvm::DenseMap<onnx_mlir::spatial::SpatWeightedCompute, size_t> computeToCpuMap;
|
||||
llvm::DenseSet<onnx_mlir::spatial::SpatWeightedCompute> isLastComputeOfCpu;
|
||||
llvm::DenseMap<size_t, onnx_mlir::spatial::SpatWeightedCompute> cpuToLastComputeMap;
|
||||
std::vector<onnx_mlir::spatial::SpatCompute> dominanceOrderCompute;
|
||||
llvm::DenseMap<onnx_mlir::spatial::SpatCompute, size_t> computeToCpuMap;
|
||||
llvm::DenseSet<onnx_mlir::spatial::SpatCompute> isLastComputeOfCpu;
|
||||
llvm::DenseMap<size_t, onnx_mlir::spatial::SpatCompute> cpuToLastComputeMap;
|
||||
};
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
// consumer land on different CPUs.
|
||||
//
|
||||
// Output: an assignment of every task to a CPU and an order within that CPU,
|
||||
// aiming to minimise the overall critical-path length (DCPL).
|
||||
// aiming to minimize the overall critical-path length (DCPL).
|
||||
//
|
||||
// Every task keeps two timing estimates:
|
||||
// AEST - earliest start time, driven by parent completions + transfers.
|
||||
@@ -16,9 +16,9 @@
|
||||
// Main loop (runDcp):
|
||||
// 1. Build a topological order and seed AEST/ALST from the unscheduled DAG.
|
||||
// 2. While there are ready tasks (all dependency parents scheduled):
|
||||
// a. Pick the candidate with tightest slack (earliest AEST breaks ties).
|
||||
// a. Pick the candidate with the tightest slack (earliest AEST breaks ties).
|
||||
// b. selectProcessor() tries every candidate CPU and picks the one that
|
||||
// minimises a composite cost (own slot + smallest unscheduled child).
|
||||
// minimizes a composite cost (own slot + the smallest unscheduled child).
|
||||
// c. Commit the placement and refresh AEST/ALST.
|
||||
// d. Release any child whose dependency parents are now all scheduled.
|
||||
//
|
||||
@@ -43,7 +43,6 @@
|
||||
#include <cassert>
|
||||
#include <chrono>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <vector>
|
||||
|
||||
#include "DCPAnalysis.hpp"
|
||||
@@ -1261,7 +1260,7 @@ DCPAnalysisResult GraphDCP::getResult() {
|
||||
auto dominanceOrder = dcp_graph::collectDominanceOrder(getRoots(), nodes.size());
|
||||
ret.dominanceOrderCompute.reserve(dominanceOrder.size());
|
||||
for (auto elem : dominanceOrder)
|
||||
ret.dominanceOrderCompute.push_back(elem->getSpatWeightedCompute());
|
||||
ret.dominanceOrderCompute.push_back(elem->getSpatCompute());
|
||||
|
||||
for (CPU cpu = 0; cpu < getLastCpu(); ++cpu) {
|
||||
const CpuTaskList* tasks = findCpuTasks(cpu);
|
||||
@@ -1269,10 +1268,10 @@ DCPAnalysisResult GraphDCP::getResult() {
|
||||
continue;
|
||||
size_t i = 0;
|
||||
for (auto node : *tasks) {
|
||||
ret.computeToCpuMap[node->getSpatWeightedCompute()] = cpu;
|
||||
ret.computeToCpuMap[node->getSpatCompute()] = cpu;
|
||||
if (i++ == tasks->size() - 1) {
|
||||
ret.isLastComputeOfCpu.insert(node->getSpatWeightedCompute());
|
||||
ret.cpuToLastComputeMap[cpu] = node->getSpatWeightedCompute();
|
||||
ret.isLastComputeOfCpu.insert(node->getSpatCompute());
|
||||
ret.cpuToLastComputeMap[cpu] = node->getSpatCompute();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -115,11 +115,11 @@ private:
|
||||
|
||||
public:
|
||||
void runDcp();
|
||||
GraphDCP(llvm::ArrayRef<onnx_mlir::spatial::SpatWeightedCompute> spatWeightedComputes,
|
||||
GraphDCP(llvm::ArrayRef<onnx_mlir::spatial::SpatCompute> spatComputes,
|
||||
llvm::ArrayRef<IndexedEdge> edges)
|
||||
: nodes(), cpuTasks(), cpuCrossbarUsage() {
|
||||
for (auto spatWeightedCompute : spatWeightedComputes)
|
||||
nodes.emplace_back(spatWeightedCompute);
|
||||
for (auto spatCompute : spatComputes)
|
||||
nodes.emplace_back(spatCompute);
|
||||
for (auto [start, end, weight] : edges)
|
||||
makeEdge(start, end, weight);
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
class TaskDCP : public onnx_mlir::LabeledListNode<TaskDCP> {
|
||||
onnx_mlir::spatial::SpatWeightedCompute spatWeightedCompute;
|
||||
onnx_mlir::spatial::SpatCompute spatCompute;
|
||||
Time aest;
|
||||
Time alst;
|
||||
std::optional<CPU> scheduledCpu;
|
||||
@@ -38,22 +38,22 @@ public:
|
||||
std::vector<Edge> parents;
|
||||
std::vector<Edge> children;
|
||||
TaskDCP() = default;
|
||||
TaskDCP(onnx_mlir::spatial::SpatWeightedCompute spatWeightedCompute)
|
||||
TaskDCP(onnx_mlir::spatial::SpatCompute spatCompute)
|
||||
: onnx_mlir::LabeledListNode<TaskDCP>(),
|
||||
spatWeightedCompute(spatWeightedCompute),
|
||||
spatCompute(spatCompute),
|
||||
aest(0),
|
||||
alst(0),
|
||||
scheduledCpu(),
|
||||
weight(getSpatComputeWeight(spatWeightedCompute)),
|
||||
weight(getSpatComputeWeight(spatCompute)),
|
||||
baseWeight(weight),
|
||||
crossbarUsage(getSpatComputeCrossbarUsage(spatWeightedCompute)),
|
||||
crossbarUsage(getSpatComputeCrossbarUsage(spatCompute)),
|
||||
syntheticId(-1),
|
||||
parents(),
|
||||
children() {}
|
||||
|
||||
TaskDCP(int64_t id, Weight weight, CrossbarUsage crossbarUsage = 0)
|
||||
: onnx_mlir::LabeledListNode<TaskDCP>(),
|
||||
spatWeightedCompute(),
|
||||
spatCompute(),
|
||||
aest(0),
|
||||
alst(0),
|
||||
scheduledCpu(),
|
||||
@@ -90,14 +90,14 @@ public:
|
||||
void setAlst(Time value) { alst = value; }
|
||||
bool hasDescendant(TaskDCP* child);
|
||||
int64_t Id() const {
|
||||
if (spatWeightedCompute)
|
||||
return reinterpret_cast<int64_t>(spatWeightedCompute.getAsOpaquePointer());
|
||||
if (spatCompute)
|
||||
return reinterpret_cast<int64_t>(spatCompute.getAsOpaquePointer());
|
||||
return syntheticId;
|
||||
}
|
||||
|
||||
bool isCriticalPath() const { return alst == aest; }
|
||||
bool isScheduled() const { return scheduledCpu.has_value(); }
|
||||
onnx_mlir::spatial::SpatWeightedCompute getSpatWeightedCompute() const { return spatWeightedCompute; }
|
||||
onnx_mlir::spatial::SpatCompute getSpatCompute() const { return spatCompute; }
|
||||
|
||||
void setFlag(long long val) { flag = val; }
|
||||
long long getFlag() const { return flag; }
|
||||
|
||||
@@ -92,18 +92,18 @@ inline T subtractOrZero(T lhs, T rhs) {
|
||||
|
||||
inline Time slackOrZero(Time earliestStart, Time latestStart) { return subtractOrZero(latestStart, earliestStart); }
|
||||
|
||||
inline Weight getSpatComputeWeight(onnx_mlir::spatial::SpatWeightedCompute spatWeightedCompute) {
|
||||
inline Weight getSpatComputeWeight(onnx_mlir::spatial::SpatCompute spatCompute) {
|
||||
constexpr Weight kOperationWeight = 100;
|
||||
Weight numOperations = 0;
|
||||
for (auto& block : spatWeightedCompute.getBody())
|
||||
for (auto& block : spatCompute.getBody())
|
||||
for ([[maybe_unused]] auto& op : block)
|
||||
numOperations = checkedAdd(numOperations, static_cast<Weight>(1));
|
||||
return checkedMultiply(numOperations, kOperationWeight);
|
||||
}
|
||||
|
||||
inline CrossbarUsage getSpatComputeCrossbarUsage(onnx_mlir::spatial::SpatWeightedCompute spatWeightedCompute) {
|
||||
inline CrossbarUsage getSpatComputeCrossbarUsage(onnx_mlir::spatial::SpatCompute spatCompute) {
|
||||
CrossbarUsage crossbarUsage = 0;
|
||||
for (auto& region : spatWeightedCompute.getBody())
|
||||
for (auto& region : spatCompute.getBody())
|
||||
for (auto& inst : region)
|
||||
if (llvm::isa<onnx_mlir::spatial::SpatWeightedVMMOp>(inst))
|
||||
crossbarUsage = checkedAdd(crossbarUsage, static_cast<CrossbarUsage>(1));
|
||||
|
||||
@@ -24,30 +24,29 @@ using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
using SpatWeightedCompute = spatial::SpatWeightedCompute;
|
||||
using SpatCompute = spatial::SpatCompute;
|
||||
|
||||
struct ComputeValueResults {
|
||||
// Value yielded by the yieldOp
|
||||
Value innerValue;
|
||||
SmallVector<Value> innerValues;
|
||||
|
||||
Value get(size_t resultIndex) const {
|
||||
assert(resultIndex < innerValues.size() && "compute result index out of range");
|
||||
return innerValues[resultIndex];
|
||||
}
|
||||
};
|
||||
|
||||
class LazyInsertComputeResult {
|
||||
using InsertPoint = mlir::IRRewriter::InsertPoint;
|
||||
ComputeValueResults computeResults;
|
||||
Value channelValue;
|
||||
bool onlyChannel;
|
||||
std::function<void(InsertPoint insertPoint)> channelSendInserter;
|
||||
InsertPoint sendInsertPoint;
|
||||
std::function<std::pair<Value, std::function<void(InsertPoint)>>()> channelNewInserter;
|
||||
std::function<std::pair<Value, std::function<void(InsertPoint)>>(size_t)> channelNewInserter;
|
||||
|
||||
public:
|
||||
LazyInsertComputeResult(ComputeValueResults computeValueResults,
|
||||
std::function<std::pair<Value, std::function<void(InsertPoint)>>()> channelNewInserter,
|
||||
std::function<std::pair<Value, std::function<void(InsertPoint)>>(size_t)> channelNewInserter,
|
||||
bool isOnlyChannel)
|
||||
: computeResults(computeValueResults),
|
||||
onlyChannel(isOnlyChannel),
|
||||
channelSendInserter(nullptr),
|
||||
sendInsertPoint({}),
|
||||
channelNewInserter(channelNewInserter) {}
|
||||
|
||||
struct ChannelOrLocalOp {
|
||||
@@ -57,12 +56,12 @@ public:
|
||||
|
||||
bool onlyChanneled() const { return onlyChannel; }
|
||||
|
||||
ChannelOrLocalOp getAsChannelValueAndInsertSender(SpatWeightedCompute currentCompute) {
|
||||
ChannelOrLocalOp getAsChannelValueAndInsertSender(SpatCompute currentCompute, size_t resultIndex) {
|
||||
Value innerValue = computeResults.get(resultIndex);
|
||||
|
||||
auto [newChannelValue, senderInserter] = channelNewInserter();
|
||||
channelValue = newChannelValue;
|
||||
channelSendInserter = senderInserter;
|
||||
auto* block = computeResults.innerValue.getParentBlock();
|
||||
auto [channelValue, channelSendInserter] = channelNewInserter(resultIndex);
|
||||
InsertPoint sendInsertPoint;
|
||||
auto* block = innerValue.getParentBlock();
|
||||
if (!block->empty() && isa<spatial::SpatYieldOp>(block->back()))
|
||||
sendInsertPoint = InsertPoint(block, --block->end());
|
||||
else
|
||||
@@ -70,28 +69,30 @@ public:
|
||||
if (currentCompute) {
|
||||
for (auto& block : currentCompute.getBody())
|
||||
if (&block == sendInsertPoint.getBlock())
|
||||
return {computeResults.innerValue, false};
|
||||
return {innerValue, false};
|
||||
}
|
||||
channelSendInserter(sendInsertPoint);
|
||||
return {channelValue, true};
|
||||
}
|
||||
|
||||
ChannelOrLocalOp getAsChannelValueAndInsertSender() { return getAsChannelValueAndInsertSender({}); }
|
||||
ChannelOrLocalOp getAsChannelValueAndInsertSender(size_t resultIndex) {
|
||||
return getAsChannelValueAndInsertSender({}, resultIndex);
|
||||
}
|
||||
};
|
||||
|
||||
struct MergeComputeNodesPass : PassWrapper<MergeComputeNodesPass, OperationPass<func::FuncOp>> {
|
||||
|
||||
private:
|
||||
DenseMap<SpatWeightedCompute, LazyInsertComputeResult> newComputeNodeResults;
|
||||
DenseMap<SpatWeightedCompute, SpatWeightedCompute> oldToNewComputeMap;
|
||||
DenseMap<int64_t, SpatWeightedCompute> cpuToNewComputeMap;
|
||||
DenseMap<SpatCompute, LazyInsertComputeResult> newComputeNodeResults;
|
||||
DenseMap<SpatCompute, SpatCompute> oldToNewComputeMap;
|
||||
DenseMap<int64_t, SpatCompute> cpuToNewComputeMap;
|
||||
|
||||
public:
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MergeComputeNodesPass)
|
||||
|
||||
StringRef getArgument() const override { return "pim-merge-compute-nodes-pass"; }
|
||||
StringRef getDescription() const override {
|
||||
return "Merge Spatial-Weighted-Compute-Nodes in order to reduce the total "
|
||||
return "Merge Spatial-Compute-Nodes in order to reduce the total "
|
||||
"execution time";
|
||||
}
|
||||
|
||||
@@ -105,22 +106,22 @@ public:
|
||||
for (auto currentComputeNode : analysisResult.dominanceOrderCompute) {
|
||||
size_t cpu = analysisResult.computeToCpuMap.at(currentComputeNode);
|
||||
if (!cpuToNewComputeMap.contains(cpu)) {
|
||||
ValueTypeRange<ResultRange> newWeightedComputeType = cpuToLastComputeMap.at(cpu).getResultTypes();
|
||||
auto [newWeightedCompute, computeValueResult] = createNewComputeNode(
|
||||
currentComputeNode, newWeightedComputeType, lastComputeOfCpu.contains(currentComputeNode));
|
||||
cpuToNewComputeMap[cpu] = newWeightedCompute;
|
||||
ValueTypeRange<ResultRange> newComputeType = cpuToLastComputeMap.at(cpu).getResultTypes();
|
||||
auto [newCompute, computeValueResult] = createNewComputeNode(
|
||||
currentComputeNode, newComputeType, lastComputeOfCpu.contains(currentComputeNode));
|
||||
cpuToNewComputeMap[cpu] = newCompute;
|
||||
newComputeNodeResults.insert(
|
||||
std::make_pair(currentComputeNode,
|
||||
createLazyComputeResult(
|
||||
newWeightedCompute, computeValueResult, lastComputeOfCpu.contains(currentComputeNode))));
|
||||
newCompute, computeValueResult, lastComputeOfCpu.contains(currentComputeNode))));
|
||||
}
|
||||
else {
|
||||
auto [newWeightedCompute, computeValueResult] = mergeIntoComputeNode(
|
||||
auto [newCompute, computeValueResult] = mergeIntoComputeNode(
|
||||
cpuToNewComputeMap[cpu], currentComputeNode, lastComputeOfCpu.contains(currentComputeNode));
|
||||
newComputeNodeResults.insert(
|
||||
std::make_pair(currentComputeNode,
|
||||
createLazyComputeResult(
|
||||
newWeightedCompute, computeValueResult, lastComputeOfCpu.contains(currentComputeNode))));
|
||||
newCompute, computeValueResult, lastComputeOfCpu.contains(currentComputeNode))));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -134,8 +135,8 @@ public:
|
||||
}
|
||||
|
||||
private:
|
||||
std::pair<SpatWeightedCompute, ComputeValueResults> createNewComputeNode(
|
||||
SpatWeightedCompute oldWeightedCompute, ValueTypeRange<ResultRange> newWeightedComputeType, bool lastCompute) {
|
||||
std::pair<SpatCompute, ComputeValueResults> createNewComputeNode(
|
||||
SpatCompute oldCompute, ValueTypeRange<ResultRange> newComputeType, bool lastCompute) {
|
||||
func::FuncOp func = getOperation();
|
||||
auto loc = func.getLoc();
|
||||
IRRewriter rewriter(&getContext());
|
||||
@@ -148,50 +149,53 @@ private:
|
||||
llvm::SmallVector<Type> newBBOperandType;
|
||||
llvm::SmallVector<Location> newBBLocations;
|
||||
|
||||
for (auto arg : oldWeightedCompute.getWeights())
|
||||
for (auto arg : oldCompute.getWeights())
|
||||
newComputeOperand.push_back(arg);
|
||||
|
||||
for (auto arg : oldWeightedCompute.getInputs())
|
||||
if (!llvm::isa_and_present<SpatWeightedCompute>(arg.getDefiningOp())) {
|
||||
for (auto arg : oldCompute.getInputs())
|
||||
if (!llvm::isa_and_present<SpatCompute>(arg.getDefiningOp())) {
|
||||
newComputeOperand.push_back(arg);
|
||||
newBBOperandType.push_back(arg.getType());
|
||||
newBBLocations.push_back(loc);
|
||||
}
|
||||
|
||||
auto newWeightedCompute = SpatWeightedCompute::create(rewriter, loc, newWeightedComputeType, newComputeOperand);
|
||||
auto newCompute = SpatCompute::create(rewriter, loc, newComputeType, newComputeOperand);
|
||||
|
||||
rewriter.createBlock(
|
||||
&newWeightedCompute.getBody(), newWeightedCompute.getBody().end(), newBBOperandType, newBBLocations);
|
||||
newWeightedCompute.getProperties().setOperandSegmentSizes(
|
||||
{(int) oldWeightedCompute.getWeights().size(), (int) newBBOperandType.size()});
|
||||
&newCompute.getBody(), newCompute.getBody().end(), newBBOperandType, newBBLocations);
|
||||
newCompute.getProperties().setOperandSegmentSizes(
|
||||
{(int) oldCompute.getWeights().size(), (int) newBBOperandType.size()});
|
||||
|
||||
auto& newBB = newWeightedCompute.getBody().front();
|
||||
auto& oldBB = oldWeightedCompute.getBody().front();
|
||||
auto& newBB = newCompute.getBody().front();
|
||||
auto& oldBB = oldCompute.getBody().front();
|
||||
rewriter.setInsertionPointToEnd(&newBB);
|
||||
|
||||
int indexNew = 0;
|
||||
size_t indexOld = oldWeightedCompute.getWeights().size();
|
||||
size_t indexOldStart = oldWeightedCompute.getWeights().size();
|
||||
for (; indexOld < oldWeightedCompute.getNumOperands(); ++indexOld) {
|
||||
if (!llvm::isa_and_present<SpatWeightedCompute>(oldWeightedCompute.getOperand(indexOld).getDefiningOp())) {
|
||||
size_t indexOld = oldCompute.getWeights().size();
|
||||
size_t indexOldStart = oldCompute.getWeights().size();
|
||||
for (; indexOld < oldCompute.getNumOperands(); ++indexOld) {
|
||||
if (!llvm::isa_and_present<SpatCompute>(oldCompute.getOperand(indexOld).getDefiningOp())) {
|
||||
mapper.map(oldBB.getArgument(indexOld - indexOldStart), newBB.getArgument(indexNew++));
|
||||
}
|
||||
else {
|
||||
auto argWeightCompute =
|
||||
llvm::dyn_cast_if_present<SpatWeightedCompute>(oldWeightedCompute.getOperand(indexOld).getDefiningOp());
|
||||
llvm::dyn_cast_if_present<SpatCompute>(oldCompute.getOperand(indexOld).getDefiningOp());
|
||||
auto argResultIndex = cast<OpResult>(oldCompute.getOperand(indexOld)).getResultNumber();
|
||||
|
||||
LazyInsertComputeResult& lazyArgWeight = newComputeNodeResults.at(argWeightCompute);
|
||||
auto [channelVal, isChannel] = lazyArgWeight.getAsChannelValueAndInsertSender();
|
||||
auto [channelVal, isChannel] = lazyArgWeight.getAsChannelValueAndInsertSender(argResultIndex);
|
||||
assert(isChannel == true);
|
||||
spatial::SpatChannelReceiveOp receiveOp =
|
||||
spatial::SpatChannelReceiveOp::create(rewriter, loc, argWeightCompute.getType(0), channelVal);
|
||||
spatial::SpatChannelReceiveOp receiveOp = spatial::SpatChannelReceiveOp::create(
|
||||
rewriter, loc, oldCompute.getOperand(indexOld).getType(), channelVal);
|
||||
mapper.map(oldBB.getArgument(indexOld - indexOldStart), receiveOp);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& op : oldWeightedCompute.getOps()) {
|
||||
for (auto& op : oldCompute.getOps()) {
|
||||
if (auto yield = dyn_cast<spatial::SpatYieldOp>(&op)) {
|
||||
computeValueResults.innerValue = mapper.lookup(yield.getOperand(0));
|
||||
computeValueResults.innerValues.reserve(yield.getNumOperands());
|
||||
for (Value yieldOperand : yield.getOperands())
|
||||
computeValueResults.innerValues.push_back(mapper.lookup(yieldOperand));
|
||||
if (lastCompute)
|
||||
rewriter.clone(op, mapper);
|
||||
}
|
||||
@@ -199,16 +203,18 @@ private:
|
||||
rewriter.clone(op, mapper);
|
||||
}
|
||||
|
||||
for (auto& use : llvm::make_early_inc_range(oldWeightedCompute->getUses()))
|
||||
if (isa<func::ReturnOp>(use.getOwner()))
|
||||
use.assign(newWeightedCompute.getResult(0));
|
||||
for (auto& use : llvm::make_early_inc_range(oldCompute->getUses()))
|
||||
if (isa<func::ReturnOp>(use.getOwner())) {
|
||||
auto resultIndex = cast<OpResult>(use.get()).getResultNumber();
|
||||
use.assign(newCompute.getResult(resultIndex));
|
||||
}
|
||||
|
||||
oldToNewComputeMap.insert({oldWeightedCompute, newWeightedCompute});
|
||||
return {cast<SpatWeightedCompute>(newWeightedCompute), computeValueResults};
|
||||
oldToNewComputeMap.insert({oldCompute, newCompute});
|
||||
return {cast<SpatCompute>(newCompute), computeValueResults};
|
||||
}
|
||||
|
||||
std::pair<SpatWeightedCompute, ComputeValueResults>
|
||||
mergeIntoComputeNode(SpatWeightedCompute toCompute, SpatWeightedCompute fromCompute, bool lastCompute) {
|
||||
std::pair<SpatCompute, ComputeValueResults>
|
||||
mergeIntoComputeNode(SpatCompute toCompute, SpatCompute fromCompute, bool lastCompute) {
|
||||
func::FuncOp func = getOperation();
|
||||
auto loc = func.getLoc();
|
||||
IRRewriter rewriter(&getContext());
|
||||
@@ -239,14 +245,15 @@ private:
|
||||
// Insert receiveOp
|
||||
rewriter.setInsertionPointToEnd(&toBB);
|
||||
for (auto [bbIndex, arg] : llvm::enumerate(fromCompute.getInputs())) {
|
||||
if (auto argWeightCompute = llvm::dyn_cast_if_present<SpatWeightedCompute>(arg.getDefiningOp())) {
|
||||
if (auto argWeightCompute = llvm::dyn_cast_if_present<SpatCompute>(arg.getDefiningOp())) {
|
||||
LazyInsertComputeResult& lazyArgWeight = newComputeNodeResults.at(argWeightCompute);
|
||||
auto argResultIndex = cast<OpResult>(arg).getResultNumber();
|
||||
|
||||
LazyInsertComputeResult::ChannelOrLocalOp channelOrLocal =
|
||||
lazyArgWeight.getAsChannelValueAndInsertSender(toCompute);
|
||||
lazyArgWeight.getAsChannelValueAndInsertSender(toCompute, argResultIndex);
|
||||
if (channelOrLocal.isChannel) {
|
||||
spatial::SpatChannelReceiveOp receiveOp =
|
||||
spatial::SpatChannelReceiveOp::create(rewriter, loc, argWeightCompute.getType(0), channelOrLocal.data);
|
||||
spatial::SpatChannelReceiveOp::create(rewriter, loc, arg.getType(), channelOrLocal.data);
|
||||
mapper.map(fromBB.getArgument(bbIndex), receiveOp.getResult());
|
||||
}
|
||||
else {
|
||||
@@ -286,7 +293,9 @@ private:
|
||||
};
|
||||
for (auto& op : fromCompute.getOps()) {
|
||||
if (auto yield = dyn_cast<spatial::SpatYieldOp>(&op)) {
|
||||
computeValueResults.innerValue = mapper.lookup(yield.getOperand(0));
|
||||
computeValueResults.innerValues.reserve(yield.getNumOperands());
|
||||
for (Value yieldOperand : yield.getOperands())
|
||||
computeValueResults.innerValues.push_back(mapper.lookup(yieldOperand));
|
||||
if (lastCompute)
|
||||
rewriter.clone(op, mapper);
|
||||
}
|
||||
@@ -299,33 +308,36 @@ private:
|
||||
}
|
||||
}
|
||||
|
||||
for (auto users : fromCompute->getUsers())
|
||||
if (auto funcRet = dyn_cast<func::ReturnOp>(users))
|
||||
funcRet.setOperand(0, toCompute.getResult(0));
|
||||
for (auto& use : llvm::make_early_inc_range(fromCompute->getUses()))
|
||||
if (isa<func::ReturnOp>(use.getOwner())) {
|
||||
auto resultIndex = cast<OpResult>(use.get()).getResultNumber();
|
||||
use.assign(toCompute.getResult(resultIndex));
|
||||
}
|
||||
|
||||
oldToNewComputeMap.insert({fromCompute, toCompute});
|
||||
return {cast<SpatWeightedCompute>(toCompute), computeValueResults};
|
||||
return {cast<SpatCompute>(toCompute), computeValueResults};
|
||||
}
|
||||
|
||||
LazyInsertComputeResult createLazyComputeResult(SpatWeightedCompute weightedCompute,
|
||||
LazyInsertComputeResult createLazyComputeResult(SpatCompute compute,
|
||||
ComputeValueResults computeValueResults,
|
||||
bool lastCompute) {
|
||||
func::FuncOp funcOp = cast<func::FuncOp>(weightedCompute->getParentOp());
|
||||
func::FuncOp funcOp = cast<func::FuncOp>(compute->getParentOp());
|
||||
auto* context = &getContext();
|
||||
auto loc = funcOp.getLoc();
|
||||
IRRewriter rewriter(context);
|
||||
|
||||
rewriter.setInsertionPointToStart(&funcOp.front());
|
||||
auto savedChannelInsertPoint = rewriter.saveInsertionPoint();
|
||||
auto insertNew = [savedChannelInsertPoint, context, loc, computeValueResults]() {
|
||||
auto insertNew = [savedChannelInsertPoint, context, loc, computeValueResults](size_t resultIndex) {
|
||||
IRRewriter rewriter(context);
|
||||
rewriter.restoreInsertionPoint(savedChannelInsertPoint);
|
||||
auto channelOp = spatial::SpatChannelNewOp::create(rewriter, loc, spatial::SpatChannelType::get(context));
|
||||
auto channelVal = channelOp.getResult();
|
||||
auto insertVal = [&context, loc, computeValueResults, channelVal](mlir::IRRewriter::InsertPoint sendInsertPoint) {
|
||||
auto insertVal =
|
||||
[&context, loc, computeValueResults, channelVal, resultIndex](mlir::IRRewriter::InsertPoint sendInsertPoint) {
|
||||
IRRewriter rewriter(context);
|
||||
rewriter.restoreInsertionPoint(sendInsertPoint);
|
||||
auto spatSend = spatial::SpatChannelSendOp::create(rewriter, loc, channelVal, computeValueResults.innerValue);
|
||||
auto spatSend = spatial::SpatChannelSendOp::create(rewriter, loc, channelVal, computeValueResults.get(resultIndex));
|
||||
return spatSend;
|
||||
};
|
||||
std::pair<Value, std::function<void(mlir::IRRewriter::InsertPoint)>> ret {channelVal, insertVal};
|
||||
|
||||
@@ -31,7 +31,7 @@ struct CountInstructionPass : public PassWrapper<CountInstructionPass, Operation
|
||||
unsigned totalInstructionCount = 0;
|
||||
|
||||
unsigned computeId = 0;
|
||||
for (auto computeOp : func.getOps<spatial::SpatWeightedCompute>()) {
|
||||
for (auto computeOp : func.getOps<spatial::SpatCompute>()) {
|
||||
unsigned instructionCount = 0;
|
||||
instructionCount += computeOp.getBody().front().getOperations().size();
|
||||
llvm::outs() << "Compute " << computeId << ": " << instructionCount << " instructions\n";
|
||||
|
||||
@@ -26,6 +26,10 @@ STAGE_COUNT = len(STAGE_TITLES)
|
||||
GENERATED_DIR_NAMES = ("inputs", "outputs", "raptor", "runner", "simulation")
|
||||
|
||||
|
||||
def sanitize_output_name(name):
|
||||
return "".join(ch if ch.isalnum() or ch in "_.-" else "_" for ch in name[:255])
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationResult:
|
||||
passed: bool
|
||||
@@ -205,7 +209,7 @@ def build_dump_ranges(config_path, outputs_descriptor):
|
||||
|
||||
def run_pim_simulator(simulator_dir, pim_dir, output_bin_path, dump_ranges, reporter=None):
|
||||
run_command(
|
||||
["cargo", "run", "--release", "--package", "pim-simulator", "--bin", "pim-simulator", "--",
|
||||
["cargo", "run", "--no-default-features", "--release", "--package", "pim-simulator", "--bin", "pim-simulator", "--",
|
||||
"-f", str(pim_dir), "-o", str(output_bin_path), "-d", dump_ranges],
|
||||
cwd=simulator_dir,
|
||||
reporter=reporter,
|
||||
@@ -229,7 +233,7 @@ def validate_outputs(sim_arrays, runner_out_dir, outputs_descriptor, threshold=1
|
||||
all_passed = True
|
||||
rows = []
|
||||
for sim_array, (oi, name, _, shape) in zip(sim_arrays, outputs_descriptor):
|
||||
csv_name = f"output{oi}_{name}.csv"
|
||||
csv_name = f"output{oi}_{sanitize_output_name(name)}.csv"
|
||||
runner_array = np.loadtxt(runner_out_dir / csv_name, delimiter=',', dtype=np.float32).reshape(shape)
|
||||
max_diff = float(np.max(np.abs(sim_array.astype(np.float64) - runner_array.astype(np.float64))))
|
||||
passed = max_diff <= threshold
|
||||
|
||||
Reference in New Issue
Block a user