31 Commits

Author SHA1 Message Date
ilgeco 6aaf1c0870 Merge branch 'refactorone' of chef.heaplab.deib.polimi.it:nnicolosi/Raptor into refactorone
Validate Operations / validate-operations (push) Waiting to run
2026-05-21 14:44:19 +02:00
ilgeco fe35b3ed43 Equivalent Class but broken 2026-05-21 14:43:59 +02:00
NiccoloN 90a9339686 better cmake to keep IDEs analyses happy
Validate Operations / validate-operations (push) Waiting to run
2026-05-21 14:13:54 +02:00
NiccoloN a50e77ff38 refactorone
Validate Operations / validate-operations (push) Has been cancelled
2026-05-20 19:06:41 +02:00
NiccoloN f56c4159b5 Merge branch 'main' of chef.heaplab.deib.polimi.it:nnicolosi/Raptor 2026-05-19 15:01:26 +02:00
ilgeco 5637c861b4 Merge branch 'main' of chef.heaplab.deib.polimi.it:nnicolosi/Raptor
Validate Operations / validate-operations (push) Has been cancelled
2026-05-19 15:00:11 +02:00
ilgeco 94157a8404 Very big timeout 2026-05-19 14:53:34 +02:00
ilgeco 68a3521978 Perft topological fix 2026-05-19 14:52:54 +02:00
NiccoloN a103ba328b remove dead logic 2026-05-19 12:23:01 +02:00
NiccoloN e263e05f56 remove dead logic
Validate Operations / validate-operations (push) Has been cancelled
2026-05-18 18:32:40 +02:00
ilgeco 34c29fdec4 Materialize modification
Validate Operations / validate-operations (push) Has been cancelled
2026-05-18 17:22:13 +02:00
ilgeco aa088e2ba5 Verify fix
Validate Operations / validate-operations (push) Has been cancelled
2026-05-18 17:20:40 +02:00
NiccoloN 2836e759ab remove useless file
Validate Operations / validate-operations (push) Has been cancelled
2026-05-18 14:51:03 +02:00
NiccoloN 8071ebab0b faster refactored merge pass
Validate Operations / validate-operations (push) Has been cancelled
2026-05-18 14:50:19 +02:00
NiccoloN f1602c0550 add peft scheduling
Validate Operations / validate-operations (push) Has been cancelled
better deadlock report by pim simulator
2026-05-18 12:09:27 +02:00
NiccoloN de0a2f4561 remove useless guard in gemm lowering
Validate Operations / validate-operations (push) Has been cancelled
2026-05-15 18:22:13 +02:00
NiccoloN 1c4a5bde76 compact softmax op lowering
Validate Operations / validate-operations (push) Has been cancelled
2026-05-15 18:14:59 +02:00
NiccoloN 78242e2887 compact resize op lowering
Validate Operations / validate-operations (push) Has been cancelled
2026-05-15 17:36:12 +02:00
NiccoloN fe244d5aa1 new ops tests for matmul, grouped conv, concat and reshape
Validate Operations / validate-operations (push) Has been cancelled
related fixes
2026-05-14 15:54:06 +02:00
NiccoloN d09e76c8f9 fix matmul rewriting/lowering
Validate Operations / validate-operations (push) Has been cancelled
fix reshape lowering
add support for grouped-convolution lowering
quieter verifier with capped error messages
2026-05-14 14:09:30 +02:00
NiccoloN c5e608fa5b replace greedy pattern rewrites with partial conversions
Validate Operations / validate-operations (push) Has been cancelled
better failure messages
2026-05-14 11:48:16 +02:00
ilgeco 43f3ccdd21 new yolo nodes with 100% more statics
Validate Operations / validate-operations (push) Has been cancelled
2026-05-14 10:47:31 +02:00
NiccoloN 8d95c604a6 automatic code formatting
Validate Operations / validate-operations (push) Has been cancelled
2026-05-13 21:51:19 +02:00
NiccoloN 55eda487dc use seed in validate.py for deterministic tests 2026-05-13 21:49:36 +02:00
NiccoloN 061139aefb fix wrong send/receive reordering in post dcp merge instructions compaction 2026-05-13 21:48:49 +02:00
NiccoloN ea61540e08 fix failing validations after last commit
Validate Operations / validate-operations (push) Has been cancelled
2026-05-13 17:46:19 +02:00
NiccoloN 324178cba8 fix instructions explosion in pim host constant folding pass
Validate Operations / validate-operations (push) Has been cancelled
2026-05-13 17:31:05 +02:00
NiccoloN e71ba07cd5 fix pim-simulator stale tests
Validate Operations / validate-operations (push) Has been cancelled
2026-05-13 16:59:53 +02:00
NiccoloN 64a3805619 fix pim-simulator stale tests
Validate Operations / validate-operations (push) Has been cancelled
2026-05-13 16:59:43 +02:00
NiccoloN 9f9e7c0892 Merge remote-tracking branch 'origin/main'
Validate Operations / validate-operations (push) Has been cancelled
2026-05-13 16:38:33 +02:00
NiccoloN 03eab42971 remove host core generation
strip config.json emitted by raptor
add actual pimsim-nn configs in validation pimsim-configs
2026-05-13 16:31:01 +02:00
175 changed files with 8552 additions and 3797 deletions
+92 -24
View File
@@ -3,31 +3,99 @@ cmake_minimum_required(VERSION 3.20.0)
project(raptor)
# Add symlink to PIM as accelerator in onnx-mlir
function(raptor_ensure_symlink link_path target_path)
get_filename_component(link_parent "${link_path}" DIRECTORY)
# Materialize a CMake shim directory
function(raptor_write_external_cmake_shim shim_dir external_source_dir description)
get_filename_component(real_external_source_dir "${external_source_dir}" REALPATH)
file(RELATIVE_PATH relative_external_source_dir "${shim_dir}" "${real_external_source_dir}")
if(NOT EXISTS "${link_parent}")
message(FATAL_ERROR "Directory not found: ${link_parent}")
endif()
if(NOT EXISTS "${link_path}")
message(STATUS "Creating symlink ${link_path} -> ${target_path}")
file(CREATE_LINK
"${target_path}"
"${link_path}"
SYMBOLIC
if (NOT EXISTS "${real_external_source_dir}/CMakeLists.txt")
message(FATAL_ERROR
"External CMake source directory not found or missing CMakeLists.txt:\n"
" ${real_external_source_dir}"
)
endif()
endif ()
if (IS_SYMLINK "${shim_dir}")
message(STATUS "Removing old full-directory symlink: ${shim_dir}")
file(REMOVE "${shim_dir}")
endif ()
if (EXISTS "${shim_dir}" AND NOT IS_DIRECTORY "${shim_dir}")
message(FATAL_ERROR "Expected directory or absent path, got file: ${shim_dir}")
endif ()
file(MAKE_DIRECTORY "${shim_dir}")
set(shim_file "${shim_dir}/CMakeLists.txt")
set(shim_contents
"get_filename_component(raptor_external_source_dir
\"\${CMAKE_CURRENT_LIST_DIR}/${relative_external_source_dir}\"
REALPATH
)
add_subdirectory(
\"\${raptor_external_source_dir}\"
\"\${CMAKE_CURRENT_BINARY_DIR}/raptor-external\"
)
if (DEFINED PIM_ENABLED)
set(PIM_ENABLED \"\${PIM_ENABLED}\" PARENT_SCOPE)
endif ()
"
)
if (EXISTS "${shim_file}")
file(READ "${shim_file}" old_contents)
else ()
set(old_contents "")
endif ()
if (NOT old_contents STREQUAL shim_contents)
file(WRITE "${shim_file}" "${shim_contents}")
message(STATUS "Wrote CMake shim for ${description}: ${shim_file}")
else ()
message(STATUS "CMake shim already up to date for ${description}")
endif ()
# Mirror the external tree's first-level entries into the shim directory
# so legacy includes like src/Accelerators/PIM/Compiler/... keep working.
file(GLOB children RELATIVE "${real_external_source_dir}" "${real_external_source_dir}/*")
foreach (child IN LISTS children)
if (child STREQUAL "CMakeLists.txt")
continue()
endif ()
set(real_child "${real_external_source_dir}/${child}")
set(shim_child "${shim_dir}/${child}")
if (IS_SYMLINK "${shim_child}")
file(READ_SYMLINK "${shim_child}" existing_link_target)
if (existing_link_target STREQUAL real_child)
continue()
endif ()
file(REMOVE_RECURSE "${shim_child}")
elseif (EXISTS "${shim_child}")
# Do not delete real files/directories. This protects the generated shim.
continue()
endif ()
file(CREATE_LINK
"${real_child}"
"${shim_child}"
SYMBOLIC
)
endforeach ()
endfunction()
raptor_ensure_symlink(
"${CMAKE_CURRENT_SOURCE_DIR}/onnx-mlir/src/Accelerators/PIM"
"${CMAKE_CURRENT_SOURCE_DIR}/src/PIM"
raptor_write_external_cmake_shim(
"${CMAKE_CURRENT_SOURCE_DIR}/onnx-mlir/src/Accelerators/PIM"
"${CMAKE_CURRENT_SOURCE_DIR}/src/PIM"
"PIM accelerator"
)
raptor_ensure_symlink(
"${CMAKE_CURRENT_SOURCE_DIR}/onnx-mlir/test/accelerators/PIM"
"${CMAKE_CURRENT_SOURCE_DIR}/test/PIM"
raptor_write_external_cmake_shim(
"${CMAKE_CURRENT_SOURCE_DIR}/onnx-mlir/test/accelerators/PIM"
"${CMAKE_CURRENT_SOURCE_DIR}/test/PIM"
"PIM accelerator tests"
)
# Patch onnx-mlir sources for PIM accelerator support.
@@ -38,21 +106,21 @@ function(raptor_apply_patch file_path anchor replacement description)
# Already applied replacement text is present
string(FIND "${contents}" "${replacement}" already_applied_pos)
if(NOT already_applied_pos EQUAL -1)
if (NOT already_applied_pos EQUAL -1)
message(STATUS "Patch already applied: ${description}")
return()
endif()
endif ()
# Anchor must exist for the patch to be applicable
string(FIND "${contents}" "${anchor}" anchor_pos)
if(anchor_pos EQUAL -1)
if (anchor_pos EQUAL -1)
message(FATAL_ERROR
"Patch anchor not found onnx-mlir may have changed.\n"
" Patch : ${description}\n"
" File : ${file_path}\n"
" Anchor: ${anchor}"
)
endif()
endif ()
string(REPLACE "${anchor}" "${replacement}" patched "${contents}")
file(WRITE "${file_path}" "${patched}")
+45 -2
View File
@@ -114,7 +114,9 @@ Pass these on the `onnx-mlir` command line when compiling for PIM:
run only the codegen tail.
- `--crossbar-size=<N>` / `--crossbar-count=<N>` — crossbar dimensions and
per-core count.
- `--core-count=<N>` — number of cores (`-1` picks the minimum).
- `--core-count=<N>` — number of cores. Required for PIM compilation.
- `--pim-merge-scheduler={peft,dcp}` — scheduler used by the Spatial
merge-compute-nodes pass (default: `peft`).
- `--dcp-critical-window-size=<N>` — DCP coarsening window (0 = legacy).
- `--use-experimental-conv-impl` — alternative convolution lowering.
- `--ignore-concat-error` — soft-fail corner case in `ConcatOp`.
@@ -129,7 +131,8 @@ Per-operation validation (from `validation/`):
```
validate.py \
--raptor-path ../cmake-build-release/Release/bin/onnx-mlir \
--onnx-include-dir ../onnx-mlir/include
--onnx-include-dir ../onnx-mlir/include \
--core-count 1000
```
End-to-end network validation (example: first 4 layers of YOLOv11n):
@@ -142,6 +145,46 @@ validate.py \
--crossbar-size 2048 --crossbar-count 256 --core-count 1000
```
Each validation run writes debugging artifacts into the benchmark's workspace
directory (for example `validation/operations/gemm/small/`):
- `inputs/` — generated input CSVs used for the run.
- `outputs/` — reference outputs dumped by the native ONNX runner.
- `raptor/` — compiler artifacts:
`*.onnx.mlir`, `dialects/spatial0.mlir`, `dialects/spatial1_dcp_merged.mlir`,
`dialects/pim0.mlir`, `dialects/pim1_buff.mlir`, `dialects/pim2_coalesced.mlir`,
`dialects/pim3_folded.mlir`, `dialects/pim4_materialized.mlir`,
`pim/config.json`, `pim/core_*.pim`, `pim/memory.bin`, and reports under
`raptor/reports/` such as `dcp_merge_report.txt`,
`memory_report.txt`, and `static_memory_coalescing_report.txt`.
- `runner/` — generated reference runner source, build tree, and shared library.
- `simulation/out.bin` — raw simulator output dump used for output comparison.
That means you usually do not need to rerun standalone `--EmitSpatial` or
`--EmitPim` commands while debugging validation failures: the per-pass dialect
dumps are already available under `raptor/dialects/`.
The validator does not currently expose a simulator tracing flag, but once a
validation has produced `raptor/pim/` you can rerun the simulator manually with
tracing enabled:
```bash
cd backend-simulators/pim/pim-simulator
cargo run --no-default-features --features tracing --release \
--package pim-simulator --bin pim-simulator -- \
-f /path/to/workspace/raptor/pim \
-o /path/to/workspace/simulation/out.bin \
-d <addr0>,<size0>,<addr1>,<size1>,...
```
With `--features tracing`, the simulator writes per-core traces as
`simulation/TraceCore0`, `simulation/TraceCore1`, ... next to `simulation/out.bin`.
The validator normally computes the `-d` dump ranges from `raptor/pim/config.json`
and the model output shapes. If you need a clean slate before rerunning, use:
```bash
validate.py --clean
```
Available networks under `validation/networks/`: `vgg16`, `yolo11n`.
Available operations under `validation/operations/`: `add`, `conv`, `div`,
`gather`, `gemm`, `gemv`, `mul`, `pool`, `reduce_mean`, `relu`, `resize`,
@@ -67,7 +67,7 @@ fn main() -> Result<()> {
.lock()
.unwrap()
.init(executor.cpu().num_core(), args.output.clone());
executor.execute();
executor.execute()?;
dump_memory(executor, &args)?;
Ok(())
}
@@ -77,7 +77,7 @@ fn map_crossbars_to_cores<'c>(
args: &Args,
global_crossbars: &'c HashMap<String, Crossbar>,
) -> Vec<Vec<&'c Crossbar>> {
let mut res = Vec::new();
let mut res = vec![Vec::new()];
let num_cores = config.get("core_cnt").unwrap().as_i64().unwrap() as i32;
if let Some(folder) = args.folder.as_ref() {
@@ -312,7 +312,7 @@ fn append_record(
29 => {
inst_data_builder
.set_rd_u8(rd)
.set_imm_core(r2_or_imm)
.set_imm_core(r2_or_imm + 1)
.set_imm_len(generic3)
.set_offset_select_value(generic1, generic2);
inst_builder.make_inst(send, inst_data_builder.build());
@@ -320,7 +320,7 @@ fn append_record(
30 => {
inst_data_builder
.set_rd_u8(rd)
.set_imm_core(r2_or_imm)
.set_imm_core(r2_or_imm + 1)
.set_imm_len(generic3)
.set_offset_select_value(generic1, generic2);
inst_builder.make_inst(recv, inst_data_builder.build());
@@ -366,23 +366,19 @@ fn binary_to_instructions(
pub fn binary_to_executor<'a, 'b>(
config: Value,
mut cores: impl Iterator<Item = &'b Vec<u8>>,
cores: impl Iterator<Item = &'b Vec<u8>>,
crossbars: Vec<Vec<&'a Crossbar>>,
) -> Result<Executable<'a>> {
let core_cnt = config
.get("core_cnt")
.context("missing core_cnt in config")?
.as_i64()
.context("core_cnt is not an integer")? as i32
- 1;
.context("core_cnt is not an integer")? as i32;
let cpu = CPU::new(core_cnt, crossbars);
let mut core_insts_builder = CoreInstructionsBuilder::new(core_cnt as usize);
cores.next();
for core_indx in 1..=core_cnt {
let core_bytes = cores
.next()
.unwrap_or_else(|| panic!("cores files less than {}", core_indx));
for (external_core_indx, core_bytes) in cores.enumerate() {
let core_indx = external_core_indx as i32 + 1;
let instructions = binary_to_instructions(core_bytes, core_indx)?;
core_insts_builder.set_core(core_indx, instructions);
}
@@ -396,6 +392,7 @@ mod tests {
HEADER_SIZE, InstructionRecord, MAGIC, RECORD_SIZE, VERSION, binary_to_instructions,
};
use crate::{
functor_to_name,
instruction_set::{InstructionsBuilder, instruction_data::InstructionDataBuilder},
json_to_instruction::json_isa::json_to_instruction,
};
@@ -490,7 +487,10 @@ mod tests {
assert_eq!(json_instructions.len(), binary_instructions.len());
for (json_inst, binary_inst) in json_instructions.iter().zip(binary_instructions.iter()) {
assert_eq!(json_inst.functor_name(), binary_inst.functor_name());
assert_eq!(
functor_to_name(json_inst.functor as usize),
functor_to_name(binary_inst.functor as usize)
);
assert_eq!(json_inst.data, binary_inst.data);
}
}
@@ -567,7 +567,7 @@ fn json_to_send(
let (offset_select, offset_value) = json_to_offset(json.get("offset").unwrap());
inst_data_builder
.set_rd(rd)
.set_imm_core(core)
.set_imm_core(core + 1)
.set_imm_len(size)
.set_offset_select(offset_select)
.set_offset_value(offset_value);
@@ -588,7 +588,7 @@ fn json_to_recv(
let (offset_select, offset_value) = json_to_offset(json.get("offset").unwrap());
inst_data_builder
.set_rd(rd)
.set_imm_core(core)
.set_imm_core(core + 1)
.set_imm_len(size)
.set_offset_select(offset_select)
.set_offset_value(offset_value);
@@ -1,49 +1,34 @@
use core::panic;
use std::io::{Read, Write};
use std::{collections::HashMap, fs::File, io::BufReader};
use serde_json::{Deserializer, Map, Value};
use serde_json::Value;
use std::{fs::File, io::BufReader};
use crate::{
CoreInstructionsBuilder, Executable,
cpu::{
CPU,
crossbar::{self, Crossbar},
},
instruction_set::{
InstructionsBuilder,
instruction_data::{self, InstructionData, InstructionDataBuilder},
},
json_to_instruction::{self, json_isa},
memory_manager::type_traits::TryToUsize,
cpu::{CPU, crossbar::Crossbar},
instruction_set::{InstructionsBuilder, instruction_data::InstructionDataBuilder},
json_to_instruction::json_isa,
};
pub fn json_to_executor<'a, 'b>(
config: Value,
mut cores: &mut Vec<BufReader<File>>,
cores: &'b mut Vec<BufReader<File>>,
crossbars: Vec<Vec<&'a Crossbar>>,
) -> Executable<'a> {
let cell_precision = config.get("cell_precision").unwrap().as_i64().unwrap() as i32;
let core_cnt = config.get("core_cnt").unwrap().as_i64().unwrap() as i32 - 1;
let xbar_count = config.get("xbar_array_count").unwrap().as_i64().unwrap() as i32;
let xbar_size = config.get("xbar_size").unwrap().as_array().unwrap();
let rows_crossbar = xbar_size[0].as_i64().unwrap() as i32;
let column_corssbar = xbar_size[1].as_i64().unwrap() as i32;
let core_cnt = config.get("core_cnt").unwrap().as_i64().unwrap() as i32;
let mut cpu = CPU::new(core_cnt, crossbars);
let cpu = CPU::new(core_cnt, crossbars);
let mut core_insts_builder = CoreInstructionsBuilder::new(core_cnt as usize);
// Note: cores[0] is intentionally empty and discarded
for core_indx in 1..=core_cnt {
for (external_core_indx, json_core_reader) in cores.iter_mut().enumerate() {
let core_indx = external_core_indx as i32 + 1;
let mut insts_builder = InstructionsBuilder::new();
let mut inst_data_builder = InstructionDataBuilder::new();
inst_data_builder.set_core_indx(core_indx).fix_core_indx();
let stream = Deserializer::from_reader(&mut cores[core_indx as usize]).into_iter::<Value>();
for (i, json_inst_result) in stream.enumerate() {
let json_inst = json_inst_result.expect("Failed to parse instruction");
// Pass the single Value to your parser
json_isa::json_to_instruction(&mut insts_builder, &mut inst_data_builder, &json_inst);
drop(json_inst);
let json_core: Value = serde_json::from_reader(json_core_reader)
.unwrap_or_else(|err| panic!("failed to parse core{}: {}", external_core_indx, err));
let json_core_insts = json_core
.as_array()
.unwrap_or_else(|| panic!("core{} has not a list of instruction", external_core_indx));
for json_inst in json_core_insts {
json_isa::json_to_instruction(&mut insts_builder, &mut inst_data_builder, json_inst);
}
core_insts_builder.set_core(core_indx, insts_builder.build());
}
@@ -1,2 +1,2 @@
mod json_isa;
pub(crate) mod json_isa;
pub mod json_to_executor;
@@ -1,5 +1,6 @@
#![allow(unused)]
use anyhow::{Result, bail};
use std::{
collections::{HashMap, HashSet},
time::{Duration, SystemTime},
@@ -87,6 +88,11 @@ pub struct Executable<'a> {
send_recv: SendRecv,
}
struct DeadlockInfo {
cycle: String,
states: String,
}
fn print_status(core_instructions: &[CoreInstructions]) {
let mut tot_instructions = 0;
let mut progress = 0;
@@ -118,7 +124,7 @@ impl<'a> Executable<'a> {
}
}
pub fn execute<'b>(&'b mut self)
pub fn execute<'b>(&'b mut self) -> Result<()>
where
'a: 'b,
{
@@ -153,7 +159,13 @@ impl<'a> Executable<'a> {
}
if (now.elapsed().unwrap() > Duration::from_secs(5)) {
print_status(cores_instructions);
check_cycle(cpu, cores_instructions, send_recv);
if let Some(deadlock) = detect_deadlock(cores_instructions) {
bail!(
"Deadlock cycle detected: {} [{}]",
deadlock.cycle,
deadlock.states
);
}
now = SystemTime::now();
}
}
@@ -178,8 +190,23 @@ impl<'a> Executable<'a> {
}
print_status(cores_instructions);
if let Some(deadlock) = detect_deadlock(cores_instructions) {
bail!(
"Deadlock cycle detected: {} [{}]",
deadlock.cycle,
deadlock.states
);
}
if cores_instructions
.iter()
.any(|core_inst| core_inst.program_counter < core_inst.instructions.len())
{
bail!("Execution stalled with unfinished instructions");
}
#[cfg(feature = "profile_time")]
TRACER.lock().unwrap().report();
Ok(())
}
pub fn cpu(&self) -> &CPU<'a> {
@@ -201,11 +228,11 @@ impl<'a> Executable<'a> {
}
}
fn check_cycle(cpu: &mut CPU, cores_instructions: &[CoreInstructions], send_recv: &mut SendRecv) {
fn detect_deadlock(cores_instructions: &[CoreInstructions]) -> Option<DeadlockInfo> {
#[derive(Debug, PartialEq, Eq)]
enum CoreState {
SendingTo(i32),
ReceivingFrom(i32),
SendingTo(i32, i32),
ReceivingFrom(i32, i32),
Working,
Halted,
}
@@ -223,9 +250,9 @@ fn check_cycle(cpu: &mut CPU, cores_instructions: &[CoreInstructions], send_recv
let (this_core, target_core) = data.get_core_immcore();
if isa_recv(functor_address) {
states.insert(this_core, CoreState::ReceivingFrom(target_core));
states.insert(this_core, CoreState::ReceivingFrom(target_core, data.imm_len()));
} else if isa_send(functor_address) {
states.insert(this_core, CoreState::SendingTo(target_core));
states.insert(this_core, CoreState::SendingTo(target_core, data.imm_len()));
} else {
states.insert(this_core, CoreState::Working);
}
@@ -235,15 +262,15 @@ fn check_cycle(cpu: &mut CPU, cores_instructions: &[CoreInstructions], send_recv
for (&core_id, state) in states.iter() {
match state {
CoreState::SendingTo(target_core) => {
CoreState::SendingTo(target_core, size) => {
let target_state = states.get(target_core).unwrap_or(&CoreState::Halted);
if target_state != &CoreState::ReceivingFrom(core_id) {
if target_state != &CoreState::ReceivingFrom(core_id, *size) {
wait_for.insert(core_id, *target_core);
}
}
CoreState::ReceivingFrom(target_core) => {
CoreState::ReceivingFrom(target_core, size) => {
let target_state = states.get(target_core).unwrap_or(&CoreState::Halted);
if target_state != &CoreState::SendingTo(core_id) {
if target_state != &CoreState::SendingTo(core_id, *size) {
wait_for.insert(core_id, *target_core);
}
}
@@ -279,11 +306,33 @@ fn check_cycle(cpu: &mut CPU, cores_instructions: &[CoreInstructions], send_recv
.collect::<Vec<_>>()
.join(" -> ");
let cycle = cycle
.iter()
.copied()
.chain(std::iter::once(waiting_for))
.collect::<Vec<_>>();
let cycle_msg = format!("{} -> {}", cycle_str, waiting_for);
let states_msg = cycle
.iter()
.filter_map(|core| {
states.get(core).map(|state| match state {
CoreState::SendingTo(target, size) => {
format!("core {} send {}B -> {}", core, size, target)
}
CoreState::ReceivingFrom(source, size) => {
format!("core {} recv {}B <- {}", core, size, source)
}
CoreState::Working => format!("core {} working", core),
CoreState::Halted => format!("core {} halted", core),
})
})
.collect::<Vec<_>>()
.join(", ");
println!("Fatal: Deadlock cycle detected: {}", cycle_msg);
// bail!("Deadlock detected: {}", cycle_msg);
break; // Stop tracing
return Some(DeadlockInfo {
cycle: cycle_msg,
states: states_msg,
});
}
// Hit a known branch that didn't result in a cycle
@@ -294,6 +343,7 @@ fn check_cycle(cpu: &mut CPU, cores_instructions: &[CoreInstructions], send_recv
current_core = waiting_for;
}
}
None
}
fn handle_wait_sync<'a, 'b, 'c>(
@@ -1,6 +1,11 @@
use std::path::Path;
use pimcore::{Executable, cpu::CPU, instruction_set::{InstructionsBuilder, instruction_data::InstructionDataBuilder, isa::*}};
use pimcore::{
Executable,
cpu::crossbar::Crossbar,
instruction_set::{InstructionsBuilder, instruction_data::InstructionDataBuilder, isa::*},
memory_manager::CoreMemory,
};
fn simple_read(path: &Path) -> Vec<f32> {
if !path.exists() {
@@ -17,14 +22,12 @@ fn simple_read(path: &Path) -> Vec<f32> {
fn mvmul_f32(err: &str)
where
{
let mut cpu = CPU::new(0);
cpu.reserve_crossbar(1, 1024 * size_of::<f32>(), 1024);
let (memory, crossbars) = cpu.host().get_memory_crossbar();
let matrix = simple_read(Path::new("B.txt")) ;
crossbars.get_mut(0).unwrap().execute_store( &matrix).unwrap();
let vector = simple_read(Path::new("A.txt"));
let matrix = simple_read(Path::new("tests/B.txt"));
let mut crossbar = Crossbar::new(1024 * size_of::<f32>(), 1024, CoreMemory::new());
crossbar.execute_store(&matrix).unwrap();
let mut cpu = pimcore::cpu::CPU::new(0, vec![vec![&crossbar]]);
let (memory, _) = cpu.host().get_memory_crossbar();
let vector = simple_read(Path::new("tests/A.txt"));
memory.execute_store(0, &vector).unwrap();
let mut inst_builder = InstructionsBuilder::new();
@@ -57,7 +60,7 @@ where
.cpu_mut()
.host()
.load::<f32>(1024 * size_of::<f32>(), 1024*size_of::<f32>()).unwrap()[0].iter().zip(
simple_read(Path::new("X.txt")) ).all(|(&a,b) : (&f32, f32)| {a-b < 0.001}),
simple_read(Path::new("tests/X.txt")) ).all(|(&a,b) : (&f32, f32)| {a-b < 0.001}),
"Wrong result for {}",
err
);
@@ -69,5 +72,3 @@ fn mvmul_big_test() {
}
@@ -0,0 +1,5 @@
use pimcore::cpu::CPU;
pub fn empty_cpu(num_cores: usize) -> CPU<'static> {
CPU::new(num_cores, vec![Vec::new(); num_cores + 1])
}
@@ -1,51 +1,103 @@
use std::{fs, io::BufReader, path::Path};
use std::{
fs::{self, File},
io::BufReader,
path::{Path, PathBuf},
};
use anyhow::{Context, Result};
use pimcore::json_to_instruction::json_to_executor;
use pimcore::{
cpu::crossbar::Crossbar,
json_to_instruction::json_to_executor,
memory_manager::CoreMemory,
};
use serde_json::Value;
fn collect_json_from_subfolders<P: AsRef<Path>>(root: P) -> Result<Vec<(Value, Vec<Value>)>> {
fn collect_examples<P: AsRef<Path>>(root: P) -> Result<Vec<PathBuf>> {
let mut result = Vec::new();
for entry in fs::read_dir(root)? {
let entry = entry.context("Root not found")?;
let path = entry.path();
if path.is_dir() {
let mut cores = Vec::new();
let mut config: Option<Value> = None;
for sub_entry in fs::read_dir(&path)
.with_context(|| format!("File {} not readable", path.display()))?
{
let sub_entry =
sub_entry.with_context(|| format!("File {} not readable", path.display()))?;
let sub_path = sub_entry.path();
if sub_path.is_file()
&& sub_path.extension().and_then(|s| s.to_str()) == Some("json")
{
let file = fs::File::open(&sub_path)
.with_context(|| format!("Subpath {} not opened", sub_path.display()))?;
let reader = BufReader::new(file);
let val: Value = serde_json::from_reader(reader).with_context(|| format!(
"Serde reader fail for subpath {}",
sub_path.display()
))?;
if sub_path.file_name().unwrap() == "config.json" {
config = Some(val);
} else {
cores.push(val);
}
}
}
result.push((config.unwrap(), cores));
result.push(path);
}
}
Ok(result)
}
fn core_sort_key(path: &Path) -> i32 {
let stem = path.file_stem().unwrap().to_str().unwrap();
stem[5..].parse::<i32>().unwrap()
}
fn crossbar_sort_key(path: &Path) -> i32 {
let stem = path.file_stem().unwrap().to_str().unwrap();
stem[9..].parse::<i32>().unwrap()
}
fn load_crossbars(folder: &Path, config: &Value) -> Result<Vec<Vec<Crossbar>>> {
let xbar_size = config.get("xbar_size").unwrap().as_array().unwrap();
let rows = xbar_size[0].as_i64().unwrap() as usize;
let cols = xbar_size[1].as_i64().unwrap() as usize;
let core_cnt = config.get("core_cnt").unwrap().as_i64().unwrap() as usize;
let mut owned_crossbars = Vec::with_capacity(core_cnt + 1);
owned_crossbars.push(Vec::new());
for core_idx in 0..core_cnt {
let core_folder = folder.join(format!("core_{core_idx}"));
let mut core_crossbars = Vec::new();
if core_folder.is_dir() {
let mut paths: Vec<_> = fs::read_dir(&core_folder)?
.map(|entry| entry.map(|entry| entry.path()))
.collect::<std::io::Result<Vec<_>>>()?;
paths.sort_by_cached_key(|path| crossbar_sort_key(path));
for path in paths {
if path.extension().and_then(|ext| ext.to_str()) != Some("bin") {
continue;
}
let bytes = fs::read(&path)
.with_context(|| format!("failed to read crossbar {}", path.display()))?;
let mut crossbar = Crossbar::new(cols * 4, rows, CoreMemory::new());
crossbar.execute_store(&bytes)?;
core_crossbars.push(crossbar);
}
}
owned_crossbars.push(core_crossbars);
}
Ok(owned_crossbars)
}
#[test]
fn json_folder_tester() {
let examples = collect_json_from_subfolders("data").unwrap();
for example in examples {
let (config, cores) = example;
json_to_executor::json_to_executor(config, cores.iter()).execute();
let examples = collect_examples("tests/data").unwrap();
for folder in examples {
let config_path = folder.join("config.json");
let config_file = File::open(&config_path).unwrap();
let config: Value = serde_json::from_reader(BufReader::new(config_file)).unwrap();
let core_cnt = config.get("core_cnt").unwrap().as_i64().unwrap() as usize;
let mut core_paths: Vec<_> = fs::read_dir(&folder)
.unwrap()
.map(|entry| entry.unwrap().path())
.filter(|path| path.extension().and_then(|ext| ext.to_str()) == Some("json"))
.filter(|path| path.file_name().unwrap() != "config.json")
.collect();
core_paths.sort_by_cached_key(|path| core_sort_key(path));
assert_eq!(core_paths.len(), core_cnt);
let mut core_readers: Vec<_> = core_paths
.into_iter()
.map(|path| BufReader::new(File::open(path).unwrap()))
.collect();
let owned_crossbars = load_crossbars(&folder, &config).unwrap();
let crossbars = owned_crossbars
.iter()
.map(|core_crossbars| core_crossbars.iter().collect())
.collect();
let mut executable = json_to_executor::json_to_executor(config, &mut core_readers, crossbars);
let memory = fs::read(folder.join("memory.bin")).unwrap();
executable.cpu_mut().host().execute_store(0, &memory).unwrap();
executable.execute();
}
}
@@ -1,11 +1,17 @@
mod common;
use pimcore::{Executable, cpu::CPU, instruction_set::{InstructionType, InstructionsBuilder, instruction_data::InstructionDataBuilder, isa::*}};
use pimcore::{
Executable,
instruction_set::{
InstructionType, InstructionsBuilder, instruction_data::InstructionDataBuilder, isa::*,
},
};
#[test]
#[should_panic(expected = "Function not found for the requested size") ]
fn wrong_size_place_holder() {
let cpu = CPU::new(0);
let cpu = common::empty_cpu(0);
let mut inst_builder = InstructionsBuilder::new();
let mut idata_build = InstructionDataBuilder::new();
idata_build.set_core_indx(0).fix_core_indx();
@@ -30,7 +36,7 @@ fn wrong_size_place_holder() {
fn place_holder(inst : InstructionType) {
let mut cpu = CPU::new(0);
let mut cpu = common::empty_cpu(0);
let mut idata_build = InstructionDataBuilder::new();
idata_build.set_core_indx(0).fix_core_indx();
inst(&mut cpu, idata_build.build()).unwrap();
@@ -1,8 +1,10 @@
mod common;
use pimcore::{
Executable,
cpu::CPU,
cpu::crossbar::Crossbar,
instruction_set::{InstructionsBuilder, instruction_data::InstructionDataBuilder, isa::*},
memory_manager::{MemoryStorable, type_traits::UpcastDestTraits},
memory_manager::{CoreMemory, MemoryStorable, type_traits::UpcastDestTraits},
};
/// VVADD Test
@@ -11,7 +13,7 @@ where
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
{
let mut cpu = CPU::new(0);
let mut cpu = common::empty_cpu(0);
let buff: [F; _] = [
1.0.into(),
2.0.into(),
@@ -115,7 +117,7 @@ where
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
{
let mut cpu = CPU::new(0);
let mut cpu = common::empty_cpu(0);
let buff: [F; _] = [
1.0.into(),
2.0.into(),
@@ -219,7 +221,7 @@ where
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
{
let mut cpu = CPU::new(0);
let mut cpu = common::empty_cpu(0);
let buff: [F; _] = [
1.0.into(),
2.0.into(),
@@ -323,7 +325,7 @@ where
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
{
let mut cpu = CPU::new(0);
let mut cpu = common::empty_cpu(0);
let buff: [F; _] = [
1.0.into(),
2.0.into(),
@@ -420,7 +422,7 @@ where
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
{
let mut cpu = CPU::new(0);
let mut cpu = common::empty_cpu(0);
let buff: [F; _] = [
9.0.into(),
2.0.into(),
@@ -524,7 +526,7 @@ where
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
{
let mut cpu = CPU::new(0);
let mut cpu = common::empty_cpu(0);
let buff: [F; _] = [
9.0.into(),
2.0.into(),
@@ -562,6 +564,7 @@ where
vavg,
idata_build
.set_rdr1r2(3, 1, 1)
.set_offset_select(1)
.set_imm_len(8 * size_of::<F>() as i32)
.build(),
);
@@ -617,7 +620,7 @@ where
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
{
let mut cpu = CPU::new(0);
let mut cpu = common::empty_cpu(0);
let buff: [F; _] = [
(-9.0).into(),
2.0.into(),
@@ -717,7 +720,7 @@ where
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable + UpcastDestTraits<T>,
{
let mut cpu = CPU::new(0);
let mut cpu = common::empty_cpu(0);
let buff: [F; _] = [
0.1.into(),
0.2.into(),
@@ -819,7 +822,7 @@ where
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable + UpcastDestTraits<T>,
{
let mut cpu = CPU::new(0);
let mut cpu = common::empty_cpu(0);
let buff: [F; _] = [
0.1.into(),
0.2.into(),
@@ -923,9 +926,6 @@ where
M: From<f32> + std::fmt::Debug + PartialEq<M> + MemoryStorable,
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable + UpcastDestTraits<T>,
{
let mut cpu = CPU::new(0);
cpu.reserve_crossbar(1, 4 * size_of::<M>(), 4);
let (memory, crossbars) = cpu.host().get_memory_crossbar();
let matrix: [M; _] = [
1.0.into(),
2.0.into(),
@@ -944,7 +944,10 @@ where
15.0.into(),
16.0.into(),
];
crossbars.get_mut(0).unwrap().execute_store( &matrix).unwrap();
let mut crossbar = Crossbar::new(4 * size_of::<M>(), 4, CoreMemory::new());
crossbar.execute_store(&matrix).unwrap();
let mut cpu = pimcore::cpu::CPU::new(0, vec![vec![&crossbar]]);
let (memory, _) = cpu.host().get_memory_crossbar();
let vector: [F; _] = [
1.0.into(),
2.0.into(),
@@ -1054,5 +1057,3 @@ fn mvmul_test() {
mvmul_test_generic::<f64,f64,f64>("mvmul<f64,f64,f64>",1);
}
@@ -1,12 +1,13 @@
mod common;
use pimcore::{
Executable, CoreInstructionsBuilder,
cpu::CPU,
instruction_set::{InstructionsBuilder, instruction_data::InstructionDataBuilder, isa::*},
};
#[test]
fn ld_test() {
let mut cpu = CPU::new(1);
let mut cpu = common::empty_cpu(1);
let mut core_instruction_builder = CoreInstructionsBuilder::new(1);
let buff: [f32; _] = [
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
@@ -41,7 +42,7 @@ fn ld_test() {
#[test]
fn st_test() {
let mut cpu = CPU::new(1);
let mut cpu = common::empty_cpu(1);
let mut core_instruction_builder = CoreInstructionsBuilder::new(1);
let buff: [f32; _] = [
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
@@ -76,7 +77,7 @@ fn st_test() {
#[test]
fn lldi_test() {
let cpu = CPU::new(1);
let cpu = common::empty_cpu(1);
let mut core_instruction_builder = CoreInstructionsBuilder::new(1);
let mut inst_builder = InstructionsBuilder::new();
let mut idata_build = InstructionDataBuilder::new();
@@ -106,7 +107,7 @@ fn lldi_test() {
#[test]
fn lmv_test() {
let mut cpu = CPU::new(1);
let mut cpu = common::empty_cpu(1);
let mut core_instruction_builder = CoreInstructionsBuilder::new(1);
let buff: [f32; _] = [
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
@@ -148,7 +149,7 @@ fn lmv_test() {
#[test]
fn simple_send_recv_test() {
let mut cpu = CPU::new(2);
let mut cpu = common::empty_cpu(2);
let mut core_instruction_builder = CoreInstructionsBuilder::new(2);
let buff: [f32; _] = [
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
@@ -207,7 +208,7 @@ fn simple_send_recv_test() {
#[test]
fn multiple_send_recv_test() {
let mut cpu = CPU::new(4);
let mut cpu = common::empty_cpu(4);
let mut core_instruction_builder = CoreInstructionsBuilder::new(4);
let buff: [f32; _] = [
1.0, 1.0, 1.0, 1.0, 1.0
@@ -226,7 +227,7 @@ fn multiple_send_recv_test() {
];
cpu.core(4).execute_store(0, &buff).unwrap();
let send_inst = |cpu :&mut CPU, core_instruction_builder: &mut CoreInstructionsBuilder, inst_builder: &mut InstructionsBuilder, from : i32, to : i32| {
let send_inst = |inst_builder: &mut InstructionsBuilder, from: i32, to: i32| {
let mut idata_build = InstructionDataBuilder::new();
idata_build.set_core_indx(from).fix_core_indx();
inst_builder.make_inst(sldi, idata_build.set_rdimm(1, from*size_of::<f32>() as i32).build());
@@ -240,7 +241,7 @@ fn multiple_send_recv_test() {
);
};
let recv_inst = |cpu :&mut CPU, core_instruction_builder: &mut CoreInstructionsBuilder, mut inst_builder: &mut InstructionsBuilder, to : i32, from : i32| {
let recv_inst = |inst_builder: &mut InstructionsBuilder, to: i32, from: i32| {
let mut idata_build = InstructionDataBuilder::new();
idata_build.set_core_indx(to).fix_core_indx();
inst_builder.make_inst(sldi, idata_build.set_rdimm(1, from*size_of::<f32>() as i32).build());
@@ -257,26 +258,26 @@ fn multiple_send_recv_test() {
// 1 -> 3
send_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,1, 3);
send_inst(&mut inst_builder, 1, 3);
core_instruction_builder.set_core(1, inst_builder.build());
// 2 -> 3
// 2 <- 4
send_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,2, 3);
recv_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,2, 4);
send_inst(&mut inst_builder, 2, 3);
recv_inst(&mut inst_builder, 2, 4);
core_instruction_builder.set_core(2, inst_builder.build());
// 3 <- 2
// 3 <- 4
// 3 <- 1
recv_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,3, 2);
recv_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,3, 4);
recv_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,3, 1);
recv_inst(&mut inst_builder, 3, 2);
recv_inst(&mut inst_builder, 3, 4);
recv_inst(&mut inst_builder, 3, 1);
core_instruction_builder.set_core(3, inst_builder.build());
// 4 -> 2
// 4 -> 3
send_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,4, 2);
send_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,4, 3);
send_inst(&mut inst_builder, 4, 2);
send_inst(&mut inst_builder, 4, 3);
core_instruction_builder.set_core(4, inst_builder.build());
let mut executable = Executable::new(cpu, core_instruction_builder.build());
+1
View File
@@ -1,5 +1,6 @@
add_pim_library(OMPimCommon
IR/AddressAnalysis.cpp
IR/ConstantUtils.cpp
IR/CoreBlockUtils.cpp
IR/EntryPointUtils.cpp
IR/ShapeUtils.cpp
+54
View File
@@ -1,5 +1,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
@@ -55,6 +57,47 @@ mlir::Value resolveLoopCarriedAliasImpl(mlir::Value value, const StaticValueKnow
}
llvm::FailureOr<int64_t> resolveOpFoldResult(mlir::OpFoldResult ofr, const StaticValueKnowledge* knowledge);
llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticValueKnowledge* knowledge);
static llvm::FailureOr<int64_t> resolveConstantGlobalLoad(mlir::memref::LoadOp loadOp,
const StaticValueKnowledge* knowledge) {
auto getGlobalOp = loadOp.getMemRef().getDefiningOp<mlir::memref::GetGlobalOp>();
if (!getGlobalOp)
return mlir::failure();
auto moduleOp = loadOp->getParentOfType<mlir::ModuleOp>();
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp || !globalOp.getConstant() || !globalOp.getInitialValue())
return mlir::failure();
auto denseAttr = mlir::dyn_cast<mlir::DenseElementsAttr>(*globalOp.getInitialValue());
auto globalType = mlir::dyn_cast<mlir::MemRefType>(getGlobalOp.getType());
if (!denseAttr || !globalType || !globalType.hasStaticShape())
return mlir::failure();
auto elementType = denseAttr.getElementType();
if (!elementType.isIndex() && !elementType.isInteger())
return mlir::failure();
llvm::SmallVector<int64_t> indices;
indices.reserve(loadOp.getIndices().size());
for (mlir::Value index : loadOp.getIndices()) {
auto resolvedIndex = resolveIndexValueImpl(index, knowledge);
if (failed(resolvedIndex))
return mlir::failure();
indices.push_back(*resolvedIndex);
}
if (indices.size() != static_cast<size_t>(globalType.getRank()))
return mlir::failure();
auto strides = computeRowMajorStrides(globalType.getShape());
int64_t linearIndex = linearizeIndex(indices, strides);
if (linearIndex < 0 || linearIndex >= globalType.getNumElements())
return mlir::failure();
return denseAttr.getValues<llvm::APInt>()[linearIndex].getSExtValue();
}
llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticValueKnowledge* knowledge) {
value = resolveAlias(value, knowledge);
@@ -110,6 +153,14 @@ llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticVa
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) / static_cast<uint64_t>(*rhs));
}
if (auto minOp = mlir::dyn_cast<mlir::arith::MinUIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(minOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(minOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs))
return mlir::failure();
return static_cast<int64_t>(std::min(static_cast<uint64_t>(*lhs), static_cast<uint64_t>(*rhs)));
}
if (auto remOp = mlir::dyn_cast<mlir::arith::RemUIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(remOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(remOp.getRhs(), knowledge);
@@ -118,6 +169,9 @@ llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticVa
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) % static_cast<uint64_t>(*rhs));
}
if (auto loadOp = mlir::dyn_cast<mlir::memref::LoadOp>(definingOp))
return resolveConstantGlobalLoad(loadOp, knowledge);
return mlir::failure();
}
+62
View File
@@ -0,0 +1,62 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
#include "ConstantUtils.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
Block* getHostConstantBlock(Operation* anchorOp) {
assert(anchorOp && "expected a valid anchor operation");
for (Operation* current = anchorOp; current; current = current->getParentOp())
if (isa<spatial::SpatCompute, spatial::SpatComputeBatch, pim::PimCoreOp, pim::PimCoreBatchOp>(current))
return current->getBlock();
if (auto funcOp = anchorOp->getParentOfType<func::FuncOp>())
return &funcOp.getBody().front();
if (auto moduleOp = anchorOp->getParentOfType<ModuleOp>())
return moduleOp.getBody();
return anchorOp->getBlock();
}
Value getOrCreateHostConstant(Operation* anchorOp, Attribute value, Type type, OperationFolder& folder) {
assert(anchorOp && "expected a valid anchor operation");
Block* hostBlock = getHostConstantBlock(anchorOp);
for (Operation& op : *hostBlock) {
auto constantOp = dyn_cast<arith::ConstantOp>(&op);
if (!constantOp || constantOp.getType() != type || constantOp.getValue() != value)
continue;
return constantOp.getResult();
}
auto* arithDialect = anchorOp->getContext()->getOrLoadDialect<arith::ArithDialect>();
return folder.getOrCreateConstant(hostBlock, arithDialect, value, type);
}
Value getOrCreateHostConstantLike(arith::ConstantOp constantOp, OperationFolder& folder) {
return getOrCreateHostConstant(constantOp.getOperation(), constantOp.getValue(), constantOp.getType(), folder);
}
Value getOrCreateHostIndexConstant(Operation* anchorOp, int64_t value, OperationFolder& folder) {
Builder builder(anchorOp->getContext());
return getOrCreateHostConstant(anchorOp, builder.getIndexAttr(value), builder.getIndexType(), folder);
}
Value getOrCreateHostI32Constant(Operation* anchorOp, int32_t value, OperationFolder& folder) {
Builder builder(anchorOp->getContext());
return getOrCreateHostConstant(anchorOp, builder.getI32IntegerAttr(value), builder.getI32Type(), folder);
}
Value getOrCreateHostI64Constant(Operation* anchorOp, int64_t value, OperationFolder& folder) {
Builder builder(anchorOp->getContext());
return getOrCreateHostConstant(anchorOp, builder.getI64IntegerAttr(value), builder.getI64Type(), folder);
}
} // namespace onnx_mlir
+28
View File
@@ -0,0 +1,28 @@
#pragma once
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Value.h"
#include "mlir/Transforms/FoldUtils.h"
namespace onnx_mlir {
mlir::Block* getHostConstantBlock(mlir::Operation* anchorOp);
mlir::Value getOrCreateHostConstant(mlir::Operation* anchorOp,
mlir::Attribute value,
mlir::Type type,
mlir::OperationFolder& folder);
mlir::Value getOrCreateHostConstantLike(mlir::arith::ConstantOp constantOp, mlir::OperationFolder& folder);
mlir::Value getOrCreateHostIndexConstant(mlir::Operation* anchorOp, int64_t value, mlir::OperationFolder& folder);
mlir::Value getOrCreateHostI32Constant(mlir::Operation* anchorOp, int32_t value, mlir::OperationFolder& folder);
mlir::Value getOrCreateHostI64Constant(mlir::Operation* anchorOp, int64_t value, mlir::OperationFolder& folder);
} // namespace onnx_mlir
+5
View File
@@ -1,4 +1,5 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp"
@@ -12,6 +13,7 @@ bool isCoreStaticAddressOp(mlir::Operation* op) {
mlir::arith::SubIOp,
mlir::arith::MulIOp,
mlir::arith::DivUIOp,
mlir::arith::MinUIOp,
mlir::arith::RemUIOp,
mlir::arith::IndexCastOp,
mlir::memref::AllocOp,
@@ -29,6 +31,9 @@ walkPimCoreBlock(mlir::Block& block,
for (mlir::Operation& op : block) {
if (mlir::isa<pim::PimHaltOp, mlir::scf::YieldOp>(op) || isCoreStaticAddressOp(&op))
continue;
if (auto loadOp = mlir::dyn_cast<mlir::memref::LoadOp>(op);
loadOp && succeeded(resolveIndexValue(loadOp.getResult(), knowledge)))
continue;
if (auto forOp = mlir::dyn_cast<mlir::scf::ForOp>(op)) {
mlir::Block& loopBody = forOp.getRegion().front();
+1 -2
View File
@@ -1,7 +1,6 @@
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
using namespace mlir;
+25 -16
View File
@@ -21,12 +21,13 @@ namespace {
template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy>
bool hasMvmVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) {
mlir::Value weightArg = parentOp.getWeightArgument(weightIndex);
bool found = false;
parentOp.walk([&](mlir::Operation* op) {
if (auto mvmOp = mlir::dyn_cast<MVMOpTy>(op))
found |= mvmOp.getWeightIndex() == weightIndex;
found |= mvmOp.getWeight() == weightArg;
else if (auto vmmOp = mlir::dyn_cast<VMMOpTy>(op))
found |= vmmOp.getWeightIndex() == weightIndex;
found |= vmmOp.getWeight() == weightArg;
});
return found;
}
@@ -35,13 +36,18 @@ template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy>
void walkMvmVmmWeightUses(ParentOpTy parentOp, llvm::function_ref<void(mlir::OpOperand&)> callback) {
auto weights = parentOp.getWeights();
llvm::SmallSet<unsigned, 8> visited;
auto walkWeightIndex = [&](unsigned weightIndex) {
if (weightIndex < weights.size() && visited.insert(weightIndex).second)
callback(parentOp->getOpOperand(weightIndex));
auto walkWeight = [&](mlir::Value weight) {
for (unsigned weightIndex = 0; weightIndex < weights.size(); ++weightIndex) {
if (parentOp.getWeightArgument(weightIndex) != weight)
continue;
if (visited.insert(weightIndex).second)
callback(parentOp->getOpOperand(weightIndex));
break;
}
};
parentOp.walk([&](MVMOpTy op) { walkWeightIndex(op.getWeightIndex()); });
parentOp.walk([&](VMMOpTy op) { walkWeightIndex(op.getWeightIndex()); });
parentOp.walk([&](MVMOpTy op) { walkWeight(op.getWeight()); });
parentOp.walk([&](VMMOpTy op) { walkWeight(op.getWeight()); });
}
} // namespace
@@ -90,18 +96,21 @@ void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir
assert(root && "expected valid root op");
root->walk([&](pim::PimCoreOp coreOp) {
coreOp.walk([&](pim::PimVMMOp vmmOp) {
auto weights = coreOp.getWeights();
unsigned weightIndex = vmmOp.getWeightIndex();
if (weightIndex < weights.size())
callback(coreOp->getOpOperand(weightIndex));
for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex)
if (coreOp.getWeightArgument(weightIndex) == vmmOp.getWeight()) {
callback(coreOp->getOpOperand(weightIndex));
break;
}
});
});
root->walk([&](pim::PimCoreBatchOp coreBatchOp) {
auto weights = coreBatchOp.getWeights();
for (auto weight : weights)
for (mlir::OpOperand& use : weight.getUses())
if (use.getOwner() == coreBatchOp.getOperation())
callback(use);
coreBatchOp.walk([&](pim::PimVMMOp vmmOp) {
for (unsigned weightIndex = 0; weightIndex < coreBatchOp.getWeights().size(); ++weightIndex)
if (coreBatchOp.getWeightArgument(weightIndex) == vmmOp.getWeight()) {
callback(coreBatchOp->getOpOperand(weightIndex));
break;
}
});
});
}
+1
View File
@@ -12,6 +12,7 @@
#include "llvm/ADT/StringRef.h"
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/EntryPointUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
+24
View File
@@ -7,10 +7,34 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/StringRef.h"
#include <cstdint>
#include <system_error>
namespace onnx_mlir::pim {
struct CappedDiagnosticReporter {
explicit CappedDiagnosticReporter(int64_t maxReportedFailures = 8) : maxReportedFailures(maxReportedFailures) {}
template <typename EmitFn>
void report(mlir::Operation* op, EmitFn&& emit) {
numFailures++;
if (numFailures <= maxReportedFailures)
emit(op);
}
void emitSuppressedSummary(mlir::Operation* op, llvm::StringRef failureDescription) const {
if (numFailures > maxReportedFailures)
op->emitError() << "suppressed " << (numFailures - maxReportedFailures) << " additional "
<< failureDescription;
}
bool hasFailure() const { return numFailures != 0; }
private:
int64_t maxReportedFailures;
int64_t numFailures = 0;
};
/// Emits a consistent diagnostic for target paths that require static shapes.
mlir::InFlightDiagnostic emitUnsupportedStaticShapeDiagnostic(mlir::Operation* op, llvm::StringRef valueDescription);
+1 -2
View File
@@ -1,8 +1,7 @@
#include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp"
#include "llvm/Support/Format.h"
#include "src/Accelerators/PIM/Common/Support/FileSystemUtils.hpp"
#include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp"
namespace onnx_mlir {
+1 -2
View File
@@ -1,10 +1,9 @@
#pragma once
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/raw_ostream.h"
#include <cstdint>
#include <fstream>
#include <limits>
#include <string>
-35
View File
@@ -20,38 +20,6 @@ using namespace mlir;
namespace onnx_mlir {
OnnxMlirCompilerErrorCodes writeHostCoreArtifacts(StringRef outputDirPath) {
std::error_code errorCode;
std::string outputHostCorePath = outputDirPath.str() + "/core_0.pim";
raw_fd_ostream hostFileStream(outputHostCorePath, errorCode, sys::fs::OF_None);
if (errorCode) {
errs() << "Error while opening host core file `" << outputHostCorePath << "`: " << errorCode.message() << '\n';
return InvalidOutputFileAccess;
}
pim_binary::writeHeader(hostFileStream);
pim_binary::InstructionRecord noop;
noop.opcode = pim_binary::Opcode::sldi;
pim_binary::writeInstructionRecord(hostFileStream, noop);
pim_binary::writeInstructionRecord(hostFileStream, noop);
pim_binary::patchInstructionCount(hostFileStream, 2);
hostFileStream.close();
if (pimEmitJson.getValue()) {
std::string outputHostJsonPath = outputDirPath.str() + "/core_0.json";
raw_fd_ostream hostJsonStream(outputHostJsonPath, errorCode);
if (errorCode) {
errs() << "Error while opening host core json file `" << outputHostJsonPath << "`: " << errorCode.message()
<< '\n';
return InvalidOutputFileAccess;
}
// The host core json contains two no-op-like instructions to satisfy pimsim-nn
hostJsonStream << "[{\"imm\":0,\"op\":\"sldi\",\"rd\":0},{\"imm\":0,\"op\":\"sldi\",\"rd\":0}]";
hostJsonStream.close();
}
return CompilerSuccess;
}
OnnxMlirCompilerErrorCodes
writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory& memory, StringRef outputDirPath) {
auto memoryFilePath = (outputDirPath + "/memory.bin").str();
@@ -109,9 +77,6 @@ OnnxMlirCompilerErrorCodes writeConfigJson(func::FuncOp funcOp,
json::Object configJson;
configJson["core_cnt"] = maxCoreId + 1;
configJson["adc_count"] = 16;
configJson["cell_precision"] = 2;
configJson["xbar_array_count"] = crossbarCountInCore.getValue();
configJson["xbar_size"] = {crossbarSize.getValue(), crossbarSize.getValue()};
configJson["array_group_map"] = std::move(xbarsPerArrayGroup);
-1
View File
@@ -12,7 +12,6 @@ namespace onnx_mlir {
class PimAcceleratorMemory;
OnnxMlirCompilerErrorCodes writeHostCoreArtifacts(llvm::StringRef outputDirPath);
OnnxMlirCompilerErrorCodes writeMemoryBinary(mlir::ModuleOp moduleOp,
mlir::func::FuncOp funcOp,
PimAcceleratorMemory& memory,
+92 -69
View File
@@ -1,7 +1,11 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h"
#include "llvm/ADT/StringRef.h"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
@@ -24,113 +28,132 @@ static SmallVector<int32_t> getLaneChunkCoreIds(ArrayRef<int32_t> coreIds, size_
return laneCoreIds;
}
static void scalarizeBatchOpsInCore(pim::PimCoreOp scalarCore, size_t laneCount, unsigned lane) {
IRRewriter rewriter(scalarCore.getContext());
SmallVector<Operation*> batchOps;
scalarCore.walk([&](Operation* op) {
if (isa<pim::PimSendBatchOp,
pim::PimSendTensorBatchOp,
pim::PimReceiveBatchOp,
pim::PimReceiveTensorBatchOp,
pim::PimMemCopyHostToDevBatchOp>(op)) {
batchOps.push_back(op);
}
});
static void cloneScalarizedLaneBody(OpBuilder& builder,
pim::PimCoreBatchOp coreBatchOp,
unsigned lane,
OperationFolder& constantFolder) {
Block& oldBlock = coreBatchOp.getBody().front();
size_t laneCount = static_cast<size_t>(coreBatchOp.getLaneCount());
size_t weightCount = coreBatchOp.getWeights().size();
for (Operation* op : batchOps) {
rewriter.setInsertionPoint(op);
IRMapping mapper;
for (auto [argIndex, blockArg] : llvm::enumerate(oldBlock.getArguments())) {
if (blockArg.getType().isIndex()) {
mapper.map(blockArg, getOrCreateHostIndexConstant(coreBatchOp, static_cast<int64_t>(lane), constantFolder));
continue;
}
if (argIndex <= weightCount) {
mapper.map(blockArg, coreBatchOp.getWeights()[argIndex - 1]);
continue;
}
size_t inputIndex = argIndex - 1 - weightCount;
assert(inputIndex < coreBatchOp.getInputs().size() && "pim.core_batch block input index out of range");
mapper.map(blockArg, coreBatchOp.getInputs()[inputIndex]);
}
for (Operation& op : oldBlock) {
if (isa<pim::PimHaltOp>(op))
continue;
if (auto sendBatchOp = dyn_cast<pim::PimSendBatchOp>(op)) {
pim::PimSendOp::create(rewriter,
sendBatchOp.getLoc(),
sendBatchOp.getInput(),
sendBatchOp.getSizeAttr(),
rewriter.getI32IntegerAttr(sendBatchOp.getTargetCoreIds()[lane]));
rewriter.eraseOp(op);
Operation* anchorOp = builder.getInsertionBlock()->getParentOp();
pim::PimSendOp::create(
builder,
sendBatchOp.getLoc(),
mapper.lookup(sendBatchOp.getInput()),
sendBatchOp.getSizeAttr(),
getOrCreateHostIndexConstant(anchorOp, sendBatchOp.getTargetCoreIds()[lane], constantFolder));
continue;
}
if (auto sendTensorBatchOp = dyn_cast<pim::PimSendTensorBatchOp>(op)) {
pim::PimSendTensorOp::create(
rewriter,
builder,
sendTensorBatchOp.getLoc(),
sendTensorBatchOp.getInput(),
rewriter.getDenseI32ArrayAttr(getLaneChunkCoreIds(sendTensorBatchOp.getTargetCoreIds(), laneCount, lane)));
rewriter.eraseOp(op);
mapper.lookup(sendTensorBatchOp.getInput()),
builder.getDenseI32ArrayAttr(getLaneChunkCoreIds(sendTensorBatchOp.getTargetCoreIds(), laneCount, lane)));
continue;
}
if (auto receiveBatchOp = dyn_cast<pim::PimReceiveBatchOp>(op)) {
auto scalarReceive =
pim::PimReceiveOp::create(rewriter,
receiveBatchOp.getLoc(),
receiveBatchOp.getOutput().getType(),
receiveBatchOp.getOutputBuffer(),
receiveBatchOp.getSizeAttr(),
rewriter.getI32IntegerAttr(receiveBatchOp.getSourceCoreIds()[lane]));
rewriter.replaceOp(op, scalarReceive->getResults());
Operation* anchorOp = builder.getInsertionBlock()->getParentOp();
auto scalarReceive = pim::PimReceiveOp::create(
builder,
receiveBatchOp.getLoc(),
receiveBatchOp.getOutput().getType(),
mapper.lookup(receiveBatchOp.getOutputBuffer()),
receiveBatchOp.getSizeAttr(),
getOrCreateHostIndexConstant(anchorOp, receiveBatchOp.getSourceCoreIds()[lane], constantFolder));
mapper.map(receiveBatchOp.getOutput(), scalarReceive.getOutput());
continue;
}
if (auto receiveTensorBatchOp = dyn_cast<pim::PimReceiveTensorBatchOp>(op)) {
auto scalarReceive = pim::PimReceiveTensorOp::create(
rewriter,
builder,
receiveTensorBatchOp.getLoc(),
receiveTensorBatchOp.getOutput().getType(),
receiveTensorBatchOp.getOutputBuffer(),
rewriter.getDenseI32ArrayAttr(getLaneChunkCoreIds(receiveTensorBatchOp.getSourceCoreIds(), laneCount, lane)));
rewriter.replaceOp(op, scalarReceive->getResults());
mapper.lookup(receiveTensorBatchOp.getOutputBuffer()),
builder.getDenseI32ArrayAttr(getLaneChunkCoreIds(receiveTensorBatchOp.getSourceCoreIds(), laneCount, lane)));
mapper.map(receiveTensorBatchOp.getOutput(), scalarReceive.getOutput());
continue;
}
auto memcpBatchOp = cast<pim::PimMemCopyHostToDevBatchOp>(op);
auto scalarCopy = pim::PimMemCopyHostToDevOp::create(rewriter,
memcpBatchOp.getLoc(),
memcpBatchOp.getOutput().getType(),
memcpBatchOp.getDeviceTarget(),
memcpBatchOp.getHostSource(),
memcpBatchOp.getDeviceTargetOffsetAttr(),
memcpBatchOp.getHostSourceOffsetAttr(),
memcpBatchOp.getSizeAttr());
rewriter.replaceOp(op, scalarCopy->getResults());
if (auto memcpBatchOp = dyn_cast<pim::PimMemCopyHostToDevBatchOp>(op)) {
auto scalarCopy = pim::PimMemCopyHostToDevOp::create(
builder,
memcpBatchOp.getLoc(),
memcpBatchOp.getOutput().getType(),
getOrCreateHostIndexConstant(coreBatchOp, memcpBatchOp.getDeviceTargetOffset(), constantFolder),
getOrCreateHostIndexConstant(coreBatchOp, memcpBatchOp.getHostSourceOffset(), constantFolder),
mapper.lookup(memcpBatchOp.getDeviceTarget()),
mapper.lookup(memcpBatchOp.getHostSource()),
memcpBatchOp.getSizeAttr());
mapper.map(memcpBatchOp.getOutput(), scalarCopy.getOutput());
continue;
}
Operation* cloned = builder.clone(op, mapper);
for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults()))
mapper.map(originalResult, clonedResult);
}
}
} // namespace
LogicalResult withScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp,
unsigned lane,
llvm::function_ref<LogicalResult(pim::PimCoreOp)> callback) {
LogicalResult withScalarCoreFromBatchLanes(pim::PimCoreBatchOp coreBatchOp,
ArrayRef<unsigned> lanes,
llvm::function_ref<LogicalResult(pim::PimCoreOp)> callback) {
assert(!lanes.empty() && "expected at least one batch lane");
OwningOpRef<ModuleOp> scratchModule = ModuleOp::create(coreBatchOp.getLoc());
OpBuilder builder(scratchModule->getContext());
OperationFolder constantFolder(scratchModule->getContext());
builder.setInsertionPointToStart(scratchModule->getBody());
size_t laneCount = static_cast<size_t>(coreBatchOp.getLaneCount());
size_t weightsPerLane = coreBatchOp.getWeights().size() / laneCount;
SmallVector<Value> laneWeights;
laneWeights.reserve(weightsPerLane);
for (size_t weightIndex = 0; weightIndex < weightsPerLane; ++weightIndex)
laneWeights.push_back(coreBatchOp.getWeights()[lane * weightsPerLane + weightIndex]);
SmallVector<Value> weights(coreBatchOp.getWeights().begin(), coreBatchOp.getWeights().end());
auto coreIds = getBatchCoreIds(coreBatchOp);
auto scalarCore = pim::PimCoreOp::create(
builder, coreBatchOp.getLoc(), ValueRange(laneWeights), builder.getI32IntegerAttr(coreIds[lane]));
int32_t coreId = coreIds[lanes.front()];
for (unsigned lane : lanes)
assert(coreIds[lane] == coreId && "all grouped lanes must target the same core");
auto scalarCore =
pim::PimCoreOp::create(builder, coreBatchOp.getLoc(), ValueRange(weights), builder.getI32IntegerAttr(coreId));
Block* block = builder.createBlock(&scalarCore.getBody(), scalarCore.getBody().end());
IRMapping mapper;
if (coreBatchOp.getBody().front().getNumArguments() == 1)
mapper.map(coreBatchOp.getBody().front().getArgument(0), coreBatchOp.getInputs()[lane]);
builder.setInsertionPointToEnd(block);
for (Operation& op : coreBatchOp.getBody().front()) {
Operation* cloned = builder.clone(op, mapper);
for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults()))
mapper.map(originalResult, clonedResult);
}
for (unsigned lane : lanes)
cloneScalarizedLaneBody(builder, coreBatchOp, lane, constantFolder);
if (block->empty() || !isa<pim::PimHaltOp>(block->back()))
pim::PimHaltOp::create(builder, coreBatchOp.getLoc());
scalarizeBatchOpsInCore(scalarCore, laneCount, lane);
return callback(scalarCore);
}
LogicalResult withScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp,
unsigned lane,
llvm::function_ref<LogicalResult(pim::PimCoreOp)> callback) {
return withScalarCoreFromBatchLanes(coreBatchOp, ArrayRef<unsigned> {lane}, callback);
}
} // namespace onnx_mlir
+3
View File
@@ -9,5 +9,8 @@ namespace onnx_mlir {
mlir::LogicalResult withScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp,
unsigned lane,
llvm::function_ref<mlir::LogicalResult(pim::PimCoreOp)> callback);
mlir::LogicalResult withScalarCoreFromBatchLanes(pim::PimCoreBatchOp coreBatchOp,
llvm::ArrayRef<unsigned> lanes,
llvm::function_ref<mlir::LogicalResult(pim::PimCoreOp)> callback);
} // namespace onnx_mlir
+36 -43
View File
@@ -70,9 +70,7 @@ inline void writeUint32LE(llvm::raw_ostream& os, uint32_t value) {
os.write(bytes.data(), bytes.size());
}
inline void writeInt32LE(llvm::raw_ostream& os, int32_t value) {
writeUint32LE(os, static_cast<uint32_t>(value));
}
inline void writeInt32LE(llvm::raw_ostream& os, int32_t value) { writeUint32LE(os, static_cast<uint32_t>(value)); }
inline void writeHeader(llvm::raw_ostream& os) {
os.write(kMagic, sizeof(kMagic));
@@ -186,39 +184,39 @@ inline Opcode opcodeFromString(llvm::StringRef opName) {
inline llvm::StringRef opcodeToString(Opcode opcode) {
switch (opcode) {
case Opcode::nop: return "nop";
case Opcode::sldi: return "sldi";
case Opcode::sld: return "sld";
case Opcode::sadd: return "sadd";
case Opcode::ssub: return "ssub";
case Opcode::smul: return "smul";
case Opcode::saddi: return "saddi";
case Opcode::smuli: return "smuli";
case Opcode::setbw: return "setbw";
case Opcode::mvmul: return "mvmul";
case Opcode::vvadd: return "vvadd";
case Opcode::vvsub: return "vvsub";
case Opcode::vvmul: return "vvmul";
case Opcode::vvdmul: return "vvdmul";
case Opcode::vvmax: return "vvmax";
case Opcode::vvsll: return "vvsll";
case Opcode::vvsra: return "vvsra";
case Opcode::vavg: return "vavg";
case Opcode::vrelu: return "vrelu";
case Opcode::vtanh: return "vtanh";
case Opcode::vsigm: return "vsigm";
case Opcode::nop: return "nop";
case Opcode::sldi: return "sldi";
case Opcode::sld: return "sld";
case Opcode::sadd: return "sadd";
case Opcode::ssub: return "ssub";
case Opcode::smul: return "smul";
case Opcode::saddi: return "saddi";
case Opcode::smuli: return "smuli";
case Opcode::setbw: return "setbw";
case Opcode::mvmul: return "mvmul";
case Opcode::vvadd: return "vvadd";
case Opcode::vvsub: return "vvsub";
case Opcode::vvmul: return "vvmul";
case Opcode::vvdmul: return "vvdmul";
case Opcode::vvmax: return "vvmax";
case Opcode::vvsll: return "vvsll";
case Opcode::vvsra: return "vvsra";
case Opcode::vavg: return "vavg";
case Opcode::vrelu: return "vrelu";
case Opcode::vtanh: return "vtanh";
case Opcode::vsigm: return "vsigm";
case Opcode::vsoftmax: return "vsoftmax";
case Opcode::vmv: return "vmv";
case Opcode::vrsu: return "vrsu";
case Opcode::vrsl: return "vrsl";
case Opcode::ld: return "ld";
case Opcode::st: return "st";
case Opcode::lldi: return "lldi";
case Opcode::lmv: return "lmv";
case Opcode::send: return "send";
case Opcode::recv: return "recv";
case Opcode::wait: return "wait";
case Opcode::sync: return "sync";
case Opcode::vmv: return "vmv";
case Opcode::vrsu: return "vrsu";
case Opcode::vrsl: return "vrsl";
case Opcode::ld: return "ld";
case Opcode::st: return "st";
case Opcode::lldi: return "lldi";
case Opcode::lmv: return "lmv";
case Opcode::send: return "send";
case Opcode::recv: return "recv";
case Opcode::wait: return "wait";
case Opcode::sync: return "sync";
}
llvm_unreachable("Unsupported PIM binary opcode");
}
@@ -235,9 +233,7 @@ inline InstructionRecord makeInstructionRecord(const llvm::json::Object& instruc
case Opcode::sldi:
case Opcode::saddi:
case Opcode::smuli:
case Opcode::lldi:
record.r2OrImm = getOptionalInt(instruction, "imm");
break;
case Opcode::lldi: record.r2OrImm = getOptionalInt(instruction, "imm"); break;
case Opcode::mvmul:
record.r2OrImm = getOptionalInt(instruction, "mbiw");
record.generic1 = getOptionalInt(instruction, "relu");
@@ -252,9 +248,7 @@ inline InstructionRecord makeInstructionRecord(const llvm::json::Object& instruc
record.r2OrImm = getOptionalInt(instruction, "core");
record.generic3 = getOptionalInt(instruction, "size");
break;
default:
record.r2OrImm = getOptionalInt(instruction, "rs2");
break;
default: record.r2OrImm = getOptionalInt(instruction, "rs2"); break;
}
if (record.opcode != Opcode::mvmul && record.opcode != Opcode::setbw) {
@@ -371,8 +365,7 @@ inline llvm::json::Object makeInstructionJson(const InstructionRecord& record) {
break;
case Opcode::wait:
case Opcode::sync:
case Opcode::nop:
break;
case Opcode::nop: break;
}
return instruction;
+68 -25
View File
@@ -41,15 +41,23 @@ using namespace mlir;
using namespace onnx_mlir;
using namespace onnx_mlir::compact_asm;
static size_t getElementTypeSizeInBytes(mlir::Type elementType) {
if (elementType.isIndex())
return sizeof(int64_t);
if (elementType.isIntOrFloat())
return elementType.getIntOrFloatBitWidth() / 8;
llvm_unreachable("unsupported shaped element type");
}
static size_t getValueSizeInBytes(mlir::Value value) {
auto type = cast<ShapedType>(value.getType());
return type.getNumElements() * type.getElementTypeBitWidth() / 8;
return type.getNumElements() * getElementTypeSizeInBytes(type.getElementType());
}
MemEntry* PimMemory::gatherMemEntry(mlir::Value value) {
auto type = cast<ShapedType>(value.getType());
assert("Only static shape is supported" && type.hasStaticShape());
size_t allocSize = type.getNumElements() * type.getElementType().getIntOrFloatBitWidth() / 8;
size_t allocSize = type.getNumElements() * getElementTypeSizeInBytes(type.getElementType());
MemEntry memEntry = {0, allocSize};
return &memEntries.emplace_back(memEntry, value).first;
}
@@ -367,7 +375,7 @@ void PimCodeGen::emitMemCopyOp(StringRef opName,
instruction.generic1 = 0;
instruction.generic2 = 0;
instruction.generic3 = static_cast<int32_t>(size);
(void)sizeFieldName;
(void) sizeFieldName;
emitInstruction(instruction);
}
@@ -398,20 +406,28 @@ void PimCodeGen::emitMvmOp(size_t groupId, size_t rdAddr, size_t rdOffset, size_
}
void PimCodeGen::codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp, const StaticValueKnowledge& knowledge) const {
auto deviceTargetOffset = resolveIndexValue(loadOp.getDeviceTargetOffset(), knowledge);
auto hostSourceOffset = resolveIndexValue(loadOp.getHostSourceOffset(), knowledge);
assert(succeeded(deviceTargetOffset) && succeeded(hostSourceOffset)
&& "pim.memcp_hd offsets must be statically resolvable during codegen");
emitMemCopyOp("ld",
addressOf(loadOp.getDeviceTarget(), knowledge),
loadOp.getDeviceTargetOffset(),
*deviceTargetOffset,
addressOf(loadOp.getHostSource(), knowledge),
loadOp.getHostSourceOffset(),
*hostSourceOffset,
loadOp.getSize());
}
void PimCodeGen::codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp, const StaticValueKnowledge& knowledge) const {
auto hostTargetOffset = resolveIndexValue(storeOp.getHostTargetOffset(), knowledge);
auto deviceSourceOffset = resolveIndexValue(storeOp.getDeviceSourceOffset(), knowledge);
assert(succeeded(hostTargetOffset) && succeeded(deviceSourceOffset)
&& "pim.memcp_dh offsets must be statically resolvable during codegen");
emitMemCopyOp("st",
addressOf(storeOp.getHostTarget(), knowledge),
storeOp.getHostTargetOffset(),
*hostTargetOffset,
addressOf(storeOp.getDeviceSource(), knowledge),
storeOp.getDeviceSourceOffset(),
*deviceSourceOffset,
storeOp.getSize());
}
@@ -426,8 +442,9 @@ void PimCodeGen::codeGenLmvOp(pim::PimMemCopyOp lmvOp, const StaticValueKnowledg
}
void PimCodeGen::codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValueKnowledge& knowledge) const {
emitCommunicationOp(
"recv", addressOf(receiveOp.getOutputBuffer(), knowledge), receiveOp.getSourceCoreId(), receiveOp.getSize());
auto sourceCoreId = resolveIndexValue(receiveOp.getSourceCoreId(), knowledge);
assert(succeeded(sourceCoreId) && "pim.receive source core id must be statically resolvable during codegen");
emitCommunicationOp("recv", addressOf(receiveOp.getOutputBuffer(), knowledge), *sourceCoreId, receiveOp.getSize());
}
void PimCodeGen::codeGenReceiveTensorOp(pim::PimReceiveTensorOp receiveTensorOp,
@@ -439,7 +456,9 @@ void PimCodeGen::codeGenReceiveTensorOp(pim::PimReceiveTensorOp receiveTensorOp,
}
void PimCodeGen::codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const {
emitCommunicationOp("send", addressOf(sendOp.getInput(), knowledge), sendOp.getTargetCoreId(), sendOp.getSize());
auto targetCoreId = resolveIndexValue(sendOp.getTargetCoreId(), knowledge);
assert(succeeded(targetCoreId) && "pim.send target core id must be statically resolvable during codegen");
emitCommunicationOp("send", addressOf(sendOp.getInput(), knowledge), *targetCoreId, sendOp.getSize());
}
void PimCodeGen::codeGenSendTensorOp(pim::PimSendTensorOp sendTensorOp, const StaticValueKnowledge& knowledge) const {
@@ -728,12 +747,19 @@ std::string getMemorySizeAsString(size_t size) {
static SmallVector<unsigned, 8> getUsedWeightIndices(Block& block) {
SmallVector<unsigned, 8> indices;
auto addIndex = [&](unsigned weightIndex) {
if (!llvm::is_contained(indices, weightIndex))
indices.push_back(weightIndex);
auto coreOp = dyn_cast<pim::PimCoreOp>(block.getParentOp());
auto addWeight = [&](mlir::Value weight) {
if (!coreOp)
return;
for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex) {
if (coreOp.getWeightArgument(weightIndex) != weight)
continue;
if (!llvm::is_contained(indices, weightIndex))
indices.push_back(weightIndex);
return;
}
};
block.walk([&](pim::PimVMMOp vmmOp) { addIndex(vmmOp.getWeightIndex()); });
block.walk([&](pim::PimVMMOp vmmOp) { addWeight(vmmOp.getWeight()); });
llvm::sort(indices);
return indices;
}
@@ -795,6 +821,15 @@ static void aliasMaterializedHostGlobals(ModuleOp moduleOp,
/// fully resolved before the JSON instructions are emitted.
/// Returns the number of emitted instructions, or -1 on failure.
static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
auto resolveWeightIndex = [&](pim::PimVMMOp vmmOp) -> std::optional<unsigned> {
auto coreOp = vmmOp->getParentOfType<pim::PimCoreOp>();
if (!coreOp)
return std::nullopt;
for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex)
if (coreOp.getWeightArgument(weightIndex) == vmmOp.getWeight())
return weightIndex;
return std::nullopt;
};
size_t processedOperations = 0;
auto result =
walkPimCoreBlock(block, StaticValueKnowledge {}, [&](Operation& op, const StaticValueKnowledge& knowledge) {
@@ -814,8 +849,12 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
coreCodeGen.codeGenSendTensorOp(sendTensorOp, knowledge);
else if (auto concatOp = dyn_cast<pim::PimConcatOp>(op))
coreCodeGen.codeGenConcatOp(concatOp, knowledge);
else if (auto vmmOp = dyn_cast<pim::PimVMMOp>(op))
coreCodeGen.codeGenMVMLikeOp<pim::PimVMMOp>(vmmOp.getWeightIndex(), vmmOp, true, knowledge);
else if (auto vmmOp = dyn_cast<pim::PimVMMOp>(op)) {
auto weightIndex = resolveWeightIndex(vmmOp);
if (!weightIndex)
return failure();
coreCodeGen.codeGenMVMLikeOp<pim::PimVMMOp>(*weightIndex, vmmOp, true, knowledge);
}
else if (auto transposeOp = dyn_cast<pim::PimTransposeOp>(op))
coreCodeGen.codeGenTransposeOp(transposeOp, knowledge);
else if (auto vvaddOp = dyn_cast<pim::PimVVAddOp>(op))
@@ -875,11 +914,6 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
if (auto err = writeMemoryBinary(moduleOp, funcOp, memory, outputDirPath))
return err;
if (auto err = writeHostCoreArtifacts(outputDirPath))
return err;
// For each core, specify the number of crossbar per array group.
// This implementation always assigns one crossbar per group.
json::Object xbarsPerArrayGroup;
size_t maxCoreId = 0;
uint64_t nextBatchReportId = 0;
@@ -891,7 +925,7 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
SmallDenseMap<memref::GlobalOp, MemEntry, 16> materializedHostGlobals =
collectMaterializedHostGlobals(moduleOp, funcOp, memory);
llvm::DenseMap<size_t, size_t> emittedCoreIds;
size_t nextEmittedCoreId = 1;
size_t nextEmittedCoreId = 0;
for (Operation* op : coreLikeOps) {
if (auto coreOp = dyn_cast<pim::PimCoreOp>(op)) {
@@ -1009,10 +1043,19 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::
reportedCoreIds.reserve(batchCoreIds.size());
MemoryReportRow batchRow;
std::optional<MemoryReportRow> batchPerCoreRow;
llvm::DenseMap<size_t, SmallVector<unsigned>> lanesByCoreId;
SmallVector<size_t> orderedOriginalCoreIds;
for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane) {
size_t originalCoreId = static_cast<size_t>(batchCoreIds[lane]);
auto [it, inserted] = lanesByCoreId.try_emplace(originalCoreId);
if (inserted)
orderedOriginalCoreIds.push_back(originalCoreId);
it->second.push_back(lane);
}
for (size_t originalCoreId : orderedOriginalCoreIds) {
OnnxMlirCompilerErrorCodes laneResult = CompilerSuccess;
if (failed(withScalarCoreFromBatchLane(coreBatchOp, lane, [&](pim::PimCoreOp coreOp) {
size_t originalCoreId = static_cast<size_t>(batchCoreIds[lane]);
if (failed(withScalarCoreFromBatchLanes(coreBatchOp, lanesByCoreId[originalCoreId], [&](pim::PimCoreOp coreOp) {
size_t coreId = emittedCoreIds.lookup(originalCoreId);
reportedCoreIds.push_back(static_cast<int32_t>(coreId));
MemoryReportRow laneRow;
+22 -3
View File
@@ -1,5 +1,7 @@
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "llvm/Support/ErrorHandling.h"
#define DEBUG_TYPE "PimCompilerOptions"
namespace onnx_mlir {
@@ -13,6 +15,14 @@ llvm::cl::opt<PimEmissionTargetType> pimEmissionTarget(
llvm::cl::init(EmitPimCodegen),
llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<PimMergeSchedulerType> pimMergeScheduler(
"pim-merge-scheduler",
llvm::cl::desc("Scheduler used by the Spatial merge-compute-nodes pass"),
llvm::cl::values(clEnumValN(MergeSchedulerPeft, "peft", "Use PEFT scheduling")),
llvm::cl::values(clEnumValN(MergeSchedulerDcp, "dcp", "Use the legacy DCP-inspired scheduler")),
llvm::cl::init(MergeSchedulerPeft),
llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<bool>
pimOnlyCodegen("pim-only-codegen",
llvm::cl::desc("Only generate code for PIM (assume input is already in bufferized PIM IR)"),
@@ -30,19 +40,19 @@ llvm::cl::opt<bool> pimEmitJson("pim-emit-json",
llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<size_t>
crossbarSize("crossbar-size", llvm::cl::desc("Width and heigth of a single crossbar"), llvm::cl::init(2));
crossbarSize("crossbar-size", llvm::cl::desc("Width and height of a single crossbar"), llvm::cl::init(2));
llvm::cl::opt<size_t>
crossbarCountInCore("crossbar-count", llvm::cl::desc("Number of crossbars in each core"), llvm::cl::init(256));
llvm::cl::opt<long> coresCount("core-count",
llvm::cl::desc("Number of cores in the chip. `-1` to use the minimum amount of cores."),
llvm::cl::desc("Number of cores in the chip. Required for PIM compilation."),
llvm::cl::init(-1));
llvm::cl::opt<size_t> dcpCriticalWindowSize(
"dcp-critical-window-size",
llvm::cl::desc("Number of lowest-slack virtual nodes considered by each DCP coarsening iteration. "
"Use 0 to run the legacy full-graph DCP analysis."),
"Use 0 to run the legacy full-graph DCP analysis. Only used by the DCP scheduler."),
llvm::cl::init(4000));
llvm::cl::opt<bool>
@@ -50,4 +60,13 @@ llvm::cl::opt<bool>
llvm::cl::desc("Ignore ConcatOp corner case: do not assert and do a simplification"),
llvm::cl::init(false));
bool hasExplicitPimCoreCount() { return coresCount.getNumOccurrences() != 0; }
void verifyExplicitPimCoreCount() {
if (!hasExplicitPimCoreCount())
llvm::report_fatal_error("PIM compilation requires an explicit --core-count=<positive integer>");
if (coresCount.getValue() <= 0)
llvm::report_fatal_error("PIM compilation requires --core-count to be a positive integer");
}
} // namespace onnx_mlir
+9
View File
@@ -20,8 +20,14 @@ typedef enum {
EmitPimCodegen = 3
} PimEmissionTargetType;
typedef enum {
MergeSchedulerPeft = 0,
MergeSchedulerDcp = 1,
} PimMergeSchedulerType;
extern llvm::cl::OptionCategory OnnxMlirOptions;
extern llvm::cl::opt<PimEmissionTargetType> pimEmissionTarget;
extern llvm::cl::opt<PimMergeSchedulerType> pimMergeScheduler;
extern llvm::cl::opt<bool> pimOnlyCodegen;
extern llvm::cl::opt<bool> useExperimentalConvImpl;
@@ -32,6 +38,9 @@ extern llvm::cl::opt<size_t> crossbarCountInCore;
extern llvm::cl::opt<long> coresCount;
extern llvm::cl::opt<size_t> dcpCriticalWindowSize;
bool hasExplicitPimCoreCount();
void verifyExplicitPimCoreCount();
// This option, by default set to false, will ignore an error when resolving a
// specific tiles of the operands of a concat. This specific case is when the
// wanted tile is generated by two separate operands of the concat. If this is
+1
View File
@@ -17,6 +17,7 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
PassManager& pm,
EmissionTargetType& emissionTarget,
std::string outputNameNoExt) {
verifyExplicitPimCoreCount();
if (pimOnlyCodegen) {
// Skip all the lowering passes and directly generate code for PIM.
+64 -15
View File
@@ -33,7 +33,7 @@ struct DenseWeightView {
};
FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value weight) {
SmallVector<memref::SubViewOp> subviews;
SmallVector<Operation*> viewOps;
mlir::Value current = weight;
memref::GetGlobalOp getGlobalOp;
@@ -46,7 +46,7 @@ FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value
if (auto subview = dyn_cast<memref::SubViewOp>(defOp)) {
if (!hasAllStaticSubviewParts(subview))
return failure();
subviews.push_back(subview);
viewOps.push_back(subview);
current = subview.getSource();
continue;
}
@@ -54,6 +54,24 @@ FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value
current = cast.getSource();
continue;
}
if (auto collapse = dyn_cast<memref::CollapseShapeOp>(defOp)) {
auto srcType = dyn_cast<MemRefType>(collapse.getSrc().getType());
auto resultType = dyn_cast<MemRefType>(collapse.getResult().getType());
if (!srcType || !resultType || !srcType.hasStaticShape() || !resultType.hasStaticShape())
return failure();
viewOps.push_back(collapse);
current = collapse.getSrc();
continue;
}
if (auto expand = dyn_cast<memref::ExpandShapeOp>(defOp)) {
auto srcType = dyn_cast<MemRefType>(expand.getSrc().getType());
auto resultType = dyn_cast<MemRefType>(expand.getResult().getType());
if (!srcType || !resultType || !srcType.hasStaticShape() || !resultType.hasStaticShape())
return failure();
viewOps.push_back(expand);
current = expand.getSrc();
continue;
}
return failure();
}
@@ -70,16 +88,39 @@ FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value
view.shape.assign(denseAttr.getType().getShape().begin(), denseAttr.getType().getShape().end());
view.strides = computeRowMajorStrides(view.shape);
for (memref::SubViewOp subview : llvm::reverse(subviews)) {
SmallVector<int64_t> nextStrides;
nextStrides.reserve(subview.getStaticStrides().size());
for (auto [offset, stride, sourceStride] :
llvm::zip_equal(subview.getStaticOffsets(), subview.getStaticStrides(), view.strides)) {
view.offset += offset * sourceStride;
nextStrides.push_back(stride * sourceStride);
for (Operation* viewOp : llvm::reverse(viewOps)) {
if (auto subview = dyn_cast<memref::SubViewOp>(viewOp)) {
SmallVector<int64_t> nextStrides;
nextStrides.reserve(subview.getStaticStrides().size());
for (auto [offset, stride, sourceStride] :
llvm::zip_equal(subview.getStaticOffsets(), subview.getStaticStrides(), view.strides)) {
view.offset += offset * sourceStride;
nextStrides.push_back(stride * sourceStride);
}
view.shape.assign(subview.getStaticSizes().begin(), subview.getStaticSizes().end());
view.strides = std::move(nextStrides);
continue;
}
// Collapse/expand are accepted only as contiguous static reshapes of a
// dense global view, so a row-major stride recomputation preserves layout.
if (auto collapse = dyn_cast<memref::CollapseShapeOp>(viewOp)) {
if (view.strides != computeRowMajorStrides(view.shape))
return failure();
auto resultType = cast<MemRefType>(collapse.getResult().getType());
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
view.strides = computeRowMajorStrides(view.shape);
continue;
}
if (auto expand = dyn_cast<memref::ExpandShapeOp>(viewOp)) {
if (view.strides != computeRowMajorStrides(view.shape))
return failure();
auto resultType = cast<MemRefType>(expand.getResult().getType());
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
view.strides = computeRowMajorStrides(view.shape);
continue;
}
view.shape.assign(subview.getStaticSizes().begin(), subview.getStaticSizes().end());
view.strides = std::move(nextStrides);
}
return view;
@@ -87,12 +128,20 @@ FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value
SmallVector<unsigned, 8> getUsedWeightIndices(Block& block) {
SmallVector<unsigned, 8> indices;
auto addIndex = [&](unsigned weightIndex) {
if (!llvm::is_contained(indices, weightIndex))
indices.push_back(weightIndex);
auto coreOp = dyn_cast<pim::PimCoreOp>(block.getParentOp());
auto addWeight = [&](mlir::Value weight) {
if (!coreOp)
return;
for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex) {
if (coreOp.getWeightArgument(weightIndex) != weight)
continue;
if (!llvm::is_contained(indices, weightIndex))
indices.push_back(weightIndex);
return;
}
};
block.walk([&](pim::PimVMMOp vmmOp) { addIndex(vmmOp.getWeightIndex()); });
block.walk([&](pim::PimVMMOp vmmOp) { addWeight(vmmOp.getWeight()); });
llvm::sort(indices);
return indices;
}
@@ -18,13 +18,17 @@ namespace detail {
inline mlir::ValueRange getBlockArgs(mlir::Block* block) { return mlir::ValueRange(block->getArguments()); }
inline mlir::ValueRange getInputBlockArgs(mlir::Block* block, size_t weightCount) {
return mlir::ValueRange(block->getArguments()).drop_front(weightCount);
}
template <typename Fn, size_t... Is>
decltype(auto) invokeWithBlockArgs(Fn&& fn, mlir::Block* block, std::index_sequence<Is...>) {
return std::forward<Fn>(fn)(block->getArgument(Is)...);
}
template <typename Fn, size_t... Is>
decltype(auto) invokeWithValues(Fn&& fn, mlir::ArrayRef<mlir::Value> values, std::index_sequence<Is...>) {
decltype(auto) invokeWithValues(Fn&& fn, mlir::ValueRange values, std::index_sequence<Is...>) {
return std::forward<Fn>(fn)(values[Is]...);
}
@@ -85,6 +89,8 @@ auto createSpatCompute(RewriterT& rewriter,
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
auto* block = new mlir::Block();
for (mlir::Value weight : weights)
block->addArgument(weight.getType(), loc);
for (mlir::Value input : inputs)
block->addArgument(input.getType(), loc);
@@ -93,14 +99,15 @@ auto createSpatCompute(RewriterT& rewriter,
using BodyResult = detail::InvokeWithBlockArgsResultT<std::decay_t<BodyFn>, std::make_index_sequence<NumInputs>>;
if constexpr (std::is_same_v<BodyResult, void>) {
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
detail::invokeWithValues(
std::forward<BodyFn>(body), detail::getInputBlockArgs(block, weights.size()), std::make_index_sequence<NumInputs> {});
rewriter.setInsertionPointAfter(computeOp);
return computeOp;
}
else {
auto bodyResult =
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
auto bodyResult = detail::invokeWithValues(
std::forward<BodyFn>(body), detail::getInputBlockArgs(block, weights.size()), std::make_index_sequence<NumInputs> {});
if (mlir::failed(bodyResult)) {
rewriter.setInsertionPointAfter(computeOp);
rewriter.eraseOp(computeOp);
@@ -123,6 +130,8 @@ auto createSpatCompute(RewriterT& rewriter,
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
auto* block = new mlir::Block();
for (mlir::Value weight : weights)
block->addArgument(weight.getType(), loc);
for (mlir::Value input : inputs)
block->addArgument(input.getType(), loc);
@@ -131,13 +140,13 @@ auto createSpatCompute(RewriterT& rewriter,
using BodyResult = detail::InvokeWithValueRangeResultT<std::decay_t<BodyFn>>;
if constexpr (std::is_same_v<BodyResult, void>) {
std::forward<BodyFn>(body)(detail::getBlockArgs(block));
std::forward<BodyFn>(body)(detail::getInputBlockArgs(block, weights.size()));
rewriter.setInsertionPointAfter(computeOp);
return computeOp;
}
else {
auto bodyResult = std::forward<BodyFn>(body)(detail::getBlockArgs(block));
auto bodyResult = std::forward<BodyFn>(body)(detail::getInputBlockArgs(block, weights.size()));
if (mlir::failed(bodyResult)) {
rewriter.setInsertionPointAfter(computeOp);
rewriter.eraseOp(computeOp);
@@ -100,18 +100,27 @@ DenseMap<HSliceId, DenseMap<CoreId, SmallVector<Value>>> tileMatrix(
return tiles;
}
tensor::SplatOp
broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatternRewriter& rewriter, Location loc) {
Value broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatternRewriter& rewriter, Location loc) {
auto oldType = cast<RankedTensorType>(scalarToBroadcast.getType());
Type elementType = oldType.getElementType();
int64_t shape[2] = {1, length};
Type type = oldType.cloneWith(ArrayRef(shape), elementType);
auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0).getResult();
SmallVector<Value> index(oldType.getRank(), zero);
auto elementValue = tensor::ExtractOp::create(rewriter, loc, scalarToBroadcast, index).getResult();
auto buildBroadcast = [&](Value input) -> Value {
auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0).getResult();
SmallVector<Value> index(oldType.getRank(), zero);
auto elementValue = tensor::ExtractOp::create(rewriter, loc, input, index).getResult();
return tensor::SplatOp::create(rewriter, loc, type, elementValue);
};
return tensor::SplatOp::create(rewriter, loc, type, elementValue);
if (isHostFoldableValue(scalarToBroadcast))
return buildBroadcast(scalarToBroadcast);
auto broadcastCompute =
createSpatCompute<1>(rewriter, loc, TypeRange {type}, {}, ValueRange {scalarToBroadcast}, [&](Value input) {
spatial::SpatYieldOp::create(rewriter, loc, buildBroadcast(input));
});
return broadcastCompute.getResult(0);
}
} // namespace onnx_mlir
@@ -136,9 +136,9 @@ tileMatrix(mlir::Value& matrixToTile,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location& loc);
mlir::tensor::SplatOp broadcastToVector(mlir::Value scalarToBroadcast,
int64_t length,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location loc);
mlir::Value broadcastToVector(mlir::Value scalarToBroadcast,
int64_t length,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location loc);
} // namespace onnx_mlir
@@ -1,8 +1,12 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
@@ -18,6 +22,11 @@ static bool hasStaticUnitStrides(tensor::ExtractSliceOp extractSliceOp) {
return llvm::all_of(extractSliceOp.getStaticStrides(), [](int64_t stride) { return stride == 1; });
}
static bool hasConstantIndices(tensor::ExtractOp extractOp) {
return llvm::all_of(extractOp.getIndices(),
[](Value index) { return isa_and_nonnull<arith::ConstantIndexOp>(index.getDefiningOp()); });
}
static bool isStaticTensorResult(Operation* op) {
return llvm::all_of(op->getResultTypes(), [](Type type) {
auto shapedType = dyn_cast<ShapedType>(type);
@@ -25,6 +34,167 @@ static bool isStaticTensorResult(Operation* op) {
});
}
static SmallVector<int64_t> computeRowMajorStrides(ArrayRef<int64_t> shape) {
SmallVector<int64_t> strides(shape.size(), 1);
for (int64_t dim = static_cast<int64_t>(shape.size()) - 2; dim >= 0; --dim)
strides[dim] = strides[dim + 1] * shape[dim + 1];
return strides;
}
static FailureOr<DenseElementsAttr> transposeDenseElements(DenseElementsAttr denseAttr, ArrayRef<int64_t> perms) {
auto tensorType = dyn_cast<RankedTensorType>(denseAttr.getType());
if (!tensorType)
return failure();
int64_t rank = tensorType.getRank();
if (static_cast<int64_t>(perms.size()) != rank)
return failure();
llvm::SmallBitVector seen(rank);
SmallVector<int64_t> transposedShape;
transposedShape.reserve(rank);
for (int64_t perm : perms) {
if (perm < 0 || perm >= rank || seen.test(perm))
return failure();
seen.set(perm);
transposedShape.push_back(tensorType.getShape()[perm]);
}
auto transposedType = RankedTensorType::get(transposedShape, tensorType.getElementType(), tensorType.getEncoding());
if (denseAttr.isSplat())
return DenseElementsAttr::get(transposedType, denseAttr.getSplatValue<Attribute>());
SmallVector<Attribute> originalValues(denseAttr.getValues<Attribute>());
SmallVector<Attribute> transposedValues(originalValues.size());
SmallVector<int64_t> originalStrides = computeRowMajorStrides(tensorType.getShape());
SmallVector<int64_t> transposedStrides = computeRowMajorStrides(transposedShape);
SmallVector<int64_t> originalIndices(rank);
for (auto [linearIndex, value] : llvm::enumerate(originalValues)) {
int64_t remaining = static_cast<int64_t>(linearIndex);
for (int64_t dim = 0; dim < rank; ++dim) {
originalIndices[dim] = remaining / originalStrides[dim];
remaining %= originalStrides[dim];
}
int64_t transposedLinearIndex = 0;
for (int64_t dim = 0; dim < rank; ++dim)
transposedLinearIndex += originalIndices[perms[dim]] * transposedStrides[dim];
transposedValues[transposedLinearIndex] = value;
}
return DenseElementsAttr::get(transposedType, transposedValues);
}
static FailureOr<DenseElementsAttr> reshapeDenseElements(DenseElementsAttr denseAttr, RankedTensorType resultType) {
auto sourceType = dyn_cast<RankedTensorType>(denseAttr.getType());
if (!sourceType || !resultType || sourceType.getNumElements() != resultType.getNumElements())
return failure();
if (denseAttr.isSplat())
return DenseElementsAttr::get(resultType, denseAttr.getSplatValue<Attribute>());
SmallVector<Attribute> values(denseAttr.getValues<Attribute>());
return DenseElementsAttr::get(resultType, values);
}
static FailureOr<DenseElementsAttr> extractSliceDenseElements(DenseElementsAttr denseAttr,
tensor::ExtractSliceOp extractSliceOp) {
auto sourceType = dyn_cast<RankedTensorType>(denseAttr.getType());
auto resultType = dyn_cast<RankedTensorType>(extractSliceOp.getType());
if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape())
return failure();
ArrayRef<int64_t> offsets = extractSliceOp.getStaticOffsets();
ArrayRef<int64_t> sizes = extractSliceOp.getStaticSizes();
ArrayRef<int64_t> strides = extractSliceOp.getStaticStrides();
if (llvm::any_of(offsets, [](int64_t value) { return ShapedType::isDynamic(value); })
|| llvm::any_of(sizes, [](int64_t value) { return ShapedType::isDynamic(value); })
|| llvm::any_of(strides, [](int64_t stride) { return ShapedType::isDynamic(stride) || stride != 1; }))
return failure();
if (denseAttr.isSplat())
return DenseElementsAttr::get(resultType, denseAttr.getSplatValue<Attribute>());
SmallVector<Attribute> sourceValues(denseAttr.getValues<Attribute>());
SmallVector<int64_t> sourceStrides = computeRowMajorStrides(sourceType.getShape());
SmallVector<int64_t> resultStrides = computeRowMajorStrides(resultType.getShape());
SmallVector<Attribute> resultValues;
resultValues.reserve(resultType.getNumElements());
for (int64_t linearIndex = 0; linearIndex < resultType.getNumElements(); ++linearIndex) {
int64_t remaining = linearIndex;
int64_t sourceLinearIndex = 0;
for (int64_t dim = 0; dim < resultType.getRank(); ++dim) {
const int64_t resultIndex = resultStrides.empty() ? 0 : remaining / resultStrides[dim];
remaining = resultStrides.empty() ? 0 : remaining % resultStrides[dim];
sourceLinearIndex += (offsets[dim] + resultIndex) * sourceStrides[dim];
}
resultValues.push_back(sourceValues[sourceLinearIndex]);
}
return DenseElementsAttr::get(resultType, resultValues);
}
static DenseElementsAttr getDirectDenseConstantAttr(Value value) {
if (auto constantOp = value.getDefiningOp<arith::ConstantOp>())
return dyn_cast<DenseElementsAttr>(constantOp.getValue());
if (auto constantOp = value.getDefiningOp<ONNXConstantOp>())
return dyn_cast_or_null<DenseElementsAttr>(constantOp.getValueAttr());
return nullptr;
}
static DenseElementsAttr getHostFoldableDenseElementsAttrImpl(Value value, llvm::SmallPtrSetImpl<Operation*>& visited) {
auto* definingOp = value.getDefiningOp();
if (!definingOp || !visited.insert(definingOp).second)
return nullptr;
// Rebuild dense attributes through view-only host-foldable chains so later
// lowering stages can still recognize grouped/sliced constants.
if (auto denseAttr = getDirectDenseConstantAttr(value))
return denseAttr;
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(definingOp)) {
auto inputAttr = getHostFoldableDenseElementsAttrImpl(transposeOp.getData(), visited);
if (!inputAttr)
return nullptr;
SmallVector<int64_t> perm;
perm.reserve(transposeOp.getPermAttr().size());
for (IntegerAttr attr : transposeOp.getPermAttr().getAsRange<IntegerAttr>())
perm.push_back(attr.getInt());
auto transposedAttr = transposeDenseElements(inputAttr, perm);
return succeeded(transposedAttr) ? *transposedAttr : nullptr;
}
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(definingOp)) {
auto inputAttr = getHostFoldableDenseElementsAttrImpl(collapseShapeOp.getSrc(), visited);
if (!inputAttr)
return nullptr;
auto reshapedAttr = reshapeDenseElements(inputAttr, cast<RankedTensorType>(collapseShapeOp.getType()));
return succeeded(reshapedAttr) ? *reshapedAttr : nullptr;
}
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(definingOp)) {
auto inputAttr = getHostFoldableDenseElementsAttrImpl(expandShapeOp.getSrc(), visited);
if (!inputAttr)
return nullptr;
auto reshapedAttr = reshapeDenseElements(inputAttr, cast<RankedTensorType>(expandShapeOp.getType()));
return succeeded(reshapedAttr) ? *reshapedAttr : nullptr;
}
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) {
auto inputAttr = getHostFoldableDenseElementsAttrImpl(extractSliceOp.getSource(), visited);
if (!inputAttr)
return nullptr;
auto slicedAttr = extractSliceDenseElements(inputAttr, extractSliceOp);
return succeeded(slicedAttr) ? *slicedAttr : nullptr;
}
return nullptr;
}
static bool isHostFoldableOpImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*>& visited) {
if (!op || !visited.insert(op).second)
return false;
@@ -32,6 +202,9 @@ static bool isHostFoldableOpImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*
if (isa<arith::ConstantOp, ONNXConstantOp, ONNXNoneOp>(op))
return true;
if (auto extractOp = dyn_cast<tensor::ExtractOp>(op))
return hasConstantIndices(extractOp) && isHostFoldableValue(extractOp.getTensor());
if (!isStaticTensorResult(op))
return false;
@@ -47,6 +220,9 @@ static bool isHostFoldableOpImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op))
return hasStaticUnitStrides(extractSliceOp) && isHostFoldableValue(extractSliceOp.getSource());
if (auto splatOp = dyn_cast<tensor::SplatOp>(op))
return isHostFoldableValue(splatOp.getInput());
if (auto extractRowsOp = dyn_cast<spatial::SpatExtractRowsOp>(op))
return isHostFoldableValue(extractRowsOp.getInput());
@@ -72,4 +248,9 @@ bool isHostFoldableOp(Operation* op) {
return isHostFoldableOpImpl(op, visited);
}
DenseElementsAttr getHostFoldableDenseElementsAttr(Value value) {
llvm::SmallPtrSet<Operation*, 8> visited;
return getHostFoldableDenseElementsAttrImpl(value, visited);
}
} // namespace onnx_mlir
@@ -1,5 +1,6 @@
#pragma once
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Value.h"
@@ -9,4 +10,6 @@ bool isHostFoldableValue(mlir::Value value);
bool isHostFoldableOp(mlir::Operation* op);
mlir::DenseElementsAttr getHostFoldableDenseElementsAttr(mlir::Value value);
} // namespace onnx_mlir
@@ -2,6 +2,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostLegality.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
@@ -11,7 +12,7 @@ using namespace mlir;
namespace onnx_mlir {
LogicalResult verifyONNXToSpatialHostLegality(func::FuncOp funcOp) {
bool hasFailure = false;
pim::CappedDiagnosticReporter diagnostics;
for (Operation& op : funcOp.getFunctionBody().front()) {
if (isa<func::ReturnOp, spatial::SpatCompute, spatial::SpatComputeBatch>(&op))
@@ -19,11 +20,15 @@ LogicalResult verifyONNXToSpatialHostLegality(func::FuncOp funcOp) {
if (isHostFoldableOp(&op))
continue;
op.emitOpError("non-foldable top-level runtime op remains after ONNX-to-Spatial; lower it inside spat.compute");
hasFailure = true;
diagnostics.report(&op, [](Operation* illegalOp) {
illegalOp->emitOpError("non-foldable top-level runtime op remains after ONNX-to-Spatial; lower it inside "
"spat.compute");
});
}
return success(!hasFailure);
diagnostics.emitSuppressedSummary(funcOp, "ONNX-to-Spatial host legality failures");
return success(!diagnostics.hasFailure());
}
} // namespace onnx_mlir
@@ -5,17 +5,15 @@
#include "mlir/IR/IRMapping.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Debug.h"
#include "Common/Common.hpp"
#include "Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostLegality.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp"
@@ -46,7 +44,8 @@ static void populateEmptyFunction(func::FuncOp funcOp) {
IRRewriter rewriter(funcOp.getContext());
IRMapping mapper;
SmallVector<spatial::SpatCompute> computes(funcOp.getOps<spatial::SpatCompute>());
if (!computes.empty())
SmallVector<spatial::SpatComputeBatch> computeBatches(funcOp.getOps<spatial::SpatComputeBatch>());
if (!computes.empty() || !computeBatches.empty())
return;
auto returnOp = cast<func::ReturnOp>(funcOp.getFunctionBody().front().getTerminator());
@@ -87,17 +86,68 @@ static void populateEmptyFunction(func::FuncOp funcOp) {
returnOp.setOperand(index, computeResult);
}
static void wrapTopLevelRuntimeTransposes(func::FuncOp funcOp) {
IRRewriter rewriter(funcOp.getContext());
Block& entryBlock = funcOp.getFunctionBody().front();
for (Operation& op : llvm::make_early_inc_range(entryBlock)) {
auto transposeOp = dyn_cast<ONNXTransposeOp>(&op);
if (!transposeOp || isHostFoldableOp(transposeOp))
continue;
// Transpose stays globally legal because constant/view-only cases are
// allowed on the host. Any residual runtime transpose must be sunk into
// spat.compute before the host legality check.
auto resultType = transposeOp.getResult().getType();
rewriter.setInsertionPoint(transposeOp);
auto computeOp = createSpatCompute<1>(
rewriter, transposeOp.getLoc(), TypeRange {resultType}, {}, ValueRange {transposeOp.getData()}, [&](Value input) {
Value transposed =
ONNXTransposeOp::create(rewriter, transposeOp.getLoc(), resultType, input, transposeOp.getPermAttr());
spatial::SpatYieldOp::create(rewriter, transposeOp.getLoc(), transposed);
});
rewriter.replaceOp(transposeOp, computeOp.getResult(0));
}
}
void ONNXToSpatialPass::runOnOperation() {
ModuleOp moduleOp = getOperation();
MLIRContext* ctx = &getContext();
ConversionTarget preTarget(*ctx);
preTarget.addLegalDialect<spatial::SpatialDialect,
ONNXDialect,
tensor::TensorDialect,
arith::ArithDialect,
scf::SCFDialect>();
preTarget.addIllegalOp<ONNXConstantOp, ONNXFlattenOp>();
RewritePatternSet prePatterns(ctx);
populatePrePatterns(prePatterns, ctx);
if (failed(applyPatternsGreedily(moduleOp, std::move(prePatterns))))
moduleOp.emitWarning("failed to apply ONNX-to-Spatial pre-patterns; continuing");
if (failed(applyPartialConversion(moduleOp, preTarget, std::move(prePatterns)))) {
moduleOp.emitError("failed to apply ONNX-to-Spatial pre-rewrites");
signalPassFailure();
return;
}
auto entryFunc = getPimEntryFunc(moduleOp);
if (failed(entryFunc)) {
moduleOp.emitError("failed to locate the PIM entry function during ONNX-to-Spatial lowering");
signalPassFailure();
return;
}
RewritePatternSet matmulPatterns(ctx);
populateMatMulRewritePatterns(matmulPatterns, ctx);
walkAndApplyPatterns(moduleOp, std::move(matmulPatterns));
bool hasUnloweredMatMul = false;
moduleOp.walk([&](ONNXMatMulOp matmulOp) {
hasUnloweredMatMul = true;
matmulOp.emitOpError("remaining ONNX MatMul before the required ONNX-to-Spatial conversion");
});
if (hasUnloweredMatMul) {
moduleOp.emitError("failed to lower all ONNX MatMul ops before ONNX-to-Spatial conversion");
signalPassFailure();
return;
}
@@ -130,30 +180,17 @@ void ONNXToSpatialPass::runOnOperation() {
RewritePatternSet conversionPatterns(ctx);
populateConversionPatterns(conversionPatterns, ctx);
if (failed(applyPartialConversion(moduleOp, target, std::move(conversionPatterns)))) {
moduleOp.emitError("failed to convert required ONNX ops to Spatial ops");
signalPassFailure();
return;
}
RewritePatternSet earlyPostPatterns(ctx);
populateEarlyPostPatterns(earlyPostPatterns, ctx);
if (failed(applyPatternsGreedily(*entryFunc, std::move(earlyPostPatterns)))) {
signalPassFailure();
return;
}
if (coresCount != -1) {
int computeOpsCount = 0;
for (Operation& op : entryFunc->getFunctionBody().front().getOperations())
if (isa<spatial::SpatCompute>(op))
computeOpsCount++;
if (computeOpsCount > coresCount) {
entryFunc->emitError() << "number of compute ops (" << computeOpsCount << ") exceeds the core count ("
<< coresCount << ")";
signalPassFailure();
return;
}
}
ConversionTarget earlyPostTarget(*ctx);
earlyPostTarget.addLegalDialect<spatial::SpatialDialect,
ONNXDialect,
tensor::TensorDialect,
arith::ArithDialect,
scf::SCFDialect>();
PassManager cleanupPM(ctx);
cleanupPM.addPass(createCanonicalizerPass());
@@ -162,14 +199,29 @@ void ONNXToSpatialPass::runOnOperation() {
annotateWeightsConstants(*entryFunc);
ConversionTarget postTarget(*ctx);
postTarget.addLegalDialect<spatial::SpatialDialect,
ONNXDialect,
tensor::TensorDialect,
arith::ArithDialect,
scf::SCFDialect>();
postTarget.addDynamicallyLegalOp<spatial::SpatCompute>(
[](spatial::SpatCompute computeOp) { return !requiresPostRewrite(computeOp); });
postTarget.addDynamicallyLegalOp<spatial::SpatComputeBatch>(
[](spatial::SpatComputeBatch computeOp) { return !requiresPostRewrite(computeOp); });
RewritePatternSet postPatterns(ctx);
populatePostPatterns(postPatterns, ctx);
if (failed(applyPatternsGreedily(*entryFunc, std::move(postPatterns)))) {
if (failed(applyPartialConversion(*entryFunc, postTarget, std::move(postPatterns)))) {
moduleOp.emitError("failed to normalize weight-like Spatial compute operands before Spatial-to-PIM lowering");
signalPassFailure();
return;
}
wrapTopLevelRuntimeTransposes(*entryFunc);
if (failed(verifyONNXToSpatialHostLegality(*entryFunc))) {
moduleOp.emitError("ONNX-to-Spatial host legality verification failed");
signalPassFailure();
return;
}
@@ -11,6 +11,7 @@
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -27,16 +28,6 @@ struct ConvToGemm : OpConversionPattern<ONNXConvOp> {
ConversionPatternRewriter& rewriter) const override;
};
static DenseElementsAttr getDenseConstantAttr(Value value) {
if (auto constantOp = value.getDefiningOp<arith::ConstantOp>())
return dyn_cast<DenseElementsAttr>(constantOp.getValue());
if (auto constantOp = value.getDefiningOp<ONNXConstantOp>())
return dyn_cast_or_null<DenseElementsAttr>(constantOp.getValueAttr());
return nullptr;
}
static int64_t getI64FromArrayAttr(ArrayAttr arr, size_t idx) { return cast<IntegerAttr>(arr[idx]).getInt(); }
static Value expandBiasIfNeeded(Value bias, ConversionPatternRewriter& rewriter, Location loc) {
@@ -355,49 +346,22 @@ static Value createCollectedConvOutput(ValueRange gemmRows,
return collectComputeOp.getResult(0);
}
} // namespace
LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
ONNXConvOpAdaptor convOpAdaptor,
ConversionPatternRewriter& rewriter) const {
Location loc = convOp.getLoc();
Value x = convOpAdaptor.getX();
Value w = convOpAdaptor.getW();
Value b = convOpAdaptor.getB();
auto xType = cast<RankedTensorType>(x.getType());
auto wType = cast<RankedTensorType>(w.getType());
auto outType = cast<RankedTensorType>(convOp.getY().getType());
if (!xType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv input");
return failure();
}
if (!wType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv weight");
return failure();
}
if (!outType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv result");
return failure();
}
if (xType.getRank() != 4) {
pim::emitUnsupportedRankDiagnostic(convOp, "conv input", xType.getRank(), {4});
return failure();
}
if (wType.getRank() != 4) {
pim::emitUnsupportedRankDiagnostic(convOp, "conv weight", wType.getRank(), {4});
return failure();
}
if (outType.getRank() != 4) {
pim::emitUnsupportedRankDiagnostic(convOp, "conv result", outType.getRank(), {4});
return failure();
}
if (convOp.getGroup() != 1) {
convOp.emitOpError("only group=1 convolution is supported for Spatial lowering");
return failure();
}
static Value lowerSingleConvGroup(Value x,
Value w,
Value b,
RankedTensorType xType,
RankedTensorType wType,
RankedTensorType outType,
int64_t padHeightBegin,
int64_t padHeightEnd,
int64_t padWidthBegin,
int64_t padWidthEnd,
int64_t strideHeight,
int64_t strideWidth,
int64_t dilationHeight,
int64_t dilationWidth,
ConversionPatternRewriter& rewriter,
Location loc) {
const int64_t batchSize = xType.getDimSize(0);
const int64_t numChannelsIn = xType.getDimSize(1);
const int64_t xHeight = xType.getDimSize(2);
@@ -408,71 +372,6 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
const int64_t outHeight = outType.getDimSize(2);
const int64_t outWidth = outType.getDimSize(3);
// Read optional conv attributes (ONNX defaults: stride=1, dilation=1, pad=0)
const auto stridesAttr = convOp.getStrides();
const auto dilationsAttr = convOp.getDilations();
const auto padsAttr = convOp.getPads();
if (stridesAttr && stridesAttr->size() != 2) {
convOp.emitOpError("requires exactly two stride values for Spatial lowering");
return failure();
}
if (dilationsAttr && dilationsAttr->size() != 2) {
convOp.emitOpError("requires exactly two dilation values for Spatial lowering");
return failure();
}
if (padsAttr && padsAttr->size() != 4) {
convOp.emitOpError("requires exactly four pad values for 2D Spatial lowering");
return failure();
}
const int64_t strideHeight = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 0) : 1;
const int64_t strideWidth = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 1) : 1;
const int64_t dilationHeight = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 0) : 1;
const int64_t dilationWidth = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 1) : 1;
int64_t padHeightBegin = 0;
int64_t padHeightEnd = 0;
int64_t padWidthBegin = 0;
int64_t padWidthEnd = 0;
if (padsAttr) {
padHeightBegin = getI64FromArrayAttr(*padsAttr, 0);
padWidthBegin = getI64FromArrayAttr(*padsAttr, 1);
padHeightEnd = getI64FromArrayAttr(*padsAttr, 2);
padWidthEnd = getI64FromArrayAttr(*padsAttr, 3);
}
else {
// Compute padding from auto_pad attribute
const auto autoPad = convOp.getAutoPad();
if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
const int64_t effectiveKernelH = (wHeight - 1) * dilationHeight + 1;
const int64_t effectiveKernelW = (wWidth - 1) * dilationWidth + 1;
const int64_t totalPadH =
std::max(static_cast<int64_t>(0), (outHeight - 1) * strideHeight + effectiveKernelH - xHeight);
const int64_t totalPadW =
std::max(static_cast<int64_t>(0), (outWidth - 1) * strideWidth + effectiveKernelW - xWidth);
if (autoPad == "SAME_UPPER") {
padHeightBegin = totalPadH / 2;
padHeightEnd = totalPadH - padHeightBegin;
padWidthBegin = totalPadW / 2;
padWidthEnd = totalPadW - padWidthBegin;
}
else { // SAME_LOWER
padHeightEnd = totalPadH / 2;
padHeightBegin = totalPadH - padHeightEnd;
padWidthEnd = totalPadW / 2;
padWidthBegin = totalPadW - padWidthEnd;
}
}
else if (autoPad != "NOTSET" && autoPad != "VALID") {
convOp.emitOpError() << "unsupported auto_pad value `" << autoPad << "` for Spatial lowering";
return failure();
}
// "NOTSET" or "VALID" -> all pads stay 0
}
// im2col layout (flipped with respect to the standard, so filters sit in B = crossbar):
// A (im2col): [numPatches, patchSize] -- one row per output spatial position
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
@@ -492,7 +391,7 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
const int64_t xbarSize = static_cast<int64_t>(crossbarSize.getValue());
const int64_t wMaxDim = std::max(patchSize, numChannelsOut);
const int64_t maxParallelPixels = std::max<int64_t>(1, xbarSize / wMaxDim);
auto wDenseAttr = getDenseConstantAttr(w);
auto wDenseAttr = getHostFoldableDenseElementsAttr(w);
// Prepare weight matrix W for crossbar storage:
// W: [numChannelsOut, numChannelsIn, wHeight, wWidth] -> [numChannelsOut, patchSize] -> [patchSize, numChannelsOut]
@@ -513,7 +412,7 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
DenseElementsAttr biasDenseAttr;
if (hasB) {
gemmBias = b;
biasDenseAttr = getDenseConstantAttr(b);
biasDenseAttr = getHostFoldableDenseElementsAttr(b);
biasMatrix = expandBiasIfNeeded(b, rewriter, loc);
}
const bool canPackWeightsAsConstants = static_cast<bool>(wDenseAttr);
@@ -589,17 +488,246 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
rewriter.getBoolAttr(false))
.getY();
rewriter.replaceOp(convOp,
createCollectedConvOutput(ValueRange {gemmRows},
convOp.getType(),
gemmOutType,
nhwcType,
outType,
numPatches,
numChannelsOut,
effectiveMaxParallelPixels,
rewriter,
loc));
return createCollectedConvOutput(ValueRange {gemmRows},
outType,
gemmOutType,
nhwcType,
outType,
numPatches,
numChannelsOut,
effectiveMaxParallelPixels,
rewriter,
loc);
}
} // namespace
LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
ONNXConvOpAdaptor convOpAdaptor,
ConversionPatternRewriter& rewriter) const {
Location loc = convOp.getLoc();
Value x = convOpAdaptor.getX();
Value w = convOpAdaptor.getW();
Value b = convOpAdaptor.getB();
auto xType = cast<RankedTensorType>(x.getType());
auto wType = cast<RankedTensorType>(w.getType());
auto outType = cast<RankedTensorType>(convOp.getY().getType());
if (!xType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv input");
return failure();
}
if (!wType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv weight");
return failure();
}
if (!outType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv result");
return failure();
}
if (xType.getRank() != 4) {
pim::emitUnsupportedRankDiagnostic(convOp, "conv input", xType.getRank(), {4});
return failure();
}
if (wType.getRank() != 4) {
pim::emitUnsupportedRankDiagnostic(convOp, "conv weight", wType.getRank(), {4});
return failure();
}
if (outType.getRank() != 4) {
pim::emitUnsupportedRankDiagnostic(convOp, "conv result", outType.getRank(), {4});
return failure();
}
if (convOp.getGroup() < 1) {
convOp.emitOpError("requires group >= 1 for Spatial lowering");
return failure();
}
const int64_t batchSize = xType.getDimSize(0);
const int64_t numChannelsIn = xType.getDimSize(1);
const int64_t xHeight = xType.getDimSize(2);
const int64_t xWidth = xType.getDimSize(3);
const int64_t numChannelsOut = wType.getDimSize(0);
const int64_t wHeight = wType.getDimSize(2);
const int64_t wWidth = wType.getDimSize(3);
const int64_t outHeight = outType.getDimSize(2);
const int64_t outWidth = outType.getDimSize(3);
const int64_t group = convOp.getGroup();
const bool hasB = !isa<ONNXNoneOp>(b.getDefiningOp());
if (numChannelsIn % group != 0) {
convOp.emitOpError() << "requires input channels " << numChannelsIn << " to be divisible by group " << group
<< " for Spatial lowering";
return failure();
}
if (numChannelsOut % group != 0) {
convOp.emitOpError() << "requires output channels " << numChannelsOut << " to be divisible by group " << group
<< " for Spatial lowering";
return failure();
}
const int64_t numChannelsInPerGroup = numChannelsIn / group;
const int64_t numChannelsOutPerGroup = numChannelsOut / group;
if (wType.getDimSize(1) != numChannelsInPerGroup) {
convOp.emitOpError() << "requires grouped conv weight input channels " << wType.getDimSize(1)
<< " to match input channels per group " << numChannelsInPerGroup << " for Spatial lowering";
return failure();
}
if (wType.getDimSize(0) != numChannelsOut) {
convOp.emitOpError() << "requires weight output channels " << wType.getDimSize(0) << " to match result channels "
<< numChannelsOut << " for Spatial lowering";
return failure();
}
// Read optional conv attributes (ONNX defaults: stride=1, dilation=1, pad=0)
const auto stridesAttr = convOp.getStrides();
const auto dilationsAttr = convOp.getDilations();
const auto padsAttr = convOp.getPads();
if (stridesAttr && stridesAttr->size() != 2) {
convOp.emitOpError("requires exactly two stride values for Spatial lowering");
return failure();
}
if (dilationsAttr && dilationsAttr->size() != 2) {
convOp.emitOpError("requires exactly two dilation values for Spatial lowering");
return failure();
}
if (padsAttr && padsAttr->size() != 4) {
convOp.emitOpError("requires exactly four pad values for 2D Spatial lowering");
return failure();
}
const int64_t strideHeight = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 0) : 1;
const int64_t strideWidth = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 1) : 1;
const int64_t dilationHeight = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 0) : 1;
const int64_t dilationWidth = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 1) : 1;
int64_t padHeightBegin = 0;
int64_t padHeightEnd = 0;
int64_t padWidthBegin = 0;
int64_t padWidthEnd = 0;
if (padsAttr) {
padHeightBegin = getI64FromArrayAttr(*padsAttr, 0);
padWidthBegin = getI64FromArrayAttr(*padsAttr, 1);
padHeightEnd = getI64FromArrayAttr(*padsAttr, 2);
padWidthEnd = getI64FromArrayAttr(*padsAttr, 3);
}
else {
// Compute padding from auto_pad attribute
const auto autoPad = convOp.getAutoPad();
if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
const int64_t effectiveKernelH = (wHeight - 1) * dilationHeight + 1;
const int64_t effectiveKernelW = (wWidth - 1) * dilationWidth + 1;
const int64_t totalPadH =
std::max(static_cast<int64_t>(0), (outHeight - 1) * strideHeight + effectiveKernelH - xHeight);
const int64_t totalPadW =
std::max(static_cast<int64_t>(0), (outWidth - 1) * strideWidth + effectiveKernelW - xWidth);
if (autoPad == "SAME_UPPER") {
padHeightBegin = totalPadH / 2;
padHeightEnd = totalPadH - padHeightBegin;
padWidthBegin = totalPadW / 2;
padWidthEnd = totalPadW - padWidthBegin;
}
else { // SAME_LOWER
padHeightEnd = totalPadH / 2;
padHeightBegin = totalPadH - padHeightEnd;
padWidthEnd = totalPadW / 2;
padWidthBegin = totalPadW - padWidthEnd;
}
}
else if (autoPad != "NOTSET" && autoPad != "VALID") {
convOp.emitOpError() << "unsupported auto_pad value `" << autoPad << "` for Spatial lowering";
return failure();
}
// "NOTSET" or "VALID" -> all pads stay 0
}
if (group == 1) {
rewriter.replaceOp(convOp,
lowerSingleConvGroup(x,
w,
b,
xType,
wType,
outType,
padHeightBegin,
padHeightEnd,
padWidthBegin,
padWidthEnd,
strideHeight,
strideWidth,
dilationHeight,
dilationWidth,
rewriter,
loc));
return success();
}
SmallVector<Value> xSlices = sliceTensor(x, /*axis=*/1, numChannelsInPerGroup, rewriter, loc);
SmallVector<Value> wSlices = sliceTensor(w, /*axis=*/0, numChannelsOutPerGroup, rewriter, loc);
SmallVector<Value> bSlices;
if (hasB) {
auto biasType = cast<RankedTensorType>(b.getType());
int64_t biasAxis = -1;
if (biasType.getRank() == 1)
biasAxis = 0;
else if (biasType.getRank() == 2)
biasAxis = biasType.getDimSize(0) != 1 ? 0 : 1;
else {
convOp.emitOpError() << "requires rank-1 or rank-2 bias for grouped convolution Spatial lowering, but got rank "
<< biasType.getRank();
return failure();
}
bSlices = sliceTensor(b, biasAxis, numChannelsOutPerGroup, rewriter, loc);
}
if (xSlices.size() != static_cast<size_t>(group) || wSlices.size() != static_cast<size_t>(group)
|| (hasB && bSlices.size() != static_cast<size_t>(group))) {
convOp.emitOpError("failed to partition grouped convolution operands for Spatial lowering");
return failure();
}
SmallVector<Value> groupResults;
groupResults.reserve(group);
auto groupOutType =
RankedTensorType::get({batchSize, numChannelsOutPerGroup, outHeight, outWidth}, outType.getElementType());
Value noBias = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
for (int64_t groupId = 0; groupId < group; groupId++) {
Value groupX = xSlices[groupId];
Value groupW = wSlices[groupId];
Value groupB = hasB ? bSlices[groupId] : noBias;
groupResults.push_back(lowerSingleConvGroup(groupX,
groupW,
groupB,
cast<RankedTensorType>(groupX.getType()),
cast<RankedTensorType>(groupW.getType()),
groupOutType,
padHeightBegin,
padHeightEnd,
padWidthBegin,
padWidthEnd,
strideHeight,
strideWidth,
dilationHeight,
dilationWidth,
rewriter,
loc));
}
Value result;
if (llvm::all_of(groupResults, isHostFoldableValue)) {
result = createSpatConcat(rewriter, loc, /*axis=*/1, groupResults);
}
else {
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {outType}, {}, groupResults, [&](ValueRange args) {
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, /*axis=*/1, args));
});
result = concatCompute.getResult(0);
}
rewriter.replaceOp(convOp, result);
return success();
}
@@ -402,24 +402,37 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
for (size_t aSliceId = 0; aSliceId < aHSlices[coreId].size(); aSliceId++)
weights.push_back(bTiles[outSliceId][coreId][aSliceId]);
auto computeOp = createSpatCompute(
rewriter, gemmLoc, currOutHSliceType, weights, aHSlices[coreId], [&](ValueRange aHSlicesArgs) -> LogicalResult {
SmallVector<Value> vmmOutputs;
vmmOutputs.reserve(aHSlicesArgs.size());
for (auto [aHSliceId, computeArg] : llvm::enumerate(aHSlicesArgs))
vmmOutputs.push_back(
spatial::SpatVMMOp::create(rewriter, gemmLoc, currOutHSliceType, aHSliceId, computeArg));
if (vmmOutputs.empty()) {
gemmOp.emitOpError("requires at least one non-empty slice when lowering tiled Gemm to Spatial VMMs");
return failure();
}
auto computeOp =
spatial::SpatCompute::create(rewriter, gemmLoc, TypeRange {currOutHSliceType}, weights, aHSlices[coreId]);
SmallVector<Type> blockArgTypes;
SmallVector<Location> blockArgLocs;
blockArgTypes.reserve(weights.size() + aHSlices[coreId].size());
blockArgLocs.reserve(weights.size() + aHSlices[coreId].size());
for (Value weight : weights) {
blockArgTypes.push_back(weight.getType());
blockArgLocs.push_back(gemmLoc);
}
for (Value input : aHSlices[coreId]) {
blockArgTypes.push_back(input.getType());
blockArgLocs.push_back(gemmLoc);
}
Block* body =
rewriter.createBlock(&computeOp.getBody(), computeOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
rewriter.setInsertionPointToEnd(body);
Value partialVmmSum = sumTensors(vmmOutputs, rewriter);
spatial::SpatYieldOp::create(rewriter, gemmLoc, partialVmmSum);
return success();
});
if (failed(computeOp))
SmallVector<Value> vmmOutputs;
vmmOutputs.reserve(aHSlices[coreId].size());
for (auto aHSliceId : llvm::seq<size_t>(0, aHSlices[coreId].size()))
vmmOutputs.push_back(spatial::SpatVMMOp::create(
rewriter, gemmLoc, currOutHSliceType, computeOp.getWeightArgument(aHSliceId), computeOp.getInputArgument(aHSliceId)));
if (vmmOutputs.empty()) {
gemmOp.emitOpError("requires at least one non-empty slice when lowering tiled Gemm to Spatial VMMs");
return failure();
}
Value partialVmmSum = sumTensors(vmmOutputs, rewriter);
spatial::SpatYieldOp::create(rewriter, gemmLoc, partialVmmSum);
rewriter.setInsertionPointAfter(computeOp);
partialResults.push_back(computeOp->getResult(0));
}
@@ -502,9 +515,6 @@ LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp,
}
(void) bType;
if (!isHostFoldableValue(b))
return failure();
Value sharedBias;
if (hasC) {
auto scaledC = materializeScaledConstantTensor(c, gemmOpAdaptor.getBeta().convertToFloat(), rewriter, loc);
@@ -533,37 +543,47 @@ LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp,
sharedBias = c;
}
SmallVector<Value> aSlices = materializeBatchRowSlices(a, aType, rewriter, loc);
auto aSliceType = cast<RankedTensorType>(aSlices.front().getType());
auto outRowType = RankedTensorType::get({1, outType.getDimSize(1)}, outType.getElementType());
SmallVector<Type> resultTypes(static_cast<size_t>(numOutRows), outRowType);
SmallVector<Value> weights(static_cast<size_t>(numOutRows), b);
auto aRowType = RankedTensorType::get({1, aType.getDimSize(1)}, aType.getElementType());
auto batchOp = spatial::SpatComputeBatch::create(rewriter,
loc,
TypeRange(resultTypes),
TypeRange {outType},
rewriter.getI32IntegerAttr(static_cast<int32_t>(numOutRows)),
ValueRange(weights),
ValueRange(aSlices));
ValueRange {b},
ValueRange {a});
Block* body = rewriter.createBlock(
&batchOp.getBody(), batchOp.getBody().end(), TypeRange {aSliceType}, SmallVector<Location>(1, loc));
SmallVector<Type> blockArgTypes {rewriter.getIndexType(), bType, aType, outType};
SmallVector<Location> blockArgLocs(4, loc);
Block* body =
rewriter.createBlock(&batchOp.getBody(), batchOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
rewriter.setInsertionPointToEnd(body);
Value vmmResult = spatial::SpatVMMOp::create(rewriter, loc, outRowType, 0, body->getArgument(0)).getResult();
Value lane = batchOp.getLaneArgument();
Value weight = batchOp.getWeightArgument(0);
Value packedInput = batchOp.getInputArgument(0);
Value packedOutput = batchOp.getOutputArgument(0);
SmallVector<OpFoldResult> inputOffsets {lane, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> inputSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(aType.getDimSize(1))};
SmallVector<OpFoldResult> unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
Value row =
tensor::ExtractSliceOp::create(rewriter, loc, aRowType, packedInput, inputOffsets, inputSizes, unitStrides)
.getResult();
Value vmmResult = spatial::SpatVMMOp::create(rewriter, loc, outRowType, weight, row).getResult();
Value laneResult = vmmResult;
if (sharedBias)
laneResult = spatial::SpatVAddOp::create(rewriter, loc, outRowType, vmmResult, sharedBias).getResult();
spatial::SpatYieldOp::create(rewriter, loc, laneResult);
auto inParallelOp = spatial::SpatInParallelOp::create(rewriter, loc);
rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front());
SmallVector<OpFoldResult> outputOffsets {lane, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> outputSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(outType.getDimSize(1))};
tensor::ParallelInsertSliceOp::create(rewriter, loc, laneResult, packedOutput, outputOffsets, outputSizes,
unitStrides);
rewriter.setInsertionPointAfter(batchOp);
SmallVector<Value> laneResults(batchOp->result_begin(), batchOp->result_end());
auto concatComputeOp = createSpatCompute(rewriter, loc, gemmOp.getType(), {}, laneResults, [&](ValueRange args) {
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, /*axis=*/0, args));
});
rewriter.replaceOp(gemmOp, concatComputeOp);
rewriter.replaceOp(gemmOp, batchOp.getResults());
return success();
}
@@ -2,8 +2,12 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include <functional>
#include <numeric>
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
@@ -19,6 +23,79 @@ static bool haveStaticPositiveShape(ArrayRef<int64_t> shape) {
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
}
static int64_t getStaticShapeElementCount(ArrayRef<int64_t> shape) {
return std::accumulate(shape.begin(), shape.end(), int64_t {1}, std::multiplies<int64_t> {});
}
static FailureOr<SmallVector<int64_t>> inferSupportedBatchShape(ArrayRef<int64_t> lhsBatchShape,
ArrayRef<int64_t> rhsBatchShape) {
if (lhsBatchShape.empty())
return SmallVector<int64_t>(rhsBatchShape.begin(), rhsBatchShape.end());
if (rhsBatchShape.empty())
return SmallVector<int64_t>(lhsBatchShape.begin(), lhsBatchShape.end());
if (!llvm::equal(lhsBatchShape, rhsBatchShape))
return failure();
return SmallVector<int64_t>(lhsBatchShape.begin(), lhsBatchShape.end());
}
static Value collapseBatchDims(Value value,
int64_t batchSize,
int64_t rows,
int64_t cols,
PatternRewriter& rewriter,
Location loc) {
auto type = cast<RankedTensorType>(value.getType());
if (type.getRank() == 2 || type.getRank() == 3)
return value;
auto collapsedType =
RankedTensorType::get({batchSize, rows, cols}, type.getElementType(), type.getEncoding());
SmallVector<ReassociationIndices> reassociation = {
ReassociationIndices {},
ReassociationIndices {static_cast<int64_t>(type.getRank() - 2)},
ReassociationIndices {static_cast<int64_t>(type.getRank() - 1)}
};
for (int64_t dim = 0; dim < type.getRank() - 2; ++dim)
reassociation.front().push_back(dim);
auto buildCollapsed = [&](Value input) -> Value {
return tensor::CollapseShapeOp::create(rewriter, loc, collapsedType, input, reassociation);
};
if (isHostFoldableValue(value))
return buildCollapsed(value);
auto collapseCompute =
createSpatCompute<1>(rewriter, loc, TypeRange {collapsedType}, {}, ValueRange {value}, [&](Value input) {
spatial::SpatYieldOp::create(rewriter, loc, buildCollapsed(input));
});
return collapseCompute.getResult(0);
}
static Value expandBatchDims(Value value,
RankedTensorType outputType,
size_t batchRank,
PatternRewriter& rewriter,
Location loc) {
if (cast<RankedTensorType>(value.getType()) == outputType)
return value;
SmallVector<ReassociationIndices> reassociation = {
ReassociationIndices {},
ReassociationIndices {static_cast<int64_t>(batchRank)},
ReassociationIndices {static_cast<int64_t>(batchRank + 1)}
};
for (size_t dim = 0; dim < batchRank; ++dim)
reassociation.front().push_back(static_cast<int64_t>(dim));
auto expandCompute =
createSpatCompute<1>(rewriter, loc, TypeRange {outputType}, {}, ValueRange {value}, [&](Value input) {
Value expanded = tensor::ExpandShapeOp::create(rewriter, loc, outputType, input, reassociation);
spatial::SpatYieldOp::create(rewriter, loc, expanded);
});
return expandCompute.getResult(0);
}
static Value extractBatchMatrix(Value value,
int64_t batchIndex,
int64_t batchSize,
@@ -62,13 +139,29 @@ static Value extractBatchMatrix(Value value,
static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Location loc) {
auto type = cast<RankedTensorType>(value.getType());
auto shape = type.getShape();
RankedTensorType transposedType;
SmallVector<int64_t> perm;
if (type.getRank() == 2) {
auto transposedType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType());
return ONNXTransposeOp::create(rewriter, loc, transposedType, value, rewriter.getI64ArrayAttr({1, 0}));
transposedType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType());
perm = {1, 0};
}
else {
transposedType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType());
perm = {0, 2, 1};
}
auto transposedType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType());
return ONNXTransposeOp::create(rewriter, loc, transposedType, value, rewriter.getI64ArrayAttr({0, 2, 1}));
auto buildTranspose = [&](Value input) -> Value {
return ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm));
};
if (isHostFoldableValue(value))
return buildTranspose(value);
auto transposeCompute =
createSpatCompute<1>(rewriter, loc, TypeRange {transposedType}, {}, ValueRange {value}, [&](Value input) {
spatial::SpatYieldOp::create(rewriter, loc, buildTranspose(input));
});
return transposeCompute.getResult(0);
}
static Value transposeLastTwoDimsInCompute(Value value, PatternRewriter& rewriter, Location loc) {
@@ -120,24 +213,25 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
if (!lhsType || !rhsType || !outType || !lhsType.hasStaticShape() || !rhsType.hasStaticShape()
|| !outType.hasStaticShape())
return failure();
if ((lhsType.getRank() != 2 && lhsType.getRank() != 3) || (rhsType.getRank() != 2 && rhsType.getRank() != 3)
|| (outType.getRank() != 2 && outType.getRank() != 3))
if (lhsType.getRank() < 2 || rhsType.getRank() < 2 || outType.getRank() < 2)
return failure();
if (!haveStaticPositiveShape(lhsType.getShape()) || !haveStaticPositiveShape(rhsType.getShape())
|| !haveStaticPositiveShape(outType.getShape()))
return failure();
const int64_t lhsBatch = lhsType.getRank() == 3 ? lhsType.getDimSize(0) : 1;
const int64_t rhsBatch = rhsType.getRank() == 3 ? rhsType.getDimSize(0) : 1;
const int64_t batch = std::max(lhsBatch, rhsBatch);
if ((lhsBatch != 1 && lhsBatch != batch) || (rhsBatch != 1 && rhsBatch != batch))
SmallVector<int64_t> lhsBatchShape(lhsType.getShape().begin(), lhsType.getShape().end() - 2);
SmallVector<int64_t> rhsBatchShape(rhsType.getShape().begin(), rhsType.getShape().end() - 2);
auto batchShape = inferSupportedBatchShape(lhsBatchShape, rhsBatchShape);
if (failed(batchShape))
return failure();
const int64_t lhsBatch = lhsBatchShape.empty() ? 1 : getStaticShapeElementCount(lhsBatchShape);
const int64_t rhsBatch = rhsBatchShape.empty() ? 1 : getStaticShapeElementCount(rhsBatchShape);
const int64_t batch = batchShape->empty() ? 1 : getStaticShapeElementCount(*batchShape);
const int64_t m = lhsType.getRank() == 3 ? lhsType.getDimSize(1) : lhsType.getDimSize(0);
const int64_t k = lhsType.getRank() == 3 ? lhsType.getDimSize(2) : lhsType.getDimSize(1);
const int64_t rhsK = rhsType.getRank() == 3 ? rhsType.getDimSize(1) : rhsType.getDimSize(0);
const int64_t n = rhsType.getRank() == 3 ? rhsType.getDimSize(2) : rhsType.getDimSize(1);
const int64_t m = lhsType.getDimSize(lhsType.getRank() - 2);
const int64_t k = lhsType.getDimSize(lhsType.getRank() - 1);
const int64_t rhsK = rhsType.getDimSize(rhsType.getRank() - 2);
const int64_t n = rhsType.getDimSize(rhsType.getRank() - 1);
if (k != rhsK)
return failure();
@@ -146,15 +240,17 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
return failure();
}
else {
if (outType.getDimSize(0) != batch || outType.getDimSize(1) != m || outType.getDimSize(2) != n)
SmallVector<int64_t> outBatchShape(outType.getShape().begin(), outType.getShape().end() - 2);
if (!llvm::equal(outBatchShape, *batchShape) || outType.getDimSize(outType.getRank() - 2) != m
|| outType.getDimSize(outType.getRank() - 1) != n)
return failure();
}
Location loc = matmulOp.getLoc();
bool useTransposedForm = isHostFoldableValue(matmulOp.getA()) && !isHostFoldableValue(matmulOp.getB());
Value lhs = matmulOp.getA();
Value rhs = matmulOp.getB();
Value lhs = collapseBatchDims(matmulOp.getA(), lhsBatch, m, k, rewriter, loc);
Value rhs = collapseBatchDims(matmulOp.getB(), rhsBatch, k, n, rewriter, loc);
int64_t lhsBatchForGemm = lhsBatch;
int64_t rhsBatchForGemm = rhsBatch;
int64_t gemmM = m;
@@ -239,6 +335,7 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
}
Value result = concatValues(batchResults, /*axis=*/0, rewriter, loc);
result = expandBatchDims(result, outType, batchShape->size(), rewriter, loc);
rewriter.replaceOp(matmulOp, result);
return success();
}
@@ -1,9 +1,10 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/DialectConversion.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -22,53 +23,83 @@ static SmallVector<int64_t> permuteShape(ArrayRef<int64_t> shape, ArrayRef<int64
return permutedShape;
}
static Value createSoftmaxCompute(Value input, ConversionPatternRewriter& rewriter, Location loc) {
static Value buildLoopSoftmaxSlice(Value input,
Value accumulator,
RankedTensorType inputType,
ArrayRef<Value> outerIndices,
ConversionPatternRewriter& rewriter,
Location loc) {
int64_t rank = inputType.getRank();
SmallVector<int64_t> sliceShape(static_cast<size_t>(rank - 1), 1);
sliceShape.push_back(inputType.getDimSize(rank - 1));
auto sliceType = RankedTensorType::get(sliceShape, inputType.getElementType(), inputType.getEncoding());
SmallVector<OpFoldResult> offsets;
SmallVector<OpFoldResult> sizes;
SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
offsets.reserve(rank);
sizes.reserve(rank);
for (Value outerIndex : outerIndices) {
offsets.push_back(outerIndex);
sizes.push_back(rewriter.getIndexAttr(1));
}
offsets.push_back(rewriter.getIndexAttr(0));
sizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(rank - 1)));
Value inputSlice = tensor::ExtractSliceOp::create(rewriter, loc, sliceType, input, offsets, sizes, strides);
Value softmaxSlice = spatial::SpatSoftmaxOp::create(rewriter, loc, sliceType, inputSlice).getResult();
return tensor::InsertSliceOp::create(rewriter, loc, softmaxSlice, accumulator, offsets, sizes, strides);
}
static Value buildLoopSoftmaxNest(Value input,
Value accumulator,
RankedTensorType inputType,
int64_t axis,
SmallVectorImpl<Value>& outerIndices,
ConversionPatternRewriter& rewriter,
Location loc) {
if (axis == inputType.getRank() - 1)
return buildLoopSoftmaxSlice(input, accumulator, inputType, outerIndices, rewriter, loc);
Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
Value c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
Value cUpper = arith::ConstantIndexOp::create(rewriter, loc, inputType.getDimSize(axis));
auto loop = scf::ForOp::create(rewriter, loc, c0, cUpper, c1, ValueRange {accumulator});
rewriter.setInsertionPointToStart(loop.getBody());
Value loopIndex = loop.getInductionVar();
Value loopAccumulator = loop.getRegionIterArgs().front();
outerIndices.push_back(loopIndex);
Value updatedAccumulator =
buildLoopSoftmaxNest(input, loopAccumulator, inputType, axis + 1, outerIndices, rewriter, loc);
outerIndices.pop_back();
scf::YieldOp::create(rewriter, loc, updatedAccumulator);
rewriter.setInsertionPointAfter(loop);
return loop.getResult(0);
}
static Value createLoopSoftmaxCompute(Value input, ConversionPatternRewriter& rewriter, Location loc) {
auto inputType = cast<RankedTensorType>(input.getType());
constexpr size_t numInputs = 1;
auto computeOp =
createSpatCompute<numInputs>(rewriter, loc, TypeRange {inputType}, {}, ValueRange {input}, [&](Value x) {
auto softmaxOp = spatial::SpatSoftmaxOp::create(rewriter, loc, inputType, x);
spatial::SpatYieldOp::create(rewriter, loc, softmaxOp.getResult());
if (inputType.getRank() == 1) {
Value softmax = spatial::SpatSoftmaxOp::create(rewriter, loc, inputType, x).getResult();
spatial::SpatYieldOp::create(rewriter, loc, softmax);
return;
}
Value outputInit = tensor::EmptyOp::create(rewriter, loc, inputType.getShape(), inputType.getElementType());
SmallVector<Value> outerIndices;
Value result = buildLoopSoftmaxNest(x, outputInit, inputType, /*axis=*/0, outerIndices, rewriter, loc);
spatial::SpatYieldOp::create(rewriter, loc, result);
});
return computeOp.getResult(0);
}
static Value concatValues(ValueRange inputs, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
auto firstType = cast<RankedTensorType>(inputs.front().getType());
SmallVector<int64_t> outputShape(firstType.getShape().begin(), firstType.getShape().end());
int64_t concatDimSize = 0;
for (Value input : inputs)
concatDimSize += cast<RankedTensorType>(input.getType()).getDimSize(axis);
outputShape[axis] = concatDimSize;
auto resultType = RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
if (llvm::all_of(inputs, isHostFoldableValue))
return createSpatConcat(rewriter, loc, axis, inputs);
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {resultType}, {}, inputs, [&](ValueRange args) {
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, axis, args));
});
return concatCompute.getResult(0);
}
static Value
buildSoftmax(Value input, int64_t softmaxAxis, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
auto inputType = cast<RankedTensorType>(input.getType());
if (axis == inputType.getRank())
return createSoftmaxCompute(input, rewriter, loc);
if (axis == softmaxAxis)
return buildSoftmax(input, softmaxAxis, axis + 1, rewriter, loc);
SmallVector<Value> slices = sliceTensor(input, axis, /*sliceSize=*/1, rewriter, loc);
SmallVector<Value> rebuiltSlices;
rebuiltSlices.reserve(slices.size());
for (Value slice : slices)
rebuiltSlices.push_back(buildSoftmax(slice, softmaxAxis, axis + 1, rewriter, loc));
return concatValues(rebuiltSlices, axis, rewriter, loc);
}
struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
using OpConversionPattern::OpConversionPattern;
@@ -86,7 +117,7 @@ struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
Value input = adaptor.getInput();
Value result;
if (axis == inputType.getRank() - 1) {
result = buildSoftmax(input, axis, /*axis=*/0, rewriter, softmaxOp.getLoc());
result = createLoopSoftmaxCompute(input, rewriter, softmaxOp.getLoc());
}
else {
SmallVector<int64_t> permutation;
@@ -109,8 +140,7 @@ struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
spatial::SpatYieldOp::create(rewriter, softmaxOp.getLoc(), transposed);
});
Value transposedInput = preTransposeCompute.getResult(0);
Value transposedResult = buildSoftmax(
transposedInput, /*softmaxAxis=*/inputType.getRank() - 1, /*axis=*/0, rewriter, softmaxOp.getLoc());
Value transposedResult = createLoopSoftmaxCompute(transposedInput, rewriter, softmaxOp.getLoc());
auto postTransposeCompute =
createSpatCompute<1>(rewriter, softmaxOp.getLoc(), TypeRange {inputType}, {}, transposedResult, [&](Value x) {
Value transposed = ONNXTransposeOp::create(
@@ -80,6 +80,22 @@ static bool inferExpandReassociation(ArrayRef<int64_t> sourceShape,
return sourceIdx == sourceShape.size() && resultIdx == resultShape.size();
}
static SmallVector<ReassociationIndices> getCollapseTo1DReassociation(size_t rank) {
SmallVector<ReassociationIndices> reassociation(1);
reassociation.front().reserve(rank);
for (size_t dim = 0; dim < rank; ++dim)
reassociation.front().push_back(static_cast<int64_t>(dim));
return reassociation;
}
static SmallVector<ReassociationIndices> getExpandFrom1DReassociation(size_t rank) {
SmallVector<ReassociationIndices> reassociation(1);
reassociation.front().reserve(rank);
for (size_t dim = 0; dim < rank; ++dim)
reassociation.front().push_back(static_cast<int64_t>(dim));
return reassociation;
}
struct Reshape : OpConversionPattern<ONNXReshapeOp> {
using OpConversionPattern::OpConversionPattern;
@@ -126,6 +142,23 @@ struct Reshape : OpConversionPattern<ONNXReshapeOp> {
return tensor::ExpandShapeOp::create(rewriter, reshapeOp.getLoc(), resultType, data, reassociation);
});
if (sourceType.getNumElements() != resultType.getNumElements())
return failure();
return replaceWithReshape([&](Value data) -> Value {
Value reshaped = data;
if (sourceType.getRank() != 1) {
auto flatType = RankedTensorType::get({sourceType.getNumElements()}, sourceType.getElementType());
reshaped = tensor::CollapseShapeOp::create(
rewriter, reshapeOp.getLoc(), flatType, reshaped, getCollapseTo1DReassociation(sourceType.getRank()));
}
if (resultType.getRank() == 1)
return reshaped;
return tensor::ExpandShapeOp::create(
rewriter, reshapeOp.getLoc(), resultType, reshaped, getExpandFrom1DReassociation(resultType.getRank()))
.getResult();
});
return failure();
}
};
@@ -1,10 +1,10 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/STLExtras.h"
#include <algorithm>
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
@@ -15,42 +15,88 @@ using namespace mlir;
namespace onnx_mlir {
namespace {
static Value
extractSliceAt(Value input, int64_t axis, int64_t offset, ConversionPatternRewriter& rewriter, Location loc) {
auto inputType = cast<RankedTensorType>(input.getType());
SmallVector<OpFoldResult> offsets(inputType.getRank(), rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> sizes;
SmallVector<OpFoldResult> strides(inputType.getRank(), rewriter.getIndexAttr(1));
sizes.reserve(inputType.getRank());
for (int64_t dim : inputType.getShape())
sizes.push_back(rewriter.getIndexAttr(dim));
offsets[axis] = rewriter.getIndexAttr(offset);
sizes[axis] = rewriter.getIndexAttr(1);
return tensor::ExtractSliceOp::create(rewriter, loc, input, offsets, sizes, strides);
static Value buildNearestAsymmetricIndex(
Value outputIndex, int64_t inputDim, int64_t outputDim, ConversionPatternRewriter& rewriter, Location loc) {
Value cInputDim = arith::ConstantIndexOp::create(rewriter, loc, inputDim);
Value cOutputDim = arith::ConstantIndexOp::create(rewriter, loc, outputDim);
Value cInputDimLast = arith::ConstantIndexOp::create(rewriter, loc, inputDim - 1);
Value scaledIndex = arith::MulIOp::create(rewriter, loc, outputIndex, cInputDim);
Value inputIndex = arith::DivUIOp::create(rewriter, loc, scaledIndex, cOutputDim);
return arith::MinUIOp::create(rewriter, loc, inputIndex, cInputDimLast);
}
static int64_t nearestAsymmetricIndex(int64_t outputIndex, int64_t inputDim, int64_t outputDim) {
return std::min<int64_t>((outputIndex * inputDim) / outputDim, inputDim - 1);
}
static Value buildNearestResizeLoop(Value input,
RankedTensorType inputType,
RankedTensorType resultType,
ConversionPatternRewriter& rewriter,
Location loc) {
auto elemType = resultType.getElementType();
SmallVector<int64_t> unitShape(resultType.getRank(), 1);
auto unitTensorType = RankedTensorType::get(unitShape, elemType);
static Value buildNearestResize(Value input,
ArrayRef<int64_t> inputShape,
ArrayRef<int64_t> outputShape,
int64_t axis,
ConversionPatternRewriter& rewriter,
Location loc) {
if (axis == static_cast<int64_t>(outputShape.size()))
return input;
SmallVector<OpFoldResult> unitSizes(resultType.getRank(), rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> unitStrides(resultType.getRank(), rewriter.getIndexAttr(1));
SmallVector<Value> slices;
slices.reserve(outputShape[axis]);
for (int64_t outputIndex = 0; outputIndex < outputShape[axis]; ++outputIndex) {
int64_t inputIndex = nearestAsymmetricIndex(outputIndex, inputShape[axis], outputShape[axis]);
Value slice = extractSliceAt(input, axis, inputIndex, rewriter, loc);
slices.push_back(buildNearestResize(slice, inputShape, outputShape, axis + 1, rewriter, loc));
}
Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
Value c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
Value cOutputN = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(0));
Value cOutputC = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(1));
Value cOutputH = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(2));
Value cOutputW = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(3));
return createSpatConcat(rewriter, loc, axis, slices);
Value outputInit = tensor::EmptyOp::create(rewriter, loc, resultType.getShape(), elemType);
auto batchLoop = scf::ForOp::create(rewriter, loc, c0, cOutputN, c1, ValueRange {outputInit});
rewriter.setInsertionPointToStart(batchLoop.getBody());
Value outputN = batchLoop.getInductionVar();
Value outputBatchAcc = batchLoop.getRegionIterArgs().front();
Value inputN = buildNearestAsymmetricIndex(outputN, inputType.getDimSize(0), resultType.getDimSize(0), rewriter, loc);
auto channelLoop = scf::ForOp::create(rewriter, loc, c0, cOutputC, c1, ValueRange {outputBatchAcc});
rewriter.setInsertionPointToStart(channelLoop.getBody());
Value outputC = channelLoop.getInductionVar();
Value outputChannelAcc = channelLoop.getRegionIterArgs().front();
Value inputC =
buildNearestAsymmetricIndex(outputC, inputType.getDimSize(1), resultType.getDimSize(1), rewriter, loc);
auto heightLoop = scf::ForOp::create(rewriter, loc, c0, cOutputH, c1, ValueRange {outputChannelAcc});
rewriter.setInsertionPointToStart(heightLoop.getBody());
Value outputH = heightLoop.getInductionVar();
Value outputHeightAcc = heightLoop.getRegionIterArgs().front();
Value inputH =
buildNearestAsymmetricIndex(outputH, inputType.getDimSize(2), resultType.getDimSize(2), rewriter, loc);
auto widthLoop = scf::ForOp::create(rewriter, loc, c0, cOutputW, c1, ValueRange {outputHeightAcc});
rewriter.setInsertionPointToStart(widthLoop.getBody());
Value outputW = widthLoop.getInductionVar();
Value outputWidthAcc = widthLoop.getRegionIterArgs().front();
Value inputW =
buildNearestAsymmetricIndex(outputW, inputType.getDimSize(3), resultType.getDimSize(3), rewriter, loc);
SmallVector<OpFoldResult> inputOffsets = {inputN, inputC, inputH, inputW};
Value inputSlice =
tensor::ExtractSliceOp::create(rewriter, loc, unitTensorType, input, inputOffsets, unitSizes, unitStrides);
SmallVector<OpFoldResult> outputOffsets = {outputN, outputC, outputH, outputW};
Value updatedOutput =
tensor::InsertSliceOp::create(rewriter, loc, inputSlice, outputWidthAcc, outputOffsets, unitSizes, unitStrides);
scf::YieldOp::create(rewriter, loc, updatedOutput);
rewriter.setInsertionPointAfter(widthLoop);
scf::YieldOp::create(rewriter, loc, widthLoop.getResult(0));
rewriter.setInsertionPointAfter(heightLoop);
scf::YieldOp::create(rewriter, loc, heightLoop.getResult(0));
rewriter.setInsertionPointAfter(channelLoop);
scf::YieldOp::create(rewriter, loc, channelLoop.getResult(0));
rewriter.setInsertionPointAfter(batchLoop);
return batchLoop.getResult(0);
}
struct Resize : OpConversionPattern<ONNXResizeOp> {
@@ -62,20 +108,22 @@ struct Resize : OpConversionPattern<ONNXResizeOp> {
auto inputType = dyn_cast<RankedTensorType>(adaptor.getX().getType());
auto resultType = dyn_cast<RankedTensorType>(resizeOp.getY().getType());
if (!inputType || !resultType || !inputType.hasStaticShape() || !resultType.hasStaticShape())
return failure();
return rewriter.notifyMatchFailure(resizeOp, "resize lowering requires static ranked tensor types.");
if (inputType.getRank() != 4 || resultType.getRank() != 4)
return rewriter.notifyMatchFailure(resizeOp, "resize lowering currently supports only rank-4 NCHW tensors.");
if (resizeOp.getMode() != "nearest" || resizeOp.getCoordinateTransformationMode() != "asymmetric"
|| resizeOp.getNearestMode() != "floor")
return failure();
return rewriter.notifyMatchFailure(
resizeOp, "resize lowering currently supports only nearest + asymmetric + floor.");
if (llvm::any_of(inputType.getShape(), [](int64_t dim) { return dim <= 0; })
|| llvm::any_of(resultType.getShape(), [](int64_t dim) { return dim <= 0; }))
return failure();
return rewriter.notifyMatchFailure(resizeOp, "resize lowering requires positive static dimensions.");
auto computeOp =
createSpatCompute<1>(rewriter, resizeOp.getLoc(), TypeRange {resultType}, {}, adaptor.getX(), [&](Value x) {
Value result =
buildNearestResize(x, inputType.getShape(), resultType.getShape(), /*axis=*/0, rewriter, resizeOp.getLoc());
Value result = buildNearestResizeLoop(x, inputType, resultType, rewriter, resizeOp.getLoc());
spatial::SpatYieldOp::create(rewriter, resizeOp.getLoc(), result);
});
rewriter.replaceOp(resizeOp, computeOp.getResults());
@@ -31,46 +31,18 @@ static bool isDirectConstantValue(Value value) {
return isa_and_nonnull<arith::ConstantOp, ONNXConstantOp>(value.getDefiningOp());
}
// Collapses one-lane batches so later phases do not carry batch-only structure unnecessarily.
struct FoldSingleLaneComputeBatchPattern : OpRewritePattern<spatial::SpatComputeBatch> {
using OpRewritePattern<spatial::SpatComputeBatch>::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatComputeBatch batchOp, PatternRewriter& rewriter) const override {
if (batchOp.getLaneCount() != 1)
return rewriter.notifyMatchFailure(batchOp, "requires a single lane");
auto loc = batchOp.getLoc();
rewriter.setInsertionPoint(batchOp);
auto computeOp =
spatial::SpatCompute::create(rewriter, loc, batchOp.getResultTypes(), batchOp.getWeights(), batchOp.getInputs());
computeOp.getProperties().setOperandSegmentSizes(
{static_cast<int>(batchOp.getWeights().size()), static_cast<int>(batchOp.getInputs().size())});
Block& templateBlock = batchOp.getBody().front();
SmallVector<Type> blockArgTypes;
SmallVector<Location> blockArgLocs;
blockArgTypes.reserve(templateBlock.getNumArguments());
blockArgLocs.reserve(templateBlock.getNumArguments());
for (BlockArgument arg : templateBlock.getArguments()) {
blockArgTypes.push_back(arg.getType());
blockArgLocs.push_back(loc);
}
auto* newBlock =
rewriter.createBlock(&computeOp.getBody(), computeOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
IRMapping mapper;
for (auto [oldArg, newArg] : llvm::zip(templateBlock.getArguments(), newBlock->getArguments()))
mapper.map(oldArg, newArg);
rewriter.setInsertionPointToEnd(newBlock);
for (Operation& op : templateBlock)
rewriter.clone(op, mapper);
batchOp->replaceAllUsesWith(computeOp->getResults());
rewriter.eraseOp(batchOp);
return success();
template <typename ComputeOpTy>
static bool hasPromotableWeightLikeInputs(ComputeOpTy compute) {
Block& block = compute.getBody().front();
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
if (!isWeightLikeComputeOperand(input))
continue;
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(compute.getInputArgument(inputIdx)))
continue;
return true;
}
};
return false;
}
// Promotes foldable helper chains from runtime inputs to weights to avoid artificial compute inputs.
struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCompute> {
@@ -81,11 +53,9 @@ struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCom
bool needsRewrite = false;
Block& oldBlock = compute.getBody().front();
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
if (inputIdx >= oldBlock.getNumArguments())
continue;
if (!isWeightLikeComputeOperand(input))
continue;
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(oldBlock.getArgument(inputIdx)))
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(compute.getInputArgument(inputIdx)))
continue;
promoteInput[inputIdx] = true;
needsRewrite = true;
@@ -116,8 +86,16 @@ struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCom
auto newCompute =
spatial::SpatCompute::create(rewriter, compute.getLoc(), compute.getResultTypes(), newWeights, newInputs);
SmallVector<Type> newBlockArgTypes;
SmallVector<Location> newBlockArgLocs;
for (Value weight : newWeights) {
newBlockArgTypes.push_back(weight.getType());
newBlockArgLocs.push_back(weight.getLoc());
}
llvm::append_range(newBlockArgTypes, newInputTypes);
llvm::append_range(newBlockArgLocs, newInputLocs);
auto* newBlock =
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), newInputTypes, newInputLocs);
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), TypeRange(newBlockArgTypes), newBlockArgLocs);
newCompute.getProperties().setOperandSegmentSizes(
{static_cast<int>(newWeights.size()), static_cast<int>(newInputs.size())});
rewriter.setInsertionPointToStart(newBlock);
@@ -126,14 +104,17 @@ struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCom
bodyRewriter.setInsertionPointToStart(newBlock);
IRMapping mapper;
for (auto [weightIndex, weight] : llvm::enumerate(compute.getWeights()))
mapper.map(compute.getWeightArgument(weightIndex), newCompute.getWeightArgument(weightIndex));
size_t newInputIdx = 0;
for (auto [oldInputIdx, oldArg] : llvm::enumerate(oldBlock.getArguments())) {
for (auto [oldInputIdx, input] : llvm::enumerate(compute.getInputs())) {
BlockArgument oldArg = compute.getInputArgument(oldInputIdx);
if (!promoteInput[oldInputIdx]) {
mapper.map(oldArg, newBlock->getArgument(newInputIdx++));
mapper.map(oldArg, newCompute.getInputArgument(newInputIdx++));
continue;
}
auto clonedValue = materializeWeightLikeValueInBlock(compute.getInputs()[oldInputIdx], bodyRewriter, mapper);
auto clonedValue = materializeWeightLikeValueInBlock(input, bodyRewriter, mapper);
if (failed(clonedValue))
return rewriter.notifyMatchFailure(compute, "failed to materialize promoted weight-like operand");
mapper.map(oldArg, *clonedValue);
@@ -165,11 +146,9 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
bool needsRewrite = false;
Block& oldBlock = compute.getBody().front();
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
if (inputIdx >= oldBlock.getNumArguments())
continue;
if (!isWeightLikeComputeOperand(input))
continue;
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(oldBlock.getArgument(inputIdx)))
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(compute.getInputArgument(inputIdx)))
continue;
promoteInput[inputIdx] = true;
needsRewrite = true;
@@ -205,8 +184,25 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
rewriter.getI32IntegerAttr(static_cast<int32_t>(compute.getLaneCount())),
newWeights,
newInputs);
auto* newBlock =
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), newInputTypes, newInputLocs);
SmallVector<Type> newBlockArgTypes;
SmallVector<Location> newBlockArgLocs;
newBlockArgTypes.reserve(1 + newWeights.size() + newInputTypes.size() + compute.getNumResults());
newBlockArgLocs.reserve(1 + newWeights.size() + newInputLocs.size() + compute.getNumResults());
newBlockArgTypes.push_back(compute.getLaneArgument().getType());
newBlockArgLocs.push_back(compute.getLaneArgument().getLoc());
for (Value weight : newWeights) {
newBlockArgTypes.push_back(weight.getType());
newBlockArgLocs.push_back(weight.getLoc());
}
llvm::append_range(newBlockArgTypes, newInputTypes);
llvm::append_range(newBlockArgLocs, newInputLocs);
for (auto [resultIndex, resultType] : llvm::enumerate(compute.getResultTypes())) {
newBlockArgTypes.push_back(resultType);
newBlockArgLocs.push_back(compute.getOutputArgument(resultIndex).getLoc());
}
auto* newBlock = rewriter.createBlock(
&newCompute.getBody(), newCompute.getBody().end(), TypeRange(newBlockArgTypes), newBlockArgLocs);
newCompute.getProperties().setOperandSegmentSizes(
{static_cast<int>(newWeights.size()), static_cast<int>(newInputs.size())});
rewriter.setInsertionPointToStart(newBlock);
@@ -215,31 +211,28 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
bodyRewriter.setInsertionPointToStart(newBlock);
IRMapping mapper;
mapper.map(compute.getLaneArgument(), newCompute.getLaneArgument());
for (auto [weightIndex, weight] : llvm::enumerate(compute.getWeights()))
mapper.map(compute.getWeightArgument(weightIndex), newCompute.getWeightArgument(weightIndex));
size_t newInputIdx = 0;
for (auto [oldInputIdx, oldArg] : llvm::enumerate(oldBlock.getArguments())) {
for (auto [oldInputIdx, input] : llvm::enumerate(compute.getInputs())) {
BlockArgument oldArg = compute.getInputArgument(oldInputIdx);
if (!promoteInput[oldInputIdx]) {
mapper.map(oldArg, newBlock->getArgument(newInputIdx++));
mapper.map(oldArg, newCompute.getInputArgument(newInputIdx++));
continue;
}
auto clonedValue = materializeWeightLikeValueInBlock(compute.getInputs()[oldInputIdx], bodyRewriter, mapper);
auto clonedValue = materializeWeightLikeValueInBlock(input, bodyRewriter, mapper);
if (failed(clonedValue))
return rewriter.notifyMatchFailure(compute, "failed to materialize promoted batch weight-like operand");
mapper.map(oldArg, *clonedValue);
}
for (auto resultIndex : llvm::seq<size_t>(0, compute.getNumResults()))
mapper.map(compute.getOutputArgument(resultIndex), newBlock->getArgument(1 + newWeights.size() + newInputs.size() + resultIndex));
for (Operation& op : oldBlock.without_terminator())
for (Operation& op : oldBlock)
rewriter.clone(op, mapper);
auto oldYield = cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
SmallVector<Value> newYieldOperands;
newYieldOperands.reserve(oldYield.getOutputs().size());
for (Value operand : oldYield.getOutputs()) {
auto mapped = mapper.lookupOrNull(operand);
newYieldOperands.push_back(mapped ? cast<Value>(mapped) : operand);
}
spatial::SpatYieldOp::create(rewriter, oldYield.getLoc(), newYieldOperands);
rewriter.replaceOp(compute, newCompute.getResults());
return success();
}
@@ -247,10 +240,6 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
} // namespace
void populateEarlyPostPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.add<FoldSingleLaneComputeBatchPattern>(ctx);
}
void populatePostPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.add<PromoteWeightLikeComputeInputsPattern, PromoteWeightLikeComputeBatchInputsPattern>(ctx);
}
@@ -262,4 +251,8 @@ void annotateWeightsConstants(func::FuncOp funcOp) {
});
}
bool requiresPostRewrite(spatial::SpatCompute computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
bool requiresPostRewrite(spatial::SpatComputeBatch computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
} // namespace onnx_mlir
@@ -3,9 +3,13 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/MLIRContext.h"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
namespace onnx_mlir {
void populateEarlyPostPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
bool requiresPostRewrite(spatial::SpatCompute computeOp);
bool requiresPostRewrite(spatial::SpatComputeBatch computeOp);
void populatePostPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
@@ -17,9 +17,7 @@ void populatePrePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* c
patterns.add<convAddToConvWithBiasLeft>(ctx);
patterns.add<convAddToConvWithBiasRight>(ctx);
patterns.add<matMulAddToGemm>(ctx);
patterns.add<matMulToGemm>(ctx);
patterns.add<removeFlattenSameShape>(ctx);
populateMatMulRewritePatterns(patterns, ctx);
}
} // namespace onnx_mlir
@@ -202,6 +202,7 @@ void SpatialToGraphvizPass::runOnOperation() {
auto entryFunc = getPimEntryFunc(module);
if (failed(entryFunc)) {
module.emitError("failed to locate the PIM entry function for Spatial graph visualization");
signalPassFailure();
return;
}
@@ -2,6 +2,7 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
@@ -17,6 +18,37 @@ namespace {
static int32_t translateSpatialCoreIdToPimCoreId(size_t spatialCoreId) { return static_cast<int32_t>(spatialCoreId); }
static FailureOr<int32_t> getConstantI32Value(Value value) {
APInt constantValue;
if (!matchPattern(value, m_ConstantInt(&constantValue)))
return failure();
return static_cast<int32_t>(constantValue.getSExtValue());
}
static FailureOr<SmallVector<int32_t>> getConstantI32Values(ValueRange values) {
SmallVector<int32_t> constants;
constants.reserve(values.size());
for (Value value : values) {
FailureOr<int32_t> constantValue = getConstantI32Value(value);
if (failed(constantValue))
return failure();
constants.push_back(*constantValue);
}
return constants;
}
static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
if (isa<pim::PimMemCopyDevToHostOp>(op))
return operandIndex == 2;
return false;
}
static bool isUsedOnlyAsExplicitHostOperand(Value value) {
return !value.use_empty() && llvm::all_of(value.getUses(), [](OpOperand& use) {
return isExplicitHostOperand(use.getOwner(), use.getOperandNumber());
});
}
static SmallVector<int32_t> getPimCoreIdsForBatchOp(spatial::SpatComputeBatch computeBatchOp, size_t& fallbackCoreId) {
if (auto coreIdsAttr = computeBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
@@ -28,27 +60,30 @@ static SmallVector<int32_t> getPimCoreIdsForBatchOp(spatial::SpatComputeBatch co
return coreIds;
}
static void lowerChannelSendTensorBatch(spatial::SpatChannelSendTensorBatchOp sendTensorBatchOp,
IRMapping& mapper,
IRRewriter& rewriter) {
SmallVector<int32_t> targetCoreIds;
targetCoreIds.reserve(sendTensorBatchOp.getTargetCoreIds().size());
for (int32_t targetCoreId : sendTensorBatchOp.getTargetCoreIds())
targetCoreIds.push_back(translateSpatialCoreIdToPimCoreId(targetCoreId));
static LogicalResult lowerChannelSendTensorBatch(spatial::SpatChannelSendTensorBatchOp sendTensorBatchOp,
IRMapping& mapper,
IRRewriter& rewriter) {
FailureOr<SmallVector<int32_t>> targetCoreIds = getConstantI32Values(sendTensorBatchOp.getTargetCoreIds());
if (failed(targetCoreIds))
return sendTensorBatchOp.emitOpError("expected constant targetCoreIds");
for (int32_t& targetCoreId : *targetCoreIds)
targetCoreId = translateSpatialCoreIdToPimCoreId(targetCoreId);
pim::PimSendTensorBatchOp::create(rewriter,
sendTensorBatchOp.getLoc(),
mapper.lookup(sendTensorBatchOp.getInput()),
rewriter.getDenseI32ArrayAttr(targetCoreIds));
rewriter.getDenseI32ArrayAttr(*targetCoreIds));
return success();
}
static void lowerChannelReceiveTensorBatch(spatial::SpatChannelReceiveTensorBatchOp receiveTensorBatchOp,
IRMapping& mapper,
IRRewriter& rewriter) {
SmallVector<int32_t> sourceCoreIds;
sourceCoreIds.reserve(receiveTensorBatchOp.getSourceCoreIds().size());
for (int32_t sourceCoreId : receiveTensorBatchOp.getSourceCoreIds())
sourceCoreIds.push_back(translateSpatialCoreIdToPimCoreId(sourceCoreId));
static LogicalResult lowerChannelReceiveTensorBatch(spatial::SpatChannelReceiveTensorBatchOp receiveTensorBatchOp,
IRMapping& mapper,
IRRewriter& rewriter) {
FailureOr<SmallVector<int32_t>> sourceCoreIds = getConstantI32Values(receiveTensorBatchOp.getSourceCoreIds());
if (failed(sourceCoreIds))
return receiveTensorBatchOp.emitOpError("expected constant sourceCoreIds");
for (int32_t& sourceCoreId : *sourceCoreIds)
sourceCoreId = translateSpatialCoreIdToPimCoreId(sourceCoreId);
auto outputType = cast<ShapedType>(receiveTensorBatchOp.getOutput().getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveTensorBatchOp.getLoc(), outputType);
@@ -56,24 +91,26 @@ static void lowerChannelReceiveTensorBatch(spatial::SpatChannelReceiveTensorBatc
receiveTensorBatchOp.getLoc(),
outputBuffer.getType(),
outputBuffer,
rewriter.getDenseI32ArrayAttr(sourceCoreIds))
rewriter.getDenseI32ArrayAttr(*sourceCoreIds))
.getOutput();
mapper.map(receiveTensorBatchOp.getOutput(), received);
return success();
}
} // namespace
LogicalResult
lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState& state, IRRewriter& rewriter) {
if (computeBatchOp.getNumResults() != 0)
return computeBatchOp.emitOpError(
"batched Spatial-to-PIM lowering currently requires channelized compute_batch with no results");
Location loc = computeBatchOp.getLoc();
Block& oldBlock = computeBatchOp.getBody().front();
auto oldYield = cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
if (oldYield.getNumOperands() != 0)
return computeBatchOp.emitOpError("batched Spatial-to-PIM lowering currently requires empty spat.yield");
if (computeBatchOp.getNumResults() != 0)
return computeBatchOp.emitOpError(
"batched Spatial-to-PIM lowering currently requires channelized compute_batch with no results; "
"materialize explicit communication before lowering to PIM");
auto oldYield = dyn_cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
if (!oldYield || oldYield.getNumOperands() != 0)
return computeBatchOp.emitOpError("resultless compute_batch lowering requires empty spat.yield");
SmallVector<int32_t> coreIds = getPimCoreIdsForBatchOp(computeBatchOp, state.nextCoreId);
SmallVector<Value> batchWeights(computeBatchOp.getWeights().begin(), computeBatchOp.getWeights().end());
@@ -102,7 +139,12 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState&
IRMapping mapper;
rewriter.setInsertionPointToStart(newBlock);
for (auto [oldArg, newArg] : llvm::zip(oldBlock.getArguments(), newBlock->getArguments())) {
mapper.map(computeBatchOp.getLaneArgument(), coreBatchOp.getLaneArgument());
for (unsigned weightIndex = 0; weightIndex < computeBatchOp.getWeights().size(); ++weightIndex)
mapper.map(computeBatchOp.getWeightArgument(weightIndex), coreBatchOp.getWeightArgument(weightIndex));
for (unsigned inputIndex = 0; inputIndex < computeBatchOp.getInputs().size(); ++inputIndex) {
BlockArgument oldArg = computeBatchOp.getInputArgument(inputIndex);
BlockArgument newArg = coreBatchOp.getInputArgument(inputIndex);
auto newArgType = cast<ShapedType>(newArg.getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, newArgType);
auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter,
@@ -142,20 +184,31 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState&
continue;
if (auto sendBatchOp = dyn_cast<spatial::SpatChannelSendBatchOp>(op)) {
FailureOr<SmallVector<int32_t>> targetCoreIds = getConstantI32Values(sendBatchOp.getTargetCoreIds());
if (failed(targetCoreIds))
return sendBatchOp.emitOpError("expected constant targetCoreIds");
for (int32_t& targetCoreId : *targetCoreIds)
targetCoreId = translateSpatialCoreIdToPimCoreId(targetCoreId);
pim::PimSendBatchOp::create(rewriter,
loc,
mapper.lookup(sendBatchOp.getInput()),
getTensorSizeInBytesAttr(rewriter, mapper.lookup(sendBatchOp.getInput())),
sendBatchOp.getTargetCoreIdsAttr());
rewriter.getDenseI32ArrayAttr(*targetCoreIds));
continue;
}
if (auto sendTensorBatchOp = dyn_cast<spatial::SpatChannelSendTensorBatchOp>(op)) {
lowerChannelSendTensorBatch(sendTensorBatchOp, mapper, rewriter);
if (failed(lowerChannelSendTensorBatch(sendTensorBatchOp, mapper, rewriter)))
return failure();
continue;
}
if (auto receiveBatchOp = dyn_cast<spatial::SpatChannelReceiveBatchOp>(op)) {
FailureOr<SmallVector<int32_t>> sourceCoreIds = getConstantI32Values(receiveBatchOp.getSourceCoreIds());
if (failed(sourceCoreIds))
return receiveBatchOp.emitOpError("expected constant sourceCoreIds");
for (int32_t& sourceCoreId : *sourceCoreIds)
sourceCoreId = translateSpatialCoreIdToPimCoreId(sourceCoreId);
auto outputType = cast<ShapedType>(receiveBatchOp.getOutput().getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, outputType);
auto received = pim::PimReceiveBatchOp::create(rewriter,
@@ -163,14 +216,15 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState&
outputBuffer.getType(),
outputBuffer,
getTensorSizeInBytesAttr(rewriter, receiveBatchOp.getOutput()),
receiveBatchOp.getSourceCoreIdsAttr())
rewriter.getDenseI32ArrayAttr(*sourceCoreIds))
.getOutput();
mapper.map(receiveBatchOp.getOutput(), received);
continue;
}
if (auto receiveTensorBatchOp = dyn_cast<spatial::SpatChannelReceiveTensorBatchOp>(op)) {
lowerChannelReceiveTensorBatch(receiveTensorBatchOp, mapper, rewriter);
if (failed(lowerChannelReceiveTensorBatch(receiveTensorBatchOp, mapper, rewriter)))
return failure();
continue;
}
@@ -178,6 +232,10 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState&
if (isa_and_present<memref::GetGlobalOp>(toTensorOp.getBuffer().getDefiningOp())) {
Operation* cloned = rewriter.clone(op, mapper);
auto clonedTensor = cloned->getResult(0);
if (isUsedOnlyAsExplicitHostOperand(toTensorOp.getResult())) {
mapper.map(toTensorOp.getResult(), clonedTensor);
continue;
}
auto clonedType = cast<ShapedType>(clonedTensor.getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, clonedType);
auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter,
@@ -194,9 +252,11 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState&
}
}
for (Value operand : op.getOperands()) {
for (auto [operandIndex, operand] : llvm::enumerate(op.getOperands())) {
if (!isa<TensorType>(operand.getType()) || mapper.contains(operand))
continue;
if (isExplicitHostOperand(&op, operandIndex))
continue;
Operation* definingOp = operand.getDefiningOp();
if (definingOp && definingOp->getBlock() == &oldBlock)
@@ -22,6 +22,8 @@ add_pim_library(OMSpatialToPim
LINK_LIBS PUBLIC
MLIRSCFDialect
MLIRSCFUtils
MLIRTransformUtils
MLIRTosaDialect
OMCompilerOptions
OMPimCommon
@@ -1,4 +1,5 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Matchers.h"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
@@ -12,15 +13,24 @@ namespace {
static int32_t toPimCoreId(int32_t spatialCoreId) { return spatialCoreId; }
static FailureOr<SmallVector<int32_t>> getConstantI32Values(ValueRange values) {
SmallVector<int32_t> constants;
constants.reserve(values.size());
for (Value value : values) {
APInt constantValue;
if (!matchPattern(value, m_ConstantInt(&constantValue)))
return failure();
constants.push_back(static_cast<int32_t>(constantValue.getSExtValue()));
}
return constants;
}
struct ChannelSendLowering : OpRewritePattern<spatial::SpatChannelSendOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatChannelSendOp op, PatternRewriter& rewriter) const override {
pim::PimSendOp::create(rewriter,
op.getLoc(),
op.getInput(),
getTensorSizeInBytesAttr(rewriter, op.getInput()),
rewriter.getI32IntegerAttr(toPimCoreId(op.getTargetCoreId())));
pim::PimSendOp::create(
rewriter, op.getLoc(), op.getInput(), getTensorSizeInBytesAttr(rewriter, op.getInput()), op.getTargetCoreId());
rewriter.eraseOp(op);
return success();
}
@@ -42,7 +52,7 @@ struct ChannelReceiveLowering : OpRewritePattern<spatial::SpatChannelReceiveOp>
op.getResult().getType(),
outputBuffer,
getTensorSizeInBytesAttr(rewriter, op.getResult()),
rewriter.getI32IntegerAttr(toPimCoreId(op.getSourceCoreId())))
op.getSourceCoreId())
.getOutput();
rewriter.replaceOp(op, received);
return success();
@@ -53,11 +63,12 @@ struct ChannelSendTensorLowering : OpRewritePattern<spatial::SpatChannelSendTens
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatChannelSendTensorOp op, PatternRewriter& rewriter) const override {
SmallVector<int32_t> targetCoreIds;
targetCoreIds.reserve(op.getTargetCoreIds().size());
for (int32_t targetCoreId : op.getTargetCoreIds())
targetCoreIds.push_back(toPimCoreId(targetCoreId));
pim::PimSendTensorOp::create(rewriter, op.getLoc(), op.getInput(), rewriter.getDenseI32ArrayAttr(targetCoreIds));
FailureOr<SmallVector<int32_t>> targetCoreIds = getConstantI32Values(op.getTargetCoreIds());
if (failed(targetCoreIds))
return rewriter.notifyMatchFailure(op, "expected constant targetCoreIds");
for (int32_t& targetCoreId : *targetCoreIds)
targetCoreId = toPimCoreId(targetCoreId);
pim::PimSendTensorOp::create(rewriter, op.getLoc(), op.getInput(), rewriter.getDenseI32ArrayAttr(*targetCoreIds));
rewriter.eraseOp(op);
return success();
}
@@ -67,16 +78,17 @@ struct ChannelReceiveTensorLowering : OpRewritePattern<spatial::SpatChannelRecei
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatChannelReceiveTensorOp op, PatternRewriter& rewriter) const override {
SmallVector<int32_t> sourceCoreIds;
sourceCoreIds.reserve(op.getSourceCoreIds().size());
for (int32_t sourceCoreId : op.getSourceCoreIds())
sourceCoreIds.push_back(toPimCoreId(sourceCoreId));
FailureOr<SmallVector<int32_t>> sourceCoreIds = getConstantI32Values(op.getSourceCoreIds());
if (failed(sourceCoreIds))
return rewriter.notifyMatchFailure(op, "expected constant sourceCoreIds");
for (int32_t& sourceCoreId : *sourceCoreIds)
sourceCoreId = toPimCoreId(sourceCoreId);
auto outputType = cast<ShapedType>(op.getOutput().getType());
Value outputBuffer =
tensor::EmptyOp::create(rewriter, op.getLoc(), outputType.getShape(), outputType.getElementType()).getResult();
Value received =
pim::PimReceiveTensorOp::create(
rewriter, op.getLoc(), op.getOutput().getType(), outputBuffer, rewriter.getDenseI32ArrayAttr(sourceCoreIds))
rewriter, op.getLoc(), op.getOutput().getType(), outputBuffer, rewriter.getDenseI32ArrayAttr(*sourceCoreIds))
.getOutput();
rewriter.replaceOp(op, received);
return success();
@@ -29,7 +29,10 @@ void replaceAndEraseDirectComputeLikeInput(PatternRewriter& rewriter,
unsigned inputIndex,
Value replacement) {
Block& body = owner->getRegion(0).front();
BlockArgument bodyArgument = body.getArgument(inputIndex);
BlockArgument bodyArgument = isa<spatial::SpatCompute>(owner)
? cast<spatial::SpatCompute>(owner).getInputArgument(inputIndex)
: cast<spatial::SpatComputeBatch>(owner).getInputArgument(inputIndex);
unsigned bodyArgIndex = bodyArgument.getArgNumber();
rewriter.startOpModification(owner);
bodyArgument.replaceAllUsesWith(replacement);
@@ -37,7 +40,7 @@ void replaceAndEraseDirectComputeLikeInput(PatternRewriter& rewriter,
compute.getInputsMutable().erase(inputIndex);
else
cast<spatial::SpatComputeBatch>(owner).getInputsMutable().erase(inputIndex);
body.eraseArgument(inputIndex);
body.eraseArgument(bodyArgIndex);
rewriter.finalizeOpModification(owner);
}
@@ -3,6 +3,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
@@ -27,7 +28,8 @@ static bool isChannelUseChainOp(Operation* op) {
pim::PimTransposeOp>(op);
}
static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter) {
static void
cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter, OperationFolder& constantFolder) {
for (Value operand : op->getOperands()) {
if (mapping.lookupOrNull(operand))
continue;
@@ -36,7 +38,12 @@ static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewri
if (!definingOp)
continue;
if (!isa<tensor::EmptyOp, arith::ConstantOp>(definingOp))
if (auto constantOp = dyn_cast<arith::ConstantOp>(definingOp)) {
mapping.map(operand, getOrCreateHostConstantLike(constantOp, constantFolder));
continue;
}
if (!isa<tensor::EmptyOp>(definingOp))
continue;
Operation* clonedOp = rewriter.clone(*definingOp, mapping);
@@ -48,6 +55,18 @@ static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewri
static int32_t translateSpatialCoreIdToPimCoreId(size_t spatialCoreId) { return static_cast<int32_t>(spatialCoreId); }
static FailureOr<SmallVector<int32_t>> getConstantI32Values(ValueRange values) {
SmallVector<int32_t> constants;
constants.reserve(values.size());
for (Value value : values) {
APInt constantValue;
if (!matchPattern(value, m_ConstantInt(&constantValue)))
return failure();
constants.push_back(static_cast<int32_t>(constantValue.getSExtValue()));
}
return constants;
}
static int32_t getPimCoreIdForComputeOp(spatial::SpatCompute computeOp, size_t& fallbackCoreId) {
if (auto spatialCoreIdAttr = computeOp->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
return static_cast<int32_t>(spatialCoreIdAttr.getInt());
@@ -92,7 +111,9 @@ static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp,
return success();
}
static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute computeOp, IRRewriter& rewriter) {
static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute computeOp,
IRRewriter& rewriter,
OperationFolder& constantFolder) {
if (!computeOp.getInputs().empty() || computeOp.getNumResults() != 1)
return false;
if (!llvm::all_of(computeOp.getResult(0).getUsers(), [](Operation* user) {
@@ -101,7 +122,7 @@ static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute
return false;
Block& block = computeOp.getBody().front();
if (block.getNumArguments() != 0)
if (block.getNumArguments() != computeOp.getWeights().size())
return false;
auto yieldOp = dyn_cast<spatial::SpatYieldOp>(block.getTerminator());
@@ -110,8 +131,10 @@ static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute
rewriter.setInsertionPoint(computeOp);
IRMapping mapping;
for (auto [weightIndex, weight] : llvm::enumerate(computeOp.getWeights()))
mapping.map(computeOp.getWeightArgument(weightIndex), weight);
for (Operation& op : block.without_terminator()) {
cloneMappedHelperOperands(&op, mapping, rewriter);
cloneMappedHelperOperands(&op, mapping, rewriter, constantFolder);
Operation* clonedOp = rewriter.clone(op, mapping);
for (auto [originalResult, newResult] : llvm::zip(op.getResults(), clonedOp->getResults()))
mapping.map(originalResult, newResult);
@@ -133,7 +156,7 @@ void markOpToRemove(CoreLoweringState& state, Operation* op) {
LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState& state, IRRewriter& rewriter) {
Location loc = computeOp->getLoc();
if (inlineInputlessHelperComputeForWeightLikeUsers(computeOp, rewriter))
if (inlineInputlessHelperComputeForWeightLikeUsers(computeOp, rewriter, state.constantFolder))
return success();
SmallVector<Operation*> helperChain;
@@ -143,21 +166,42 @@ LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState&
auto& block = computeOp.getRegion().front();
auto yieldOp = cast<spatial::SpatYieldOp>(block.getTerminator());
for (auto [argIndex, blockArg] : llvm::enumerate(block.getArguments())) {
auto receiveOp = dyn_cast_or_null<spatial::SpatChannelReceiveOp>(computeOp.getInputs()[argIndex].getDefiningOp());
if (!receiveOp || blockArg.use_empty())
for (auto [inputIndex, input] : llvm::enumerate(computeOp.getInputs())) {
BlockArgument blockArg = computeOp.getInputArgument(inputIndex);
auto receiveOp = dyn_cast_or_null<spatial::SpatChannelReceiveOp>(input.getDefiningOp());
if (receiveOp && !blockArg.use_empty()) {
rewriter.setInsertionPoint(getEarliestUserWithinBlock(blockArg));
auto outputType = cast<ShapedType>(blockArg.getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveOp.getLoc(), outputType);
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, blockArg);
Value received =
PimReceiveOp::create(
rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, receiveOp.getSourceCoreId())
.getOutput();
blockArg.replaceAllUsesWith(received);
markOpToRemove(state, receiveOp);
continue;
}
rewriter.setInsertionPoint(getEarliestUserWithinBlock(blockArg));
auto outputType = cast<ShapedType>(blockArg.getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveOp.getLoc(), outputType);
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, blockArg);
auto sourceCoreIdAttr = rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(receiveOp.getSourceCoreId()));
Value received = PimReceiveOp::create(
rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr)
.getOutput();
blockArg.replaceAllUsesWith(received);
markOpToRemove(state, receiveOp);
auto receiveTensorOp = dyn_cast_or_null<spatial::SpatChannelReceiveTensorOp>(input.getDefiningOp());
if (receiveTensorOp && !blockArg.use_empty()) {
FailureOr<SmallVector<int32_t>> sourceCoreIds = getConstantI32Values(receiveTensorOp.getSourceCoreIds());
if (failed(sourceCoreIds))
return receiveTensorOp.emitOpError("expected constant sourceCoreIds");
for (int32_t& sourceCoreId : *sourceCoreIds)
sourceCoreId = translateSpatialCoreIdToPimCoreId(sourceCoreId);
rewriter.setInsertionPoint(getEarliestUserWithinBlock(blockArg));
auto outputType = cast<ShapedType>(blockArg.getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveTensorOp.getLoc(), outputType);
Value received = PimReceiveTensorOp::create(rewriter,
receiveTensorOp.getLoc(),
outputBuffer.getType(),
outputBuffer,
rewriter.getDenseI32ArrayAttr(*sourceCoreIds))
.getOutput();
blockArg.replaceAllUsesWith(received);
markOpToRemove(state, receiveTensorOp);
}
}
if (computeOp.getNumResults() != yieldOp.getNumOperands())
@@ -197,11 +241,36 @@ LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState&
loc,
ValueRange(computeWeights),
rewriter.getI32IntegerAttr(getPimCoreIdForComputeOp(computeOp, state.nextCoreId)));
rewriter.setInsertionPointToStart(&block);
auto& coreOpBlocks = coreOp.getBody().getBlocks();
for (auto [argIndex, blockArg] : llvm::enumerate(block.getArguments()))
if (!blockArg.use_empty())
blockArg.replaceAllUsesWith(computeOp.getInputs()[argIndex]);
block.eraseArguments(0, block.getNumArguments());
for (auto [inputIndex, input] : llvm::enumerate(computeOp.getInputs())) {
BlockArgument blockArg = computeOp.getInputArgument(inputIndex);
if (blockArg.use_empty())
continue;
if (auto constantOp = input.getDefiningOp<arith::ConstantOp>()) {
blockArg.replaceAllUsesWith(getOrCreateHostConstantLike(constantOp, state.constantFolder));
continue;
}
auto inputType = dyn_cast<ShapedType>(input.getType());
if (!inputType)
return computeOp.emitOpError("expected shaped compute input during pim.core lowering");
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, inputType);
auto copied =
PimMemCopyHostToDevOp::create(rewriter,
loc,
outputBuffer.getType(),
getOrCreateHostIndexConstant(outputBuffer.getOperation(), 0, state.constantFolder),
getOrCreateHostIndexConstant(outputBuffer.getOperation(), 0, state.constantFolder),
outputBuffer,
input,
getTensorSizeInBytesAttr(rewriter, input))
.getOutput();
blockArg.replaceAllUsesWith(copied);
}
if (!computeOp.getInputs().empty())
block.eraseArguments(computeOp.getWeights().size(), computeOp.getInputs().size());
coreOpBlocks.splice(coreOpBlocks.begin(), computeOp.getBody().getBlocks());
Block* tempComputeBlock = new Block();
computeOp.getBody().push_back(tempComputeBlock);
@@ -1,6 +1,7 @@
#pragma once
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/FoldUtils.h"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
@@ -11,6 +12,7 @@ struct CoreLoweringState {
size_t& nextCoreId;
llvm::SmallVectorImpl<OutputTensorFactory>& outputTensors;
llvm::SmallVectorImpl<mlir::Operation*>& operationsToRemove;
mlir::OperationFolder& constantFolder;
};
void markOpToRemove(CoreLoweringState& state, mlir::Operation* op);
@@ -76,8 +76,7 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, uses.getOperandNumber());
if (!inputIndex)
return failure();
auto BBArgIndex = *inputIndex;
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
auto BBArgValue = spatCompute.getInputArgument(*inputIndex);
if (BBArgValue.use_empty())
continue;
@@ -89,14 +88,13 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
}
replaceAndEraseDirectComputeLikeInput(
rewriter, spatCompute.getOperation(), BBArgIndex, mapSpatToExtract[spatCompute.getOperation()]);
rewriter, spatCompute.getOperation(), *inputIndex, mapSpatToExtract[spatCompute.getOperation()]);
}
else if (auto spatComputeBatch = dyn_cast<spatial::SpatComputeBatch>(uses.getOwner())) {
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, uses.getOperandNumber());
if (!inputIndex)
return failure();
auto BBArgIndex = *inputIndex;
auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex);
auto BBArgValue = spatComputeBatch.getInputArgument(*inputIndex);
if (BBArgValue.use_empty())
continue;
@@ -108,7 +106,7 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
}
replaceAndEraseDirectComputeLikeInput(
rewriter, spatComputeBatch.getOperation(), BBArgIndex, mapSpatToExtract[spatComputeBatch.getOperation()]);
rewriter, spatComputeBatch.getOperation(), *inputIndex, mapSpatToExtract[spatComputeBatch.getOperation()]);
}
else {
{
@@ -254,7 +252,7 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::Con
}
}
else if (constantOp.getType().isIntOrIndexOrFloat()) {
llvm::DenseMap<Operation*, Value> mapSpatComputeToConst;
Value hostConstant = constantOp.getResult();
for (auto& constUses : llvm::make_early_inc_range(constantOp->getUses())) {
auto constUsers = constUses.getOwner();
@@ -264,40 +262,22 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::Con
if (!inputIndex)
return failure();
auto BBArgIndex = *inputIndex;
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
auto newConst = rewriter.clone(*constantOp);
replaceAndEraseDirectComputeLikeInput(
rewriter, spatCompute.getOperation(), BBArgIndex, newConst->getResult(0));
replaceAndEraseDirectComputeLikeInput(rewriter, spatCompute.getOperation(), BBArgIndex, hostConstant);
}
else if (auto spatComputeBatch = llvm::dyn_cast<spatial::SpatComputeBatch>(constUsers)) {
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, constUses.getOperandNumber());
if (!inputIndex)
return failure();
auto BBArgIndex = *inputIndex;
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
auto newConst = rewriter.clone(*constantOp);
replaceAndEraseDirectComputeLikeInput(
rewriter, spatComputeBatch.getOperation(), BBArgIndex, newConst->getResult(0));
replaceAndEraseDirectComputeLikeInput(rewriter, spatComputeBatch.getOperation(), BBArgIndex, hostConstant);
}
else if (auto parent = constUsers->getParentOfType<spatial::SpatCompute>()) {
if (!mapSpatComputeToConst.contains(parent)) {
rewriter.setInsertionPoint(&parent.getBody().front().front());
auto newConst = rewriter.clone(*constantOp);
mapSpatComputeToConst.insert({parent.getOperation(), newConst->getResult(0)});
}
constUses.set(mapSpatComputeToConst[parent.getOperation()]);
else if (constUsers->getParentOfType<spatial::SpatCompute>()) {
constUses.set(hostConstant);
}
else {
auto batchParent = constUsers->getParentOfType<spatial::SpatComputeBatch>();
assert(batchParent && "Global Constant used direcly not within a compute");
if (!mapSpatComputeToConst.contains(batchParent.getOperation())) {
rewriter.setInsertionPoint(&batchParent.getBody().front().front());
auto newConst = rewriter.clone(*constantOp);
mapSpatComputeToConst.insert({batchParent.getOperation(), newConst->getResult(0)});
}
constUses.set(mapSpatComputeToConst[batchParent.getOperation()]);
constUses.set(hostConstant);
}
}
}
@@ -6,8 +6,10 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Transforms/FoldUtils.h"
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
@@ -318,7 +320,8 @@ static LogicalResult mapIndicesThroughHelperChain(ArrayRef<int64_t> sourceIndice
return success();
}
static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter) {
static void
cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter, OperationFolder& constantFolder) {
for (Value operand : op->getOperands()) {
if (mapping.lookupOrNull(operand))
continue;
@@ -327,7 +330,12 @@ static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewri
if (!definingOp)
continue;
if (!isa<tensor::EmptyOp, arith::ConstantOp>(definingOp))
if (auto constantOp = dyn_cast<arith::ConstantOp>(definingOp)) {
mapping.map(operand, getOrCreateHostConstantLike(constantOp, constantFolder));
continue;
}
if (!isa<tensor::EmptyOp>(definingOp))
continue;
Operation* clonedOp = rewriter.clone(*definingOp, mapping);
@@ -337,15 +345,18 @@ static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewri
}
}
static void
cloneHelperChain(Value sourceValue, ArrayRef<Operation*> helperChain, IRRewriter& rewriter, Value& clonedValue) {
static void cloneHelperChain(Value sourceValue,
ArrayRef<Operation*> helperChain,
IRRewriter& rewriter,
OperationFolder& constantFolder,
Value& clonedValue) {
IRMapping mapping;
mapping.map(sourceValue, sourceValue);
clonedValue = sourceValue;
rewriter.setInsertionPointAfterValue(sourceValue);
for (Operation* op : helperChain) {
cloneMappedHelperOperands(op, mapping, rewriter);
cloneMappedHelperOperands(op, mapping, rewriter, constantFolder);
Operation* clonedOp = rewriter.clone(*op, mapping);
for (auto [originalResult, newResult] : llvm::zip(op->getResults(), clonedOp->getResults()))
mapping.map(originalResult, newResult);
@@ -360,14 +371,19 @@ static Value emitHostCopy(IRRewriter& rewriter,
Value sourceValue,
int32_t hostTargetOffset,
int32_t deviceSourceOffset,
int32_t sizeInBytes) {
int32_t sizeInBytes,
OperationFolder& constantFolder) {
Operation* anchorOp = sourceValue.getDefiningOp() ? sourceValue.getDefiningOp() : outputTensor.getDefiningOp();
assert(anchorOp && "expected a concrete op anchor for return-path host copy constants");
Value hostTargetOffsetValue = getOrCreateHostIndexConstant(anchorOp, hostTargetOffset, constantFolder);
Value deviceSourceOffsetValue = getOrCreateHostIndexConstant(anchorOp, deviceSourceOffset, constantFolder);
return PimMemCopyDevToHostOp::create(rewriter,
loc,
outputTensor.getType(),
hostTargetOffsetValue,
deviceSourceOffsetValue,
outputTensor,
sourceValue,
rewriter.getI32IntegerAttr(hostTargetOffset),
rewriter.getI32IntegerAttr(deviceSourceOffset),
rewriter.getI32IntegerAttr(sizeInBytes))
.getOutput();
}
@@ -411,69 +427,84 @@ void addReturnOutputBuffers(func::ReturnOp returnOp,
}
}
ReturnPathLoweringResult lowerComputeResultReturnPath(
spatial::SpatCompute computeOp, OpResult result, Value yieldValue, ReturnPathState& state, IRRewriter& rewriter) {
Location loc = computeOp->getLoc();
auto yieldType = cast<TensorType>(yieldValue.getType());
ReturnPathLoweringResult lowerProducedValueReturnPath(
Operation* producerOp, Value producedValue, Value storedValue, ReturnPathState& state, IRRewriter& rewriter) {
Location loc = producerOp->getLoc();
OperationFolder constantFolder(producerOp->getContext());
auto storedTensorType = cast<TensorType>(storedValue.getType());
if (auto returnUse = analyzeReturnUse(result)) {
Value storedValue = yieldValue;
cloneHelperChain(yieldValue, returnUse->helperChain, rewriter, storedValue);
if (auto returnUse = analyzeReturnUse(producedValue)) {
Value currentStoredValue = storedValue;
cloneHelperChain(storedValue, returnUse->helperChain, rewriter, constantFolder, currentStoredValue);
for (Operation* op : returnUse->helperChain)
markOpToRemove(state, op);
auto storedType = cast<ShapedType>(storedValue.getType());
auto storedType = cast<ShapedType>(currentStoredValue.getType());
size_t elementSize = storedType.getElementTypeBitWidth() / 8;
if (auto storedOp = storedValue.getDefiningOp())
if (auto storedOp = currentStoredValue.getDefiningOp())
rewriter.setInsertionPointAfter(storedOp);
Value outputTensor = state.outputTensors[returnUse->returnIndex](rewriter, loc);
emitHostCopy(
rewriter, loc, outputTensor, storedValue, 0, 0, static_cast<int32_t>(storedType.getNumElements() * elementSize));
emitHostCopy(rewriter,
loc,
outputTensor,
currentStoredValue,
0,
0,
static_cast<int32_t>(storedType.getNumElements() * elementSize),
constantFolder);
return ReturnPathLoweringResult::Handled;
}
auto resultUses = result.getUses();
auto resultUses = producedValue.getUses();
if (rangeLength(resultUses) == 1) {
OpOperand& resultUse = *resultUses.begin();
Operation* resultUser = resultUse.getOwner();
if (isa<func::ReturnOp>(resultUser)) {
size_t resultIndexInReturn = resultUse.getOperandNumber();
size_t elementSize = yieldType.getElementType().getIntOrFloatBitWidth() / 8;
rewriter.setInsertionPointAfterValue(yieldValue);
size_t elementSize = storedTensorType.getElementType().getIntOrFloatBitWidth() / 8;
rewriter.setInsertionPointAfterValue(storedValue);
Value outputTensor = state.outputTensors[resultIndexInReturn](rewriter, loc);
emitHostCopy(
rewriter, loc, outputTensor, yieldValue, 0, 0, static_cast<int32_t>(yieldType.getNumElements() * elementSize));
emitHostCopy(rewriter,
loc,
outputTensor,
storedValue,
0,
0,
static_cast<int32_t>(storedTensorType.getNumElements() * elementSize),
constantFolder);
return ReturnPathLoweringResult::Handled;
}
}
if (auto concatReturnUse = analyzeConcatReturnUse(result)) {
size_t elementSize = yieldType.getElementTypeBitWidth() / 8;
if (auto concatReturnUse = analyzeConcatReturnUse(producedValue)) {
size_t elementSize = storedTensorType.getElementTypeBitWidth() / 8;
for (Operation* concatOp : concatReturnUse->concatChain)
markOpToRemove(state, concatOp);
if (concatReturnUse->helperChain.empty()) {
rewriter.setInsertionPointAfterValue(yieldValue);
rewriter.setInsertionPointAfterValue(storedValue);
Value outputTensor = state.outputTensors[concatReturnUse->returnIndex](rewriter, loc);
auto outputType = cast<ShapedType>(outputTensor.getType());
int64_t flatOffset = computeFlatElementIndex(concatReturnUse->sliceOffsets, outputType.getShape());
emitHostCopy(rewriter,
loc,
outputTensor,
yieldValue,
storedValue,
static_cast<int32_t>(flatOffset * elementSize),
0,
static_cast<int32_t>(yieldType.getNumElements() * elementSize));
static_cast<int32_t>(storedTensorType.getNumElements() * elementSize),
constantFolder);
return ReturnPathLoweringResult::Handled;
}
auto storedType = dyn_cast<RankedTensorType>(yieldValue.getType());
auto storedType = dyn_cast<RankedTensorType>(storedValue.getType());
if (!storedType) {
computeOp.emitOpError("has an unsupported non-ranked concat-return helper yield during Spatial-to-PIM lowering");
producerOp->emitOpError(
"has an unsupported non-ranked concat-return helper yield during Spatial-to-PIM lowering");
return ReturnPathLoweringResult::Failure;
}
rewriter.setInsertionPointAfterValue(yieldValue);
rewriter.setInsertionPointAfterValue(storedValue);
Value outputTensor = state.outputTensors[concatReturnUse->returnIndex](rewriter, loc);
auto outputType = cast<ShapedType>(outputTensor.getType());
for (int64_t linearIndex = 0; linearIndex < storedType.getNumElements(); ++linearIndex) {
@@ -484,7 +515,7 @@ ReturnPathLoweringResult lowerComputeResultReturnPath(
SmallVector<int64_t> destinationIndices;
if (failed(mapIndicesThroughHelperChain(
sourceIndices, concatReturnUse->concatShape, concatReturnUse->helperChain, destinationIndices))) {
computeOp.emitOpError("has an unsupported concat-return helper chain during Spatial-to-PIM lowering");
producerOp->emitOpError("has an unsupported concat-return helper chain during Spatial-to-PIM lowering");
return ReturnPathLoweringResult::Failure;
}
@@ -503,7 +534,7 @@ ReturnPathLoweringResult lowerComputeResultReturnPath(
auto scalarTensorType =
RankedTensorType::get(SmallVector<int64_t>(storedType.getRank(), 1), storedType.getElementType());
auto elementSlice = tensor::ExtractSliceOp::create(
rewriter, loc, scalarTensorType, yieldValue, extractOffsets, extractSizes, extractStrides);
rewriter, loc, scalarTensorType, storedValue, extractOffsets, extractSizes, extractStrides);
rewriter.setInsertionPointAfter(elementSlice);
int64_t destinationFlatOffset = computeFlatElementIndex(destinationIndices, outputType.getShape());
@@ -513,7 +544,8 @@ ReturnPathLoweringResult lowerComputeResultReturnPath(
elementSlice.getResult(),
static_cast<int32_t>(destinationFlatOffset * elementSize),
0,
static_cast<int32_t>(elementSize));
static_cast<int32_t>(elementSize),
constantFolder);
}
return ReturnPathLoweringResult::Handled;
}
@@ -521,6 +553,11 @@ ReturnPathLoweringResult lowerComputeResultReturnPath(
return ReturnPathLoweringResult::NotReturnPath;
}
ReturnPathLoweringResult lowerComputeResultReturnPath(
spatial::SpatCompute computeOp, OpResult result, Value yieldValue, ReturnPathState& state, IRRewriter& rewriter) {
return lowerProducedValueReturnPath(computeOp.getOperation(), result, yieldValue, state, rewriter);
}
void replaceReturnWithOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewriter, ReturnPathState& state) {
auto markOwnedReturnChain = [&](Operation* op, auto&& markOwnedReturnChain) -> void {
if (!op)
@@ -569,7 +606,16 @@ void replaceReturnWithOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewrite
markOpToRemove(state, concatOp);
for (Value operand : concatOp.getInputs())
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
return;
}
if (auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(op)) {
markOpToRemove(state, receiveOp);
return;
}
if (auto receiveTensorOp = dyn_cast<spatial::SpatChannelReceiveTensorOp>(op))
markOpToRemove(state, receiveTensorOp);
};
SmallVector<Value> originalOperands(returnOp.getOperands().begin(), returnOp.getOperands().end());
@@ -32,6 +32,12 @@ ReturnPathLoweringResult lowerComputeResultReturnPath(spatial::SpatCompute compu
ReturnPathState& state,
mlir::IRRewriter& rewriter);
ReturnPathLoweringResult lowerProducedValueReturnPath(mlir::Operation* producerOp,
mlir::Value producedValue,
mlir::Value storedValue,
ReturnPathState& state,
mlir::IRRewriter& rewriter);
void replaceReturnWithOutputBuffers(mlir::func::ReturnOp returnOp, mlir::IRRewriter& rewriter, ReturnPathState& state);
} // namespace onnx_mlir
@@ -16,8 +16,8 @@ def onnxToPimTranspose : Pat<
>;
def spatToPimVMM : Pat<
(SpatVMMOp:$srcOpRes $weightIndex, $vector),
(PimVMMOp $weightIndex, $vector,
(SpatVMMOp:$srcOpRes $weight, $vector),
(PimVMMOp $weight, $vector,
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
>;
@@ -3,6 +3,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/BuiltinOps.h"
@@ -12,6 +13,8 @@
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/Value.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
#include "llvm/ADT/StringRef.h"
@@ -104,23 +107,34 @@ static memref::GlobalOp getOrCreateZeroGlobal(IRRewriter& rewriter, Location loc
IntegerAttr {});
}
static Value createZeroedDeviceHVector(IRRewriter& rewriter, Location loc, RankedTensorType tensorType) {
static Value createZeroedDeviceHVector(IRRewriter& rewriter,
Location loc,
RankedTensorType tensorType,
OperationFolder& constantFolder) {
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, tensorType);
auto zeroGlobal = getOrCreateZeroGlobal(rewriter, loc, tensorType);
auto zeroValue = memref::GetGlobalOp::create(rewriter, loc, zeroGlobal.getType(), zeroGlobal.getName());
auto zeroAttr = rewriter.getI32IntegerAttr(0);
auto zeroIndex = getOrCreateHostIndexConstant(outputBuffer.getOperation(), 0, constantFolder);
auto sizeAttr = rewriter.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(tensorType)));
if (outputBuffer->getParentOfType<PimCoreBatchOp>())
return PimMemCopyHostToDevBatchOp::create(
rewriter, loc, tensorType, outputBuffer, zeroValue, zeroAttr, zeroAttr, sizeAttr)
return PimMemCopyHostToDevBatchOp::create(rewriter,
loc,
tensorType,
outputBuffer,
zeroValue,
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(0),
sizeAttr)
.getOutput();
return PimMemCopyHostToDevOp::create(rewriter, loc, tensorType, outputBuffer, zeroValue, zeroAttr, zeroAttr, sizeAttr)
return PimMemCopyHostToDevOp::create(
rewriter, loc, tensorType, zeroIndex, zeroIndex, outputBuffer, zeroValue, sizeAttr)
.getOutput();
}
static Value padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector) {
static Value
padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector, OperationFolder& constantFolder) {
auto vectorType = cast<RankedTensorType>(vector.getType());
ArrayRef<int64_t> shape = vectorType.getShape();
assert(isHVectorShape(shape) && "expected a horizontal vector");
@@ -131,25 +145,27 @@ static Value padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, V
auto paddedType = RankedTensorType::get(
{shape[0], static_cast<int64_t>(crossbarSize)}, vectorType.getElementType(), vectorType.getEncoding());
Value zeroed = createZeroedDeviceHVector(rewriter, loc, paddedType);
Value zeroed = createZeroedDeviceHVector(rewriter, loc, paddedType, constantFolder);
auto zeroAttr = rewriter.getI32IntegerAttr(0);
auto sizeAttr = rewriter.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(vectorType)));
return PimMemCopyOp::create(rewriter, loc, paddedType, zeroed, vector, zeroAttr, zeroAttr, sizeAttr).getOutput();
}
void SpatialToPimPass::runOnOperation() {
coreId = 1;
coreId = 0;
ModuleOp moduleOp = getOperation();
MLIRContext* ctx = moduleOp.getContext();
auto entryFunc = getPimEntryFunc(moduleOp);
if (failed(entryFunc)) {
moduleOp.emitError("failed to locate the PIM entry function during Spatial-to-PIM lowering");
signalPassFailure();
return;
}
func::FuncOp funcOp = *entryFunc;
IRRewriter rewriter(&getContext());
OperationFolder constantFolder(&getContext());
ConversionTarget target(*ctx);
target.addLegalDialect<PimDialect,
@@ -169,34 +185,32 @@ void SpatialToPimPass::runOnOperation() {
spatial::SpatChannelSendTensorBatchOp,
spatial::SpatExtractRowsOp>();
{
RewritePatternSet patterns(ctx);
populateWithGenerated(patterns);
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
signalPassFailure();
return;
}
}
{
RewritePatternSet patterns(ctx);
populateGlobalTensorMaterializationPatterns(patterns);
walkAndApplyPatterns(moduleOp, std::move(patterns));
}
auto returnOp = cast<func::ReturnOp>(funcOp.front().getTerminator());
addReturnOutputBuffers(returnOp, rewriter, outputTensors);
if (failed(allocateAndInitializeCoreLocalVariables(funcOp, rewriter))) {
RewritePatternSet initialPatterns(ctx);
populateWithGenerated(initialPatterns);
if (failed(applyPartialConversion(moduleOp, target, std::move(initialPatterns)))) {
moduleOp.emitError("failed to lower required Spatial ops to the initial PIM form");
signalPassFailure();
return;
}
CoreLoweringState coreLoweringState {coreId, outputTensors, operationsToRemove};
RewritePatternSet globalTensorPatterns(ctx);
populateGlobalTensorMaterializationPatterns(globalTensorPatterns);
walkAndApplyPatterns(moduleOp, std::move(globalTensorPatterns));
auto returnOp = cast<func::ReturnOp>(funcOp.front().getTerminator());
addReturnOutputBuffers(returnOp, rewriter, outputTensors);
ReturnPathState returnPathState {outputTensors, operationsToRemove};
if (failed(allocateAndInitializeCoreLocalVariables(funcOp, rewriter))) {
funcOp.emitOpError("failed to allocate or initialize core-local tensors during Spatial-to-PIM lowering");
signalPassFailure();
return;
}
CoreLoweringState coreLoweringState {coreId, outputTensors, operationsToRemove, constantFolder};
for (auto computeOp : funcOp.getOps<spatial::SpatCompute>()) {
markOpToRemove(computeOp);
if (failed(lowerComputeOp(computeOp, coreLoweringState, rewriter))) {
computeOp.emitOpError("failed to lower spat.compute to pim.core");
signalPassFailure();
return;
}
@@ -205,17 +219,16 @@ void SpatialToPimPass::runOnOperation() {
for (auto computeBatchOp : funcOp.getOps<spatial::SpatComputeBatch>()) {
markOpToRemove(computeBatchOp);
if (failed(lowerComputeBatchOp(computeBatchOp, coreLoweringState, rewriter))) {
computeBatchOp.emitOpError("failed to lower spat.compute_batch to pim.core_batch");
signalPassFailure();
return;
}
}
{
RewritePatternSet patterns(ctx);
populateTensorPackingPatterns(patterns);
walkAndApplyPatterns(funcOp, std::move(patterns));
eraseUnusedTensorPackingOps(funcOp, rewriter);
}
RewritePatternSet initialTensorPackingPatterns(ctx);
populateTensorPackingPatterns(initialTensorPackingPatterns);
walkAndApplyPatterns(funcOp, std::move(initialTensorPackingPatterns));
eraseUnusedTensorPackingOps(funcOp, rewriter);
SmallVector<spatial::SpatChannelReceiveOp> receiveOps;
for (auto op : funcOp.getOps<spatial::SpatChannelReceiveOp>())
@@ -229,74 +242,72 @@ void SpatialToPimPass::runOnOperation() {
}
}
{
RewritePatternSet coreBodyPatterns(ctx);
populateWithGenerated(coreBodyPatterns);
FrozenRewritePatternSet frozenCoreBodyPatterns(std::move(coreBodyPatterns));
RewritePatternSet coreBodyPatterns(ctx);
populateWithGenerated(coreBodyPatterns);
FrozenRewritePatternSet frozenCoreBodyPatterns(std::move(coreBodyPatterns));
SmallVector<pim::PimCoreOp> coreOps;
funcOp.walk([&](pim::PimCoreOp coreOp) { coreOps.push_back(coreOp); });
for (auto coreOp : coreOps) {
if (failed(applyFullConversion(coreOp.getOperation(), target, frozenCoreBodyPatterns))) {
signalPassFailure();
return;
}
}
SmallVector<pim::PimCoreBatchOp> coreBatchOps;
funcOp.walk([&](pim::PimCoreBatchOp coreBatchOp) { coreBatchOps.push_back(coreBatchOp); });
for (auto coreBatchOp : coreBatchOps) {
if (failed(applyFullConversion(coreBatchOp.getOperation(), target, frozenCoreBodyPatterns))) {
signalPassFailure();
return;
}
}
}
enlargeVMMOutTensorsToCrossbarSize(funcOp, rewriter);
ReturnPathState returnPathState {outputTensors, operationsToRemove};
replaceReturnWithOutputBuffers(returnOp, rewriter, returnPathState);
SmallVector<Operation*> pendingRemovals(operationsToRemove.begin(), operationsToRemove.end());
if (failed(erasePendingOps(pendingRemovals, rewriter))) {
signalPassFailure();
return;
}
{
RewritePatternSet patterns(ctx);
populateTensorPackingPatterns(patterns);
walkAndApplyPatterns(funcOp, std::move(patterns));
eraseUnusedTensorPackingOps(funcOp, rewriter);
}
{
ConversionTarget communicationTarget(*ctx);
communicationTarget.addLegalDialect<PimDialect,
tensor::TensorDialect,
arith::ArithDialect,
bufferization::BufferizationDialect,
func::FuncDialect,
memref::MemRefDialect,
scf::SCFDialect,
BuiltinDialect>();
communicationTarget.addLegalOp<ModuleOp>();
communicationTarget.addIllegalOp<spatial::SpatConcatOp,
spatial::SpatChannelReceiveOp,
spatial::SpatChannelReceiveTensorOp,
spatial::SpatChannelSendOp,
spatial::SpatChannelSendTensorOp,
spatial::SpatExtractRowsOp>();
RewritePatternSet communicationPatterns(ctx);
populateChannelLoweringPatterns(communicationPatterns);
if (failed(applyFullConversion(funcOp, communicationTarget, std::move(communicationPatterns)))) {
SmallVector<pim::PimCoreOp> coreOps;
funcOp.walk([&](pim::PimCoreOp coreOp) { coreOps.push_back(coreOp); });
for (auto coreOp : coreOps) {
if (failed(applyFullConversion(coreOp.getOperation(), target, frozenCoreBodyPatterns))) {
coreOp.emitOpError("failed to convert nested Spatial ops inside pim.core");
signalPassFailure();
return;
}
}
SmallVector<pim::PimCoreBatchOp> coreBatchOps;
funcOp.walk([&](pim::PimCoreBatchOp coreBatchOp) { coreBatchOps.push_back(coreBatchOp); });
for (auto coreBatchOp : coreBatchOps) {
if (failed(applyFullConversion(coreBatchOp.getOperation(), target, frozenCoreBodyPatterns))) {
coreBatchOp.emitOpError("failed to convert nested Spatial ops inside pim.core_batch");
signalPassFailure();
return;
}
}
enlargeVMMOutTensorsToCrossbarSize(funcOp, rewriter);
replaceReturnWithOutputBuffers(returnOp, rewriter, returnPathState);
SmallVector<Operation*> pendingRemovals(operationsToRemove.begin(), operationsToRemove.end());
if (failed(erasePendingOps(pendingRemovals, rewriter))) {
funcOp.emitOpError("failed to erase obsolete Spatial ops after lowering to PIM");
signalPassFailure();
return;
}
RewritePatternSet finalTensorPackingPatterns(ctx);
populateTensorPackingPatterns(finalTensorPackingPatterns);
walkAndApplyPatterns(funcOp, std::move(finalTensorPackingPatterns));
eraseUnusedTensorPackingOps(funcOp, rewriter);
ConversionTarget communicationTarget(*ctx);
communicationTarget.addLegalDialect<PimDialect,
tensor::TensorDialect,
arith::ArithDialect,
bufferization::BufferizationDialect,
func::FuncDialect,
memref::MemRefDialect,
scf::SCFDialect,
BuiltinDialect>();
communicationTarget.addLegalOp<ModuleOp>();
communicationTarget.addIllegalOp<spatial::SpatConcatOp,
spatial::SpatChannelReceiveOp,
spatial::SpatChannelReceiveTensorOp,
spatial::SpatChannelSendOp,
spatial::SpatChannelSendTensorOp,
spatial::SpatExtractRowsOp>();
RewritePatternSet communicationPatterns(ctx);
populateChannelLoweringPatterns(communicationPatterns);
if (failed(applyFullConversion(funcOp, communicationTarget, std::move(communicationPatterns)))) {
funcOp.emitOpError("failed to lower Spatial communication ops to PIM communication ops");
signalPassFailure();
return;
}
if (failed(verifySpatialToPimBoundary(moduleOp))) {
moduleOp.emitError("Spatial-to-PIM boundary verification failed");
signalPassFailure();
return;
}
@@ -306,6 +317,7 @@ void SpatialToPimPass::runOnOperation() {
}
void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) {
OperationFolder constantFolder(funcOp.getContext());
funcOp.walk([&](PimVMMOp vmmOp) {
auto outputType = cast<RankedTensorType>(vmmOp.getOutput().getType());
ArrayRef<int64_t> outputShape = outputType.getShape();
@@ -313,7 +325,7 @@ void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, I
assert(outputShape[1] <= static_cast<int64_t>(crossbarSize) && "output width must fit in one crossbar");
rewriter.setInsertionPoint(vmmOp);
Value paddedInput = padHVectorInputToCrossbarSize(rewriter, vmmOp.getLoc(), vmmOp.getInput());
Value paddedInput = padHVectorInputToCrossbarSize(rewriter, vmmOp.getLoc(), vmmOp.getInput(), constantFolder);
auto paddedOutputType = RankedTensorType::get(
{outputShape[0], static_cast<int64_t>(crossbarSize)}, outputType.getElementType(), outputType.getEncoding());
Value paddedOutputBuffer = outputShape[1] == static_cast<int64_t>(crossbarSize)
@@ -340,10 +352,13 @@ void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, I
LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) {
Location loc = funcOp.getLoc();
OperationFolder constantFolder(funcOp.getContext());
auto insertMemCopyHostToDev = [&](Value inputTensor, int64_t elementsOffset) {
auto tensorType = cast<ShapedType>(inputTensor.getType());
Type elementType = tensorType.getElementType();
if (!elementType.isIntOrFloat())
return;
size_t elementByteSize = elementType.getIntOrFloatBitWidth() / 8;
rewriter.setInsertionPointAfter(inputTensor.getDefiningOp());
@@ -353,10 +368,11 @@ LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::Fu
rewriter,
loc,
tensorType,
getOrCreateHostIndexConstant(deviceTensor.getOperation(), 0, constantFolder),
getOrCreateHostIndexConstant(
deviceTensor.getOperation(), static_cast<int64_t>(elementsOffset * elementByteSize), constantFolder),
deviceTensor,
inputTensor,
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(static_cast<int32_t>(elementsOffset * elementByteSize)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(tensorType.getNumElements() * elementByteSize)));
rewriter.replaceAllUsesExcept(inputTensor, memCopyHostToDevOp.getResult(), {memCopyHostToDevOp});
@@ -1,5 +1,4 @@
#include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir;
@@ -75,16 +74,14 @@ struct PackSpatialConcatInputsPattern final : OpRewritePattern<spatial::SpatConc
return failure();
auto outputType = cast<ShapedType>(concatOp.getOutput().getType());
auto newConcat = pim::PimConcatOp::create(rewriter,
concatOp.getLoc(),
concatOp.getOutput().getType(),
concatOp.getAxisAttr(),
ValueRange(packedInputs),
tensor::EmptyOp::create(rewriter,
concatOp.getLoc(),
outputType.getShape(),
outputType.getElementType())
.getResult());
auto newConcat = pim::PimConcatOp::create(
rewriter,
concatOp.getLoc(),
concatOp.getOutput().getType(),
concatOp.getAxisAttr(),
ValueRange(packedInputs),
tensor::EmptyOp::create(rewriter, concatOp.getLoc(), outputType.getShape(), outputType.getElementType())
.getResult());
rewriter.replaceOp(concatOp, newConcat.getOutput());
return success();
}
@@ -1,7 +1,7 @@
#pragma once
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/PatternMatch.h"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
+34 -16
View File
@@ -2,6 +2,7 @@
#define PIM_DIALECT_H
include "mlir/IR/OpBase.td"
include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/Dialect/MemRef/IR/MemRefBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -24,7 +25,8 @@ def PimTensor :
// Execution
//===----------------------------------------------------------------------===//
def PimCoreOp : PimOp<"core", [SingleBlock, IsolatedFromAbove]> {
def PimCoreOp : PimOp<"core", [SingleBlock,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>]> {
let summary = "Execute a block on a PIM core";
let regions = (region SizedRegion<1>:$body);
@@ -34,12 +36,16 @@ def PimCoreOp : PimOp<"core", [SingleBlock, IsolatedFromAbove]> {
I32Attr:$coreId
);
let assemblyFormat = [{
`(` $weights `)` attr-dict regions `:` type($weights) `->` `(` `)`
let extraClassDeclaration = [{
::mlir::BlockArgument getWeightArgument(unsigned idx);
}];
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
def PimCoreBatchOp : PimOp<"core_batch", [SingleBlock, IsolatedFromAbove, AttrSizedOperandSegments]> {
def PimCoreBatchOp : PimOp<"core_batch", [SingleBlock, AttrSizedOperandSegments,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>]> {
let summary = "Execute equivalent batched core bodies";
let regions = (region SizedRegion<1>:$body);
@@ -50,6 +56,13 @@ def PimCoreBatchOp : PimOp<"core_batch", [SingleBlock, IsolatedFromAbove, AttrSi
Variadic<PimTensor>:$inputs
);
let extraClassDeclaration = [{
::mlir::BlockArgument getLaneArgument();
::mlir::BlockArgument getWeightArgument(unsigned idx);
::mlir::BlockArgument getInputArgument(unsigned idx);
}];
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
@@ -81,11 +94,11 @@ def PimSendOp : PimOp<"send", []> {
let arguments = (ins
PimTensor:$input,
I32Attr:$size,
I32Attr:$targetCoreId
Index:$targetCoreId
);
let assemblyFormat = [{
`(` $input `)` attr-dict `:` type($input) `->` `(` `)`
`(` $input `,` $targetCoreId `)` attr-dict `:` type($input) `->` `(` `)`
}];
}
@@ -131,7 +144,7 @@ def PimReceiveOp : PimOp<"receive", [DestinationStyleOpInterface]> {
let arguments = (ins
PimTensor:$outputBuffer,
I32Attr:$size,
I32Attr:$sourceCoreId
Index:$sourceCoreId
);
let results = (outs
@@ -145,7 +158,7 @@ def PimReceiveOp : PimOp<"receive", [DestinationStyleOpInterface]> {
}];
let assemblyFormat = [{
`(` $outputBuffer `)` attr-dict `:` type($outputBuffer) `->` type($output)
`(` $outputBuffer `,` $sourceCoreId `)` attr-dict `:` type($outputBuffer) `->` type($output)
}];
}
@@ -219,10 +232,10 @@ def PimMemCopyHostToDevOp : PimOp<"memcp_hd", [DestinationStyleOpInterface]> {
let summary = "Copy a memory region from host memory into device memory";
let arguments = (ins
Index:$deviceTargetOffset,
Index:$hostSourceOffset,
PimTensor:$deviceTarget,
PimTensor:$hostSource,
I32Attr:$deviceTargetOffset,
I32Attr:$hostSourceOffset,
I32Attr:$size
);
@@ -237,7 +250,9 @@ def PimMemCopyHostToDevOp : PimOp<"memcp_hd", [DestinationStyleOpInterface]> {
}];
let assemblyFormat = [{
`(` $deviceTarget `,` $hostSource `)` attr-dict `:` `(` type($deviceTarget) `,` type($hostSource) `)` `->` type($output)
`[` $deviceTargetOffset `,` $hostSourceOffset `]`
`(` $deviceTarget `,` $hostSource `)` attr-dict
`:` type($deviceTarget) `,` type($hostSource) `->` type($output)
}];
}
@@ -271,10 +286,10 @@ def PimMemCopyDevToHostOp : PimOp<"memcp_dh", [DestinationStyleOpInterface]> {
let summary = "Copy a memory region from device memory into host memory";
let arguments = (ins
Index:$hostTargetOffset,
Index:$deviceSourceOffset,
PimTensor:$hostTarget,
PimTensor:$deviceSource,
I32Attr:$hostTargetOffset,
I32Attr:$deviceSourceOffset,
I32Attr:$size
);
@@ -289,7 +304,9 @@ def PimMemCopyDevToHostOp : PimOp<"memcp_dh", [DestinationStyleOpInterface]> {
}];
let assemblyFormat = [{
`(` $hostTarget `,` $deviceSource `)` attr-dict `:` `(` type($hostTarget) `,` type($deviceSource) `)` `->` type($output)
`[` $hostTargetOffset `,` $deviceSourceOffset `]`
`(` $hostTarget `,` $deviceSource `)` attr-dict
`:` type($hostTarget) `,` type($deviceSource) `->` type($output)
}];
}
@@ -374,7 +391,7 @@ def PimVMMOp : PimOp<"vmm", [DestinationStyleOpInterface]> {
let summary = "Vector-matrix multiplication: c = a * b";
let arguments = (ins
I32Attr:$weightIndex,
PimTensor:$weight,
PimTensor:$input,
PimTensor:$outputBuffer
);
@@ -391,7 +408,8 @@ def PimVMMOp : PimOp<"vmm", [DestinationStyleOpInterface]> {
let hasVerifier = 1;
let assemblyFormat = [{
`(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output)
`[` $weight `]` `(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($weight) `,` type($input) `,`
type($outputBuffer) `)` `->` type($output)
}];
}
+33
View File
@@ -1,8 +1,41 @@
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include <string>
using namespace mlir;
namespace onnx_mlir {
namespace pim {
BlockArgument PimCoreOp::getWeightArgument(unsigned idx) { return getBody().front().getArgument(idx); }
void PimCoreOp::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
if (region.empty())
return;
for (unsigned index = 0; index < getWeights().size(); ++index)
setNameFn(getWeightArgument(index), ("w" + std::to_string(index)).c_str());
}
BlockArgument PimCoreBatchOp::getLaneArgument() { return getBody().front().getArgument(0); }
BlockArgument PimCoreBatchOp::getWeightArgument(unsigned idx) { return getBody().front().getArgument(1 + idx); }
BlockArgument PimCoreBatchOp::getInputArgument(unsigned idx) {
return getBody().front().getArgument(1 + getWeights().size() + idx);
}
void PimCoreBatchOp::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
if (region.empty())
return;
setNameFn(getLaneArgument(), "lane");
for (unsigned index = 0; index < getWeights().size(); ++index)
setNameFn(getWeightArgument(index), ("w" + std::to_string(index)).c_str());
for (unsigned index = 0; index < getInputs().size(); ++index)
setNameFn(getInputArgument(index), ("in" + std::to_string(index)).c_str());
}
void PimDialect::initialize() {
addOperations<
#define GET_OP_LIST
+180 -31
View File
@@ -20,6 +20,80 @@ static DenseI32ArrayAttr getDenseI32ArrayAttr(OpAsmParser& parser, ArrayRef<int3
return parser.getBuilder().getDenseI32ArrayAttr(values);
}
static IntegerAttr getI32Attr(OpAsmParser& parser, int32_t value) {
return parser.getBuilder().getI32IntegerAttr(value);
}
static bool parseOptionalKeywordAlias(OpAsmParser& parser, StringRef preferred, StringRef legacy) {
return succeeded(parser.parseOptionalKeyword(preferred)) || succeeded(parser.parseOptionalKeyword(legacy));
}
static void printBlockArgumentList(OpAsmPrinter& printer, ArrayRef<BlockArgument> arguments) {
printer << "(";
for (auto [index, argument] : llvm::enumerate(arguments)) {
if (index != 0)
printer << ", ";
printer.printOperand(argument);
}
printer << ")";
}
static ParseResult parseBlockArgumentList(OpAsmParser& parser, SmallVectorImpl<OpAsmParser::Argument>& arguments) {
if (parser.parseLParen())
return failure();
if (succeeded(parser.parseOptionalRParen()))
return success();
OpAsmParser::Argument argument;
if (parser.parseArgument(argument))
return failure();
arguments.push_back(argument);
while (succeeded(parser.parseOptionalComma())) {
if (parser.parseArgument(argument))
return failure();
arguments.push_back(argument);
}
return parser.parseRParen();
}
static void printBoundValueList(OpAsmPrinter& printer, ValueRange arguments, ValueRange operands, ListDelimiter delimiter) {
printCompressedValueList(printer, arguments, delimiter);
printer << " = ";
printCompressedValueList(printer, operands, delimiter);
}
static ParseResult parseBoundValueList(OpAsmParser& parser,
ListDelimiter delimiter,
SmallVectorImpl<OpAsmParser::Argument>& arguments,
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
if (parseOpenDelimiter(parser, delimiter))
return failure();
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) {
if (parser.parseEqual() || parseCompressedOperandList(parser, delimiter, operands))
return failure();
return success();
}
if (parseOneCompressedArgumentEntry(parser, arguments))
return failure();
while (succeeded(parser.parseOptionalComma()))
if (parseOneCompressedArgumentEntry(parser, arguments))
return failure();
auto parseCloseDelimiter = [&](ListDelimiter currentDelimiter) -> ParseResult {
switch (currentDelimiter) {
case ListDelimiter::Paren:
return parser.parseRParen();
case ListDelimiter::Square:
return parser.parseRSquare();
}
llvm_unreachable("unsupported delimiter");
};
if (parseCloseDelimiter(delimiter) || parser.parseEqual() || parseCompressedOperandList(parser, delimiter, operands))
return failure();
return success();
}
static void printCoreIdList(OpAsmPrinter& printer, StringRef keyword, ArrayRef<int32_t> coreIds) {
printer << " " << keyword << " ";
printCompressedIntegerList(printer, coreIds);
@@ -33,15 +107,76 @@ static ParseResult parseOptionalCoreIdList(OpAsmParser& parser, StringRef keywor
} // namespace
void PimCoreBatchOp::print(OpAsmPrinter& printer) {
printer << " lanes " << getLaneCount() << " ";
size_t weightsPerLane = getLaneCount() > 0 ? getWeights().size() / static_cast<size_t>(getLaneCount()) : 0;
if (getLaneCount() > 1 && hasRepeatedTuple(getWeights(), weightsPerLane))
printValueTupleRun(printer, getWeights(), weightsPerLane, ListDelimiter::Paren);
else
printCompressedValueList(printer, getWeights(), ListDelimiter::Paren);
void PimCoreOp::print(OpAsmPrinter& printer) {
SmallVector<Value> weightArgs;
weightArgs.reserve(getWeights().size());
for (unsigned index = 0; index < getWeights().size(); ++index)
weightArgs.push_back(getWeightArgument(index));
printer << " ";
printCompressedValueList(printer, getInputs(), ListDelimiter::Square);
printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square);
printer << " coreId " << getCoreId();
printer.printOptionalAttrDict((*this)->getAttrs(), {getCoreIdAttrName().getValue()});
printer << " : ";
printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square);
printer << " -> () ";
printer.printRegion(getBody(), /*printEntryBlockArgs=*/false);
}
ParseResult PimCoreOp::parse(OpAsmParser& parser, OperationState& result) {
SmallVector<OpAsmParser::Argument> weightArgs;
SmallVector<OpAsmParser::UnresolvedOperand> weights;
SmallVector<Type> weightTypes;
int32_t coreId = 0;
if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights))
return failure();
bool hasCoreId = parseOptionalKeywordAlias(parser, "coreId", "core_id");
if (hasCoreId && parser.parseInteger(coreId))
return failure();
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parseCompressedRepeatedList(
parser, ListDelimiter::Square, weightTypes, [&](Type& type) { return parser.parseType(type); })
|| parser.parseArrow() || parser.parseLParen() || parser.parseRParen())
return failure();
if (weights.size() != weightTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
if (weightArgs.size() != weights.size())
return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match");
if (hasCoreId && result.attributes.get("coreId"))
return parser.emitError(parser.getCurrentLocation(),
"coreId cannot be specified both positionally and in attr-dict");
if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands))
return failure();
if (hasCoreId)
result.addAttribute("coreId", getI32Attr(parser, coreId));
Region* body = result.addRegion();
applyArgumentTypes(weightTypes, weightArgs);
return parser.parseRegion(*body, weightArgs);
}
void PimCoreBatchOp::print(OpAsmPrinter& printer) {
printer << " ";
printer.printOperand(getLaneArgument());
printer << " = 0 to " << getLaneCount() << " ";
SmallVector<Value> weightArgs;
weightArgs.reserve(getWeights().size());
for (unsigned index = 0; index < getWeights().size(); ++index)
weightArgs.push_back(getWeightArgument(index));
printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square);
printer << " ";
SmallVector<Value> inputArgs;
inputArgs.reserve(getInputs().size());
for (unsigned index = 0; index < getInputs().size(); ++index)
inputArgs.push_back(getInputArgument(index));
printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren);
if (auto coreIdsAttr = (*this)->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
printCoreIdList(printer, "coreIds", coreIdsAttr.asArrayRef());
@@ -49,51 +184,57 @@ void PimCoreBatchOp::print(OpAsmPrinter& printer) {
printer.printOptionalAttrDict(
(*this)->getAttrs(),
{getLaneCountAttrName().getValue(), getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdsAttrName});
printer << " ";
printer.printRegion(getBody(), /*printEntryBlockArgs=*/false);
printer << " : ";
if (getLaneCount() > 1 && hasRepeatedTuple(TypeRange(getWeights()), weightsPerLane))
printTypeTupleRun(printer, TypeRange(getWeights()), weightsPerLane, ListDelimiter::Paren);
else
printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Paren);
printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square);
printer << " ";
printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Square);
printer << " -> ()";
printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren);
printer << " -> () ";
printer.printRegion(getBody(), /*printEntryBlockArgs=*/false);
}
ParseResult PimCoreBatchOp::parse(OpAsmParser& parser, OperationState& result) {
int64_t lowerBound = 0;
int32_t laneCount = 0;
OpAsmParser::Argument laneArg;
SmallVector<OpAsmParser::Argument> weightArgs;
SmallVector<OpAsmParser::Argument> inputArgs;
SmallVector<OpAsmParser::Argument> regionArgs;
SmallVector<OpAsmParser::UnresolvedOperand> weights;
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
SmallVector<Type> weightTypes;
SmallVector<Type> inputTypes;
SmallVector<int32_t> coreIds;
if (parser.parseKeyword("lanes") || parser.parseInteger(laneCount)
|| parseCompressedOrTupleOperandList(parser, ListDelimiter::Paren, weights)
|| parseCompressedOperandList(parser, ListDelimiter::Square, inputs))
if (parser.parseArgument(laneArg) || parser.parseEqual() || parser.parseInteger(lowerBound)
|| parser.parseKeyword("to") || parser.parseInteger(laneCount))
return failure();
if (lowerBound != 0)
return parser.emitError(parser.getCurrentLocation(), "core_batch currently requires a zero lower bound");
if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights)
|| parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs))
return failure();
bool hasCoreIds = succeeded(parser.parseOptionalKeyword("coreIds"));
bool hasCoreIds = parseOptionalKeywordAlias(parser, "coreIds", "core_ids");
if (hasCoreIds && parseCompressedIntegerList(parser, coreIds))
return failure();
if (parser.parseOptionalAttrDict(result.attributes))
return failure();
Region* body = result.addRegion();
if (parser.parseRegion(*body))
return failure();
if (parser.parseColon() || parseCompressedOrTupleTypeList(parser, ListDelimiter::Paren, weightTypes)
|| parseCompressedTypeList(parser, ListDelimiter::Square, inputTypes) || parser.parseArrow()
|| parser.parseLParen() || parser.parseRParen())
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|| parseCompressedRepeatedList(
parser, ListDelimiter::Square, weightTypes, [&](Type& type) { return parser.parseType(type); })
|| parseCompressedRepeatedList(
parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); })
|| parser.parseArrow() || parser.parseLParen() || parser.parseRParen())
return failure();
if (weights.size() != weightTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
if (weightArgs.size() != weights.size())
return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match");
if (inputs.size() != inputTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
if (inputArgs.size() != inputs.size())
return parser.emitError(parser.getCurrentLocation(), "number of input bindings and input operands must match");
if (hasCoreIds && result.attributes.get(onnx_mlir::kCoreIdsAttrName))
return parser.emitError(parser.getCurrentLocation(),
"coreIds cannot be specified both positionally and in attr-dict");
@@ -110,7 +251,15 @@ ParseResult PimCoreBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|| parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands)) {
return failure();
}
return success();
Region* body = result.addRegion();
laneArg.type = builder.getIndexType();
regionArgs.push_back(laneArg);
applyArgumentTypes(weightTypes, weightArgs);
llvm::append_range(regionArgs, weightArgs);
applyArgumentTypes(inputTypes, inputArgs);
llvm::append_range(regionArgs, inputArgs);
return parser.parseRegion(*body, regionArgs);
}
void PimYieldOp::print(OpAsmPrinter& printer) {
+86 -16
View File
@@ -1,5 +1,7 @@
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/TypeUtilities.h"
#include "llvm/Support/LogicalResult.h"
@@ -14,6 +16,52 @@ namespace pim {
namespace {
static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
if (isa<PimMemCopyHostToDevOp>(op))
return operandIndex == 3;
if (isa<PimMemCopyHostToDevBatchOp>(op))
return operandIndex == 1;
if (isa<PimMemCopyDevToHostOp>(op))
return operandIndex == 2;
return false;
}
static Region* getParentRegion(Value value) {
if (auto blockArgument = dyn_cast<BlockArgument>(value))
return blockArgument.getParentRegion();
Operation* definingOp = value.getDefiningOp();
return definingOp ? definingOp->getParentRegion() : nullptr;
}
static bool isDefinedInsideRegion(Value value, Region& region) {
Region* parentRegion = getParentRegion(value);
return parentRegion && (&region == parentRegion || region.isAncestor(parentRegion));
}
static bool isConstantExternalValue(Value value) {
Operation* definingOp = value.getDefiningOp();
return definingOp && definingOp->hasTrait<OpTrait::ConstantLike>();
}
static LogicalResult verifyOnlyConstantExternalValues(Operation* ownerOp, Region& region, StringRef kind) {
bool hasFailure = false;
region.walk([&](Operation* op) {
for (OpOperand& operand : op->getOpOperands()) {
Value value = operand.get();
if (isDefinedInsideRegion(value, region) || isConstantExternalValue(value)
|| isExplicitHostOperand(op, operand.getOperandNumber()))
continue;
InFlightDiagnostic diagnostic =
ownerOp->emitOpError() << kind << " body may only directly reference external constants";
diagnostic.attachNote(op->getLoc())
<< "non-constant external operand #" << operand.getOperandNumber() << " is used by " << op->getName();
hasFailure = true;
}
});
return success(!hasFailure);
}
static bool haveSameShapedContainerKind(Type lhs, Type rhs) {
return (isa<RankedTensorType>(lhs) && isa<RankedTensorType>(rhs)) || (isa<MemRefType>(lhs) && isa<MemRefType>(rhs));
}
@@ -78,24 +126,46 @@ verifyTensorBatchCommunication(Operation* op, Type type, ArrayRef<int32_t> coreI
return success();
}
static FailureOr<ArrayRef<int64_t>> getWeightShapeForVMM(Operation* op, size_t weightIndex) {
if (auto coreOp = op->getParentOfType<PimCoreOp>()) {
if (weightIndex >= coreOp.getWeights().size())
return failure();
return cast<ShapedType>(coreOp.getWeights()[weightIndex].getType()).getShape();
}
if (auto coreBatchOp = op->getParentOfType<PimCoreBatchOp>()) {
if (weightIndex >= coreBatchOp.getWeights().size())
return failure();
return cast<ShapedType>(coreBatchOp.getWeights()[weightIndex].getType()).getShape();
}
return failure();
static FailureOr<ArrayRef<int64_t>> getWeightShapeForVMM(Value weight) {
auto shapedType = dyn_cast<ShapedType>(weight.getType());
if (!shapedType)
return failure();
return shapedType.getShape();
}
} // namespace
LogicalResult PimCoreOp::verify() {
Block& block = getBody().front();
if (block.getNumArguments() != getWeights().size())
return emitError("core body must have one block argument per weight");
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) {
if (getWeightArgument(weightIndex).getType() != weight.getType())
return emitError("core weight block argument types must match weight operand types exactly");
}
return verifyOnlyConstantExternalValues(getOperation(), getBody(), "pim.core");
}
LogicalResult PimCoreBatchOp::verify() {
if (getLaneCount() <= 0)
return emitError("laneCount must be positive");
Block& block = getBody().front();
unsigned expectedArgCount = 1 + getWeights().size() + getInputs().size();
if (block.getNumArguments() != expectedArgCount)
return emitError("core_batch body must have lane, weight, and input block arguments");
if (!getLaneArgument().getType().isIndex())
return emitError("core_batch first block argument must have index type");
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) {
if (getWeightArgument(weightIndex).getType() != weight.getType())
return emitError("core_batch weight block argument types must match weight operand types exactly");
}
for (auto [inputIndex, input] : llvm::enumerate(getInputs())) {
if (getInputArgument(inputIndex).getType() != input.getType())
return emitError("core_batch input block argument types must match input operand types exactly");
}
return verifyOnlyConstantExternalValues(getOperation(), getBody(), "pim.core_batch");
}
LogicalResult PimSendTensorOp::verify() {
return verifyTensorCommunication(getOperation(), getInput().getType(), getTargetCoreIds(), "send_tensor");
}
@@ -126,9 +196,9 @@ LogicalResult PimVMMOp::verify() {
getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match")))
return failure();
auto matrixShapeOpt = getWeightShapeForVMM(getOperation(), getWeightIndex());
auto matrixShapeOpt = getWeightShapeForVMM(getWeight());
if (failed(matrixShapeOpt))
return emitError("must be nested inside pim.core or pim.core_batch with a valid weightIndex");
return emitError("weight must be a shaped value");
ArrayRef<int64_t> matrixShape = *matrixShapeOpt;
auto vectorType = dyn_cast<ShapedType>(getInput().getType());
@@ -38,10 +38,10 @@ struct MemCopyHostToDevOpInterface
replaceOpWithNewBufferizedOp<PimMemCopyHostToDevOp>(rewriter,
memCopyHostToDevOp,
deviceTargetMemRef.getType(),
memCopyHostToDevOp.getDeviceTargetOffset(),
memCopyHostToDevOp.getHostSourceOffset(),
deviceTargetMemRef,
hostSourceMemRef,
memCopyHostToDevOp.getDeviceTargetOffsetAttr(),
memCopyHostToDevOp.getHostSourceOffsetAttr(),
memCopyHostToDevOp.getSizeAttr());
return success();
}
@@ -96,10 +96,10 @@ struct MemCopyDevToHostOpInterface
replaceOpWithNewBufferizedOp<PimMemCopyDevToHostOp>(rewriter,
memCopyDevToHostOp,
hostTargetMemRef.getType(),
memCopyDevToHostOp.getHostTargetOffset(),
memCopyDevToHostOp.getDeviceSourceOffset(),
hostTargetMemRef,
deviceSourceMemRef,
memCopyDevToHostOp.getHostTargetOffsetAttr(),
memCopyDevToHostOp.getDeviceSourceOffsetAttr(),
memCopyDevToHostOp.getSizeAttr());
return success();
}
@@ -151,12 +151,8 @@ struct ReceiveOpInterface : DstBufferizableOpInterfaceExternalModel<ReceiveOpInt
if (failed(outputBufferOpt))
return failure();
replaceOpWithNewBufferizedOp<PimReceiveOp>(rewriter,
op,
outputBufferOpt->getType(),
*outputBufferOpt,
receiveOp.getSizeAttr(),
receiveOp.getSourceCoreIdAttr());
replaceOpWithNewBufferizedOp<PimReceiveOp>(
rewriter, op, outputBufferOpt->getType(), *outputBufferOpt, receiveOp.getSizeAttr(), receiveOp.getSourceCoreId());
return success();
}
};
@@ -302,7 +298,7 @@ struct SendOpInterface : BufferizableOpInterface::ExternalModel<SendOpInterface,
op,
materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter),
sendOp.getSizeAttr(),
sendOp.getTargetCoreIdAttr());
sendOp.getTargetCoreId());
return success();
}
};
@@ -368,6 +364,37 @@ struct CoreOpInterface : BufferizableOpInterface::ExternalModel<CoreOpInterface,
return {};
}
AliasingOpOperandList getAliasingOpOperands(Operation* op, Value value, const AnalysisState& state) const {
auto coreOp = cast<PimCoreOp>(op);
auto bbArg = dyn_cast<BlockArgument>(value);
if (!bbArg || bbArg.getOwner() != &coreOp.getBody().front())
return {};
unsigned weightIndex = bbArg.getArgNumber();
return {
{&coreOp->getOpOperand(weightIndex), BufferRelation::Equivalent}
};
}
bool isWritable(Operation* op, Value value, const AnalysisState& state) const { return false; }
FailureOr<BufferLikeType> getBufferType(Operation* op,
Value value,
const BufferizationOptions& options,
const BufferizationState& state,
SmallVector<Value>& invocationStack) const {
auto coreOp = cast<PimCoreOp>(op);
auto bbArg = dyn_cast<BlockArgument>(value);
if (!bbArg || bbArg.getOwner() != &coreOp.getBody().front())
return failure();
Value tiedWeight = coreOp.getWeights()[bbArg.getArgNumber()];
if (auto memRefType = dyn_cast<BufferLikeType>(tiedWeight.getType()))
return memRefType;
return bufferization::getBufferType(tiedWeight, options, state, invocationStack);
}
LogicalResult bufferize(Operation* op,
RewriterBase& rewriter,
const BufferizationOptions& options,
@@ -375,7 +402,10 @@ struct CoreOpInterface : BufferizableOpInterface::ExternalModel<CoreOpInterface,
auto coreOp = cast<PimCoreOp>(op);
bool alreadyBufferized =
llvm::all_of(coreOp.getWeights(), [](Value weight) { return isa<BufferLikeType>(weight.getType()); });
llvm::all_of(coreOp.getWeights(), [](Value weight) { return isa<BufferLikeType>(weight.getType()); })
&& llvm::all_of(coreOp.getBody().front().getArguments(), [](BlockArgument arg) {
return !isa<TensorType>(arg.getType()) || isa<BufferLikeType>(arg.getType());
});
if (alreadyBufferized)
return success();
@@ -420,9 +450,17 @@ struct CoreBatchOpInterface : BufferizableOpInterface::ExternalModel<CoreBatchOp
if (!bbArg || bbArg.getOwner() != &coreBatchOp.getBody().front())
return {};
unsigned inputOperandIndex = coreBatchOp.getWeights().size() + bbArg.getArgNumber();
unsigned argNumber = bbArg.getArgNumber();
if (argNumber == 0)
return {};
unsigned weightCount = coreBatchOp.getWeights().size();
unsigned operandIndex = argNumber - 1;
if (argNumber > weightCount + 1)
operandIndex = weightCount + (argNumber - 1 - weightCount);
return {
{&coreBatchOp->getOpOperand(inputOperandIndex), BufferRelation::Equivalent}
{&coreBatchOp->getOpOperand(operandIndex), BufferRelation::Equivalent}
};
}
@@ -438,11 +476,21 @@ struct CoreBatchOpInterface : BufferizableOpInterface::ExternalModel<CoreBatchOp
if (!bbArg || bbArg.getOwner() != &coreBatchOp.getBody().front())
return failure();
Value tiedInput = coreBatchOp.getInputs()[bbArg.getArgNumber()];
if (auto memRefType = dyn_cast<BufferLikeType>(tiedInput.getType()))
unsigned argNumber = bbArg.getArgNumber();
if (argNumber == 0)
return failure();
Value tiedOperand;
unsigned weightCount = coreBatchOp.getWeights().size();
if (argNumber <= weightCount)
tiedOperand = coreBatchOp.getWeights()[argNumber - 1];
else
tiedOperand = coreBatchOp.getInputs()[argNumber - 1 - weightCount];
if (auto memRefType = dyn_cast<BufferLikeType>(tiedOperand.getType()))
return memRefType;
return bufferization::getBufferType(tiedInput, options, state, invocationStack);
return bufferization::getBufferType(tiedOperand, options, state, invocationStack);
}
LogicalResult bufferize(Operation* op,
@@ -454,8 +502,9 @@ struct CoreBatchOpInterface : BufferizableOpInterface::ExternalModel<CoreBatchOp
bool alreadyBufferized =
llvm::all_of(coreBatchOp.getWeights(), [](Value weight) { return isa<BufferLikeType>(weight.getType()); })
&& llvm::all_of(coreBatchOp.getInputs(), [](Value input) { return isa<BufferLikeType>(input.getType()); })
&& llvm::all_of(coreBatchOp.getBody().front().getArguments(),
[](BlockArgument arg) { return isa<BufferLikeType>(arg.getType()); });
&& llvm::all_of(coreBatchOp.getBody().front().getArguments(), [](BlockArgument arg) {
return !isa<TensorType>(arg.getType()) || isa<BufferLikeType>(arg.getType());
});
if (alreadyBufferized)
return success();
@@ -553,6 +602,10 @@ struct VMMOpInterface : DstBufferizableOpInterfaceExternalModel<VMMOpInterface,
BufferizationState& state) const {
auto vmmOp = cast<PimVMMOp>(op);
auto weightOpt = getBufferOrValue(rewriter, vmmOp.getWeight(), options, state);
if (failed(weightOpt))
return failure();
auto inputOpt = getBufferOrValue(rewriter, vmmOp.getInput(), options, state);
if (failed(inputOpt))
return failure();
@@ -564,7 +617,7 @@ struct VMMOpInterface : DstBufferizableOpInterfaceExternalModel<VMMOpInterface,
Value contiguousInput = materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter);
replaceOpWithNewBufferizedOp<PimVMMOp>(
rewriter, op, outputBufferOpt->getType(), vmmOp.getWeightIndexAttr(), contiguousInput, *outputBufferOpt);
rewriter, op, outputBufferOpt->getType(), *weightOpt, contiguousInput, *outputBufferOpt);
return success();
}
};
@@ -79,6 +79,7 @@ void PimBufferizationPass::runOnOperation() {
return WalkResult::skip();
});
if (hasFailed) {
moduleOp.emitError("failed to lower memref.copy-like ops inside PIM core bodies during bufferization");
signalPassFailure();
return;
}
@@ -1,15 +1,16 @@
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescing.hpp"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/STLExtras.h"
#include <limits>
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescing.hpp"
using namespace mlir;
namespace onnx_mlir {
@@ -29,9 +30,8 @@ static uint64_t getTypeSizeBytes(MemRefType type) {
return static_cast<uint64_t>(type.getNumElements() * type.getElementTypeBitWidth() / 8);
}
static FailureOr<uint64_t> getLastUseInstruction(memref::AllocOp allocOp,
Block& body,
const DenseMap<Operation*, uint64_t>& opOrder) {
static FailureOr<uint64_t>
getLastUseInstruction(memref::AllocOp allocOp, Block& body, const DenseMap<Operation*, uint64_t>& opOrder) {
uint64_t endInstruction = opOrder.lookup(allocOp);
SmallPtrSet<Operation*, 16> visited;
SmallVector<Value> pendingValues;
@@ -45,9 +45,15 @@ static FailureOr<uint64_t> getLastUseInstruction(memref::AllocOp allocOp,
if (!visited.insert(user).second)
continue;
if (isSupportedAliasOp(user)) {
if (isSupportedAliasOp(user))
for (Value result : user->getResults())
pendingValues.push_back(result);
if (auto forOp = dyn_cast<scf::ForOp>(user)) {
for (auto [index, initArg] : llvm::enumerate(forOp.getInitArgs())) {
if (initArg == value)
pendingValues.push_back(forOp.getResult(index));
}
}
if (auto dpsOp = dyn_cast<DestinationStyleOpInterface>(user)) {
@@ -2,7 +2,6 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Operation.h"
#include "llvm/ADT/SmallVector.h"
@@ -45,9 +45,7 @@ struct CoalescingReportEntry {
CoalescingReportRow row;
};
static std::string formatMemory(uint64_t bytes) {
return formatReportMemory(bytes);
}
static std::string formatMemory(uint64_t bytes) { return formatReportMemory(bytes); }
static SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) {
auto coreIdsAttr = coreBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName);
@@ -58,9 +56,10 @@ static SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) {
static void printReportRow(raw_ostream& os, const CoalescingReportRow& row) {
llvm::SmallVector<ReportField, 4> fields = {
{"Number of candidates", std::to_string(row.numCandidates)},
{"Skipped allocations", std::to_string(row.numSkipped)},
{"Removed allocations", std::to_string(row.numRemoved)},
{"Saved memory", formatMemory(row.savedBytes)}};
{"Skipped allocations", std::to_string(row.numSkipped) },
{"Removed allocations", std::to_string(row.numRemoved) },
{"Saved memory", formatMemory(row.savedBytes) }
};
printReportFlatFields(os, fields);
}
@@ -87,10 +86,12 @@ static void emitReport(ArrayRef<CoalescingReportEntry> entries) {
totalRow.savedBytes += entryTotal.savedBytes;
}
llvm::SmallVector<ReportField, 4> totalFields = {{"Number of candidates", std::to_string(totalRow.numCandidates)},
{"Skipped allocations", std::to_string(totalRow.numSkipped)},
{"Removed allocations", std::to_string(totalRow.numRemoved)},
{"Saved memory", formatMemory(totalRow.savedBytes)}};
llvm::SmallVector<ReportField, 4> totalFields = {
{"Number of candidates", std::to_string(totalRow.numCandidates)},
{"Skipped allocations", std::to_string(totalRow.numSkipped) },
{"Removed allocations", std::to_string(totalRow.numRemoved) },
{"Saved memory", formatMemory(totalRow.savedBytes) }
};
printReportTotalsBlock(os, totalFields);
if (!entries.empty())
os << "\n";
@@ -127,15 +128,17 @@ static void emitReport(ArrayRef<CoalescingReportEntry> entries) {
if (sortedEntries[index].kind == CoalescingReportEntry::Kind::Batch) {
llvm::SmallVector<ReportField, 4> perCoreFields = {
{"Number of candidates", std::to_string(sortedEntries[index].row.numCandidates)},
{"Skipped allocations", std::to_string(sortedEntries[index].row.numSkipped)},
{"Removed allocations", std::to_string(sortedEntries[index].row.numRemoved)},
{"Saved memory", formatMemory(sortedEntries[index].row.savedBytes)}};
{"Skipped allocations", std::to_string(sortedEntries[index].row.numSkipped) },
{"Removed allocations", std::to_string(sortedEntries[index].row.numRemoved) },
{"Saved memory", formatMemory(sortedEntries[index].row.savedBytes) }
};
CoalescingReportRow totalRow = getTotalRow(sortedEntries[index]);
llvm::SmallVector<ReportField, 4> totalFields = {
{"Number of candidates", std::to_string(totalRow.numCandidates)},
{"Skipped allocations", std::to_string(totalRow.numSkipped)},
{"Removed allocations", std::to_string(totalRow.numRemoved)},
{"Saved memory", formatMemory(totalRow.savedBytes)}};
{"Skipped allocations", std::to_string(totalRow.numSkipped) },
{"Removed allocations", std::to_string(totalRow.numRemoved) },
{"Saved memory", formatMemory(totalRow.savedBytes) }
};
printReportPerCoreAndTotalFields(os, perCoreFields, totalFields);
}
else {
@@ -196,8 +199,6 @@ struct StaticMemoryCoalescingPass : PassWrapper<StaticMemoryCoalescingPass, Oper
} // namespace
std::unique_ptr<Pass> createPimStaticMemoryCoalescingPass() {
return std::make_unique<StaticMemoryCoalescingPass>();
}
std::unique_ptr<Pass> createPimStaticMemoryCoalescingPass() { return std::make_unique<StaticMemoryCoalescingPass>(); }
} // namespace onnx_mlir
+7
View File
@@ -8,7 +8,14 @@ add_pim_library(SpatialOps
SpatialOpsVerify.cpp
SpatialOpsCanonicalization.cpp
Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp
Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp
Transforms/MergeComputeNodes/PostMergeCompaction.cpp
Transforms/MergeComputeNodes/RegularOpCompaction.cpp
Transforms/MergeComputeNodes/Scheduling/ComputeGraph.cpp
Transforms/MergeComputeNodes/Scheduling/ComputeInstanceUtils.cpp
Transforms/MergeComputeNodes/Scheduling/DcpScheduler.cpp
Transforms/MergeComputeNodes/Scheduling/MergeSchedulingAnalysis.cpp
Transforms/MergeComputeNodes/Scheduling/PeftScheduler.cpp
Transforms/MergeComputeNodes/DCPGraph/Graph.cpp
Transforms/MergeComputeNodes/DCPGraph/GraphDebug.cpp
Transforms/MergeComputeNodes/DCPGraph/GraphSupport.cpp
+74 -16
View File
@@ -1,5 +1,6 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Matchers.h"
#include "src/Accelerators/PIM/Dialect/Spatial/Channels.hpp"
@@ -9,19 +10,62 @@ namespace onnx_mlir::spatial {
namespace {
static Channels::ChannelId getChannelId(SpatChannelSendOp sendOp) { return sendOp.getChannelId(); }
static FailureOr<int64_t> getConstantI64(Value value) {
APInt constantValue;
if (!matchPattern(value, m_ConstantInt(&constantValue)))
return failure();
return constantValue.getSExtValue();
}
static Channels::ChannelId getChannelId(SpatChannelReceiveOp receiveOp) { return receiveOp.getChannelId(); }
static FailureOr<int32_t> getConstantI32(Value value) {
APInt constantValue;
if (!matchPattern(value, m_ConstantInt(&constantValue)))
return failure();
return static_cast<int32_t>(constantValue.getSExtValue());
}
static FailureOr<Channels::ChannelId> getChannelId(SpatChannelSendOp sendOp) {
return getConstantI64(sendOp.getChannelId());
}
static FailureOr<Channels::ChannelId> getChannelId(SpatChannelReceiveOp receiveOp) {
return getConstantI64(receiveOp.getChannelId());
}
static FailureOr<int32_t> getSourceCoreId(SpatChannelSendOp sendOp) { return getConstantI32(sendOp.getSourceCoreId()); }
static FailureOr<int32_t> getSourceCoreId(SpatChannelReceiveOp receiveOp) {
return getConstantI32(receiveOp.getSourceCoreId());
}
static FailureOr<int32_t> getTargetCoreId(SpatChannelSendOp sendOp) { return getConstantI32(sendOp.getTargetCoreId()); }
static FailureOr<int32_t> getTargetCoreId(SpatChannelReceiveOp receiveOp) {
return getConstantI32(receiveOp.getTargetCoreId());
}
static LogicalResult verifyEndpointPair(ChannelEndpoints endpoints) {
if (!endpoints.send || !endpoints.receive)
return failure();
if (endpoints.send.getSourceCoreId() != endpoints.receive.getSourceCoreId()) {
FailureOr<int32_t> sendSourceCoreId = getSourceCoreId(endpoints.send);
FailureOr<int32_t> receiveSourceCoreId = getSourceCoreId(endpoints.receive);
if (failed(sendSourceCoreId) || failed(receiveSourceCoreId)) {
endpoints.send.emitOpError("channel endpoints must use constant sourceCoreId operands");
return failure();
}
if (*sendSourceCoreId != *receiveSourceCoreId) {
endpoints.send.emitOpError("sourceCoreId does not match paired spat.channel_receive");
return failure();
}
if (endpoints.send.getTargetCoreId() != endpoints.receive.getTargetCoreId()) {
FailureOr<int32_t> sendTargetCoreId = getTargetCoreId(endpoints.send);
FailureOr<int32_t> receiveTargetCoreId = getTargetCoreId(endpoints.receive);
if (failed(sendTargetCoreId) || failed(receiveTargetCoreId)) {
endpoints.send.emitOpError("channel endpoints must use constant targetCoreId operands");
return failure();
}
if (*sendTargetCoreId != *receiveTargetCoreId) {
endpoints.send.emitOpError("targetCoreId does not match paired spat.channel_receive");
return failure();
}
@@ -46,20 +90,26 @@ Channels::Channels(func::FuncOp funcOp) {
Channels::ChannelId Channels::allocate() { return nextChannelId++; }
void Channels::insertSend(SpatChannelSendOp sendOp) {
ChannelId channelId = getChannelId(sendOp);
nextChannelId = std::max(nextChannelId, channelId + 1);
endpoints[channelId].send = sendOp;
FailureOr<ChannelId> channelId = getChannelId(sendOp);
if (failed(channelId))
return;
nextChannelId = std::max(nextChannelId, *channelId + 1);
endpoints[*channelId].send = sendOp;
}
void Channels::insertReceive(SpatChannelReceiveOp receiveOp) {
ChannelId channelId = getChannelId(receiveOp);
nextChannelId = std::max(nextChannelId, channelId + 1);
endpoints[channelId].receive = receiveOp;
FailureOr<ChannelId> channelId = getChannelId(receiveOp);
if (failed(channelId))
return;
nextChannelId = std::max(nextChannelId, *channelId + 1);
endpoints[*channelId].receive = receiveOp;
}
void Channels::eraseSend(SpatChannelSendOp sendOp) {
ChannelId channelId = getChannelId(sendOp);
auto it = endpoints.find(channelId);
FailureOr<ChannelId> channelId = getChannelId(sendOp);
if (failed(channelId))
return;
auto it = endpoints.find(*channelId);
if (it == endpoints.end())
return;
it->second.send = {};
@@ -68,8 +118,10 @@ void Channels::eraseSend(SpatChannelSendOp sendOp) {
}
void Channels::eraseReceive(SpatChannelReceiveOp receiveOp) {
ChannelId channelId = getChannelId(receiveOp);
auto it = endpoints.find(channelId);
FailureOr<ChannelId> channelId = getChannelId(receiveOp);
if (failed(channelId))
return;
auto it = endpoints.find(*channelId);
if (it == endpoints.end())
return;
it->second.receive = {};
@@ -85,14 +137,20 @@ FailureOr<ChannelEndpoints> Channels::lookup(ChannelId id) const {
}
FailureOr<SpatChannelReceiveOp> Channels::getReceiveFor(SpatChannelSendOp sendOp) const {
auto endpointsOr = lookup(getChannelId(sendOp));
FailureOr<ChannelId> channelId = getChannelId(sendOp);
if (failed(channelId))
return failure();
auto endpointsOr = lookup(*channelId);
if (failed(endpointsOr) || !endpointsOr->receive)
return failure();
return endpointsOr->receive;
}
FailureOr<SpatChannelSendOp> Channels::getSendFor(SpatChannelReceiveOp receiveOp) const {
auto endpointsOr = lookup(getChannelId(receiveOp));
FailureOr<ChannelId> channelId = getChannelId(receiveOp);
if (failed(channelId))
return failure();
auto endpointsOr = lookup(*channelId);
if (failed(endpointsOr) || !endpointsOr->send)
return failure();
return endpointsOr->send;
+100 -45
View File
@@ -2,8 +2,12 @@
#define SPATIAL_DIALECT_H
include "mlir/IR/OpBase.td"
include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/BuiltinTypes.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/RegionKindInterface.td"
include "mlir/Interfaces/ParallelCombiningOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
def SpatialDialect : Dialect {
let name = "spat";
@@ -22,7 +26,9 @@ def SpatTensor :
// Execution
//===----------------------------------------------------------------------===//
def SpatCompute : SpatOp<"compute", [SingleBlock, AttrSizedOperandSegments]> {
def SpatCompute : SpatOp<"compute",
[SingleBlock, AttrSizedOperandSegments,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>]> {
let summary = "Compute region with attached constant weights";
let arguments = (ins
@@ -36,14 +42,20 @@ def SpatCompute : SpatOp<"compute", [SingleBlock, AttrSizedOperandSegments]> {
let regions = (region SizedRegion<1>:$body);
let extraClassDeclaration = [{
::mlir::BlockArgument getWeightArgument(unsigned idx);
::mlir::BlockArgument getInputArgument(unsigned idx);
}];
let hasVerifier = 1;
let hasFolder = 1;
let hasCustomAssemblyFormat = 1;
}
def SpatComputeBatch : SpatOp<"compute_batch",
[SingleBlock, AttrSizedOperandSegments]> {
let summary = "Compressed batch of independent equivalent compute lanes";
[SingleBlock, AttrSizedOperandSegments,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>]> {
let summary = "Tensor-native batch of equivalent compute lanes with shared weights and packed inputs";
let arguments = (ins
I32Attr:$laneCount,
@@ -57,10 +69,41 @@ def SpatComputeBatch : SpatOp<"compute_batch",
let regions = (region SizedRegion<1>:$body);
let extraClassDeclaration = [{
::mlir::BlockArgument getLaneArgument();
::mlir::BlockArgument getWeightArgument(unsigned idx);
::mlir::BlockArgument getInputArgument(unsigned idx);
::mlir::BlockArgument getOutputArgument(unsigned idx);
}];
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
}
def SpatInParallelOp : SpatOp<"in_parallel", [
Pure,
Terminator,
DeclareOpInterfaceMethods<InParallelOpInterface>,
HasParent<"SpatComputeBatch">,
] # GraphRegionNoTerminator.traits> {
let summary = "Parallel combining terminator for resultful spat.compute_batch";
let regions = (region SizedRegion<1>:$region);
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
let skipDefaultBuilders = 1;
let builders = [
OpBuilder<(ins)>,
];
let extraClassDeclaration = [{
::llvm::iterator_range<::mlir::Block::iterator> getYieldingOps();
::mlir::OpResult getParentResult(int64_t idx);
}];
}
def SpatYieldOp : SpatOp<"yield", [Terminator]> {
let summary = "Yield results from a compute region";
@@ -110,14 +153,14 @@ def SpatChannelSendOp : SpatOp<"channel_send", []> {
let summary = "Send a tensor through a logical channel";
let arguments = (ins
I64Attr:$channelId,
I32Attr:$sourceCoreId,
I32Attr:$targetCoreId,
Index:$channelId,
Index:$sourceCoreId,
Index:$targetCoreId,
SpatTensor:$input
);
let assemblyFormat = [{
$input attr-dict `:` type($input)
$input `channel` $channelId `from` $sourceCoreId `to` $targetCoreId attr-dict `:` type($input)
}];
}
@@ -125,9 +168,9 @@ def SpatChannelReceiveOp : SpatOp<"channel_receive", []> {
let summary = "Receive a tensor from a logical channel";
let arguments = (ins
I64Attr:$channelId,
I32Attr:$sourceCoreId,
I32Attr:$targetCoreId
Index:$channelId,
Index:$sourceCoreId,
Index:$targetCoreId
);
let results = (outs
@@ -135,31 +178,33 @@ def SpatChannelReceiveOp : SpatOp<"channel_receive", []> {
);
let assemblyFormat = [{
attr-dict `:` type($output)
`channel` $channelId `from` $sourceCoreId `to` $targetCoreId attr-dict `:` type($output)
}];
}
def SpatChannelSendTensorOp : SpatOp<"channel_send_tensor", []> {
def SpatChannelSendTensorOp : SpatOp<"channel_send_tensor", [AttrSizedOperandSegments]> {
let summary = "Send equal contiguous chunks of one tensor through logical channels";
let arguments = (ins
DenseI64ArrayAttr:$channelIds,
DenseI32ArrayAttr:$sourceCoreIds,
DenseI32ArrayAttr:$targetCoreIds,
Variadic<Index>:$channelIds,
Variadic<Index>:$sourceCoreIds,
Variadic<Index>:$targetCoreIds,
SpatTensor:$input
);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
let assemblyFormat = [{
$input `channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($input)
}];
}
def SpatChannelReceiveTensorOp : SpatOp<"channel_receive_tensor", []> {
def SpatChannelReceiveTensorOp : SpatOp<"channel_receive_tensor", [AttrSizedOperandSegments]> {
let summary = "Receive equal contiguous chunks of one tensor from logical channels";
let arguments = (ins
DenseI64ArrayAttr:$channelIds,
DenseI32ArrayAttr:$sourceCoreIds,
DenseI32ArrayAttr:$targetCoreIds
Variadic<Index>:$channelIds,
Variadic<Index>:$sourceCoreIds,
Variadic<Index>:$targetCoreIds
);
let results = (outs
@@ -167,44 +212,50 @@ def SpatChannelReceiveTensorOp : SpatOp<"channel_receive_tensor", []> {
);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
let assemblyFormat = [{
`channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($output)
}];
}
def SpatChannelSendBatchOp : SpatOp<"channel_send_batch", []> {
def SpatChannelSendBatchOp : SpatOp<"channel_send_batch", [AttrSizedOperandSegments]> {
let summary = "Send per-lane tensors through logical channels in a batch body";
let arguments = (ins
DenseI64ArrayAttr:$channelIds,
DenseI32ArrayAttr:$sourceCoreIds,
DenseI32ArrayAttr:$targetCoreIds,
Variadic<Index>:$channelIds,
Variadic<Index>:$sourceCoreIds,
Variadic<Index>:$targetCoreIds,
SpatTensor:$input
);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
let assemblyFormat = [{
$input `channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($input)
}];
}
def SpatChannelSendTensorBatchOp : SpatOp<"channel_send_tensor_batch", []> {
def SpatChannelSendTensorBatchOp : SpatOp<"channel_send_tensor_batch", [AttrSizedOperandSegments]> {
let summary = "Send equal contiguous chunks of one per-lane tensor through logical channels in a batch body";
let arguments = (ins
DenseI64ArrayAttr:$channelIds,
DenseI32ArrayAttr:$sourceCoreIds,
DenseI32ArrayAttr:$targetCoreIds,
Variadic<Index>:$channelIds,
Variadic<Index>:$sourceCoreIds,
Variadic<Index>:$targetCoreIds,
SpatTensor:$input
);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
let assemblyFormat = [{
$input `channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($input)
}];
}
def SpatChannelReceiveBatchOp : SpatOp<"channel_receive_batch", []> {
def SpatChannelReceiveBatchOp : SpatOp<"channel_receive_batch", [AttrSizedOperandSegments]> {
let summary = "Receive a per-lane tensor through logical channels in a batch body";
let arguments = (ins
DenseI64ArrayAttr:$channelIds,
DenseI32ArrayAttr:$sourceCoreIds,
DenseI32ArrayAttr:$targetCoreIds
Variadic<Index>:$channelIds,
Variadic<Index>:$sourceCoreIds,
Variadic<Index>:$targetCoreIds
);
let results = (outs
@@ -212,16 +263,18 @@ def SpatChannelReceiveBatchOp : SpatOp<"channel_receive_batch", []> {
);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
let assemblyFormat = [{
`channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($output)
}];
}
def SpatChannelReceiveTensorBatchOp : SpatOp<"channel_receive_tensor_batch", []> {
def SpatChannelReceiveTensorBatchOp : SpatOp<"channel_receive_tensor_batch", [AttrSizedOperandSegments]> {
let summary = "Receive equal contiguous chunks of one per-lane tensor through logical channels in a batch body";
let arguments = (ins
DenseI64ArrayAttr:$channelIds,
DenseI32ArrayAttr:$sourceCoreIds,
DenseI32ArrayAttr:$targetCoreIds
Variadic<Index>:$channelIds,
Variadic<Index>:$sourceCoreIds,
Variadic<Index>:$targetCoreIds
);
let results = (outs
@@ -229,7 +282,9 @@ def SpatChannelReceiveTensorBatchOp : SpatOp<"channel_receive_tensor_batch", []>
);
let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;
let assemblyFormat = [{
`channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($output)
}];
}
//===----------------------------------------------------------------------===//
@@ -240,7 +295,7 @@ def SpatVMMOp : SpatOp<"wvmm", []> {
let summary = "Vector-matrix multiplication within a weighted compute operation";
let arguments = (ins
I32Attr:$weightIndex,
SpatTensor:$weight,
SpatTensor:$input
);
@@ -251,7 +306,7 @@ def SpatVMMOp : SpatOp<"wvmm", []> {
let hasVerifier = 1;
let assemblyFormat = [{
`(` $input `)` attr-dict `:` type($input) `->` type($output)
`[` $weight `]` `(` $input `)` attr-dict `:` `(` type($weight) `,` type($input) `)` `->` type($output)
}];
}
@@ -259,7 +314,7 @@ def SpatMVMOp : SpatOp<"Wmvm", []> {
let summary = "Matrix-vector multiplication within a weighted compute operation";
let arguments = (ins
I32Attr:$weightIndex,
SpatTensor:$weight,
SpatTensor:$input
);
@@ -270,7 +325,7 @@ def SpatMVMOp : SpatOp<"Wmvm", []> {
let hasVerifier = 1;
let assemblyFormat = [{
`(` $input `)` attr-dict `:` type($input) `->` type($output)
`[` $weight `]` `(` $input `)` attr-dict `:` `(` type($weight) `,` type($input) `)` `->` type($output)
}];
}
+64
View File
@@ -1,10 +1,74 @@
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include <string>
using namespace mlir;
namespace onnx_mlir {
namespace spatial {
BlockArgument SpatCompute::getWeightArgument(unsigned idx) { return getBody().front().getArgument(idx); }
BlockArgument SpatCompute::getInputArgument(unsigned idx) {
return getBody().front().getArgument(getWeights().size() + idx);
}
void SpatCompute::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
if (region.empty())
return;
for (unsigned index = 0; index < getWeights().size(); ++index)
setNameFn(getWeightArgument(index), ("w" + std::to_string(index)).c_str());
for (unsigned index = 0; index < getInputs().size(); ++index)
setNameFn(getInputArgument(index), ("in" + std::to_string(index)).c_str());
}
BlockArgument SpatComputeBatch::getLaneArgument() { return getBody().front().getArgument(0); }
BlockArgument SpatComputeBatch::getWeightArgument(unsigned idx) { return getBody().front().getArgument(1 + idx); }
BlockArgument SpatComputeBatch::getInputArgument(unsigned idx) {
return getBody().front().getArgument(1 + getWeights().size() + idx);
}
BlockArgument SpatComputeBatch::getOutputArgument(unsigned idx) {
return getBody().front().getArgument(1 + getWeights().size() + getInputs().size() + idx);
}
void SpatComputeBatch::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
if (region.empty())
return;
setNameFn(getLaneArgument(), "lane");
for (unsigned index = 0; index < getWeights().size(); ++index)
setNameFn(getWeightArgument(index), ("w" + std::to_string(index)).c_str());
for (unsigned index = 0; index < getInputs().size(); ++index)
setNameFn(getInputArgument(index), ("in" + std::to_string(index)).c_str());
for (unsigned index = 0; index < getNumResults(); ++index) {
if (index == 0) {
setNameFn(getOutputArgument(index), "out");
continue;
}
setNameFn(getOutputArgument(index), ("out" + std::to_string(index)).c_str());
}
}
void SpatInParallelOp::build(OpBuilder& builder, OperationState& result) {
OpBuilder::InsertionGuard guard(builder);
Region* bodyRegion = result.addRegion();
builder.createBlock(bodyRegion);
}
OpResult SpatInParallelOp::getParentResult(int64_t idx) { return getOperation()->getParentOp()->getResult(idx); }
llvm::iterator_range<Block::iterator> SpatInParallelOp::getYieldingOps() {
return getRegion().front().getOperations();
}
void SpatialDialect::initialize() {
addTypes<
#define GET_TYPEDEF_LIST
+2
View File
@@ -5,7 +5,9 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/RegionKindInterface.h"
#include "mlir/IR/Types.h"
#include "mlir/Interfaces/ParallelCombiningOpInterface.h"
#include <map>
#include <string>
+152 -218
View File
@@ -23,22 +23,6 @@ static bool parseOptionalKeywordAlias(OpAsmParser& parser, StringRef preferred,
return succeeded(parser.parseOptionalKeyword(preferred)) || succeeded(parser.parseOptionalKeyword(legacy));
}
static void printChannelMetadata(OpAsmPrinter& printer,
ArrayRef<int64_t> channelIds,
ArrayRef<int32_t> sourceCoreIds,
ArrayRef<int32_t> targetCoreIds) {
printer << " channels ";
printCompressedIntegerList(printer, channelIds);
printer << " from ";
printCompressedIntegerList(printer, sourceCoreIds);
printer << " to ";
printCompressedIntegerList(printer, targetCoreIds);
}
static DenseI64ArrayAttr getDenseI64ArrayAttr(OpAsmParser& parser, ArrayRef<int64_t> values) {
return parser.getBuilder().getDenseI64ArrayAttr(values);
}
static DenseI32ArrayAttr getDenseI32ArrayAttr(OpAsmParser& parser, ArrayRef<int32_t> values) {
return parser.getBuilder().getDenseI32ArrayAttr(values);
}
@@ -47,94 +31,89 @@ static IntegerAttr getI32Attr(OpAsmParser& parser, int32_t value) {
return parser.getBuilder().getI32IntegerAttr(value);
}
template <typename TensorSendOpTy>
static void printTensorSendOp(OpAsmPrinter& printer, TensorSendOpTy op) {
printer << " ";
printer.printOperand(op.getInput());
printChannelMetadata(printer, op.getChannelIds(), op.getSourceCoreIds(), op.getTargetCoreIds());
printer.printOptionalAttrDict(op->getAttrs(),
{op.getChannelIdsAttrName().getValue(),
op.getSourceCoreIdsAttrName().getValue(),
op.getTargetCoreIdsAttrName().getValue()});
printer << " : ";
printer.printType(op.getInput().getType());
static void printBlockArgumentList(OpAsmPrinter& printer, ArrayRef<BlockArgument> arguments) {
printer << "(";
for (auto [index, argument] : llvm::enumerate(arguments)) {
if (index != 0)
printer << ", ";
printer.printOperand(argument);
}
printer << ")";
}
template <typename TensorReceiveOpTy>
static void printTensorReceiveOp(OpAsmPrinter& printer, TensorReceiveOpTy op) {
printChannelMetadata(printer, op.getChannelIds(), op.getSourceCoreIds(), op.getTargetCoreIds());
printer.printOptionalAttrDict(op->getAttrs(),
{op.getChannelIdsAttrName().getValue(),
op.getSourceCoreIdsAttrName().getValue(),
op.getTargetCoreIdsAttrName().getValue()});
printer << " : ";
printer.printType(op.getOutput().getType());
}
static ParseResult parseTensorSendOp(OpAsmParser& parser, OperationState& result) {
OpAsmParser::UnresolvedOperand input;
Type inputType;
SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds;
SmallVector<int32_t> targetCoreIds;
if (parser.parseOperand(input))
static ParseResult parseBlockArgumentList(OpAsmParser& parser, SmallVectorImpl<OpAsmParser::Argument>& arguments) {
if (parser.parseLParen())
return failure();
if (succeeded(parser.parseOptionalRParen()))
return success();
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
if (hasMetadata) {
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|| parseCompressedIntegerList(parser, targetCoreIds))
OpAsmParser::Argument argument;
if (parser.parseArgument(argument))
return failure();
arguments.push_back(argument);
while (succeeded(parser.parseOptionalComma())) {
if (parser.parseArgument(argument))
return failure();
arguments.push_back(argument);
}
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType))
return failure();
if (hasMetadata
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|| result.attributes.get("targetCoreIds")))
return parser.emitError(parser.getCurrentLocation(),
"channel metadata cannot be specified both positionally and in attr-dict");
if (hasMetadata) {
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
}
return parser.resolveOperand(input, inputType, result.operands);
return parser.parseRParen();
}
static ParseResult parseTensorReceiveOp(OpAsmParser& parser, OperationState& result) {
Type outputType;
SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds;
SmallVector<int32_t> targetCoreIds;
static void applyBatchRegionArgumentTypes(ArrayRef<Type> inputTypes,
ArrayRef<Type> weightTypes,
ArrayRef<Type> outputTypes,
OpAsmParser::Argument& laneArg,
SmallVectorImpl<OpAsmParser::Argument>& weightArgs,
SmallVectorImpl<OpAsmParser::Argument>& inputArgs,
SmallVectorImpl<OpAsmParser::Argument>& outputArgs,
SmallVectorImpl<OpAsmParser::Argument>& regionArgs,
Builder& builder) {
laneArg.type = builder.getIndexType();
regionArgs.push_back(laneArg);
applyArgumentTypes(weightTypes, weightArgs);
llvm::append_range(regionArgs, weightArgs);
applyArgumentTypes(inputTypes, inputArgs);
applyArgumentTypes(outputTypes, outputArgs);
llvm::append_range(regionArgs, inputArgs);
llvm::append_range(regionArgs, outputArgs);
}
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
if (hasMetadata) {
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|| parseCompressedIntegerList(parser, targetCoreIds))
return failure();
}
static void
printBoundValueList(OpAsmPrinter& printer, ValueRange arguments, ValueRange operands, ListDelimiter delimiter) {
printCompressedValueList(printer, arguments, delimiter);
printer << " = ";
printCompressedValueList(printer, operands, delimiter);
}
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(outputType))
static ParseResult parseBoundValueList(OpAsmParser& parser,
ListDelimiter delimiter,
SmallVectorImpl<OpAsmParser::Argument>& arguments,
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
if (parseOpenDelimiter(parser, delimiter))
return failure();
if (hasMetadata
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|| result.attributes.get("targetCoreIds")))
return parser.emitError(parser.getCurrentLocation(),
"channel metadata cannot be specified both positionally and in attr-dict");
if (hasMetadata) {
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) {
if (parser.parseEqual() || parseCompressedOperandList(parser, delimiter, operands))
return failure();
return success();
}
result.addTypes(outputType);
if (parseOneCompressedArgumentEntry(parser, arguments))
return failure();
while (succeeded(parser.parseOptionalComma()))
if (parseOneCompressedArgumentEntry(parser, arguments))
return failure();
auto parseCloseDelimiter = [&](ListDelimiter currentDelimiter) -> ParseResult {
switch (currentDelimiter) {
case ListDelimiter::Paren:
return parser.parseRParen();
case ListDelimiter::Square:
return parser.parseRSquare();
}
llvm_unreachable("unsupported delimiter");
};
if (parseCloseDelimiter(delimiter) || parser.parseEqual() || parseCompressedOperandList(parser, delimiter, operands)) {
return failure();
}
return success();
}
@@ -243,9 +222,17 @@ ParseResult SpatConcatOp::parse(OpAsmParser& parser, OperationState& result) {
void SpatCompute::print(OpAsmPrinter& printer) {
printer << " ";
printCompressedValueList(printer, getWeights(), ListDelimiter::Square);
SmallVector<Value> weightArgs;
weightArgs.reserve(getWeights().size());
for (unsigned index = 0; index < getWeights().size(); ++index)
weightArgs.push_back(getWeightArgument(index));
printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square);
printer << " ";
printArgumentBindings(printer, getBody().front(), getInputs());
SmallVector<Value> inputArgs;
inputArgs.reserve(getInputs().size());
for (unsigned index = 0; index < getInputs().size(); ++index)
inputArgs.push_back(getInputArgument(index));
printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren);
if (auto coreIdAttr = (*this)->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
printer << " coreId " << coreIdAttr.getInt();
@@ -264,6 +251,7 @@ void SpatCompute::print(OpAsmPrinter& printer) {
}
ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) {
SmallVector<OpAsmParser::Argument> weightArgs;
SmallVector<OpAsmParser::Argument> regionArgs;
SmallVector<OpAsmParser::UnresolvedOperand> weights;
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
@@ -272,10 +260,11 @@ ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) {
SmallVector<Type> outputTypes;
int32_t coreId = 0;
if (parseCompressedOperandList(parser, ListDelimiter::Square, weights))
if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights))
return failure();
if (parseArgumentBindings(parser, regionArgs, inputs))
SmallVector<OpAsmParser::Argument> inputArgs;
if (parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs))
return failure();
bool hasCoreId = parseOptionalKeywordAlias(parser, "coreId", "core_id");
@@ -292,9 +281,11 @@ ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) {
if (weights.size() != weightTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
if (weightArgs.size() != weights.size())
return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match");
if (inputs.size() != inputTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
if (regionArgs.size() != inputs.size())
if (inputArgs.size() != inputs.size())
return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match");
if (hasCoreId && result.attributes.get(onnx_mlir::kCoreIdAttrName))
return parser.emitError(parser.getCurrentLocation(),
@@ -313,19 +304,39 @@ ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) {
result.addTypes(outputTypes);
Region* body = result.addRegion();
applyArgumentTypes(inputTypes, regionArgs);
applyArgumentTypes(weightTypes, weightArgs);
applyArgumentTypes(inputTypes, inputArgs);
llvm::append_range(regionArgs, weightArgs);
llvm::append_range(regionArgs, inputArgs);
return parser.parseRegion(*body, regionArgs);
}
void SpatComputeBatch::print(OpAsmPrinter& printer) {
printer << " lanes " << getLaneCount() << " ";
size_t weightsPerLane = getLaneCount() > 0 ? getWeights().size() / static_cast<size_t>(getLaneCount()) : 0;
if (getLaneCount() > 1 && hasRepeatedTuple(getWeights(), weightsPerLane))
printValueTupleRun(printer, getWeights(), weightsPerLane, ListDelimiter::Square);
else
printCompressedValueList(printer, getWeights(), ListDelimiter::Square);
printer << " ";
printArgumentBindings(printer, getBody().front(), getInputs());
printer.printOperand(getLaneArgument());
printer << " = 0 to " << getLaneCount();
printer << " ";
SmallVector<Value> weightArgs;
weightArgs.reserve(getWeights().size());
for (unsigned index = 0; index < getWeights().size(); ++index)
weightArgs.push_back(getWeightArgument(index));
printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square);
printer << " ";
SmallVector<Value> inputArgs;
inputArgs.reserve(getInputs().size());
for (unsigned index = 0; index < getInputs().size(); ++index)
inputArgs.push_back(getInputArgument(index));
printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren);
if (getNumResults() != 0) {
printer << " shared_outs";
SmallVector<BlockArgument> outputArgs;
outputArgs.reserve(getNumResults());
for (unsigned index = 0; index < getNumResults(); ++index)
outputArgs.push_back(getOutputArgument(index));
printBlockArgumentList(printer, outputArgs);
}
if (auto coreIdsAttr = (*this)->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName)) {
printer << " coreIds ";
@@ -337,10 +348,7 @@ void SpatComputeBatch::print(OpAsmPrinter& printer) {
{getLaneCountAttrName().getValue(), getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdsAttrName});
printer << " : ";
if (getLaneCount() > 1 && hasRepeatedTuple(TypeRange(getWeights()), weightsPerLane))
printTypeTupleRun(printer, TypeRange(getWeights()), weightsPerLane, ListDelimiter::Square);
else
printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square);
printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square);
printer << " ";
printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren);
printer << " -> ";
@@ -350,7 +358,12 @@ void SpatComputeBatch::print(OpAsmPrinter& printer) {
}
ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result) {
int64_t lowerBound = 0;
int32_t laneCount = 0;
OpAsmParser::Argument laneArg;
SmallVector<OpAsmParser::Argument> weightArgs;
SmallVector<OpAsmParser::Argument> inputArgs;
SmallVector<OpAsmParser::Argument> outputArgs;
SmallVector<OpAsmParser::Argument> regionArgs;
SmallVector<OpAsmParser::UnresolvedOperand> weights;
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
@@ -359,14 +372,21 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result)
SmallVector<Type> outputTypes;
SmallVector<int32_t> coreIds;
if (parser.parseKeyword("lanes") || parser.parseInteger(laneCount))
if (parser.parseArgument(laneArg) || parser.parseEqual() || parser.parseInteger(lowerBound)
|| parser.parseKeyword("to") || parser.parseInteger(laneCount))
return failure();
if (lowerBound != 0)
return parser.emitError(parser.getCurrentLocation(), "compute_batch currently requires a zero lower bound");
if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights))
return failure();
if (parseCompressedOrTupleOperandList(parser, ListDelimiter::Square, weights))
if (parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs))
return failure();
if (parseArgumentBindings(parser, regionArgs, inputs))
return failure();
if (succeeded(parser.parseOptionalKeyword("shared_outs")))
if (parseBlockArgumentList(parser, outputArgs))
return failure();
bool hasCoreIds = parseOptionalKeywordAlias(parser, "coreIds", "core_ids");
if (hasCoreIds && parseCompressedIntegerList(parser, coreIds))
@@ -381,10 +401,15 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result)
if (weights.size() != weightTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
if (weightArgs.size() != weights.size())
return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match");
if (inputs.size() != inputTypes.size())
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
if (regionArgs.size() != inputs.size())
if (inputArgs.size() != inputs.size())
return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match");
if (outputArgs.size() != outputTypes.size())
return parser.emitError(parser.getCurrentLocation(),
"number of shared output bindings and result types must match");
if (hasCoreIds && result.attributes.get(onnx_mlir::kCoreIdsAttrName))
return parser.emitError(parser.getCurrentLocation(),
"coreIds cannot be specified both positionally and in attr-dict");
@@ -403,119 +428,28 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result)
result.addTypes(outputTypes);
Region* body = result.addRegion();
applyArgumentTypes(inputTypes, regionArgs);
applyBatchRegionArgumentTypes(
inputTypes, weightTypes, outputTypes, laneArg, weightArgs, inputArgs, outputArgs, regionArgs, parser.getBuilder());
return parser.parseRegion(*body, regionArgs);
}
void SpatChannelSendTensorOp::print(OpAsmPrinter& printer) { printTensorSendOp(printer, *this); }
ParseResult SpatChannelSendTensorOp::parse(OpAsmParser& parser, OperationState& result) {
return parseTensorSendOp(parser, result);
}
void SpatChannelSendBatchOp::print(OpAsmPrinter& printer) {
void SpatInParallelOp::print(OpAsmPrinter& printer) {
printer << " ";
printer.printOperand(getInput());
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
printer.printOptionalAttrDict(
(*this)->getAttrs(),
{getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()});
printer << " : ";
printer.printType(getInput().getType());
printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/false);
printer.printOptionalAttrDict((*this)->getAttrs());
}
ParseResult SpatChannelSendBatchOp::parse(OpAsmParser& parser, OperationState& result) {
OpAsmParser::UnresolvedOperand input;
Type inputType;
SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds;
SmallVector<int32_t> targetCoreIds;
if (parser.parseOperand(input))
ParseResult SpatInParallelOp::parse(OpAsmParser& parser, OperationState& result) {
auto& builder = parser.getBuilder();
std::unique_ptr<Region> region = std::make_unique<Region>();
SmallVector<OpAsmParser::Argument, 4> regionArgs;
if (parser.parseRegion(*region, regionArgs))
return failure();
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
if (hasMetadata) {
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|| parseCompressedIntegerList(parser, targetCoreIds))
return failure();
}
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType))
return failure();
if (hasMetadata
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|| result.attributes.get("targetCoreIds")))
return parser.emitError(parser.getCurrentLocation(),
"channel metadata cannot be specified both positionally and in attr-dict");
if (hasMetadata) {
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
}
return parser.resolveOperand(input, inputType, result.operands);
}
void SpatChannelSendTensorBatchOp::print(OpAsmPrinter& printer) { printTensorSendOp(printer, *this); }
ParseResult SpatChannelSendTensorBatchOp::parse(OpAsmParser& parser, OperationState& result) {
return parseTensorSendOp(parser, result);
}
void SpatChannelReceiveTensorOp::print(OpAsmPrinter& printer) { printTensorReceiveOp(printer, *this); }
ParseResult SpatChannelReceiveTensorOp::parse(OpAsmParser& parser, OperationState& result) {
return parseTensorReceiveOp(parser, result);
}
void SpatChannelReceiveBatchOp::print(OpAsmPrinter& printer) {
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
printer.printOptionalAttrDict(
(*this)->getAttrs(),
{getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()});
printer << " : ";
printer.printType(getOutput().getType());
}
ParseResult SpatChannelReceiveBatchOp::parse(OpAsmParser& parser, OperationState& result) {
Type outputType;
SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds;
SmallVector<int32_t> targetCoreIds;
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
if (hasMetadata) {
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|| parseCompressedIntegerList(parser, targetCoreIds))
return failure();
}
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(outputType))
return failure();
if (hasMetadata
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|| result.attributes.get("targetCoreIds")))
return parser.emitError(parser.getCurrentLocation(),
"channel metadata cannot be specified both positionally and in attr-dict");
if (hasMetadata) {
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
}
result.addTypes(outputType);
return success();
}
void SpatChannelReceiveTensorBatchOp::print(OpAsmPrinter& printer) { printTensorReceiveOp(printer, *this); }
ParseResult SpatChannelReceiveTensorBatchOp::parse(OpAsmParser& parser, OperationState& result) {
return parseTensorReceiveOp(parser, result);
if (region->empty())
OpBuilder(builder.getContext()).createBlock(region.get());
result.addRegion(std::move(region));
return parser.parseOptionalAttrDict(result.attributes);
}
} // namespace spatial
+250 -118
View File
@@ -1,8 +1,12 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/Support/LogicalResult.h"
@@ -81,20 +85,11 @@ inline LogicalResult mvmOpVerifySize4(SpatMVMOp* emitter,
return success();
}
static FailureOr<ArrayRef<int64_t>> getWeightShapeForWeightedOp(Operation* weightedOp, size_t weightIndex) {
if (auto computeOp = weightedOp->getParentOfType<SpatCompute>())
return cast<ShapedType>(computeOp.getWeights()[weightIndex].getType()).getShape();
if (auto coreOp = weightedOp->getParentOfType<pim::PimCoreOp>())
return cast<ShapedType>(coreOp.getWeights()[weightIndex].getType()).getShape();
if (auto batchOp = weightedOp->getParentOfType<SpatComputeBatch>()) {
if (batchOp.getWeights().empty() || weightIndex >= batchOp.getWeights().size())
return failure();
return cast<ShapedType>(batchOp.getWeights()[weightIndex].getType()).getShape();
}
return failure();
static FailureOr<ArrayRef<int64_t>> getWeightShapeForWeightedOp(Value weight) {
auto shapedType = dyn_cast<ShapedType>(weight.getType());
if (!shapedType)
return failure();
return shapedType.getShape();
}
static FailureOr<int32_t> getParentBatchLaneCount(Operation* op) {
@@ -104,15 +99,86 @@ static FailureOr<int32_t> getParentBatchLaneCount(Operation* op) {
return batchOp.getLaneCount();
}
static LogicalResult verifyTensorChannelSizes(Operation* op,
Type type,
ArrayRef<int64_t> channelIds,
ArrayRef<int32_t> sourceCoreIds,
ArrayRef<int32_t> targetCoreIds,
StringRef kind) {
if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size())
static bool isBatchOutputArgument(SpatComputeBatch batchOp, Value value) {
if (batchOp.getNumResults() == 0)
return false;
auto blockArg = dyn_cast<BlockArgument>(value);
if (!blockArg || blockArg.getOwner() != &batchOp.getBody().front())
return false;
unsigned argNumber = blockArg.getArgNumber();
unsigned firstOutputArg = batchOp.getOutputArgument(0).getArgNumber();
return argNumber >= firstOutputArg && argNumber < firstOutputArg + batchOp.getNumResults();
}
static bool isConstantIndexLike(Value value) {
APInt constantValue;
return matchPattern(value, m_ConstantInt(&constantValue));
}
static bool isSupportedLaneOffsetExpr(Value value, BlockArgument laneArg) {
if (value == laneArg || isConstantIndexLike(value))
return true;
auto addOp = value.getDefiningOp<arith::AddIOp>();
if (!addOp)
return false;
return (addOp.getLhs() == laneArg && isConstantIndexLike(addOp.getRhs()))
|| (addOp.getRhs() == laneArg && isConstantIndexLike(addOp.getLhs()));
}
static LogicalResult
verifyStaticUnitStrideExtractSliceOp(tensor::ExtractSliceOp sliceOp, BlockArgument laneArg, StringRef kind) {
auto sourceType = dyn_cast<RankedTensorType>(sliceOp.getSource().getType());
auto resultType = dyn_cast<RankedTensorType>(sliceOp.getResult().getType());
if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape())
return sliceOp.emitOpError() << kind << " requires static ranked tensor types";
if (!sliceOp.hasUnitStride())
return sliceOp.emitOpError() << kind << " requires unit strides";
for (int64_t size : sliceOp.getStaticSizes())
if (ShapedType::isDynamic(size))
return sliceOp.emitOpError() << kind << " requires static slice sizes";
auto offsets = sliceOp.getOffsets();
for (auto [offsetIndex, offset] : llvm::enumerate(offsets)) {
bool supported = offsetIndex == 0 ? isSupportedLaneOffsetExpr(offset, laneArg) : isConstantIndexLike(offset);
if (!supported)
return sliceOp.emitOpError() << kind << " requires simple lane-dependent offsets";
}
return success();
}
static LogicalResult verifyStaticUnitStrideParallelInsertSliceOp(tensor::ParallelInsertSliceOp sliceOp,
BlockArgument laneArg,
StringRef kind) {
RankedTensorType sourceType = sliceOp.getSourceType();
RankedTensorType destType = sliceOp.getDestType();
if (!sourceType.hasStaticShape() || !destType.hasStaticShape())
return sliceOp.emitOpError() << kind << " requires static ranked tensor types";
if (!sliceOp.hasUnitStride())
return sliceOp.emitOpError() << kind << " requires unit strides";
for (int64_t size : sliceOp.getStaticSizes())
if (ShapedType::isDynamic(size))
return sliceOp.emitOpError() << kind << " requires static slice sizes";
auto offsets = sliceOp.getOffsets();
for (auto [offsetIndex, offset] : llvm::enumerate(offsets)) {
bool supported = offsetIndex == 0 ? isSupportedLaneOffsetExpr(offset, laneArg) : isConstantIndexLike(offset);
if (!supported)
return sliceOp.emitOpError() << kind << " requires simple lane-dependent offsets";
}
return success();
}
static LogicalResult verifyTensorChannelSizes(
Operation* op, Type type, size_t channelCount, size_t sourceCoreCount, size_t targetCoreCount, StringRef kind) {
if (channelCount != sourceCoreCount || channelCount != targetCoreCount)
return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length");
if (channelIds.empty())
if (channelCount == 0)
return op->emitError() << kind << " must carry at least one chunk";
auto shapedType = dyn_cast<ShapedType>(type);
@@ -124,40 +190,34 @@ static LogicalResult verifyTensorChannelSizes(Operation* op,
return op->emitError() << kind << " requires byte-sized elements";
int64_t totalBytes = shapedType.getNumElements() * elementBits / 8;
if (totalBytes % static_cast<int64_t>(channelIds.size()) != 0)
if (totalBytes % static_cast<int64_t>(channelCount) != 0)
return op->emitError() << kind << " tensor byte size must be divisible by the number of channel ids";
return success();
}
static LogicalResult verifyBatchChannelSizes(Operation* op,
ArrayRef<int64_t> channelIds,
ArrayRef<int32_t> sourceCoreIds,
ArrayRef<int32_t> targetCoreIds) {
if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size())
static LogicalResult
verifyBatchChannelSizes(Operation* op, size_t channelCount, size_t sourceCoreCount, size_t targetCoreCount) {
if (channelCount != sourceCoreCount || channelCount != targetCoreCount)
return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length");
auto laneCount = getParentBatchLaneCount(op);
if (failed(laneCount))
return op->emitError("must be nested inside spat.compute_batch");
if (channelIds.size() != static_cast<size_t>(*laneCount))
if (channelCount != static_cast<size_t>(*laneCount))
return op->emitError("channel metadata length must match parent laneCount");
return success();
}
static LogicalResult verifyTensorBatchChannelSizes(Operation* op,
Type type,
ArrayRef<int64_t> channelIds,
ArrayRef<int32_t> sourceCoreIds,
ArrayRef<int32_t> targetCoreIds,
StringRef kind) {
if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size())
static LogicalResult verifyTensorBatchChannelSizes(
Operation* op, Type type, size_t channelCount, size_t sourceCoreCount, size_t targetCoreCount, StringRef kind) {
if (channelCount != sourceCoreCount || channelCount != targetCoreCount)
return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length");
auto laneCount = getParentBatchLaneCount(op);
if (failed(laneCount))
return op->emitError("must be nested inside spat.compute_batch");
if (channelIds.empty() || channelIds.size() % static_cast<size_t>(*laneCount) != 0)
if (channelCount == 0 || channelCount % static_cast<size_t>(*laneCount) != 0)
return op->emitError() << kind << " channel metadata length must be a positive multiple of parent laneCount";
auto shapedType = dyn_cast<ShapedType>(type);
@@ -168,7 +228,7 @@ static LogicalResult verifyTensorBatchChannelSizes(Operation* op,
if (elementBits <= 0 || elementBits % 8 != 0)
return op->emitError() << kind << " requires byte-sized elements";
int64_t chunkCount = static_cast<int64_t>(channelIds.size()) / *laneCount;
int64_t chunkCount = static_cast<int64_t>(channelCount) / *laneCount;
int64_t totalBytes = shapedType.getNumElements() * elementBits / 8;
if (totalBytes % chunkCount != 0)
return op->emitError() << kind << " tensor byte size must be divisible by the chunk count per lane";
@@ -176,28 +236,59 @@ static LogicalResult verifyTensorBatchChannelSizes(Operation* op,
return success();
}
static LogicalResult verifyBatchBody(Operation* op, Block& block, TypeRange outputTypes, size_t weightsPerLane) {
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
if (!yieldOp)
return op->emitError("body must terminate with spat.yield");
if (outputTypes.empty()) {
static Region* getParentRegion(Value value) {
if (auto blockArg = dyn_cast<BlockArgument>(value))
return blockArg.getOwner()->getParent();
if (Operation* definingOp = value.getDefiningOp())
return definingOp->getParentRegion();
return nullptr;
}
static bool isDefinedInsideRegion(Value value, Region& region) {
Region* parentRegion = getParentRegion(value);
return parentRegion && (&region == parentRegion || region.isAncestor(parentRegion));
}
static bool isConstantExternalValue(Value value) {
Operation* definingOp = value.getDefiningOp();
return definingOp && definingOp->hasTrait<OpTrait::ConstantLike>();
}
static LogicalResult verifyOnlyConstantExternalValues(Operation* ownerOp, Region& region, StringRef kind) {
bool hasFailure = false;
region.walk([&](Operation* op) {
for (OpOperand& operand : op->getOpOperands()) {
Value value = operand.get();
if (isDefinedInsideRegion(value, region) || isConstantExternalValue(value))
continue;
InFlightDiagnostic diagnostic = ownerOp->emitOpError()
<< kind << " body may only directly reference external constants";
diagnostic.attachNote(op->getLoc()) << "non-constant external operand #" << operand.getOperandNumber()
<< " is used by " << op->getName();
hasFailure = true;
}
});
return success(!hasFailure);
}
static LogicalResult verifyBatchBody(SpatComputeBatch batchOp, Block& block) {
if (batchOp.getNumResults() == 0) {
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
if (!yieldOp)
return batchOp.emitError("resultless compute_batch body must terminate with spat.yield");
if (yieldOp.getNumOperands() != 0)
return op->emitError("body yield must be empty when compute_batch has no results");
return batchOp.emitError("resultless compute_batch body yield must be empty");
}
else {
if (yieldOp.getNumOperands() != 1)
return op->emitError("body yield must produce exactly one value");
if (yieldOp.getOperand(0).getType() != outputTypes[0])
return op->emitError("body yield type must match output type");
else if (!isa_and_nonnull<SpatInParallelOp>(block.getTerminator())) {
return batchOp.emitError("resultful compute_batch body must terminate with spat.in_parallel");
}
BlockArgument laneArg = batchOp.getLaneArgument();
for (auto& bodyOp : block) {
if (auto wvmm = dyn_cast<SpatVMMOp>(&bodyOp))
if (wvmm.getWeightIndex() < 0 || static_cast<size_t>(wvmm.getWeightIndex()) >= weightsPerLane)
return op->emitError("compute_batch body Wvmm weightIndex is out of range for one lane");
if (auto wmvm = dyn_cast<SpatMVMOp>(&bodyOp))
if (wmvm.getWeightIndex() < 0 || static_cast<size_t>(wmvm.getWeightIndex()) >= weightsPerLane)
return op->emitError("compute_batch body Wmvm weightIndex is out of range for one lane");
if (auto extractSlice = dyn_cast<tensor::ExtractSliceOp>(&bodyOp))
if (failed(verifyStaticUnitStrideExtractSliceOp(extractSlice, laneArg, "tensor.extract_slice")))
return failure();
}
return success();
}
@@ -205,9 +296,9 @@ static LogicalResult verifyBatchBody(Operation* op, Block& block, TypeRange outp
} // namespace
LogicalResult SpatMVMOp::verify() {
auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex());
auto matrixShapeOpt = getWeightShapeForWeightedOp(getWeight());
if (failed(matrixShapeOpt))
return emitError("SpatMVMOp was not within a SpatCompute or Core op");
return emitError("weight must be a shaped value");
auto matrixShape = *matrixShapeOpt;
auto vectorShape = getInput().getType().getShape();
auto outputShape = getOutput().getType().getShape();
@@ -220,9 +311,9 @@ LogicalResult SpatMVMOp::verify() {
}
LogicalResult SpatVMMOp::verify() {
auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex());
auto matrixShapeOpt = getWeightShapeForWeightedOp(getWeight());
if (failed(matrixShapeOpt))
return emitError("SpatVMMOp was not within a SpatCompute or Core op");
return emitError("weight must be a shaped value");
auto matrixShape = *matrixShapeOpt;
auto vectorShape = getInput().getType().getShape();
auto outputShape = getOutput().getType().getShape();
@@ -338,8 +429,34 @@ LogicalResult SpatConcatOp::verify() {
return success();
}
LogicalResult verifyComputeResultsUses(Operation* op) {
if (!isa<SpatCompute, SpatComputeBatch>(op))
return op->emitError("verifyComputeResultUses: Op is not a SpatCompute/SpatComputeBatch operation");
if (!llvm::all_of(op->getResults(), [](Value result) {
return llvm::all_of(result.getUsers(), [](Operation* op) {
return !(op->getParentOfType<SpatCompute>() || op->getParentOfType<SpatComputeBatch>());
});
})) {
return op->emitError("ComputeResult used directly inside another Compute");
}
return success();
}
LogicalResult SpatCompute::verify() {
auto& block = getBody().front();
unsigned expectedArgCount = getWeights().size() + getInputs().size();
if (block.getNumArguments() != expectedArgCount)
return emitError("compute body must have weight and input block arguments");
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) {
if (getWeightArgument(weightIndex).getType() != weight.getType())
return emitError("compute weight block argument types must match weight operand types exactly");
}
for (auto [inputIndex, input] : llvm::enumerate(getInputs())) {
if (getInputArgument(inputIndex).getType() != input.getType())
return emitError("compute input block argument types must match input operand types exactly");
}
if (block.mightHaveTerminator()) {
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
if (!yieldOp)
@@ -372,54 +489,59 @@ LogicalResult SpatCompute::verify() {
}
}
for (auto arg : block.getArguments())
if (arg.use_empty())
for (unsigned inputIndex = 0; inputIndex < getInputs().size(); ++inputIndex)
if (getInputArgument(inputIndex).use_empty())
return emitError("ComputeOp block argument is not used");
if (failed(verifyOnlyConstantExternalValues(this->getOperation(), getBody(), "spat.compute")))
return failure();
if (failed(verifyComputeResultsUses(this->getOperation())))
return failure();
return success();
}
LogicalResult SpatChannelSendTensorOp::verify() {
return verifyTensorChannelSizes(getOperation(),
getInput().getType(),
getChannelIds(),
getSourceCoreIds(),
getTargetCoreIds(),
getChannelIds().size(),
getSourceCoreIds().size(),
getTargetCoreIds().size(),
"channel_send_tensor");
}
LogicalResult SpatChannelReceiveTensorOp::verify() {
return verifyTensorChannelSizes(getOperation(),
getOutput().getType(),
getChannelIds(),
getSourceCoreIds(),
getTargetCoreIds(),
getChannelIds().size(),
getSourceCoreIds().size(),
getTargetCoreIds().size(),
"channel_receive_tensor");
}
LogicalResult SpatChannelSendBatchOp::verify() {
return verifyBatchChannelSizes(getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
return verifyBatchChannelSizes(
getOperation(), getChannelIds().size(), getSourceCoreIds().size(), getTargetCoreIds().size());
}
LogicalResult SpatChannelSendTensorBatchOp::verify() {
return verifyTensorBatchChannelSizes(getOperation(),
getInput().getType(),
getChannelIds(),
getSourceCoreIds(),
getTargetCoreIds(),
getChannelIds().size(),
getSourceCoreIds().size(),
getTargetCoreIds().size(),
"channel_send_tensor_batch");
}
LogicalResult SpatChannelReceiveBatchOp::verify() {
return verifyBatchChannelSizes(getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
return verifyBatchChannelSizes(
getOperation(), getChannelIds().size(), getSourceCoreIds().size(), getTargetCoreIds().size());
}
LogicalResult SpatChannelReceiveTensorBatchOp::verify() {
return verifyTensorBatchChannelSizes(getOperation(),
getOutput().getType(),
getChannelIds(),
getSourceCoreIds(),
getTargetCoreIds(),
getChannelIds().size(),
getSourceCoreIds().size(),
getTargetCoreIds().size(),
"channel_receive_tensor_batch");
}
@@ -429,35 +551,6 @@ LogicalResult SpatComputeBatch::verify() {
return emitError("laneCount must be positive");
auto laneCountSz = static_cast<size_t>(count);
if (getWeights().size() % laneCountSz != 0)
return emitError("number of weights must be a multiple of laneCount");
if (!getInputs().empty() && getInputs().size() != laneCountSz)
return emitError("number of inputs must be either 0 or laneCount");
if (!getOutputs().empty() && getOutputs().size() != laneCountSz)
return emitError("number of outputs must be either 0 or laneCount");
size_t weightsPerLane = getWeights().size() / laneCountSz;
for (size_t weightIndex = 0; weightIndex < weightsPerLane; ++weightIndex) {
Type weightType = getWeights()[weightIndex].getType();
for (size_t lane = 1; lane < laneCountSz; ++lane)
if (getWeights()[lane * weightsPerLane + weightIndex].getType() != weightType)
return emitError("corresponding weights across lanes must have the same type");
}
if (!getInputs().empty()) {
Type inputType = getInputs()[0].getType();
for (Value in : getInputs().drop_front())
if (in.getType() != inputType)
return emitError("all inputs must have the same type");
}
if (!getOutputs().empty()) {
Type outputType = getOutputs()[0].getType();
for (Value out : getOutputs().drop_front())
if (out.getType() != outputType)
return emitError("all outputs must have the same type");
}
if (auto coreIdAttr = (*this)->getAttr(kCoreIdsAttrName)) {
auto coreIdsAttr = dyn_cast<DenseI32ArrayAttr>(coreIdAttr);
@@ -465,27 +558,66 @@ LogicalResult SpatComputeBatch::verify() {
return emitError("compute_batch coreIds attribute must be a dense i32 array");
if (coreIdsAttr.size() != static_cast<int64_t>(laneCountSz))
return emitError("compute_batch coreIds array length must match laneCount");
if (llvm::any_of(coreIdsAttr.asArrayRef(), [](int32_t coreId) { return coreId <= 0; }))
return emitError("compute_batch coreIds values must be positive");
llvm::SmallDenseSet<int32_t, 8> seenCoreIds;
if (llvm::any_of(coreIdsAttr.asArrayRef(), [](int32_t coreId) { return coreId < 0; }))
return emitError("compute_batch coreIds values must be non-negative");
DenseSet<int32_t> seenCoreIds;
for (int32_t coreId : coreIdsAttr.asArrayRef())
if (!seenCoreIds.insert(coreId).second)
return emitError("compute_batch coreIds values must be distinct");
return emitError("compute_batch coreIds values must be unique");
}
Block& block = getBody().front();
if (getInputs().empty()) {
if (block.getNumArguments() != 0)
return emitError("compute_batch body must have no block arguments when there are no inputs");
unsigned expectedArgCount = 1 + getWeights().size() + getInputs().size() + getNumResults();
if (block.getNumArguments() != expectedArgCount)
return emitError("compute_batch body must have lane, weight, input, and output block arguments");
if (!getLaneArgument().getType().isIndex())
return emitError("compute_batch first block argument must have index type");
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) {
if (getWeightArgument(weightIndex).getType() != weight.getType())
return emitError("compute_batch weight block argument types must match weight operand types exactly");
}
else {
if (block.getNumArguments() != 1)
return emitError("compute_batch body must have exactly one block argument");
if (block.getArgument(0).getType() != getInputs()[0].getType())
return emitError("body block argument type must match input type");
for (auto [inputIndex, input] : llvm::enumerate(getInputs())) {
BlockArgument blockArg = getInputArgument(inputIndex);
if (blockArg.getType() != input.getType())
return emitError("compute_batch input block argument types must match input operand types exactly");
}
for (auto [resultIndex, resultType] : llvm::enumerate(getResultTypes())) {
BlockArgument blockArg = getOutputArgument(resultIndex);
if (blockArg.getType() != resultType)
return emitError("compute_batch output block argument types must match result types exactly");
}
return verifyBatchBody(getOperation(), block, getResultTypes(), weightsPerLane);
if (failed(verifyComputeResultsUses(this->getOperation())))
return failure();
if (failed(verifyOnlyConstantExternalValues(this->getOperation(), getBody(), "spat.compute_batch")))
return failure();
return verifyBatchBody(*this, block);
}
LogicalResult SpatInParallelOp::verify() {
auto batchOp = getOperation()->getParentOfType<SpatComputeBatch>();
if (!batchOp)
return emitOpError("expected spat.compute_batch parent");
if (batchOp.getNumResults() == 0)
return emitOpError("requires a resultful spat.compute_batch parent");
BlockArgument laneArg = batchOp.getLaneArgument();
for (Operation& op : getRegion().front().getOperations()) {
auto insertSliceOp = dyn_cast<tensor::ParallelInsertSliceOp>(&op);
if (!insertSliceOp)
return emitOpError("expected only tensor.parallel_insert_slice ops");
if (failed(verifyStaticUnitStrideParallelInsertSliceOp(insertSliceOp, laneArg, "tensor.parallel_insert_slice")))
return failure();
MutableOperandRange destinations = insertSliceOp.getUpdatedDestinations();
for (OpOperand& destination : destinations)
if (!isBatchOutputArgument(batchOp, destination.get()))
return op.emitOpError("may only insert into a compute_batch output block argument");
}
return success();
}
} // namespace spatial
@@ -1,802 +1,19 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"
#include <algorithm>
#include <cstdlib>
#include <iterator>
#include <numeric>
#include <optional>
#include <queue>
#include <utility>
#include <vector>
#include "DCPAnalysis.hpp"
#include "Graph.hpp"
#include "../Scheduling/ComputeGraph.hpp"
#include "../Scheduling/DcpScheduler.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Support/TypeUtilities.hpp"
namespace onnx_mlir {
namespace spatial {
using namespace mlir;
namespace {
using SpatCompute = onnx_mlir::spatial::SpatCompute;
using SpatComputeBatch = onnx_mlir::spatial::SpatComputeBatch;
bool isDcpCoarsenDebugEnabled() { return std::getenv("DCP_COARSEN_DEBUG") != nullptr; }
struct VirtualNode {
SmallVector<size_t, 4> originalComputeIndices;
Weight weight = 0;
CrossbarUsage crossbarUsage = 0;
};
struct VirtualGraph {
std::vector<VirtualNode> nodes;
std::vector<IndexedEdge> edges;
};
struct TimingInfo {
std::vector<Time> aest;
std::vector<Time> alst;
std::vector<size_t> topologicalOrder;
bool valid = false;
};
struct WindowScheduleResult {
std::vector<std::vector<size_t>> mergeGroups;
CPU cpuCount = 0;
size_t mergedNodeCount = 0;
size_t maxMergeGroupSize = 0;
};
size_t getSchedulingCpuBudget() {
if (coresCount.getValue() > 0)
return static_cast<size_t>(coresCount.getValue());
return std::numeric_limits<size_t>::max();
}
size_t getBatchChunkTargetCount(int32_t laneCount) {
assert(laneCount > 0 && "laneCount must be positive");
return std::min(static_cast<size_t>(laneCount), std::max<size_t>(1, getSchedulingCpuBudget()));
}
ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex) {
size_t totalLanes = static_cast<size_t>(batch.getLaneCount());
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount());
size_t baseChunkSize = totalLanes / chunkCount;
size_t largeChunkCount = totalLanes % chunkCount;
size_t laneStart = chunkIndex * baseChunkSize + std::min(chunkIndex, largeChunkCount);
size_t laneCount = baseChunkSize + (chunkIndex < largeChunkCount ? 1 : 0);
return {batch.getOperation(), static_cast<uint32_t>(laneStart), static_cast<uint32_t>(laneCount)};
}
ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane) {
size_t totalLanes = static_cast<size_t>(batch.getLaneCount());
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount());
size_t baseChunkSize = totalLanes / chunkCount;
size_t largeChunkCount = totalLanes % chunkCount;
size_t largeChunkSpan = largeChunkCount * (baseChunkSize + 1);
size_t chunkIndex = 0;
if (static_cast<size_t>(lane) < largeChunkSpan)
chunkIndex = static_cast<size_t>(lane) / (baseChunkSize + 1);
else
chunkIndex = largeChunkCount + (static_cast<size_t>(lane) - largeChunkSpan) / baseChunkSize;
return getBatchChunkForIndex(batch, chunkIndex);
}
std::vector<IndexedEdge> aggregateEdges(ArrayRef<IndexedEdge> edges) {
llvm::DenseMap<std::pair<size_t, size_t>, Weight> edgeWeights;
for (auto [start, end, weight] : edges) {
size_t startIndex = static_cast<size_t>(start);
size_t endIndex = static_cast<size_t>(end);
if (startIndex == endIndex)
continue;
auto key = std::make_pair(startIndex, endIndex);
Weight edgeWeight = static_cast<Weight>(weight);
auto inserted = edgeWeights.try_emplace(key, edgeWeight);
if (!inserted.second)
inserted.first->second = std::max(inserted.first->second, edgeWeight);
}
std::vector<IndexedEdge> aggregatedEdges;
aggregatedEdges.reserve(edgeWeights.size());
for (auto [key, weight] : edgeWeights)
aggregatedEdges.push_back(
{static_cast<int64_t>(key.first), static_cast<int64_t>(key.second), static_cast<int64_t>(weight)});
llvm::sort(aggregatedEdges, [](const IndexedEdge& lhs, const IndexedEdge& rhs) {
if (std::get<0>(lhs) != std::get<0>(rhs))
return std::get<0>(lhs) < std::get<0>(rhs);
return std::get<1>(lhs) < std::get<1>(rhs);
});
return aggregatedEdges;
}
Weight getComputeBodyWeight(Region& body) {
constexpr Weight kOperationWeight = 100;
Weight numOperations = 0;
for (auto& block : body)
for ([[maybe_unused]] auto& op : block)
numOperations = checkedAdd(numOperations, static_cast<Weight>(1));
return checkedMultiply(numOperations, kOperationWeight);
}
CrossbarUsage getComputeBodyCrossbarUsage(Region& body) {
CrossbarUsage crossbarUsage = 0;
for (auto& block : body)
for (auto& op : block)
if (isa<SpatVMMOp>(op))
crossbarUsage = checkedAdd(crossbarUsage, static_cast<CrossbarUsage>(1));
return crossbarUsage;
}
Weight getComputeInstanceWeight(const ComputeInstance& instance) {
if (auto spatCompute = dyn_cast<SpatCompute>(instance.op))
return getSpatComputeWeight(spatCompute);
auto batch = cast<SpatComputeBatch>(instance.op);
return checkedMultiply(getComputeBodyWeight(batch.getBody()), static_cast<Weight>(instance.laneCount));
}
CrossbarUsage getComputeInstanceCrossbarUsage(const ComputeInstance& instance) {
if (auto spatCompute = dyn_cast<SpatCompute>(instance.op))
return getSpatComputeCrossbarUsage(spatCompute);
auto batch = cast<SpatComputeBatch>(instance.op);
return checkedMultiply(getComputeBodyCrossbarUsage(batch.getBody()), static_cast<CrossbarUsage>(instance.laneCount));
}
SmallVector<Value, 4> getComputeInstanceInputs(const ComputeInstance& instance) {
if (auto spatCompute = dyn_cast<SpatCompute>(instance.op))
return SmallVector<Value, 4>(spatCompute.getInputs().begin(), spatCompute.getInputs().end());
auto batch = cast<SpatComputeBatch>(instance.op);
SmallVector<Value, 4> inputs;
inputs.reserve(instance.laneCount);
for (uint32_t lane = instance.laneStart; lane < instance.laneStart + instance.laneCount; ++lane)
inputs.push_back(batch.getInputs()[lane]);
return inputs;
}
std::optional<ComputeInstance> getOriginalComputeInstance(Value value) {
Operation* op = value.getDefiningOp();
if (!op)
return std::nullopt;
while (auto extract = dyn_cast<tensor::ExtractSliceOp>(op)) {
value = extract.getSource();
op = value.getDefiningOp();
if (!op)
return std::nullopt;
}
if (auto spatCompute = dyn_cast<SpatCompute>(op))
return ComputeInstance {spatCompute.getOperation(), 0, 1};
if (auto batch = dyn_cast<SpatComputeBatch>(op))
return getBatchChunkForLane(batch, static_cast<uint32_t>(cast<OpResult>(value).getResultNumber()));
return std::nullopt;
}
SmallVector<ComputeInstance> collectComputeInstances(Operation* entryOp) {
SmallVector<ComputeInstance> instances;
auto isUsedAsWeightOnly = [](Operation* producerOp) {
if (producerOp->getNumResults() == 0)
return false;
for (Value result : producerOp->getResults()) {
if (result.use_empty())
return false;
for (Operation* user : result.getUsers()) {
if (auto compute = dyn_cast<SpatCompute>(user)) {
if (!llvm::is_contained(compute.getWeights(), result))
return false;
continue;
}
if (auto batch = dyn_cast<SpatComputeBatch>(user)) {
if (!llvm::is_contained(batch.getWeights(), result))
return false;
continue;
}
return false;
}
}
return true;
};
for (Region& region : entryOp->getRegions()) {
for (Block& block : region) {
for (Operation& op : block) {
if (auto spatCompute = dyn_cast<SpatCompute>(&op)) {
if (isUsedAsWeightOnly(spatCompute.getOperation()))
continue;
instances.push_back({spatCompute.getOperation(), 0, 1});
continue;
}
if (auto batch = dyn_cast<SpatComputeBatch>(&op)) {
if (isUsedAsWeightOnly(batch.getOperation()))
continue;
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount());
for (size_t chunkIndex = 0; chunkIndex < chunkCount; ++chunkIndex)
instances.push_back(getBatchChunkForIndex(batch, chunkIndex));
}
}
}
}
return instances;
}
VirtualGraph buildInitialVirtualGraph(ArrayRef<ComputeInstance> computeInstances, ArrayRef<IndexedEdge> edges) {
VirtualGraph graph;
graph.nodes.reserve(computeInstances.size());
for (auto [index, computeInstance] : llvm::enumerate(computeInstances)) {
VirtualNode node;
node.originalComputeIndices.push_back(index);
node.weight = getComputeInstanceWeight(computeInstance);
node.crossbarUsage = getComputeInstanceCrossbarUsage(computeInstance);
graph.nodes.push_back(std::move(node));
}
graph.edges = aggregateEdges(edges);
return graph;
}
TimingInfo computeTiming(const VirtualGraph& graph) {
TimingInfo timing;
size_t nodeCount = graph.nodes.size();
timing.aest.assign(nodeCount, 0);
timing.alst.assign(nodeCount, 0);
timing.topologicalOrder.reserve(nodeCount);
std::vector<std::vector<std::pair<size_t, Weight>>> parents(nodeCount);
std::vector<std::vector<std::pair<size_t, Weight>>> children(nodeCount);
std::vector<size_t> incomingEdgeCount(nodeCount, 0);
for (auto [start, end, weight] : graph.edges) {
size_t startIndex = static_cast<size_t>(start);
size_t endIndex = static_cast<size_t>(end);
Weight edgeWeight = static_cast<Weight>(weight);
assert(startIndex < nodeCount && endIndex < nodeCount && "virtual edge endpoint out of range");
children[startIndex].push_back({endIndex, edgeWeight});
parents[endIndex].push_back({startIndex, edgeWeight});
incomingEdgeCount[endIndex]++;
}
auto getVirtualNodeOrderKey = [&](size_t nodeIndex) {
const VirtualNode& node = graph.nodes[nodeIndex];
if (!node.originalComputeIndices.empty())
return node.originalComputeIndices.front();
return nodeIndex;
};
auto readyNodeGreater = [&](size_t lhs, size_t rhs) {
size_t lhsKey = getVirtualNodeOrderKey(lhs);
size_t rhsKey = getVirtualNodeOrderKey(rhs);
if (lhsKey != rhsKey)
return lhsKey > rhsKey;
return lhs > rhs;
};
std::priority_queue<size_t, std::vector<size_t>, decltype(readyNodeGreater)> readyNodes(readyNodeGreater);
for (size_t i = 0; i < nodeCount; ++i)
if (incomingEdgeCount[i] == 0)
readyNodes.push(i);
while (!readyNodes.empty()) {
size_t current = readyNodes.top();
readyNodes.pop();
timing.topologicalOrder.push_back(current);
for (auto [child, weight] : children[current]) {
(void) weight;
assert(incomingEdgeCount[child] > 0 && "incoming edge count underflow");
incomingEdgeCount[child]--;
if (incomingEdgeCount[child] == 0)
readyNodes.push(child);
}
}
if (timing.topologicalOrder.size() != nodeCount)
return timing;
Time dcpl = 0;
for (size_t nodeIndex : timing.topologicalOrder) {
Time maxParentAest = 0;
for (auto [parent, transferCost] : parents[nodeIndex]) {
maxParentAest =
std::max(maxParentAest, addOrMax(addOrMax(timing.aest[parent], graph.nodes[parent].weight), transferCost));
}
timing.aest[nodeIndex] = maxParentAest;
dcpl = std::max(dcpl, addOrMax(maxParentAest, graph.nodes[nodeIndex].weight));
}
for (size_t nodeIndex : llvm::reverse(timing.topologicalOrder)) {
Time minAlst = std::numeric_limits<Time>::max();
if (children[nodeIndex].empty())
minAlst = subtractOrZero(dcpl, graph.nodes[nodeIndex].weight);
for (auto [child, transferCost] : children[nodeIndex]) {
minAlst =
std::min(minAlst, subtractOrZero(timing.alst[child], addOrMax(graph.nodes[nodeIndex].weight, transferCost)));
}
timing.alst[nodeIndex] = minAlst;
}
timing.valid = true;
return timing;
}
std::vector<std::vector<size_t>> buildUndirectedAdjacency(const VirtualGraph& graph) {
std::vector<std::vector<size_t>> adjacency(graph.nodes.size());
for (auto [start, end, weight] : graph.edges) {
(void) weight;
size_t startIndex = static_cast<size_t>(start);
size_t endIndex = static_cast<size_t>(end);
assert(startIndex < graph.nodes.size() && endIndex < graph.nodes.size() && "virtual edge endpoint out of range");
adjacency[startIndex].push_back(endIndex);
adjacency[endIndex].push_back(startIndex);
}
for (auto& neighbours : adjacency) {
llvm::sort(neighbours);
neighbours.erase(std::unique(neighbours.begin(), neighbours.end()), neighbours.end());
}
return adjacency;
}
std::vector<size_t> selectCriticalWindow(const VirtualGraph& graph, const TimingInfo& timing, size_t windowSize) {
std::vector<size_t> ranked(timing.aest.size());
std::iota(ranked.begin(), ranked.end(), 0);
auto isHigherPriority = [&](size_t lhs, size_t rhs) {
Time lhsSlack = slackOrZero(timing.aest[lhs], timing.alst[lhs]);
Time rhsSlack = slackOrZero(timing.aest[rhs], timing.alst[rhs]);
if (lhsSlack != rhsSlack)
return lhsSlack < rhsSlack;
if (timing.aest[lhs] != timing.aest[rhs])
return timing.aest[lhs] < timing.aest[rhs];
return lhs < rhs;
};
windowSize = std::min(windowSize, ranked.size());
if (windowSize == 0)
return {};
if (windowSize == ranked.size()) {
llvm::sort(ranked, isHigherPriority);
return ranked;
}
size_t criticalPoolSize = std::min(ranked.size(), std::max(windowSize, windowSize * 2));
if (criticalPoolSize < ranked.size())
std::nth_element(
ranked.begin(), ranked.begin() + static_cast<std::ptrdiff_t>(criticalPoolSize), ranked.end(), isHigherPriority);
std::vector<char> inCriticalPool(ranked.size(), false);
for (size_t i = 0; i < criticalPoolSize; ++i)
inCriticalPool[ranked[i]] = true;
size_t seed = *std::min_element(ranked.begin(), ranked.end(), isHigherPriority);
std::vector<std::vector<size_t>> adjacency = buildUndirectedAdjacency(graph);
std::vector<size_t> selected;
std::vector<char> inWindow(ranked.size(), false);
selected.reserve(windowSize);
struct FrontierEntry {
size_t node;
};
auto frontierCompare = [&](FrontierEntry lhs, FrontierEntry rhs) { return isHigherPriority(rhs.node, lhs.node); };
std::priority_queue<FrontierEntry, std::vector<FrontierEntry>, decltype(frontierCompare)> frontier(frontierCompare);
auto addToWindow = [&](size_t node, const std::vector<char>& eligible) {
if (inWindow[node])
return;
inWindow[node] = true;
selected.push_back(node);
for (size_t neighbour : adjacency[node])
if (!inWindow[neighbour] && eligible[neighbour])
frontier.push({neighbour});
};
addToWindow(seed, inCriticalPool);
while (!frontier.empty() && selected.size() < windowSize) {
size_t node = frontier.top().node;
frontier.pop();
if (!inWindow[node])
addToWindow(node, inCriticalPool);
}
if (selected.size() < windowSize) {
std::vector<char> anyNode(ranked.size(), true);
for (size_t node : selected)
for (size_t neighbour : adjacency[node])
if (!inWindow[neighbour])
frontier.push({neighbour});
while (!frontier.empty() && selected.size() < windowSize) {
size_t node = frontier.top().node;
frontier.pop();
if (!inWindow[node])
addToWindow(node, anyNode);
}
}
if (selected.size() < windowSize) {
llvm::sort(ranked, isHigherPriority);
for (size_t node : ranked) {
if (selected.size() == windowSize)
break;
if (!inWindow[node]) {
inWindow[node] = true;
selected.push_back(node);
}
}
}
llvm::sort(selected, isHigherPriority);
return selected;
}
std::vector<IndexedEdge> buildWindowEdges(const VirtualGraph& graph, const std::vector<int64_t>& nodeToWindowIndex) {
std::vector<IndexedEdge> windowEdges;
windowEdges.reserve(graph.edges.size());
for (auto [start, end, weight] : graph.edges) {
int64_t mappedStart = nodeToWindowIndex[static_cast<size_t>(start)];
int64_t mappedEnd = nodeToWindowIndex[static_cast<size_t>(end)];
if (mappedStart == -1 || mappedEnd == -1)
continue;
windowEdges.push_back({mappedStart, mappedEnd, weight});
}
return aggregateEdges(windowEdges);
}
WindowScheduleResult scheduleWindow(const VirtualGraph& graph, ArrayRef<size_t> selectedNodes, MLIRContext* context) {
std::vector<Weight> windowWeights;
std::vector<CrossbarUsage> windowCrossbarUsage;
std::vector<int64_t> windowNodeOrderKeys;
std::vector<int64_t> nodeToWindowIndex(graph.nodes.size(), -1);
windowWeights.reserve(selectedNodes.size());
windowCrossbarUsage.reserve(selectedNodes.size());
windowNodeOrderKeys.reserve(selectedNodes.size());
for (auto [windowIndex, nodeIndex] : llvm::enumerate(selectedNodes)) {
nodeToWindowIndex[nodeIndex] = static_cast<int64_t>(windowIndex);
windowWeights.push_back(graph.nodes[nodeIndex].weight);
windowCrossbarUsage.push_back(graph.nodes[nodeIndex].crossbarUsage);
windowNodeOrderKeys.push_back(static_cast<int64_t>(nodeIndex));
}
GraphDCP windowGraph(
windowWeights, buildWindowEdges(graph, nodeToWindowIndex), windowNodeOrderKeys, windowCrossbarUsage);
if (coresCount.getValue() > 0)
windowGraph.setMaxCpuCount(static_cast<int>(coresCount.getValue()));
windowGraph.setContext(context);
windowGraph.runDcp();
WindowScheduleResult result;
result.cpuCount = windowGraph.cpuCount();
for (CPU cpu = 0; cpu < windowGraph.cpuCount(); ++cpu) {
auto scheduledTasks = windowGraph.getScheduledTasks(cpu);
if (scheduledTasks.size() < 2)
continue;
result.mergedNodeCount += scheduledTasks.size();
result.maxMergeGroupSize = std::max(result.maxMergeGroupSize, scheduledTasks.size());
std::vector<size_t> mergeGroup;
mergeGroup.reserve(scheduledTasks.size());
for (const auto& task : scheduledTasks)
mergeGroup.push_back(selectedNodes[task.nodeIndex]);
result.mergeGroups.push_back(std::move(mergeGroup));
}
return result;
}
bool coarsenGraph(const VirtualGraph& graph,
ArrayRef<std::vector<size_t>> mergeGroups,
VirtualGraph& coarsenedGraph,
std::vector<size_t>& oldToNewNode) {
TimingInfo timing = computeTiming(graph);
std::vector<size_t> topologicalRank(graph.nodes.size());
std::iota(topologicalRank.begin(), topologicalRank.end(), 0);
if (timing.valid)
for (auto [rank, nodeIndex] : llvm::enumerate(timing.topologicalOrder))
topologicalRank[nodeIndex] = rank;
std::vector<std::vector<size_t>> orderedMergeGroups;
orderedMergeGroups.reserve(mergeGroups.size());
for (const auto& mergeGroup : mergeGroups) {
orderedMergeGroups.emplace_back(mergeGroup.begin(), mergeGroup.end());
std::stable_sort(orderedMergeGroups.back().begin(), orderedMergeGroups.back().end(), [&](size_t lhs, size_t rhs) {
if (topologicalRank[lhs] != topologicalRank[rhs])
return topologicalRank[lhs] < topologicalRank[rhs];
return lhs < rhs;
});
}
std::vector<int64_t> nodeToMergeGroup(graph.nodes.size(), -1);
for (auto [groupIndex, mergeGroup] : llvm::enumerate(orderedMergeGroups)) {
if (mergeGroup.size() < 2)
continue;
for (size_t nodeIndex : mergeGroup) {
assert(nodeIndex < graph.nodes.size() && "merge group node out of range");
nodeToMergeGroup[nodeIndex] = static_cast<int64_t>(groupIndex);
}
}
std::vector<std::optional<size_t>> mergeGroupToNewNode(orderedMergeGroups.size());
std::vector<size_t> newNodeRank;
oldToNewNode.assign(graph.nodes.size(), 0);
bool mergedAny = false;
coarsenedGraph.nodes.clear();
coarsenedGraph.edges.clear();
coarsenedGraph.nodes.reserve(graph.nodes.size());
newNodeRank.reserve(graph.nodes.size());
for (size_t nodeIndex = 0; nodeIndex < graph.nodes.size(); ++nodeIndex) {
int64_t mergeGroupIndex = nodeToMergeGroup[nodeIndex];
if (mergeGroupIndex == -1) {
oldToNewNode[nodeIndex] = coarsenedGraph.nodes.size();
coarsenedGraph.nodes.push_back(graph.nodes[nodeIndex]);
newNodeRank.push_back(topologicalRank[nodeIndex]);
continue;
}
auto& newNodeIndex = mergeGroupToNewNode[static_cast<size_t>(mergeGroupIndex)];
if (newNodeIndex.has_value()) {
oldToNewNode[nodeIndex] = *newNodeIndex;
continue;
}
VirtualNode mergedNode;
for (size_t memberIndex : orderedMergeGroups[static_cast<size_t>(mergeGroupIndex)]) {
const VirtualNode& memberNode = graph.nodes[memberIndex];
mergedNode.originalComputeIndices.append(memberNode.originalComputeIndices.begin(),
memberNode.originalComputeIndices.end());
mergedNode.weight = addOrMax(mergedNode.weight, memberNode.weight);
mergedNode.crossbarUsage = addOrMax(mergedNode.crossbarUsage, memberNode.crossbarUsage);
}
std::sort(mergedNode.originalComputeIndices.begin(), mergedNode.originalComputeIndices.end());
mergedAny = true;
newNodeIndex = coarsenedGraph.nodes.size();
for (size_t memberIndex : orderedMergeGroups[static_cast<size_t>(mergeGroupIndex)])
oldToNewNode[memberIndex] = *newNodeIndex;
newNodeRank.push_back(topologicalRank[orderedMergeGroups[static_cast<size_t>(mergeGroupIndex)].front()]);
coarsenedGraph.nodes.push_back(std::move(mergedNode));
}
if (!mergedAny)
return false;
std::vector<IndexedEdge> remappedEdges;
remappedEdges.reserve(graph.edges.size());
for (auto [start, end, weight] : graph.edges) {
size_t newStart = oldToNewNode[static_cast<size_t>(start)];
size_t newEnd = oldToNewNode[static_cast<size_t>(end)];
if (newStart == newEnd)
continue;
if (newNodeRank[newStart] >= newNodeRank[newEnd])
continue;
remappedEdges.push_back({static_cast<int64_t>(newStart), static_cast<int64_t>(newEnd), weight});
}
coarsenedGraph.edges = aggregateEdges(remappedEdges);
return true;
}
CPU getVirtualGraphMaxCpuCount() { return static_cast<CPU>(getSchedulingCpuBudget()); }
size_t getDcpCoarseningWindowSize(size_t nodeCount) {
size_t windowSize = std::min(dcpCriticalWindowSize.getValue(), nodeCount);
CPU maxCpuCount = std::max<CPU>(1, getVirtualGraphMaxCpuCount());
if (nodeCount > static_cast<size_t>(maxCpuCount))
windowSize = std::max(windowSize, std::min(nodeCount, static_cast<size_t>(maxCpuCount) + 1));
return windowSize;
}
DCPAnalysisResult buildResultFromVirtualGraph(const VirtualGraph& graph, ArrayRef<ComputeInstance> computeInstances) {
DCPAnalysisResult result;
TimingInfo timing = computeTiming(graph);
std::vector<size_t> virtualNodeOrder;
if (timing.valid) {
virtualNodeOrder = std::move(timing.topologicalOrder);
}
else {
virtualNodeOrder.resize(graph.nodes.size());
std::iota(virtualNodeOrder.begin(), virtualNodeOrder.end(), 0);
}
std::vector<size_t> originalComputeToCpu(computeInstances.size(), 0);
for (auto [cpu, virtualNodeIndex] : llvm::enumerate(virtualNodeOrder)) {
const VirtualNode& virtualNode = graph.nodes[virtualNodeIndex];
for (size_t originalIndex : virtualNode.originalComputeIndices)
originalComputeToCpu[originalIndex] = cpu;
}
result.dominanceOrderCompute.reserve(computeInstances.size());
llvm::DenseMap<size_t, size_t> nextCpuSlot;
for (auto [originalIndex, computeInstance] : llvm::enumerate(computeInstances)) {
size_t cpu = originalComputeToCpu[originalIndex];
result.dominanceOrderCompute.push_back(computeInstance);
result.computeToCpuMap[computeInstance] = cpu;
result.computeToCpuSlotMap[computeInstance] = nextCpuSlot[cpu]++;
result.computeToAestMap[computeInstance] = originalIndex;
result.cpuToLastComputeMap[cpu] = computeInstance;
}
for (const auto& [cpu, lastCompute] : result.cpuToLastComputeMap)
result.isLastComputeOfCpu.insert(lastCompute);
return result;
}
DCPAnalysisResult buildResultFromScheduledGraph(GraphDCP& graphDCP, ArrayRef<ComputeInstance> computeInstances) {
DCPAnalysisResult result;
result.dominanceOrderCompute.assign(computeInstances.begin(), computeInstances.end());
for (CPU cpu = 0; cpu < graphDCP.cpuCount(); ++cpu) {
auto scheduledTasks = graphDCP.getScheduledTasks(cpu);
if (scheduledTasks.empty())
continue;
for (auto [slot, task] : llvm::enumerate(scheduledTasks)) {
ComputeInstance instance = computeInstances[task.nodeIndex];
result.computeToCpuMap[instance] = cpu;
result.computeToCpuSlotMap[instance] = slot;
result.computeToAestMap[instance] = static_cast<uint64_t>(task.aest);
}
result.cpuToLastComputeMap[cpu] = computeInstances[scheduledTasks.back().nodeIndex];
result.isLastComputeOfCpu.insert(computeInstances[scheduledTasks.back().nodeIndex]);
}
return result;
}
DCPAnalysisResult
runLegacyDcp(ArrayRef<ComputeInstance> computeInstances, ArrayRef<IndexedEdge> edges, MLIRContext* context) {
SmallVector<Weight> nodeWeights;
SmallVector<CrossbarUsage> nodeCrossbarUsage;
SmallVector<int64_t> nodeOrderKeys;
nodeWeights.reserve(computeInstances.size());
nodeCrossbarUsage.reserve(computeInstances.size());
nodeOrderKeys.reserve(computeInstances.size());
for (auto [index, instance] : llvm::enumerate(computeInstances)) {
nodeWeights.push_back(getComputeInstanceWeight(instance));
nodeCrossbarUsage.push_back(getComputeInstanceCrossbarUsage(instance));
nodeOrderKeys.push_back(static_cast<int64_t>(index));
}
GraphDCP graphDCP(nodeWeights, edges, nodeOrderKeys, nodeCrossbarUsage);
if (coresCount.getValue() > 0)
graphDCP.setMaxCpuCount(static_cast<int>(coresCount.getValue()));
graphDCP.setContext(context);
graphDCP.runDcp();
return buildResultFromScheduledGraph(graphDCP, computeInstances);
}
} // namespace
SpatCompute getOriginalSpatCompute(Operation* op) {
if (!op)
return {};
while (auto extract = dyn_cast<tensor::ExtractSliceOp>(op)) {
op = extract.getSource().getDefiningOp();
if (!op)
return {};
}
if (auto res = dyn_cast<SpatCompute>(op))
return res;
return {};
}
DCPAnalysisResult DCPAnalysis::run() {
SmallVector<ComputeInstance> computeInstances = collectComputeInstances(entryOp);
SmallVector<IndexedEdge, 10> edges;
llvm::DenseMap<ComputeInstance, size_t> instanceToIndex;
instanceToIndex.reserve(computeInstances.size());
for (auto [index, instance] : llvm::enumerate(computeInstances))
instanceToIndex[instance] = index;
for (auto [indexEndEdge, computeInstance] : llvm::enumerate(computeInstances)) {
for (Value input : getComputeInstanceInputs(computeInstance)) {
if (auto producerInstance = getOriginalComputeInstance(input)) {
auto producerIt = instanceToIndex.find(*producerInstance);
assert(producerIt != instanceToIndex.end());
auto indexStartEdge = producerIt->second;
edges.push_back({static_cast<int64_t>(indexStartEdge),
static_cast<int64_t>(indexEndEdge),
static_cast<int64_t>(getSizeInBytes(cast<ShapedType>(input.getType())))});
}
}
}
if (coresCount.getValue() > 0) {
size_t schedulingCpuBudget = getSchedulingCpuBudget();
bool needsExactScheduledBatches = llvm::any_of(computeInstances, [&](const ComputeInstance& instance) {
auto batch = dyn_cast<SpatComputeBatch>(instance.op);
return batch && static_cast<size_t>(batch.getLaneCount()) > schedulingCpuBudget;
});
if (needsExactScheduledBatches)
return runLegacyDcp(computeInstances, edges, entryOp->getContext());
}
if (dcpCriticalWindowSize.getValue() == 0)
return runLegacyDcp(computeInstances, edges, entryOp->getContext());
VirtualGraph virtualGraph = buildInitialVirtualGraph(computeInstances, edges);
size_t iteration = 0;
bool debugCoarsening = isDcpCoarsenDebugEnabled();
auto tryCoarsenSelectedNodes = [&](ArrayRef<size_t> selectedNodes) {
size_t oldNodeCount = virtualGraph.nodes.size();
WindowScheduleResult windowSchedule = scheduleWindow(virtualGraph, selectedNodes, entryOp->getContext());
if (windowSchedule.mergeGroups.empty()) {
if (debugCoarsening && oldNodeCount >= 200)
llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} windowCpus={3} "
"groups=0 mergedNodes=0 maxGroup=0 new={1} changed=0\n",
iteration,
oldNodeCount,
selectedNodes.size(),
windowSchedule.cpuCount);
return false;
}
VirtualGraph coarsenedGraph;
std::vector<size_t> oldToNewNode;
if (!coarsenGraph(virtualGraph, windowSchedule.mergeGroups, coarsenedGraph, oldToNewNode))
return false;
if (debugCoarsening && (oldNodeCount >= 200 || coarsenedGraph.nodes.size() >= 200))
llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} windowCpus={3} "
"groups={4} mergedNodes={5} maxGroup={6} new={7} changed={8}\n",
iteration,
oldNodeCount,
selectedNodes.size(),
windowSchedule.cpuCount,
windowSchedule.mergeGroups.size(),
windowSchedule.mergedNodeCount,
windowSchedule.maxMergeGroupSize,
coarsenedGraph.nodes.size(),
oldNodeCount - coarsenedGraph.nodes.size());
virtualGraph = std::move(coarsenedGraph);
return true;
};
while (virtualGraph.nodes.size() > 1) {
if (virtualGraph.nodes.size() <= getSchedulingCpuBudget()) {
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
llvm::errs() << llvm::formatv(
"[DCP-COARSEN] iter={0} old={1} stop=cpu-budget\n", iteration, virtualGraph.nodes.size());
break;
}
iteration++;
TimingInfo timing = computeTiming(virtualGraph);
if (!timing.valid) {
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
llvm::errs() << llvm::formatv(
"[DCP-COARSEN] iter={0} old={1} invalid-timing\n", iteration, virtualGraph.nodes.size());
break;
}
SmallVector<size_t> selectedNodes;
auto criticalWindow =
selectCriticalWindow(virtualGraph, timing, getDcpCoarseningWindowSize(virtualGraph.nodes.size()));
selectedNodes.append(criticalWindow.begin(), criticalWindow.end());
if (selectedNodes.size() < 2) {
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} stop=small-window\n",
iteration,
virtualGraph.nodes.size(),
selectedNodes.size());
break;
}
if (tryCoarsenSelectedNodes(selectedNodes))
continue;
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
llvm::errs() << llvm::formatv(
"[DCP-COARSEN] iter={0} old={1} stop=no-merge\n", iteration, virtualGraph.nodes.size());
break;
}
return buildResultFromVirtualGraph(virtualGraph, computeInstances);
ComputeGraph graph = buildComputeGraph(entryOp);
DcpScheduleOptions options;
if (coresCount.getValue() > 0)
options.processorCount = static_cast<size_t>(coresCount.getValue());
options.criticalWindowSize = dcpCriticalWindowSize.getValue();
options.allowFallbackForAutoCoreCount = true;
return runDcpScheduler(graph, options, entryOp->getContext());
}
} // namespace spatial
@@ -2,64 +2,27 @@
#include "mlir/IR/Operation.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h"
#include <cstdint>
#include <vector>
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
// A scheduling identity that covers both spat.compute and scheduled shards of
// spat.compute_batch.
struct ComputeInstance {
mlir::Operation* op = nullptr;
uint32_t laneStart = 0;
uint32_t laneCount = 1;
bool operator==(const ComputeInstance& other) const {
return op == other.op && laneStart == other.laneStart && laneCount == other.laneCount;
}
};
struct DCPAnalysisResult {
std::vector<ComputeInstance> dominanceOrderCompute;
llvm::DenseMap<ComputeInstance, size_t> computeToCpuMap;
llvm::DenseMap<ComputeInstance, size_t> computeToCpuSlotMap;
llvm::DenseMap<ComputeInstance, uint64_t> computeToAestMap;
llvm::DenseSet<ComputeInstance> isLastComputeOfCpu;
llvm::DenseMap<size_t, ComputeInstance> cpuToLastComputeMap;
};
#include "../Scheduling/MergeSchedule.hpp"
namespace onnx_mlir {
namespace spatial {
using DCPAnalysisResult = MergeScheduleResult;
struct DCPAnalysis {
private:
DCPAnalysisResult result;
mlir::Operation* entryOp;
mlir::Operation *entryOp;
DCPAnalysisResult run();
public:
DCPAnalysis(mlir::Operation* op)
DCPAnalysis(mlir::Operation *op)
: entryOp(op) {
result = run();
}
DCPAnalysisResult& getResult() { return result; }
DCPAnalysisResult &getResult() { return result; }
};
} // namespace spatial
} // namespace onnx_mlir
namespace llvm {
template <>
struct DenseMapInfo<ComputeInstance> {
static ComputeInstance getEmptyKey() {
return {DenseMapInfo<mlir::Operation*>::getEmptyKey(), UINT32_MAX, UINT32_MAX};
}
static ComputeInstance getTombstoneKey() {
return {DenseMapInfo<mlir::Operation*>::getTombstoneKey(), UINT32_MAX, UINT32_MAX};
}
static unsigned getHashValue(const ComputeInstance& v) { return llvm::hash_combine(v.op, v.laneStart, v.laneCount); }
static bool isEqual(const ComputeInstance& a, const ComputeInstance& b) { return a == b; }
};
} // namespace llvm
using DCPAnalysisResult = onnx_mlir::spatial::DCPAnalysisResult;
@@ -0,0 +1,18 @@
#pragma once
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Support/LogicalResult.h"
#include "Scheduling/MergeSchedule.hpp"
namespace onnx_mlir {
namespace spatial {
class MergeScheduleMaterializer {
public:
mlir::LogicalResult
run(mlir::func::FuncOp func, const MergeScheduleResult &schedule, int64_t &nextChannelId);
};
} // namespace spatial
} // namespace onnx_mlir
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,743 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/Hashing.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"
#include <chrono>
#include <cstdlib>
#include <limits>
#include <optional>
#include "PostMergeCompaction.hpp"
#include "RegularOpCompaction.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
using SpatCompute = spatial::SpatCompute;
using SpatComputeBatch = spatial::SpatComputeBatch;
bool isMergeProfilingEnabled() { return std::getenv("RAPTOR_PROFILE_MERGE") != nullptr; }
class ScopedMergePhaseTimer {
public:
explicit ScopedMergePhaseTimer(StringRef phaseName)
: enabled(isMergeProfilingEnabled()), phase(phaseName.str()) {
if (enabled)
start = std::chrono::steady_clock::now();
}
~ScopedMergePhaseTimer() {
if (!enabled)
return;
auto elapsed = std::chrono::steady_clock::now() - start;
double millis = std::chrono::duration<double, std::milli>(elapsed).count();
llvm::errs() << "[merge-profile] " << phase << ": " << llvm::formatv("{0:F3}", millis) << " ms\n";
}
private:
bool enabled = false;
std::string phase;
std::chrono::steady_clock::time_point start;
};
std::optional<int32_t> getComputeCoreId(SpatCompute compute) {
if (auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
return static_cast<int32_t>(coreIdAttr.getInt());
return std::nullopt;
}
static constexpr StringLiteral kRebatchPhaseAttrName = "_pim_rebatch_phase";
static FailureOr<int64_t> getConstantI64Value(Value value) {
APInt constantValue;
if (!matchPattern(value, m_ConstantInt(&constantValue)))
return failure();
return constantValue.getSExtValue();
}
static FailureOr<int32_t> getConstantI32Value(Value value) {
APInt constantValue;
if (!matchPattern(value, m_ConstantInt(&constantValue)))
return failure();
return static_cast<int32_t>(constantValue.getSExtValue());
}
static bool getScalarChannelMetadata(spatial::SpatChannelSendOp op,
uint64_t& channelId,
uint32_t& sourceCoreId,
uint32_t& targetCoreId) {
FailureOr<int64_t> constantChannelId = getConstantI64Value(op.getChannelId());
FailureOr<int32_t> constantSourceCoreId = getConstantI32Value(op.getSourceCoreId());
FailureOr<int32_t> constantTargetCoreId = getConstantI32Value(op.getTargetCoreId());
if (failed(constantChannelId) || failed(constantSourceCoreId) || failed(constantTargetCoreId))
return false;
channelId = static_cast<uint64_t>(*constantChannelId);
sourceCoreId = static_cast<uint32_t>(*constantSourceCoreId);
targetCoreId = static_cast<uint32_t>(*constantTargetCoreId);
return true;
}
static bool getScalarChannelMetadata(spatial::SpatChannelReceiveOp op,
uint64_t& channelId,
uint32_t& sourceCoreId,
uint32_t& targetCoreId) {
FailureOr<int64_t> constantChannelId = getConstantI64Value(op.getChannelId());
FailureOr<int32_t> constantSourceCoreId = getConstantI32Value(op.getSourceCoreId());
FailureOr<int32_t> constantTargetCoreId = getConstantI32Value(op.getTargetCoreId());
if (failed(constantChannelId) || failed(constantSourceCoreId) || failed(constantTargetCoreId))
return false;
channelId = static_cast<uint64_t>(*constantChannelId);
sourceCoreId = static_cast<uint32_t>(*constantSourceCoreId);
targetCoreId = static_cast<uint32_t>(*constantTargetCoreId);
return true;
}
static SmallVector<Value> createIndexConstants(Operation* anchorOp, ArrayRef<int64_t> values, OperationFolder& folder) {
SmallVector<Value> constants;
constants.reserve(values.size());
for (int64_t value : values)
constants.push_back(getOrCreateHostIndexConstant(anchorOp, value, folder));
return constants;
}
static SmallVector<Value> createIndexConstants(Operation* anchorOp, ArrayRef<int32_t> values, OperationFolder& folder) {
SmallVector<Value> constants;
constants.reserve(values.size());
for (int32_t value : values)
constants.push_back(getOrCreateHostIndexConstant(anchorOp, value, folder));
return constants;
}
std::optional<uint64_t> getComputeRebatchPhase(SpatCompute compute) {
if (auto phaseAttr = compute->getAttrOfType<IntegerAttr>(kRebatchPhaseAttrName))
return static_cast<uint64_t>(phaseAttr.getInt());
return std::nullopt;
}
struct RebatchKey {
unsigned inputCount = 0;
unsigned resultCount = 0;
unsigned weightCount = 0;
uint64_t phase = 0;
bool hasPhase = false;
uint64_t structureHash = 0;
bool operator==(const RebatchKey& other) const {
return inputCount == other.inputCount && resultCount == other.resultCount && weightCount == other.weightCount
&& phase == other.phase && hasPhase == other.hasPhase && structureHash == other.structureHash;
}
};
struct RebatchKeyInfo {
static inline RebatchKey getEmptyKey() { return {std::numeric_limits<unsigned>::max(), 0, 0, 0, false, 0}; }
static inline RebatchKey getTombstoneKey() { return {std::numeric_limits<unsigned>::max() - 1, 0, 0, 0, false, 0}; }
static unsigned getHashValue(const RebatchKey& key) {
return static_cast<unsigned>(
llvm::hash_combine(key.inputCount, key.resultCount, key.weightCount, key.phase, key.hasPhase, key.structureHash));
}
static bool isEqual(const RebatchKey& lhs, const RebatchKey& rhs) { return lhs == rhs; }
};
uint64_t getTypeHash(Type type) { return reinterpret_cast<uintptr_t>(type.getAsOpaquePointer()); }
uint64_t getValueHash(Value value) { return reinterpret_cast<uintptr_t>(value.getAsOpaquePointer()); }
uint64_t getAttributeHash(Attribute attr) { return reinterpret_cast<uintptr_t>(attr.getAsOpaquePointer()); }
RebatchKey computeRebatchKey(SpatCompute compute) {
llvm::hash_code structureHash =
llvm::hash_combine(compute.getInputs().size(), compute.getResultTypes().size(), compute.getWeights().size());
for (Value weight : compute.getWeights())
structureHash = llvm::hash_combine(structureHash, getValueHash(weight));
if (std::optional<uint64_t> phase = getComputeRebatchPhase(compute))
structureHash = llvm::hash_combine(structureHash, *phase);
Block& body = compute.getBody().front();
structureHash = llvm::hash_combine(structureHash, body.getNumArguments());
for (BlockArgument arg : body.getArguments())
structureHash = llvm::hash_combine(structureHash, getTypeHash(arg.getType()));
for (Operation& op : body) {
structureHash = llvm::hash_combine(
structureHash, op.getName().getStringRef(), op.getNumOperands(), op.getNumResults(), op.getNumRegions());
for (Type type : op.getResultTypes())
structureHash = llvm::hash_combine(structureHash, getTypeHash(type));
for (NamedAttribute attr : op.getAttrs())
structureHash = llvm::hash_combine(structureHash, attr.getName().strref(), getAttributeHash(attr.getValue()));
}
std::optional<uint64_t> phase = getComputeRebatchPhase(compute);
return {static_cast<unsigned>(compute.getInputs().size()),
static_cast<unsigned>(compute.getResultTypes().size()),
static_cast<unsigned>(compute.getWeights().size()),
phase.value_or(0),
phase.has_value(),
static_cast<uint64_t>(structureHash)};
}
bool areEquivalentForRebatch(SpatCompute lhs, SpatCompute rhs) {
if (!lhs || !rhs)
return false;
if (lhs.getInputs().size() != rhs.getInputs().size())
return false;
if (lhs.getResultTypes() != rhs.getResultTypes())
return false;
if (lhs.getWeights().size() != rhs.getWeights().size())
return false;
if (getComputeRebatchPhase(lhs) != getComputeRebatchPhase(rhs))
return false;
if (!llvm::equal(lhs.getWeights(), rhs.getWeights()))
return false;
auto& lhsBlock = lhs.getBody().front();
auto& rhsBlock = rhs.getBody().front();
if (lhsBlock.getNumArguments() != rhsBlock.getNumArguments())
return false;
DenseMap<Value, Value> mappedValues;
for (auto [lhsArg, rhsArg] : llvm::zip(lhsBlock.getArguments(), rhsBlock.getArguments())) {
if (lhsArg.getType() != rhsArg.getType())
return false;
mappedValues[lhsArg] = rhsArg;
}
auto lhsIt = lhsBlock.begin();
auto rhsIt = rhsBlock.begin();
for (; lhsIt != lhsBlock.end() && rhsIt != rhsBlock.end(); ++lhsIt, ++rhsIt) {
Operation& lhsOp = *lhsIt;
Operation& rhsOp = *rhsIt;
if (lhsOp.getName() != rhsOp.getName())
return false;
if (lhsOp.getNumOperands() != rhsOp.getNumOperands())
return false;
if (lhsOp.getNumResults() != rhsOp.getNumResults())
return false;
if (lhsOp.getNumRegions() != 0 || rhsOp.getNumRegions() != 0)
return false;
for (auto [lhsOperand, rhsOperand] : llvm::zip(lhsOp.getOperands(), rhsOp.getOperands())) {
auto mapped = mappedValues.find(lhsOperand);
if (mapped != mappedValues.end()) {
if (mapped->second != rhsOperand)
return false;
continue;
}
if (lhsOperand != rhsOperand)
return false;
}
if (auto lhsReceive = dyn_cast<spatial::SpatChannelReceiveOp>(lhsOp)) {
auto rhsReceive = cast<spatial::SpatChannelReceiveOp>(rhsOp);
if (lhsReceive.getOutput().getType() != rhsReceive.getOutput().getType())
return false;
}
else if (auto lhsSend = dyn_cast<spatial::SpatChannelSendOp>(lhsOp)) {
auto rhsSend = cast<spatial::SpatChannelSendOp>(rhsOp);
if (lhsSend.getInput().getType() != rhsSend.getInput().getType())
return false;
}
else if (lhsOp.getAttrs() != rhsOp.getAttrs()) {
return false;
}
if (lhsOp.getResultTypes() != rhsOp.getResultTypes())
return false;
for (auto [lhsResult, rhsResult] : llvm::zip(lhsOp.getResults(), rhsOp.getResults()))
mappedValues[lhsResult] = rhsResult;
}
return lhsIt == lhsBlock.end() && rhsIt == rhsBlock.end();
}
struct BatchYieldInfo {
Value yieldedValue;
tensor::ParallelInsertSliceOp insertSlice;
};
static bool isHostOnlyBatchResultUser(Operation* user) {
return isa<func::ReturnOp,
spatial::SpatConcatOp,
tensor::ExtractSliceOp,
tensor::CastOp,
tensor::CollapseShapeOp,
tensor::ExpandShapeOp>(user);
}
static FailureOr<DenseMap<BlockArgument, BatchYieldInfo>> collectBatchYieldInfo(SpatComputeBatch batchOp) {
Block& block = batchOp.getBody().front();
auto inParallel = dyn_cast<spatial::SpatInParallelOp>(block.getTerminator());
if (!inParallel)
return failure();
DenseMap<BlockArgument, BatchYieldInfo> batchYieldByOutputArg;
for (Operation& op : inParallel.getRegion().front()) {
auto insertSlice = dyn_cast<tensor::ParallelInsertSliceOp>(&op);
if (!insertSlice)
return failure();
auto outputArg = dyn_cast<BlockArgument>(insertSlice.getDest());
if (!outputArg || outputArg.getOwner() != &block)
return failure();
batchYieldByOutputArg[outputArg] = {insertSlice.getSource(), insertSlice};
}
return batchYieldByOutputArg;
}
static FailureOr<SpatComputeBatch> cloneBatchAsResultless(SpatComputeBatch batchOp, IRRewriter& rewriter) {
auto coreIdsAttr = batchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName);
if (!coreIdsAttr)
return failure();
Block& oldBlock = batchOp.getBody().front();
rewriter.setInsertionPoint(batchOp);
auto newBatch = SpatComputeBatch::create(rewriter,
batchOp.getLoc(),
TypeRange {},
rewriter.getI32IntegerAttr(batchOp.getLaneCount()),
batchOp.getWeights(),
batchOp.getInputs());
newBatch.getProperties().setOperandSegmentSizes(
{static_cast<int>(batchOp.getWeights().size()), static_cast<int>(batchOp.getInputs().size())});
newBatch->setAttr(onnx_mlir::kCoreIdsAttrName, coreIdsAttr);
SmallVector<Type> blockArgTypes;
SmallVector<Location> blockArgLocs;
blockArgTypes.reserve(1 + batchOp.getWeights().size() + batchOp.getInputs().size());
blockArgLocs.reserve(1 + batchOp.getWeights().size() + batchOp.getInputs().size());
blockArgTypes.push_back(batchOp.getLaneArgument().getType());
blockArgLocs.push_back(batchOp.getLaneArgument().getLoc());
for (unsigned weightIndex = 0; weightIndex < batchOp.getWeights().size(); ++weightIndex) {
blockArgTypes.push_back(batchOp.getWeightArgument(weightIndex).getType());
blockArgLocs.push_back(batchOp.getWeightArgument(weightIndex).getLoc());
}
for (unsigned inputIndex = 0; inputIndex < batchOp.getInputs().size(); ++inputIndex) {
blockArgTypes.push_back(batchOp.getInputArgument(inputIndex).getType());
blockArgLocs.push_back(batchOp.getInputArgument(inputIndex).getLoc());
}
Block* newBlock =
rewriter.createBlock(&newBatch.getBody(), newBatch.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
rewriter.setInsertionPointToStart(newBlock);
IRMapping mapper;
mapper.map(batchOp.getLaneArgument(), newBatch.getLaneArgument());
for (unsigned weightIndex = 0; weightIndex < batchOp.getWeights().size(); ++weightIndex)
mapper.map(batchOp.getWeightArgument(weightIndex), newBatch.getWeightArgument(weightIndex));
for (unsigned inputIndex = 0; inputIndex < batchOp.getInputs().size(); ++inputIndex)
mapper.map(batchOp.getInputArgument(inputIndex), newBatch.getInputArgument(inputIndex));
for (Operation& op : oldBlock.without_terminator()) {
Operation* cloned = rewriter.clone(op, mapper);
for (auto [oldResult, newResult] : llvm::zip(op.getResults(), cloned->getResults()))
mapper.map(oldResult, newResult);
}
return newBatch;
}
static LogicalResult materializeBatchResultCommunication(func::FuncOp funcOp, int64_t& nextChannelId) {
IRRewriter rewriter(funcOp.getContext());
OperationFolder constantFolder(funcOp.getContext());
SmallVector<SpatComputeBatch> batches(funcOp.getOps<SpatComputeBatch>());
for (auto batchOp : batches) {
if (batchOp.getNumResults() == 0)
continue;
auto coreIdsAttr = batchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName);
if (!coreIdsAttr)
return batchOp.emitOpError("missing coreIds while materializing batch result communication");
FailureOr<DenseMap<BlockArgument, BatchYieldInfo>> batchYieldInfo = collectBatchYieldInfo(batchOp);
if (failed(batchYieldInfo))
return batchOp.emitOpError("failed to collect per-result yielded values from compute_batch body");
FailureOr<SpatComputeBatch> newBatch = cloneBatchAsResultless(batchOp, rewriter);
if (failed(newBatch))
return batchOp.emitOpError("failed to clone resultful compute_batch as resultless");
Block& oldBlock = batchOp.getBody().front();
Block& newBlock = newBatch->getBody().front();
IRMapping mapper;
mapper.map(batchOp.getLaneArgument(), newBatch->getLaneArgument());
for (unsigned weightIndex = 0; weightIndex < batchOp.getWeights().size(); ++weightIndex)
mapper.map(batchOp.getWeightArgument(weightIndex), newBatch->getWeightArgument(weightIndex));
for (unsigned inputIndex = 0; inputIndex < batchOp.getInputs().size(); ++inputIndex)
mapper.map(batchOp.getInputArgument(inputIndex), newBatch->getInputArgument(inputIndex));
auto oldIt = oldBlock.begin();
auto newIt = newBlock.begin();
for (; oldIt != oldBlock.end() && newIt != newBlock.end(); ++oldIt, ++newIt)
for (auto [oldResult, newResult] : llvm::zip(oldIt->getResults(), newIt->getResults()))
mapper.map(oldResult, newResult);
SmallVector<int32_t> sourceCoreIds(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
rewriter.setInsertionPointToEnd(&newBlock);
for (unsigned resultIndex = 0; resultIndex < batchOp.getNumResults(); ++resultIndex) {
BlockArgument outputArg = batchOp.getOutputArgument(resultIndex);
auto yieldInfoIt = batchYieldInfo->find(outputArg);
if (yieldInfoIt == batchYieldInfo->end())
return batchOp.emitOpError(
"missing yielded value for compute_batch result during communication materialization");
Value mappedYieldedValue = mapper.lookup(yieldInfoIt->second.yieldedValue);
DenseMap<int32_t, SmallVector<OpOperand*>> computeUsesByTargetCore;
SmallVector<OpOperand*> hostUses;
for (OpOperand& use : batchOp.getResult(resultIndex).getUses()) {
if (auto computeOp = dyn_cast<SpatCompute>(use.getOwner())) {
auto coreIdAttr = computeOp->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName);
if (!coreIdAttr)
return batchOp.emitOpError("compute user of compute_batch result is missing coreId");
computeUsesByTargetCore[static_cast<int32_t>(coreIdAttr.getInt())].push_back(&use);
continue;
}
if (isHostOnlyBatchResultUser(use.getOwner())) {
hostUses.push_back(&use);
continue;
}
return batchOp.emitOpError("unsupported user of compute_batch result during communication materialization")
<< ": " << use.getOwner()->getName();
}
auto createReceiveForUses = [&](ArrayRef<OpOperand*> uses, ArrayRef<int32_t> targetCoreIds) -> LogicalResult {
if (uses.empty())
return success();
SmallVector<int64_t> channelIds;
channelIds.reserve(sourceCoreIds.size());
for ([[maybe_unused]] int32_t sourceCoreId : sourceCoreIds)
channelIds.push_back(nextChannelId++);
SmallVector<Value> sendChannelIdValues = createIndexConstants(batchOp, channelIds, constantFolder);
SmallVector<Value> sendSourceCoreIdValues = createIndexConstants(batchOp, sourceCoreIds, constantFolder);
SmallVector<Value> sendTargetCoreIdValues = createIndexConstants(batchOp, targetCoreIds, constantFolder);
spatial::SpatChannelSendBatchOp::create(rewriter,
batchOp.getLoc(),
sendChannelIdValues,
sendSourceCoreIdValues,
sendTargetCoreIdValues,
mappedYieldedValue);
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointAfter(newBatch->getOperation());
SmallVector<Value> receiveChannelIdValues = createIndexConstants(batchOp, channelIds, constantFolder);
SmallVector<Value> receiveSourceCoreIdValues = createIndexConstants(batchOp, sourceCoreIds, constantFolder);
SmallVector<Value> receiveTargetCoreIdValues = createIndexConstants(batchOp, targetCoreIds, constantFolder);
auto received = spatial::SpatChannelReceiveTensorOp::create(rewriter,
batchOp.getLoc(),
batchOp.getResult(resultIndex).getType(),
receiveChannelIdValues,
receiveSourceCoreIdValues,
receiveTargetCoreIdValues);
for (OpOperand* use : uses)
use->set(received.getOutput());
rewriter.setInsertionPointToEnd(&newBlock);
return success();
};
for (auto& [targetCoreId, uses] : computeUsesByTargetCore) {
SmallVector<int32_t> targetCoreIds(static_cast<size_t>(batchOp.getLaneCount()), targetCoreId);
if (failed(createReceiveForUses(uses, targetCoreIds)))
return failure();
}
if (!hostUses.empty()) {
SmallVector<int32_t> hostTargetCoreIds(static_cast<size_t>(batchOp.getLaneCount()), 0);
if (failed(createReceiveForUses(hostUses, hostTargetCoreIds)))
return failure();
}
}
rewriter.setInsertionPointToEnd(&newBlock);
spatial::SpatYieldOp::create(rewriter, batchOp.getLoc(), ValueRange {});
rewriter.eraseOp(batchOp);
}
return success();
}
void rebatchEquivalentComputes(func::FuncOp funcOp) {
IRRewriter rewriter(funcOp.getContext());
OperationFolder constantFolder(funcOp.getContext());
SmallVector<SpatCompute> computes(funcOp.getOps<SpatCompute>());
DenseSet<Operation*> consumed;
DenseMap<Operation*, size_t> computeOrder;
DenseMap<RebatchKey, SmallVector<SpatCompute>, RebatchKeyInfo> candidatesByKey;
for (auto [index, compute] : llvm::enumerate(computes)) {
computeOrder[compute.getOperation()] = index;
if (compute.getInputs().size() <= 1 && compute.getResults().empty())
candidatesByKey[computeRebatchKey(compute)].push_back(compute);
}
for (size_t index = 0; index < computes.size(); ++index) {
auto anchor = computes[index];
if (consumed.contains(anchor))
continue;
if (anchor.getInputs().size() > 1)
continue;
if (!anchor.getResults().empty())
continue;
SmallVector<SpatCompute> group {anchor};
llvm::SmallDenseSet<int32_t, 8> usedCoreIds;
if (auto coreId = getComputeCoreId(anchor))
usedCoreIds.insert(*coreId);
auto bucketIt = candidatesByKey.find(computeRebatchKey(anchor));
if (bucketIt == candidatesByKey.end())
continue;
for (auto candidate : bucketIt->second) {
if (computeOrder.lookup(candidate.getOperation()) <= index)
continue;
if (consumed.contains(candidate))
continue;
if (!areEquivalentForRebatch(anchor, candidate))
continue;
if (auto coreId = getComputeCoreId(candidate))
if (!usedCoreIds.insert(*coreId).second)
continue;
group.push_back(candidate);
}
if (group.size() <= 1)
continue;
auto insertionAnchor = group.front();
if (llvm::all_of(group, [](SpatCompute compute) { return getComputeCoreId(compute).has_value(); })) {
llvm::stable_sort(
group, [](SpatCompute lhs, SpatCompute rhs) { return *getComputeCoreId(lhs) < *getComputeCoreId(rhs); });
}
SmallVector<Value> weights;
weights.reserve(group.size() * anchor.getWeights().size());
SmallVector<Value> inputs;
inputs.reserve(group.size() * anchor.getInputs().size());
SmallVector<int32_t> coreIds;
coreIds.reserve(group.size());
bool haveAllCoreIds = true;
for (auto compute : group) {
llvm::append_range(weights, compute.getWeights());
llvm::append_range(inputs, compute.getInputs());
auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName);
if (!coreIdAttr)
haveAllCoreIds = false;
else if (haveAllCoreIds)
coreIds.push_back(static_cast<int32_t>(coreIdAttr.getInt()));
}
rewriter.setInsertionPoint(insertionAnchor);
auto rebatched = SpatComputeBatch::create(rewriter,
insertionAnchor.getLoc(),
TypeRange {},
rewriter.getI32IntegerAttr(static_cast<int32_t>(group.size())),
ValueRange(weights),
ValueRange(inputs));
rebatched.getProperties().setOperandSegmentSizes(
{static_cast<int>(weights.size()), static_cast<int>(inputs.size())});
if (haveAllCoreIds)
rebatched->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds));
SmallVector<Type> blockArgTypes;
SmallVector<Location> blockArgLocs;
for (BlockArgument arg : anchor.getBody().front().getArguments()) {
blockArgTypes.push_back(arg.getType());
blockArgLocs.push_back(arg.getLoc());
}
auto* newBlock =
rewriter.createBlock(&rebatched.getBody(), rebatched.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
rewriter.setInsertionPointToEnd(newBlock);
IRMapping mapper;
auto& anchorBlock = anchor.getBody().front();
for (auto [oldArg, newArg] : llvm::zip(anchorBlock.getArguments(), newBlock->getArguments()))
mapper.map(oldArg, newArg);
auto opIts = llvm::map_to_vector(group, [](SpatCompute compute) { return compute.getBody().front().begin(); });
for (Operation& anchorOp : anchorBlock) {
if (auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(&anchorOp)) {
struct BatchReceiveEntry {
uint64_t channelId = 0;
uint32_t sourceCoreId = 0;
uint32_t targetCoreId = 0;
};
SmallVector<BatchReceiveEntry> entries;
entries.reserve(group.size());
for (auto [groupIndex, compute] : llvm::enumerate(group)) {
auto groupReceive = cast<spatial::SpatChannelReceiveOp>(&*opIts[groupIndex]);
BatchReceiveEntry entry;
if (!getScalarChannelMetadata(groupReceive, entry.channelId, entry.sourceCoreId, entry.targetCoreId))
return;
entries.push_back(entry);
++opIts[groupIndex];
}
SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds;
SmallVector<int32_t> targetCoreIds;
channelIds.reserve(group.size());
sourceCoreIds.reserve(group.size());
targetCoreIds.reserve(group.size());
for (const BatchReceiveEntry& entry : entries) {
channelIds.push_back(static_cast<int64_t>(entry.channelId));
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
}
SmallVector<Value> channelIdValues = createIndexConstants(receiveOp, channelIds, constantFolder);
SmallVector<Value> sourceCoreIdValues = createIndexConstants(receiveOp, sourceCoreIds, constantFolder);
SmallVector<Value> targetCoreIdValues = createIndexConstants(receiveOp, targetCoreIds, constantFolder);
auto batchReceive = spatial::SpatChannelReceiveBatchOp::create(rewriter,
receiveOp.getLoc(),
receiveOp.getOutput().getType(),
channelIdValues,
sourceCoreIdValues,
targetCoreIdValues);
mapper.map(receiveOp.getOutput(), batchReceive.getOutput());
continue;
}
if (auto sendOp = dyn_cast<spatial::SpatChannelSendOp>(&anchorOp)) {
struct BatchSendEntry {
uint64_t channelId = 0;
uint32_t sourceCoreId = 0;
uint32_t targetCoreId = 0;
};
SmallVector<BatchSendEntry> entries;
entries.reserve(group.size());
for (auto [groupIndex, compute] : llvm::enumerate(group)) {
auto groupSend = cast<spatial::SpatChannelSendOp>(&*opIts[groupIndex]);
BatchSendEntry entry;
if (!getScalarChannelMetadata(groupSend, entry.channelId, entry.sourceCoreId, entry.targetCoreId))
return;
entries.push_back(entry);
++opIts[groupIndex];
}
SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds;
SmallVector<int32_t> targetCoreIds;
channelIds.reserve(group.size());
sourceCoreIds.reserve(group.size());
targetCoreIds.reserve(group.size());
for (const BatchSendEntry& entry : entries) {
channelIds.push_back(static_cast<int64_t>(entry.channelId));
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
}
SmallVector<Value> channelIdValues = createIndexConstants(sendOp, channelIds, constantFolder);
SmallVector<Value> sourceCoreIdValues = createIndexConstants(sendOp, sourceCoreIds, constantFolder);
SmallVector<Value> targetCoreIdValues = createIndexConstants(sendOp, targetCoreIds, constantFolder);
spatial::SpatChannelSendBatchOp::create(rewriter,
sendOp.getLoc(),
channelIdValues,
sourceCoreIdValues,
targetCoreIdValues,
mapper.lookup(sendOp.getInput()));
continue;
}
if (isa<spatial::SpatYieldOp>(anchorOp)) {
for (auto& opIt : opIts)
++opIt;
spatial::SpatYieldOp::create(rewriter, anchorOp.getLoc(), ValueRange {});
continue;
}
Operation* cloned = rewriter.clone(anchorOp, mapper);
for (auto [originalResult, clonedResult] : llvm::zip(anchorOp.getResults(), cloned->getResults()))
mapper.map(originalResult, clonedResult);
for (auto& opIt : opIts)
++opIt;
}
for (auto compute : group) {
compute->removeAttr(kRebatchPhaseAttrName);
consumed.insert(compute);
rewriter.eraseOp(compute);
}
}
for (auto compute : funcOp.getOps<SpatCompute>())
compute->removeAttr(kRebatchPhaseAttrName);
}
void cleanupDeadPackingOps(func::FuncOp funcOp) {
auto eraseUnusedOps = [&](auto tag) {
using OpTy = decltype(tag);
SmallVector<OpTy> ops;
funcOp.walk([&](OpTy op) { ops.push_back(op); });
for (auto op : llvm::reverse(ops))
if (op->use_empty())
op.erase();
};
eraseUnusedOps(tensor::ExtractSliceOp {});
eraseUnusedOps(spatial::SpatConcatOp {});
eraseUnusedOps(spatial::SpatExtractRowsOp {});
}
} // namespace
LogicalResult runPostMergeCompactionPipeline(func::FuncOp funcOp, int64_t& nextChannelId) {
{
ScopedMergePhaseTimer timer("order-bilateral-channel-ops");
orderBilateralChannelOps(funcOp);
}
{
ScopedMergePhaseTimer timer("rebatch-equivalent-computes");
rebatchEquivalentComputes(funcOp);
}
{
ScopedMergePhaseTimer timer("compact-scalar-channel-runs-1");
compactScalarChannelRuns(funcOp, nextChannelId);
}
{
ScopedMergePhaseTimer timer("compact-batch-channel-runs-1");
compactBatchChannelRuns(funcOp);
}
{
ScopedMergePhaseTimer timer("compact-regular-op-runs");
compactRegularOpRuns(funcOp);
}
{
ScopedMergePhaseTimer timer("compact-row-wise-wvmm-runs");
compactRowWiseWvmmRuns(funcOp);
}
{
ScopedMergePhaseTimer timer("compact-scalar-channel-runs-2");
compactScalarChannelRuns(funcOp, nextChannelId);
}
{
ScopedMergePhaseTimer timer("compact-batch-channel-runs-2");
compactBatchChannelRuns(funcOp);
}
{
ScopedMergePhaseTimer timer("cleanup-dead-packing-ops");
cleanupDeadPackingOps(funcOp);
}
{
ScopedMergePhaseTimer timer("materialize-batch-result-communication");
if (failed(materializeBatchResultCommunication(funcOp, nextChannelId)))
return failure();
}
return success();
}
} // namespace onnx_mlir
@@ -0,0 +1,12 @@
#pragma once
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Support/LogicalResult.h"
#include <cstdint>
namespace onnx_mlir {
mlir::LogicalResult runPostMergeCompactionPipeline(mlir::func::FuncOp funcOp, int64_t &nextChannelId);
} // namespace onnx_mlir
@@ -3,16 +3,18 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include <tuple>
#include "RegularOpCompaction.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
@@ -29,7 +31,7 @@ enum class RegularStepKind {
struct RegularStep {
RegularStepKind kind;
int32_t weightIndex = 0;
Value weight;
Value invariantOperand;
Type resultType;
};
@@ -42,6 +44,122 @@ struct RegularChunk {
Value output;
};
struct RegularCompactionResult {
bool changed = false;
Operation* resumeAfter = nullptr;
};
template <typename OpTy>
struct ConsecutiveRun {
SmallVector<OpTy> ops;
Block::iterator end;
};
template <typename OpTy, typename Predicate>
static ConsecutiveRun<OpTy>
collectConsecutiveRun(Block::iterator start, Block::iterator blockEnd, Predicate predicate) {
ConsecutiveRun<OpTy> run;
run.end = start;
while (run.end != blockEnd) {
auto current = dyn_cast<OpTy>(&*run.end);
if (!current || !predicate(current))
break;
run.ops.push_back(current);
++run.end;
}
return run;
}
static uint64_t getEndpointKey(uint32_t sourceCoreId, uint32_t targetCoreId) {
return (static_cast<uint64_t>(sourceCoreId) << 32) | static_cast<uint64_t>(targetCoreId);
}
static FailureOr<int64_t> getConstantI64Value(Value value) {
APInt constantValue;
if (!matchPattern(value, m_ConstantInt(&constantValue)))
return failure();
return constantValue.getSExtValue();
}
static FailureOr<int32_t> getConstantI32Value(Value value) {
APInt constantValue;
if (!matchPattern(value, m_ConstantInt(&constantValue)))
return failure();
return static_cast<int32_t>(constantValue.getSExtValue());
}
static bool getScalarChannelMetadata(spatial::SpatChannelSendOp op,
uint64_t& channelId,
uint32_t& sourceCoreId,
uint32_t& targetCoreId) {
FailureOr<int64_t> constantChannelId = getConstantI64Value(op.getChannelId());
FailureOr<int32_t> constantSourceCoreId = getConstantI32Value(op.getSourceCoreId());
FailureOr<int32_t> constantTargetCoreId = getConstantI32Value(op.getTargetCoreId());
if (failed(constantChannelId) || failed(constantSourceCoreId) || failed(constantTargetCoreId))
return false;
channelId = static_cast<uint64_t>(*constantChannelId);
sourceCoreId = static_cast<uint32_t>(*constantSourceCoreId);
targetCoreId = static_cast<uint32_t>(*constantTargetCoreId);
return true;
}
static bool getScalarChannelMetadata(spatial::SpatChannelReceiveOp op,
uint64_t& channelId,
uint32_t& sourceCoreId,
uint32_t& targetCoreId) {
FailureOr<int64_t> constantChannelId = getConstantI64Value(op.getChannelId());
FailureOr<int32_t> constantSourceCoreId = getConstantI32Value(op.getSourceCoreId());
FailureOr<int32_t> constantTargetCoreId = getConstantI32Value(op.getTargetCoreId());
if (failed(constantChannelId) || failed(constantSourceCoreId) || failed(constantTargetCoreId))
return false;
channelId = static_cast<uint64_t>(*constantChannelId);
sourceCoreId = static_cast<uint32_t>(*constantSourceCoreId);
targetCoreId = static_cast<uint32_t>(*constantTargetCoreId);
return true;
}
static SmallVector<Value> createIndexConstants(Operation* anchorOp, ArrayRef<int64_t> values, OperationFolder& folder) {
SmallVector<Value> constants;
constants.reserve(values.size());
for (int64_t value : values)
constants.push_back(getOrCreateHostIndexConstant(anchorOp, value, folder));
return constants;
}
static SmallVector<Value> createIndexConstants(Operation* anchorOp, ArrayRef<int32_t> values, OperationFolder& folder) {
SmallVector<Value> constants;
constants.reserve(values.size());
for (int32_t value : values)
constants.push_back(getOrCreateHostIndexConstant(anchorOp, value, folder));
return constants;
}
static SmallVector<Operation*> getScalarChannelMetadataDefs(Operation* channelOp, unsigned metadataOperandCount) {
SmallVector<Operation*> defs;
defs.reserve(metadataOperandCount);
for (unsigned operandIndex = 0; operandIndex < metadataOperandCount; ++operandIndex) {
Operation* def = channelOp->getOperand(operandIndex).getDefiningOp();
auto constantOp = dyn_cast_or_null<arith::ConstantOp>(def);
if (!constantOp || def->getBlock() != channelOp->getBlock())
continue;
defs.push_back(def);
}
llvm::sort(defs, [](Operation* lhs, Operation* rhs) { return lhs->isBeforeInBlock(rhs); });
return defs;
}
static void moveScalarChannelBundleBefore(Operation* channelOp, Operation* insertionPoint) {
for (Operation* metadataDef : getScalarChannelMetadataDefs(channelOp, /*metadataOperandCount=*/3))
metadataDef->moveBefore(insertionPoint);
channelOp->moveBefore(insertionPoint);
}
static void moveScalarChannelBundleBefore(Operation* channelOp, Block* block, Block::iterator insertionPoint) {
for (Operation* metadataDef : getScalarChannelMetadataDefs(channelOp, /*metadataOperandCount=*/3))
metadataDef->moveBefore(block, insertionPoint);
channelOp->moveBefore(block, insertionPoint);
}
static spatial::SpatConcatOp getContiguousConcatUse(ValueRange values, unsigned& startOperandIndex) {
if (values.empty() || !values.front().hasOneUse())
return {};
@@ -154,7 +272,7 @@ static Value createPackedTensorForValues(ValueRange values, IRRewriter& rewriter
}
static bool areEquivalentRegularSteps(const RegularStep& lhs, const RegularStep& rhs) {
return lhs.kind == rhs.kind && lhs.weightIndex == rhs.weightIndex && lhs.invariantOperand == rhs.invariantOperand
return lhs.kind == rhs.kind && lhs.weight == rhs.weight && lhs.invariantOperand == rhs.invariantOperand
&& lhs.resultType == rhs.resultType;
}
@@ -168,14 +286,24 @@ static bool areEquivalentRegularChunks(const RegularChunk& lhs, const RegularChu
[](auto pair) { return areEquivalentRegularSteps(std::get<0>(pair), std::get<1>(pair)); });
}
static bool isForwardedChannelPayload(Value value, Block& block) {
Operation* op = value.getDefiningOp();
if (!op || op->getBlock() != &block)
return true;
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op))
return isForwardedChannelPayload(extractSliceOp.getSource(), block);
return isa<spatial::SpatChannelReceiveOp, spatial::SpatChannelReceiveTensorOp>(op);
}
static FailureOr<RegularChunk> analyzeRegularChunk(spatial::SpatVMMOp startOp) {
RegularChunk chunk;
chunk.startOp = startOp.getOperation();
chunk.input = startOp.getInput();
chunk.output = startOp.getOutput();
chunk.ops.push_back(startOp.getOperation());
chunk.steps.push_back(
{RegularStepKind::Wvmm, static_cast<int32_t>(startOp.getWeightIndex()), Value(), startOp.getOutput().getType()});
chunk.steps.push_back({RegularStepKind::Wvmm, startOp.getWeight(), Value(), startOp.getOutput().getType()});
Value currentValue = startOp.getOutput();
while (currentValue.hasOneUse()) {
@@ -188,9 +316,9 @@ static FailureOr<RegularChunk> analyzeRegularChunk(spatial::SpatVMMOp startOp) {
break;
if (vaddOp.getLhs() == currentValue)
chunk.steps.push_back({RegularStepKind::VAddLhs, 0, vaddOp.getRhs(), vaddOp.getOutput().getType()});
chunk.steps.push_back({RegularStepKind::VAddLhs, Value(), vaddOp.getRhs(), vaddOp.getOutput().getType()});
else if (vaddOp.getRhs() == currentValue)
chunk.steps.push_back({RegularStepKind::VAddRhs, 0, vaddOp.getLhs(), vaddOp.getOutput().getType()});
chunk.steps.push_back({RegularStepKind::VAddRhs, Value(), vaddOp.getLhs(), vaddOp.getOutput().getType()});
else
break;
@@ -202,9 +330,11 @@ static FailureOr<RegularChunk> analyzeRegularChunk(spatial::SpatVMMOp startOp) {
return chunk;
}
static void compactRegularChunkRun(IRRewriter& rewriter, ArrayRef<RegularChunk> run) {
static RegularCompactionResult
compactRegularChunkRun(IRRewriter& rewriter, ArrayRef<RegularChunk> run, OperationFolder& constantFolder) {
assert(!run.empty() && "expected a non-empty regular chunk run");
const RegularChunk& anchorChunk = run.front();
RegularCompactionResult result;
SmallVector<Value> inputs;
inputs.reserve(run.size());
@@ -214,16 +344,16 @@ static void compactRegularChunkRun(IRRewriter& rewriter, ArrayRef<RegularChunk>
rewriter.setInsertionPoint(anchorChunk.startOp);
Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, anchorChunk.startOp->getLoc());
if (!packedInput)
return;
return result;
auto inputType = cast<RankedTensorType>(anchorChunk.input.getType());
auto outputType = cast<RankedTensorType>(anchorChunk.output.getType());
auto packedOutputType = getPackedTensorType(outputType, static_cast<int64_t>(run.size()));
auto packedInit = tensor::EmptyOp::create(
rewriter, anchorChunk.startOp->getLoc(), packedOutputType.getShape(), packedOutputType.getElementType());
auto zero = arith::ConstantIndexOp::create(rewriter, anchorChunk.startOp->getLoc(), 0);
auto upper = arith::ConstantIndexOp::create(rewriter, anchorChunk.startOp->getLoc(), run.size());
auto step = arith::ConstantIndexOp::create(rewriter, anchorChunk.startOp->getLoc(), 1);
auto zero = getOrCreateHostIndexConstant(anchorChunk.startOp, 0, constantFolder);
auto upper = getOrCreateHostIndexConstant(anchorChunk.startOp, static_cast<int64_t>(run.size()), constantFolder);
auto step = getOrCreateHostIndexConstant(anchorChunk.startOp, 1, constantFolder);
auto loop =
scf::ForOp::create(rewriter, anchorChunk.startOp->getLoc(), zero, upper, step, ValueRange {packedInit.getResult()});
@@ -236,8 +366,7 @@ static void compactRegularChunkRun(IRRewriter& rewriter, ArrayRef<RegularChunk>
Value inputRowOffset = iv;
if (inputType.getDimSize(0) != 1) {
auto rowsPerValue =
arith::ConstantIndexOp::create(rewriter, anchorChunk.startOp->getLoc(), inputType.getDimSize(0));
auto rowsPerValue = getOrCreateHostIndexConstant(anchorChunk.startOp, inputType.getDimSize(0), constantFolder);
inputRowOffset = arith::MulIOp::create(rewriter, anchorChunk.startOp->getLoc(), iv, rowsPerValue);
}
@@ -266,8 +395,7 @@ static void compactRegularChunkRun(IRRewriter& rewriter, ArrayRef<RegularChunk>
Value mappedOutput = mapping.lookup(anchorChunk.output);
Value outputRowOffset = iv;
if (outputType.getDimSize(0) != 1) {
auto rowsPerValue =
arith::ConstantIndexOp::create(rewriter, anchorChunk.startOp->getLoc(), outputType.getDimSize(0));
auto rowsPerValue = getOrCreateHostIndexConstant(anchorChunk.startOp, outputType.getDimSize(0), constantFolder);
outputRowOffset = arith::MulIOp::create(rewriter, anchorChunk.startOp->getLoc(), iv, rowsPerValue);
}
@@ -317,30 +445,141 @@ static void compactRegularChunkRun(IRRewriter& rewriter, ArrayRef<RegularChunk>
llvm::append_range(opsToErase, chunk.ops);
for (Operation* op : llvm::reverse(opsToErase))
rewriter.eraseOp(op);
result.changed = true;
result.resumeAfter = loop.getOperation()->getNextNode();
return result;
}
} // namespace
void orderBilateralChannelOps(func::FuncOp funcOp) {
for (auto compute : funcOp.getOps<spatial::SpatCompute>()) {
auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName);
if (!coreIdAttr)
continue;
int32_t coreId = static_cast<int32_t>(coreIdAttr.getInt());
Block& block = compute.getBody().front();
SmallVector<std::pair<spatial::SpatChannelReceiveOp, Operation*>> moves;
DenseMap<uint64_t, Operation*> firstForwardedSendByEndpoint;
Operation* firstForwardedSend = nullptr;
for (Operation& op : block) {
if (auto sendOp = dyn_cast<spatial::SpatChannelSendOp>(&op)) {
uint64_t channelId = 0;
uint32_t sourceCoreId = 0;
uint32_t targetCoreId = 0;
if (getScalarChannelMetadata(sendOp, channelId, sourceCoreId, targetCoreId)
&& sourceCoreId == static_cast<uint32_t>(coreId) && isForwardedChannelPayload(sendOp.getInput(), block)) {
if (!firstForwardedSend)
firstForwardedSend = sendOp.getOperation();
uint64_t key = getEndpointKey(sourceCoreId, targetCoreId);
firstForwardedSendByEndpoint.try_emplace(key, sendOp.getOperation());
}
continue;
}
auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(&op);
uint64_t channelId = 0;
uint32_t sourceCoreId = 0;
uint32_t targetCoreId = 0;
if (!receiveOp || !getScalarChannelMetadata(receiveOp, channelId, sourceCoreId, targetCoreId)
|| targetCoreId != static_cast<uint32_t>(coreId) || sourceCoreId >= static_cast<uint32_t>(coreId)) {
continue;
}
uint64_t key = getEndpointKey(static_cast<uint32_t>(coreId), sourceCoreId);
auto firstMatchingSend = firstForwardedSendByEndpoint.find(key);
if (firstMatchingSend != firstForwardedSendByEndpoint.end())
moves.push_back({receiveOp, firstMatchingSend->second});
else if (firstForwardedSend && firstForwardedSend->isBeforeInBlock(receiveOp))
moves.push_back({receiveOp, firstForwardedSend});
}
for (auto [receiveOp, insertionPoint] : moves)
moveScalarChannelBundleBefore(receiveOp, insertionPoint);
for (auto it = block.begin(); it != block.end();) {
auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(&*it);
uint64_t channelId = 0;
uint32_t sourceCoreId = 0;
uint32_t targetCoreId = 0;
if (!receiveOp || !getScalarChannelMetadata(receiveOp, channelId, sourceCoreId, targetCoreId)
|| sourceCoreId >= static_cast<uint32_t>(coreId)) {
++it;
continue;
}
Type outputType = receiveOp.getOutput().getType();
auto run = collectConsecutiveRun<spatial::SpatChannelReceiveOp>(
it, block.end(), [&](spatial::SpatChannelReceiveOp current) {
uint64_t currentChannelId = 0;
uint32_t currentSourceCoreId = 0;
uint32_t currentTargetCoreId = 0;
return current.getOutput().getType() == outputType
&& getScalarChannelMetadata(current, currentChannelId, currentSourceCoreId, currentTargetCoreId)
&& currentSourceCoreId < static_cast<uint32_t>(coreId);
});
if (run.ops.size() > 1) {
SmallVector<spatial::SpatChannelReceiveOp> sorted(run.ops);
llvm::stable_sort(sorted, [](spatial::SpatChannelReceiveOp lhs, spatial::SpatChannelReceiveOp rhs) {
uint64_t lhsChannelId = 0;
uint32_t lhsSourceCoreId = 0;
uint32_t lhsTargetCoreId = 0;
uint64_t rhsChannelId = 0;
uint32_t rhsSourceCoreId = 0;
uint32_t rhsTargetCoreId = 0;
bool lhsHasMetadata = getScalarChannelMetadata(lhs, lhsChannelId, lhsSourceCoreId, lhsTargetCoreId);
bool rhsHasMetadata = getScalarChannelMetadata(rhs, rhsChannelId, rhsSourceCoreId, rhsTargetCoreId);
if (!lhsHasMetadata || !rhsHasMetadata)
return false;
return lhsSourceCoreId > rhsSourceCoreId;
});
Block::iterator insertIt = run.end;
for (auto op : sorted)
moveScalarChannelBundleBefore(op, &block, insertIt);
}
it = run.end;
}
}
}
void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
IRRewriter rewriter(funcOp.getContext());
OperationFolder constantFolder(funcOp.getContext());
for (auto compute : funcOp.getOps<spatial::SpatCompute>()) {
Block& block = compute.getBody().front();
for (auto it = block.begin(); it != block.end();) {
auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(&*it);
if (receiveOp) {
SmallVector<spatial::SpatChannelReceiveOp> run;
Type outputType = receiveOp.getOutput().getType();
auto runIt = it;
while (runIt != block.end()) {
auto current = dyn_cast<spatial::SpatChannelReceiveOp>(&*runIt);
if (!current || current.getOutput().getType() != outputType)
auto run = collectConsecutiveRun<spatial::SpatChannelReceiveOp>(
it, block.end(), [&](spatial::SpatChannelReceiveOp current) {
return current.getOutput().getType() == outputType;
});
bool hasRepeatedEndpoint = false;
DenseSet<uint64_t> seenEndpoints;
for (auto op : run.ops) {
uint64_t channelId = 0;
uint32_t sourceCoreId = 0;
uint32_t targetCoreId = 0;
if (!getScalarChannelMetadata(op, channelId, sourceCoreId, targetCoreId)) {
hasRepeatedEndpoint = true;
break;
run.push_back(current);
++runIt;
}
uint64_t endpointKey = getEndpointKey(sourceCoreId, targetCoreId);
if (!seenEndpoints.insert(endpointKey).second) {
hasRepeatedEndpoint = true;
break;
}
}
if (run.size() > 1) {
if (run.ops.size() > 1 && !hasRepeatedEndpoint) {
struct ReceiveEntry {
spatial::SpatChannelReceiveOp op;
size_t originalIndex = 0;
@@ -349,13 +588,21 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
uint64_t channelId = 0;
};
SmallVector<ReceiveEntry> sortedEntries;
sortedEntries.reserve(run.size());
for (auto [originalIndex, op] : llvm::enumerate(run))
sortedEntries.push_back({op, originalIndex, op.getSourceCoreId(), op.getTargetCoreId(), op.getChannelId()});
llvm::stable_sort(sortedEntries, [](const ReceiveEntry& lhs, const ReceiveEntry& rhs) {
return std::tuple(lhs.sourceCoreId, lhs.targetCoreId, lhs.channelId)
< std::tuple(rhs.sourceCoreId, rhs.targetCoreId, rhs.channelId);
});
sortedEntries.reserve(run.ops.size());
for (auto [originalIndex, op] : llvm::enumerate(run.ops)) {
uint64_t channelId = 0;
uint32_t sourceCoreId = 0;
uint32_t targetCoreId = 0;
if (!getScalarChannelMetadata(op, channelId, sourceCoreId, targetCoreId)) {
sortedEntries.clear();
break;
}
sortedEntries.push_back({op, originalIndex, sourceCoreId, targetCoreId, channelId});
}
if (sortedEntries.empty()) {
++it;
continue;
}
SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds;
@@ -364,13 +611,12 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
sourceCoreIds.reserve(sortedEntries.size());
targetCoreIds.reserve(sortedEntries.size());
for (ReceiveEntry& entry : sortedEntries) {
(void) entry;
channelIds.push_back(nextChannelId++);
channelIds.push_back(static_cast<int64_t>(entry.channelId));
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
}
auto rowType = cast<RankedTensorType>(run.front().getOutput().getType());
auto rowType = cast<RankedTensorType>(run.ops.front().getOutput().getType());
auto fallbackPackedType = getPackedTensorType(rowType, static_cast<int64_t>(sortedEntries.size()));
SmallVector<Value> sortedOutputs;
sortedOutputs.reserve(sortedEntries.size());
@@ -383,14 +629,12 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
concatOp ? getPackedConcatSliceType(concatOp, concatStartIndex, static_cast<unsigned>(sortedOutputs.size()))
: RankedTensorType {};
auto packedType = concatPackedType ? concatPackedType : fallbackPackedType;
rewriter.setInsertionPoint(run.front());
auto compactReceive =
spatial::SpatChannelReceiveTensorOp::create(rewriter,
run.front().getLoc(),
packedType,
rewriter.getDenseI64ArrayAttr(channelIds),
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
rewriter.getDenseI32ArrayAttr(targetCoreIds));
rewriter.setInsertionPoint(run.ops.front());
SmallVector<Value> channelIdValues = createIndexConstants(run.ops.front(), channelIds, constantFolder);
SmallVector<Value> sourceCoreIdValues = createIndexConstants(run.ops.front(), sourceCoreIds, constantFolder);
SmallVector<Value> targetCoreIdValues = createIndexConstants(run.ops.front(), targetCoreIds, constantFolder);
auto compactReceive = spatial::SpatChannelReceiveTensorOp::create(
rewriter, run.ops.front().getLoc(), packedType, channelIdValues, sourceCoreIdValues, targetCoreIdValues);
if (concatOp && concatPackedType) {
replaceConcatRunWithPackedValue(concatOp,
concatStartIndex,
@@ -403,7 +647,7 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
entry.op.getOutput().replaceAllUsesWith(extractPackedChunk(
compactReceive.getOutput(), rowType, static_cast<unsigned>(sortedIndex), rewriter, entry.op.getLoc()));
}
for (auto op : run)
for (auto op : run.ops)
rewriter.eraseOp(op);
it = compactReceive->getIterator();
@@ -414,18 +658,13 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
auto sendOp = dyn_cast<spatial::SpatChannelSendOp>(&*it);
if (sendOp) {
SmallVector<spatial::SpatChannelSendOp> run;
Type inputType = sendOp.getInput().getType();
auto runIt = it;
while (runIt != block.end()) {
auto current = dyn_cast<spatial::SpatChannelSendOp>(&*runIt);
if (!current || current.getInput().getType() != inputType)
break;
run.push_back(current);
++runIt;
}
auto run =
collectConsecutiveRun<spatial::SpatChannelSendOp>(it, block.end(), [&](spatial::SpatChannelSendOp current) {
return current.getInput().getType() == inputType;
});
if (run.size() > 1) {
if (run.ops.size() > 1) {
struct SendEntry {
spatial::SpatChannelSendOp op;
uint32_t sourceCoreId = 0;
@@ -433,13 +672,21 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
uint64_t channelId = 0;
};
SmallVector<SendEntry> sortedEntries;
sortedEntries.reserve(run.size());
for (auto op : run)
sortedEntries.push_back({op, op.getSourceCoreId(), op.getTargetCoreId(), op.getChannelId()});
llvm::stable_sort(sortedEntries, [](const SendEntry& lhs, const SendEntry& rhs) {
return std::tuple(lhs.sourceCoreId, lhs.targetCoreId, lhs.channelId)
< std::tuple(rhs.sourceCoreId, rhs.targetCoreId, rhs.channelId);
});
sortedEntries.reserve(run.ops.size());
for (auto op : run.ops) {
uint64_t channelId = 0;
uint32_t sourceCoreId = 0;
uint32_t targetCoreId = 0;
if (!getScalarChannelMetadata(op, channelId, sourceCoreId, targetCoreId)) {
sortedEntries.clear();
break;
}
sortedEntries.push_back({op, sourceCoreId, targetCoreId, channelId});
}
if (sortedEntries.empty()) {
++it;
continue;
}
SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds;
@@ -450,26 +697,24 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
targetCoreIds.reserve(sortedEntries.size());
inputs.reserve(sortedEntries.size());
for (SendEntry& entry : sortedEntries) {
(void) entry;
channelIds.push_back(nextChannelId++);
channelIds.push_back(static_cast<int64_t>(entry.channelId));
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
inputs.push_back(entry.op.getInput());
}
rewriter.setInsertionPoint(run.front());
Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, run.front().getLoc());
rewriter.setInsertionPoint(run.ops.front());
Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, run.ops.front().getLoc());
if (packedInput) {
spatial::SpatChannelSendTensorOp::create(rewriter,
run.front().getLoc(),
rewriter.getDenseI64ArrayAttr(channelIds),
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
rewriter.getDenseI32ArrayAttr(targetCoreIds),
packedInput);
for (auto op : run)
SmallVector<Value> channelIdValues = createIndexConstants(run.ops.front(), channelIds, constantFolder);
SmallVector<Value> sourceCoreIdValues = createIndexConstants(run.ops.front(), sourceCoreIds, constantFolder);
SmallVector<Value> targetCoreIdValues = createIndexConstants(run.ops.front(), targetCoreIds, constantFolder);
spatial::SpatChannelSendTensorOp::create(
rewriter, run.ops.front().getLoc(), channelIdValues, sourceCoreIdValues, targetCoreIdValues, packedInput);
for (auto op : run.ops)
rewriter.eraseOp(op);
it = runIt;
it = run.end;
continue;
}
}
@@ -488,32 +733,27 @@ void compactBatchChannelRuns(func::FuncOp funcOp) {
for (auto it = block.begin(); it != block.end();) {
auto receiveOp = dyn_cast<spatial::SpatChannelReceiveBatchOp>(&*it);
if (receiveOp) {
SmallVector<spatial::SpatChannelReceiveBatchOp> run;
Type outputType = receiveOp.getOutput().getType();
auto runIt = it;
while (runIt != block.end()) {
auto current = dyn_cast<spatial::SpatChannelReceiveBatchOp>(&*runIt);
if (!current || current.getOutput().getType() != outputType)
break;
run.push_back(current);
++runIt;
}
auto run = collectConsecutiveRun<spatial::SpatChannelReceiveBatchOp>(
it, block.end(), [&](spatial::SpatChannelReceiveBatchOp current) {
return current.getOutput().getType() == outputType;
});
if (run.size() > 1) {
SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds;
SmallVector<int32_t> targetCoreIds;
for (auto op : run) {
if (run.ops.size() > 1) {
SmallVector<Value> channelIds;
SmallVector<Value> sourceCoreIds;
SmallVector<Value> targetCoreIds;
for (auto op : run.ops) {
llvm::append_range(channelIds, op.getChannelIds());
llvm::append_range(sourceCoreIds, op.getSourceCoreIds());
llvm::append_range(targetCoreIds, op.getTargetCoreIds());
}
auto rowType = cast<RankedTensorType>(run.front().getOutput().getType());
auto fallbackPackedType = getPackedTensorType(rowType, static_cast<int64_t>(run.size()));
auto rowType = cast<RankedTensorType>(run.ops.front().getOutput().getType());
auto fallbackPackedType = getPackedTensorType(rowType, static_cast<int64_t>(run.ops.size()));
SmallVector<Value> outputs;
outputs.reserve(run.size());
for (auto op : run)
outputs.reserve(run.ops.size());
for (auto op : run.ops)
outputs.push_back(op.getOutput());
unsigned concatStartIndex = 0;
@@ -522,24 +762,19 @@ void compactBatchChannelRuns(func::FuncOp funcOp) {
concatOp ? getPackedConcatSliceType(concatOp, concatStartIndex, static_cast<unsigned>(outputs.size()))
: RankedTensorType {};
auto packedType = concatPackedType ? concatPackedType : fallbackPackedType;
rewriter.setInsertionPoint(run.front());
auto compactReceive =
spatial::SpatChannelReceiveTensorBatchOp::create(rewriter,
run.front().getLoc(),
packedType,
rewriter.getDenseI64ArrayAttr(channelIds),
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
rewriter.getDenseI32ArrayAttr(targetCoreIds));
rewriter.setInsertionPoint(run.ops.front());
auto compactReceive = spatial::SpatChannelReceiveTensorBatchOp::create(
rewriter, run.ops.front().getLoc(), packedType, channelIds, sourceCoreIds, targetCoreIds);
if (concatOp && concatPackedType) {
replaceConcatRunWithPackedValue(
concatOp, concatStartIndex, static_cast<unsigned>(outputs.size()), compactReceive.getOutput(), rewriter);
}
else {
for (auto [index, op] : llvm::enumerate(run))
for (auto [index, op] : llvm::enumerate(run.ops))
op.getOutput().replaceAllUsesWith(extractPackedChunk(
compactReceive.getOutput(), rowType, static_cast<unsigned>(index), rewriter, op.getLoc()));
}
for (auto op : run)
for (auto op : run.ops)
rewriter.eraseOp(op);
it = compactReceive->getIterator();
@@ -550,43 +785,34 @@ void compactBatchChannelRuns(func::FuncOp funcOp) {
auto sendOp = dyn_cast<spatial::SpatChannelSendBatchOp>(&*it);
if (sendOp) {
SmallVector<spatial::SpatChannelSendBatchOp> run;
Type inputType = sendOp.getInput().getType();
auto runIt = it;
while (runIt != block.end()) {
auto current = dyn_cast<spatial::SpatChannelSendBatchOp>(&*runIt);
if (!current || current.getInput().getType() != inputType)
break;
run.push_back(current);
++runIt;
}
auto run = collectConsecutiveRun<spatial::SpatChannelSendBatchOp>(
it, block.end(), [&](spatial::SpatChannelSendBatchOp current) {
return current.getInput().getType() == inputType;
});
if (run.size() > 1) {
SmallVector<int64_t> channelIds;
SmallVector<int32_t> sourceCoreIds;
SmallVector<int32_t> targetCoreIds;
if (run.ops.size() > 1) {
SmallVector<Value> channelIds;
SmallVector<Value> sourceCoreIds;
SmallVector<Value> targetCoreIds;
SmallVector<Value> inputs;
inputs.reserve(run.size());
for (auto op : run) {
inputs.reserve(run.ops.size());
for (auto op : run.ops) {
llvm::append_range(channelIds, op.getChannelIds());
llvm::append_range(sourceCoreIds, op.getSourceCoreIds());
llvm::append_range(targetCoreIds, op.getTargetCoreIds());
inputs.push_back(op.getInput());
}
rewriter.setInsertionPoint(run.front());
Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, run.front().getLoc());
rewriter.setInsertionPoint(run.ops.front());
Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, run.ops.front().getLoc());
if (packedInput) {
spatial::SpatChannelSendTensorBatchOp::create(rewriter,
run.front().getLoc(),
rewriter.getDenseI64ArrayAttr(channelIds),
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
rewriter.getDenseI32ArrayAttr(targetCoreIds),
packedInput);
for (auto op : run)
spatial::SpatChannelSendTensorBatchOp::create(
rewriter, run.ops.front().getLoc(), channelIds, sourceCoreIds, targetCoreIds, packedInput);
for (auto op : run.ops)
rewriter.eraseOp(op);
it = runIt;
it = run.end;
continue;
}
}
@@ -599,6 +825,7 @@ void compactBatchChannelRuns(func::FuncOp funcOp) {
void compactRegularOpRuns(func::FuncOp funcOp) {
IRRewriter rewriter(funcOp.getContext());
OperationFolder constantFolder(funcOp.getContext());
auto compactInBlock = [&](Block& block) {
for (auto it = block.begin(); it != block.end();) {
@@ -614,8 +841,9 @@ void compactRegularOpRuns(func::FuncOp funcOp) {
continue;
}
auto anchorEndIt = std::next(it, static_cast<std::ptrdiff_t>(anchorChunk->ops.size()));
SmallVector<RegularChunk> run {*anchorChunk};
auto runIt = std::next(it, static_cast<std::ptrdiff_t>(anchorChunk->ops.size()));
auto runIt = anchorEndIt;
while (runIt != block.end()) {
auto candidateStart = dyn_cast<spatial::SpatVMMOp>(&*runIt);
if (!candidateStart)
@@ -630,12 +858,26 @@ void compactRegularOpRuns(func::FuncOp funcOp) {
}
if (run.size() <= 1) {
++it;
it = anchorEndIt;
continue;
}
compactRegularChunkRun(rewriter, run);
it = runIt;
size_t originalOpCount = 0;
for (const RegularChunk& chunk : run)
originalOpCount += chunk.ops.size();
RegularCompactionResult result = compactRegularChunkRun(rewriter, run, constantFolder);
if (result.changed) {
assert(originalOpCount > anchorChunk->ops.size() && "successful regular compaction must consume the run");
if (!result.resumeAfter) {
it = block.end();
continue;
}
it = result.resumeAfter->getIterator();
continue;
}
it = anchorEndIt;
}
};
@@ -647,6 +889,7 @@ void compactRegularOpRuns(func::FuncOp funcOp) {
void compactRowWiseWvmmRuns(func::FuncOp funcOp) {
IRRewriter rewriter(funcOp.getContext());
OperationFolder constantFolder(funcOp.getContext());
for (auto compute : funcOp.getOps<spatial::SpatCompute>()) {
Block& block = compute.getBody().front();
@@ -666,37 +909,32 @@ void compactRowWiseWvmmRuns(func::FuncOp funcOp) {
continue;
}
SmallVector<spatial::SpatVMMOp> run;
auto runIt = it;
int64_t expectedRow = static_cast<int64_t>(rowResult.getResultNumber());
while (runIt != block.end()) {
auto current = dyn_cast<spatial::SpatVMMOp>(&*runIt);
if (!current || current.getWeightIndex() != wvmmOp.getWeightIndex()
auto run = collectConsecutiveRun<spatial::SpatVMMOp>(it, block.end(), [&](spatial::SpatVMMOp current) {
if (current.getWeight() != wvmmOp.getWeight()
|| current.getInput().getDefiningOp<spatial::SpatExtractRowsOp>() != extractRowsOp
|| current.getInput().getType() != wvmmOp.getInput().getType()
|| current.getOutput().getType() != wvmmOp.getOutput().getType()) {
break;
}
|| current.getOutput().getType() != wvmmOp.getOutput().getType())
return false;
auto currentRow = dyn_cast<OpResult>(current.getInput());
if (!currentRow || currentRow.getResultNumber() != static_cast<unsigned>(expectedRow))
break;
return false;
run.push_back(current);
++expectedRow;
++runIt;
}
return true;
});
if (run.size() <= 1) {
if (run.ops.size() <= 1) {
++it;
continue;
}
if (!run.front().getOutput().hasOneUse()) {
if (!run.ops.front().getOutput().hasOneUse()) {
++it;
continue;
}
auto concatUse = run.front().getOutput().getUses().begin();
auto concatUse = run.ops.front().getOutput().getUses().begin();
auto concatOp = dyn_cast<spatial::SpatConcatOp>(concatUse->getOwner());
if (!concatOp) {
++it;
@@ -705,7 +943,7 @@ void compactRowWiseWvmmRuns(func::FuncOp funcOp) {
unsigned concatStartIndex = concatUse->getOperandNumber();
bool validConcatRun = true;
for (auto [index, op] : llvm::enumerate(run)) {
for (auto [index, op] : llvm::enumerate(run.ops)) {
if (!op.getOutput().hasOneUse()) {
validConcatRun = false;
break;
@@ -736,17 +974,17 @@ void compactRowWiseWvmmRuns(func::FuncOp funcOp) {
}
int64_t firstRow = static_cast<int64_t>(rowResult.getResultNumber());
int64_t runLength = static_cast<int64_t>(run.size());
int64_t runLength = static_cast<int64_t>(run.ops.size());
auto packedType = RankedTensorType::get({runLength, outputCols}, outputType.getElementType());
rewriter.setInsertionPoint(run.front());
auto zero = arith::ConstantIndexOp::create(rewriter, run.front().getLoc(), 0);
auto upper = arith::ConstantIndexOp::create(rewriter, run.front().getLoc(), runLength);
auto step = arith::ConstantIndexOp::create(rewriter, run.front().getLoc(), 1);
rewriter.setInsertionPoint(run.ops.front());
auto zero = getOrCreateHostIndexConstant(run.ops.front(), 0, constantFolder);
auto upper = getOrCreateHostIndexConstant(run.ops.front(), runLength, constantFolder);
auto step = getOrCreateHostIndexConstant(run.ops.front(), 1, constantFolder);
auto packedInit =
tensor::EmptyOp::create(rewriter, run.front().getLoc(), packedType.getShape(), packedType.getElementType());
tensor::EmptyOp::create(rewriter, run.ops.front().getLoc(), packedType.getShape(), packedType.getElementType());
auto loop =
scf::ForOp::create(rewriter, run.front().getLoc(), zero, upper, step, ValueRange {packedInit.getResult()});
scf::ForOp::create(rewriter, run.ops.front().getLoc(), zero, upper, step, ValueRange {packedInit.getResult()});
{
OpBuilder::InsertionGuard guard(rewriter);
@@ -757,41 +995,41 @@ void compactRowWiseWvmmRuns(func::FuncOp funcOp) {
Value sourceRow = iv;
if (firstRow != 0) {
auto firstRowValue = arith::ConstantIndexOp::create(rewriter, run.front().getLoc(), firstRow);
sourceRow = arith::AddIOp::create(rewriter, run.front().getLoc(), iv, firstRowValue);
auto firstRowValue = getOrCreateHostIndexConstant(run.ops.front(), firstRow, constantFolder);
sourceRow = arith::AddIOp::create(rewriter, run.ops.front().getLoc(), iv, firstRowValue);
}
SmallVector<OpFoldResult> extractOffsets = {sourceRow, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> extractSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(inputCols)};
SmallVector<OpFoldResult> extractStrides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
auto extractedRow = tensor::ExtractSliceOp::create(rewriter,
run.front().getLoc(),
run.ops.front().getLoc(),
inputType,
extractRowsOp.getInput(),
extractOffsets,
extractSizes,
extractStrides);
auto loopWvmm = spatial::SpatVMMOp::create(
rewriter, run.front().getLoc(), outputType, wvmmOp.getWeightIndex(), extractedRow.getResult());
rewriter, run.ops.front().getLoc(), outputType, wvmmOp.getWeight(), extractedRow.getResult());
SmallVector<OpFoldResult> insertOffsets = {iv, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> insertSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(outputCols)};
SmallVector<OpFoldResult> insertStrides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
auto inserted = tensor::InsertSliceOp::create(
rewriter, run.front().getLoc(), loopWvmm.getResult(), acc, insertOffsets, insertSizes, insertStrides);
scf::YieldOp::create(rewriter, run.front().getLoc(), inserted.getResult());
rewriter, run.ops.front().getLoc(), loopWvmm.getResult(), acc, insertOffsets, insertSizes, insertStrides);
scf::YieldOp::create(rewriter, run.ops.front().getLoc(), inserted.getResult());
}
SmallVector<Value> newConcatInputs;
newConcatInputs.reserve(concatOp.getInputs().size() - run.size() + 1);
newConcatInputs.reserve(concatOp.getInputs().size() - run.ops.size() + 1);
for (auto [operandIndex, operand] : llvm::enumerate(concatOp.getInputs())) {
if (operandIndex == concatStartIndex)
newConcatInputs.push_back(loop.getResult(0));
if (operandIndex < concatStartIndex || operandIndex >= concatStartIndex + run.size())
if (operandIndex < concatStartIndex || operandIndex >= concatStartIndex + run.ops.size())
newConcatInputs.push_back(operand);
}
rewriter.modifyOpInPlace(concatOp, [&] { concatOp->setOperands(newConcatInputs); });
for (auto op : run)
for (auto op : run.ops)
rewriter.eraseOp(op);
it = loop->getIterator();
@@ -6,6 +6,7 @@
namespace onnx_mlir {
void orderBilateralChannelOps(mlir::func::FuncOp funcOp);
void compactScalarChannelRuns(mlir::func::FuncOp funcOp, int64_t& nextChannelId);
void compactBatchChannelRuns(mlir::func::FuncOp funcOp);
void compactRegularOpRuns(mlir::func::FuncOp funcOp);
@@ -0,0 +1,201 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Casting.h"
#include <algorithm>
#include <limits>
#include <optional>
#include <queue>
#include <utility>
#include <vector>
#include "ComputeGraph.hpp"
#include "src/Support/TypeUtilities.hpp"
namespace onnx_mlir {
namespace spatial {
using namespace mlir;
namespace {
Weight getComputeBodyWeight(Region& body) {
constexpr Weight kOperationWeight = 100;
Weight numOperations = 0;
for (auto& block : body)
for ([[maybe_unused]] auto& op : block)
numOperations = checkedAdd(numOperations, static_cast<Weight>(1));
return checkedMultiply(numOperations, kOperationWeight);
}
CrossbarUsage getComputeBodyCrossbarUsage(Region& body) {
CrossbarUsage crossbarUsage = 0;
for (auto& block : body)
for (auto& op : block)
if (isa<SpatVMMOp>(op))
crossbarUsage = checkedAdd(crossbarUsage, static_cast<CrossbarUsage>(1));
return crossbarUsage;
}
bool isUsedAsWeightOnly(Operation* producerOp) {
if (producerOp->getNumResults() == 0)
return false;
for (Value result : producerOp->getResults()) {
if (result.use_empty())
return false;
for (Operation* user : result.getUsers()) {
if (auto compute = dyn_cast<SpatCompute>(user)) {
if (!llvm::is_contained(compute.getWeights(), result))
return false;
continue;
}
if (auto batch = dyn_cast<SpatComputeBatch>(user)) {
if (!llvm::is_contained(batch.getWeights(), result))
return false;
continue;
}
return false;
}
}
return true;
}
std::vector<ComputeGraphEdge> aggregateEdges(llvm::ArrayRef<ComputeGraphEdge> edges) {
llvm::DenseMap<std::pair<size_t, size_t>, Weight> edgeWeights;
for (const ComputeGraphEdge& edge : edges) {
if (edge.source == edge.target)
continue;
auto inserted = edgeWeights.try_emplace({edge.source, edge.target}, edge.transferCost);
if (!inserted.second)
inserted.first->second = std::max(inserted.first->second, edge.transferCost);
}
std::vector<ComputeGraphEdge> aggregatedEdges;
aggregatedEdges.reserve(edgeWeights.size());
for (const auto& [key, weight] : edgeWeights)
aggregatedEdges.push_back({key.first, key.second, weight});
llvm::sort(aggregatedEdges, [](const ComputeGraphEdge& lhs, const ComputeGraphEdge& rhs) {
if (lhs.source != rhs.source)
return lhs.source < rhs.source;
return lhs.target < rhs.target;
});
return aggregatedEdges;
}
} // namespace
Weight getComputeInstanceWeight(const ComputeInstance& instance) {
if (auto spatCompute = dyn_cast<SpatCompute>(instance.op))
return getSpatComputeWeight(spatCompute);
auto batch = cast<SpatComputeBatch>(instance.op);
return checkedMultiply(getComputeBodyWeight(batch.getBody()), static_cast<Weight>(instance.laneCount));
}
CrossbarUsage getComputeInstanceCrossbarUsage(const ComputeInstance& instance) {
if (auto spatCompute = dyn_cast<SpatCompute>(instance.op))
return getSpatComputeCrossbarUsage(spatCompute);
auto batch = cast<SpatComputeBatch>(instance.op);
return checkedMultiply(getComputeBodyCrossbarUsage(batch.getBody()), static_cast<CrossbarUsage>(instance.laneCount));
}
ComputeGraph buildComputeGraph(Operation* entryOp) {
ComputeGraph graph;
for (Region& region : entryOp->getRegions()) {
for (Block& block : region) {
for (Operation& op : block) {
if (auto spatCompute = dyn_cast<SpatCompute>(&op)) {
if (isUsedAsWeightOnly(spatCompute.getOperation()))
continue;
ComputeInstance instance {spatCompute.getOperation(), 0, 1};
size_t index = graph.nodes.size();
graph.nodes.push_back(
{instance, getComputeInstanceWeight(instance), getComputeInstanceCrossbarUsage(instance), index});
graph.instanceToIndex[instance] = index;
continue;
}
if (auto batch = dyn_cast<SpatComputeBatch>(&op)) {
if (isUsedAsWeightOnly(batch.getOperation()))
continue;
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount());
for (size_t chunkIndex = 0; chunkIndex < chunkCount; ++chunkIndex) {
ComputeInstance instance = getBatchChunkForIndex(batch, chunkIndex);
size_t index = graph.nodes.size();
graph.nodes.push_back(
{instance, getComputeInstanceWeight(instance), getComputeInstanceCrossbarUsage(instance), index});
graph.instanceToIndex[instance] = index;
}
}
}
}
}
llvm::SmallVector<ComputeGraphEdge, 16> rawEdges;
for (const auto& [targetIndex, node] : llvm::enumerate(graph.nodes)) {
for (Value input : getComputeInstanceInputs(node.instance)) {
if (auto producerBatch = dyn_cast_or_null<SpatComputeBatch>(input.getDefiningOp());
producerBatch && producerBatch.getNumResults() != 0 && !isa<SpatComputeBatch>(node.instance.op)) {
for (uint32_t lane = 0; lane < static_cast<uint32_t>(producerBatch.getLaneCount()); ++lane) {
auto producerIt = graph.instanceToIndex.find(getBatchChunkForLane(producerBatch, lane));
if (producerIt == graph.instanceToIndex.end())
continue;
rawEdges.push_back(
{producerIt->second, targetIndex, static_cast<Weight>(getSizeInBytes(cast<ShapedType>(input.getType())))});
}
continue;
}
auto producerInstance = getComputeProducerInstance(input, &node.instance);
if (!producerInstance)
continue;
auto producerIt = graph.instanceToIndex.find(*producerInstance);
if (producerIt == graph.instanceToIndex.end())
continue;
rawEdges.push_back(
{producerIt->second, targetIndex, static_cast<Weight>(getSizeInBytes(cast<ShapedType>(input.getType())))});
}
}
std::vector<ComputeGraphEdge> aggregatedEdges = aggregateEdges(rawEdges);
graph.edges.append(aggregatedEdges.begin(), aggregatedEdges.end());
graph.successors.assign(graph.nodes.size(), {});
graph.predecessors.assign(graph.nodes.size(), {});
for (const ComputeGraphEdge& edge : graph.edges) {
graph.successors[edge.source].push_back({edge.target, edge.transferCost});
graph.predecessors[edge.target].push_back({edge.source, edge.transferCost});
}
return graph;
}
bool verifyAcyclic(const ComputeGraph& graph) {
std::vector<size_t> remainingParents(graph.nodes.size(), 0);
std::queue<size_t> readyNodes;
for (size_t node = 0; node < graph.nodes.size(); ++node) {
remainingParents[node] = graph.predecessors[node].size();
if (remainingParents[node] == 0)
readyNodes.push(node);
}
size_t visited = 0;
while (!readyNodes.empty()) {
size_t node = readyNodes.front();
readyNodes.pop();
++visited;
for (const auto& [child, weight] : graph.successors[node]) {
(void) weight;
assert(remainingParents[child] > 0 && "remaining parent count underflow");
if (--remainingParents[child] == 0)
readyNodes.push(child);
}
}
return visited == graph.nodes.size();
}
} // namespace spatial
} // namespace onnx_mlir
@@ -0,0 +1,49 @@
#pragma once
#include "mlir/IR/Operation.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallVector.h"
#include <cstddef>
#include <optional>
#include <utility>
#include <vector>
#include "../DCPGraph/Utils.hpp"
#include "ComputeInstance.hpp"
#include "ComputeInstanceUtils.hpp"
namespace onnx_mlir {
namespace spatial {
struct ComputeGraphNode {
ComputeInstance instance;
Weight weight = 0;
CrossbarUsage crossbarUsage = 0;
size_t originalOrder = 0;
};
struct ComputeGraphEdge {
size_t source = 0;
size_t target = 0;
Weight transferCost = 0;
};
struct ComputeGraph {
llvm::SmallVector<ComputeGraphNode> nodes;
llvm::SmallVector<ComputeGraphEdge> edges;
std::vector<std::vector<std::pair<size_t, Weight>>> successors;
std::vector<std::vector<std::pair<size_t, Weight>>> predecessors;
llvm::DenseMap<ComputeInstance, size_t> instanceToIndex;
};
ComputeGraph buildComputeGraph(mlir::Operation *entryOp);
bool verifyAcyclic(const ComputeGraph &graph);
Weight getComputeInstanceWeight(const ComputeInstance &instance);
CrossbarUsage getComputeInstanceCrossbarUsage(const ComputeInstance &instance);
} // namespace spatial
} // namespace onnx_mlir
@@ -0,0 +1,45 @@
#pragma once
#include "mlir/IR/Operation.h"
#include "llvm/ADT/DenseMapInfo.h"
#include "llvm/ADT/Hashing.h"
#include <cstdint>
namespace onnx_mlir {
namespace spatial {
struct ComputeInstance {
mlir::Operation *op = nullptr;
uint32_t laneStart = 0;
uint32_t laneCount = 1;
bool operator==(const ComputeInstance &other) const {
return op == other.op && laneStart == other.laneStart && laneCount == other.laneCount;
}
};
} // namespace spatial
} // namespace onnx_mlir
using ComputeInstance = onnx_mlir::spatial::ComputeInstance;
namespace llvm {
template <>
struct DenseMapInfo<onnx_mlir::spatial::ComputeInstance> {
static onnx_mlir::spatial::ComputeInstance getEmptyKey() {
return {DenseMapInfo<mlir::Operation *>::getEmptyKey(), UINT32_MAX, UINT32_MAX};
}
static onnx_mlir::spatial::ComputeInstance getTombstoneKey() {
return {DenseMapInfo<mlir::Operation *>::getTombstoneKey(), UINT32_MAX, UINT32_MAX};
}
static unsigned getHashValue(const onnx_mlir::spatial::ComputeInstance &value) {
return llvm::hash_combine(value.op, value.laneStart, value.laneCount);
}
static bool isEqual(const onnx_mlir::spatial::ComputeInstance &lhs,
const onnx_mlir::spatial::ComputeInstance &rhs) {
return lhs == rhs;
}
};
} // namespace llvm
@@ -0,0 +1,193 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include <limits>
#include <optional>
#include "ComputeInstanceUtils.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace spatial {
size_t getSchedulingCpuBudget() {
if (coresCount.getValue() > 0)
return static_cast<size_t>(coresCount.getValue());
return std::numeric_limits<size_t>::max();
}
size_t getBatchChunkTargetCount(int32_t laneCount) {
assert(laneCount > 0 && "laneCount must be positive");
return static_cast<size_t>(laneCount);
}
ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex) {
assert(chunkIndex < static_cast<size_t>(batch.getLaneCount()) && "chunkIndex out of range");
return {batch.getOperation(), static_cast<uint32_t>(chunkIndex), 1};
}
ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane) {
assert(lane < static_cast<uint32_t>(batch.getLaneCount()) && "lane out of range");
return {batch.getOperation(), lane, 1};
}
static std::optional<uint32_t> getConstantExtractLane(tensor::ExtractSliceOp extract) {
if (extract.getMixedOffsets().empty())
return std::nullopt;
OpFoldResult offset = extract.getMixedOffsets().front();
if (Attribute attr = llvm::dyn_cast<Attribute>(offset)) {
auto intAttr = dyn_cast<IntegerAttr>(attr);
if (!intAttr || intAttr.getInt() < 0)
return std::nullopt;
return static_cast<uint32_t>(intAttr.getInt());
}
Value offsetValue = llvm::cast<Value>(offset);
if (auto constantIndex = offsetValue.getDefiningOp<arith::ConstantIndexOp>()) {
if (constantIndex.value() < 0)
return std::nullopt;
return static_cast<uint32_t>(constantIndex.value());
}
return std::nullopt;
}
static std::optional<ProducerValueRef> getResultfulBatchProducerValueRef(SpatComputeBatch batch,
const ComputeInstance* consumerInstance) {
if (!consumerInstance)
return std::nullopt;
if (!isa<SpatComputeBatch>(consumerInstance->op))
return std::nullopt;
if (consumerInstance->laneStart + consumerInstance->laneCount > static_cast<uint32_t>(batch.getLaneCount()))
return std::nullopt;
return ProducerValueRef {
{batch.getOperation(), consumerInstance->laneStart, consumerInstance->laneCount},
0
};
}
std::optional<ProducerValueRef> getProducerValueRef(Value value, const ComputeInstance* consumerInstance) {
Operation* op = value.getDefiningOp();
if (!op)
return std::nullopt;
while (auto extract = dyn_cast<tensor::ExtractSliceOp>(op)) {
Value source = extract.getSource();
auto batch = dyn_cast_or_null<SpatComputeBatch>(source.getDefiningOp());
if (batch && batch.getNumResults() != 0) {
if (std::optional<uint32_t> lane = getConstantExtractLane(extract)) {
if (*lane >= static_cast<uint32_t>(batch.getLaneCount()))
return std::nullopt;
return ProducerValueRef {
{batch.getOperation(), *lane, 1},
0
};
}
return getResultfulBatchProducerValueRef(batch, consumerInstance);
}
value = source;
op = value.getDefiningOp();
if (!op)
return std::nullopt;
}
if (auto compute = dyn_cast<SpatCompute>(op)) {
return ProducerValueRef {
ComputeInstance {compute.getOperation(), 0, 1},
static_cast<size_t>(cast<OpResult>(value).getResultNumber())
};
}
if (auto batch = dyn_cast<SpatComputeBatch>(op)) {
if (batch.getNumResults() != 0)
return getResultfulBatchProducerValueRef(batch, consumerInstance);
uint32_t lane = cast<OpResult>(value).getResultNumber();
ComputeInstance instance = getBatchChunkForLane(batch, lane);
size_t resultIndex = lane - instance.laneStart;
return ProducerValueRef {instance, resultIndex};
}
return std::nullopt;
}
std::optional<ComputeInstance> getComputeProducerInstance(Value value, const ComputeInstance* consumerInstance) {
if (std::optional<ProducerValueRef> producer = getProducerValueRef(value, consumerInstance))
return producer->instance;
return std::nullopt;
}
llvm::SmallVector<Value, 4> getComputeInstanceInputs(const ComputeInstance& instance) {
if (auto compute = dyn_cast<SpatCompute>(instance.op))
return llvm::SmallVector<Value, 4>(compute.getInputs().begin(), compute.getInputs().end());
auto batch = cast<SpatComputeBatch>(instance.op);
if (batch.getNumResults() != 0)
return llvm::SmallVector<Value, 4>(batch.getInputs().begin(), batch.getInputs().end());
assert(batch.getInputs().size() % static_cast<size_t>(batch.getLaneCount()) == 0
&& "resultless compute_batch inputs must be evenly partitioned by lane");
size_t inputsPerLane = batch.getInputs().size() / static_cast<size_t>(batch.getLaneCount());
llvm::SmallVector<Value, 4> inputs;
inputs.reserve(instance.laneCount * inputsPerLane);
for (uint32_t lane = instance.laneStart; lane < instance.laneStart + instance.laneCount; ++lane) {
size_t firstInput = static_cast<size_t>(lane) * inputsPerLane;
inputs.append(batch.getInputs().begin() + firstInput, batch.getInputs().begin() + firstInput + inputsPerLane);
}
return inputs;
}
llvm::SmallVector<Value, 4> getComputeInstanceWeights(const ComputeInstance& instance) {
if (auto compute = dyn_cast<SpatCompute>(instance.op))
return llvm::SmallVector<Value, 4>(compute.getWeights().begin(), compute.getWeights().end());
auto batch = cast<SpatComputeBatch>(instance.op);
if (batch.getNumResults() != 0)
return llvm::SmallVector<Value, 4>(batch.getWeights().begin(), batch.getWeights().end());
assert(batch.getWeights().size() % static_cast<size_t>(batch.getLaneCount()) == 0
&& "resultless compute_batch weights must be evenly partitioned by lane");
size_t weightsPerLane = batch.getWeights().size() / static_cast<size_t>(batch.getLaneCount());
llvm::SmallVector<Value, 4> weights;
weights.reserve(instance.laneCount * weightsPerLane);
for (uint32_t lane = instance.laneStart; lane < instance.laneStart + instance.laneCount; ++lane) {
size_t firstWeight = static_cast<size_t>(lane) * weightsPerLane;
weights.append(batch.getWeights().begin() + firstWeight, batch.getWeights().begin() + firstWeight + weightsPerLane);
}
return weights;
}
llvm::SmallVector<Value, 4> getComputeInstanceOutputValues(const ComputeInstance& instance) {
if (auto compute = dyn_cast<SpatCompute>(instance.op))
return llvm::SmallVector<Value, 4>(compute.getResults().begin(), compute.getResults().end());
auto batch = cast<SpatComputeBatch>(instance.op);
if (batch.getNumResults() != 0)
return llvm::SmallVector<Value, 4>(batch.getResults().begin(), batch.getResults().end());
llvm::SmallVector<Value, 4> outputs;
outputs.reserve(instance.laneCount);
for (uint32_t lane = instance.laneStart; lane < instance.laneStart + instance.laneCount; ++lane)
if (!batch.getOutputs().empty())
outputs.push_back(batch.getOutputs()[lane]);
return outputs;
}
llvm::SmallVector<Type, 4> getComputeInstanceOutputTypes(const ComputeInstance& instance) {
llvm::SmallVector<Type, 4> outputTypes;
for (Value output : getComputeInstanceOutputValues(instance))
outputTypes.push_back(output.getType());
return outputTypes;
}
Block& getComputeInstanceTemplateBlock(const ComputeInstance& instance) {
if (auto compute = dyn_cast<SpatCompute>(instance.op))
return compute.getBody().front();
return cast<SpatComputeBatch>(instance.op).getBody().front();
}
} // namespace spatial
} // namespace onnx_mlir
@@ -0,0 +1,41 @@
#pragma once
#include "mlir/IR/Block.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/SmallVector.h"
#include <cstddef>
#include <optional>
#include "ComputeInstance.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
namespace onnx_mlir {
namespace spatial {
struct ProducerValueRef {
ComputeInstance instance;
size_t resultIndex = 0;
};
size_t getSchedulingCpuBudget();
size_t getBatchChunkTargetCount(int32_t laneCount);
ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex);
ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane);
std::optional<ProducerValueRef> getProducerValueRef(mlir::Value value,
const ComputeInstance *consumerInstance = nullptr);
std::optional<ComputeInstance> getComputeProducerInstance(mlir::Value value,
const ComputeInstance *consumerInstance = nullptr);
llvm::SmallVector<mlir::Value, 4> getComputeInstanceInputs(const ComputeInstance &instance);
llvm::SmallVector<mlir::Value, 4> getComputeInstanceWeights(const ComputeInstance &instance);
llvm::SmallVector<mlir::Value, 4> getComputeInstanceOutputValues(const ComputeInstance &instance);
llvm::SmallVector<mlir::Type, 4> getComputeInstanceOutputTypes(const ComputeInstance &instance);
mlir::Block &getComputeInstanceTemplateBlock(const ComputeInstance &instance);
} // namespace spatial
} // namespace onnx_mlir
@@ -0,0 +1,720 @@
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"
#include <algorithm>
#include <cstdlib>
#include <limits>
#include <numeric>
#include <optional>
#include <queue>
#include <vector>
#include "DcpScheduler.hpp"
#include "../DCPGraph/Graph.hpp"
namespace onnx_mlir {
namespace spatial {
using namespace mlir;
namespace {
bool isDcpCoarsenDebugEnabled() { return std::getenv("DCP_COARSEN_DEBUG") != nullptr; }
struct VirtualNode {
llvm::SmallVector<size_t, 4> originalNodeIndices;
Weight weight = 0;
CrossbarUsage crossbarUsage = 0;
};
struct VirtualGraph {
std::vector<VirtualNode> nodes;
std::vector<IndexedEdge> edges;
};
struct TimingInfo {
std::vector<Time> aest;
std::vector<Time> alst;
std::vector<size_t> topologicalOrder;
bool valid = false;
};
struct WindowScheduleResult {
std::vector<std::vector<size_t>> mergeGroups;
CPU cpuCount = 0;
size_t mergedNodeCount = 0;
size_t maxMergeGroupSize = 0;
};
size_t getSchedulingCpuBudget(const DcpScheduleOptions &options) {
if (options.processorCount > 0)
return options.processorCount;
return std::numeric_limits<size_t>::max();
}
std::vector<IndexedEdge> aggregateEdges(llvm::ArrayRef<IndexedEdge> edges) {
llvm::DenseMap<std::pair<size_t, size_t>, Weight> edgeWeights;
for (auto [start, end, weight] : edges) {
size_t startIndex = static_cast<size_t>(start);
size_t endIndex = static_cast<size_t>(end);
if (startIndex == endIndex)
continue;
auto key = std::make_pair(startIndex, endIndex);
Weight edgeWeight = static_cast<Weight>(weight);
auto inserted = edgeWeights.try_emplace(key, edgeWeight);
if (!inserted.second)
inserted.first->second = std::max(inserted.first->second, edgeWeight);
}
std::vector<IndexedEdge> aggregatedEdges;
aggregatedEdges.reserve(edgeWeights.size());
for (auto [key, weight] : edgeWeights)
aggregatedEdges.push_back(
{static_cast<int64_t>(key.first), static_cast<int64_t>(key.second), static_cast<int64_t>(weight)});
llvm::sort(aggregatedEdges, [](const IndexedEdge &lhs, const IndexedEdge &rhs) {
if (std::get<0>(lhs) != std::get<0>(rhs))
return std::get<0>(lhs) < std::get<0>(rhs);
return std::get<1>(lhs) < std::get<1>(rhs);
});
return aggregatedEdges;
}
VirtualGraph buildInitialVirtualGraph(const ComputeGraph &graph) {
VirtualGraph virtualGraph;
virtualGraph.nodes.reserve(graph.nodes.size());
for (auto [index, node] : llvm::enumerate(graph.nodes)) {
VirtualNode virtualNode;
virtualNode.originalNodeIndices.push_back(index);
virtualNode.weight = node.weight;
virtualNode.crossbarUsage = node.crossbarUsage;
virtualGraph.nodes.push_back(std::move(virtualNode));
}
std::vector<IndexedEdge> edges;
edges.reserve(graph.edges.size());
for (const ComputeGraphEdge &edge : graph.edges)
edges.push_back(
{static_cast<int64_t>(edge.source), static_cast<int64_t>(edge.target), static_cast<int64_t>(edge.transferCost)});
virtualGraph.edges = aggregateEdges(edges);
return virtualGraph;
}
TimingInfo computeTiming(const VirtualGraph &graph) {
TimingInfo timing;
size_t nodeCount = graph.nodes.size();
timing.aest.assign(nodeCount, 0);
timing.alst.assign(nodeCount, 0);
timing.topologicalOrder.reserve(nodeCount);
std::vector<std::vector<std::pair<size_t, Weight>>> parents(nodeCount);
std::vector<std::vector<std::pair<size_t, Weight>>> children(nodeCount);
std::vector<size_t> incomingEdgeCount(nodeCount, 0);
for (auto [start, end, weight] : graph.edges) {
size_t startIndex = static_cast<size_t>(start);
size_t endIndex = static_cast<size_t>(end);
Weight edgeWeight = static_cast<Weight>(weight);
assert(startIndex < nodeCount && endIndex < nodeCount && "virtual edge endpoint out of range");
children[startIndex].push_back({endIndex, edgeWeight});
parents[endIndex].push_back({startIndex, edgeWeight});
incomingEdgeCount[endIndex]++;
}
auto getVirtualNodeOrderKey = [&](size_t nodeIndex) {
const VirtualNode &node = graph.nodes[nodeIndex];
if (!node.originalNodeIndices.empty())
return node.originalNodeIndices.front();
return nodeIndex;
};
auto readyNodeGreater = [&](size_t lhs, size_t rhs) {
size_t lhsKey = getVirtualNodeOrderKey(lhs);
size_t rhsKey = getVirtualNodeOrderKey(rhs);
if (lhsKey != rhsKey)
return lhsKey > rhsKey;
return lhs > rhs;
};
std::priority_queue<size_t, std::vector<size_t>, decltype(readyNodeGreater)> readyNodes(readyNodeGreater);
for (size_t i = 0; i < nodeCount; ++i)
if (incomingEdgeCount[i] == 0)
readyNodes.push(i);
while (!readyNodes.empty()) {
size_t current = readyNodes.top();
readyNodes.pop();
timing.topologicalOrder.push_back(current);
for (auto [child, weight] : children[current]) {
(void) weight;
assert(incomingEdgeCount[child] > 0 && "incoming edge count underflow");
incomingEdgeCount[child]--;
if (incomingEdgeCount[child] == 0)
readyNodes.push(child);
}
}
if (timing.topologicalOrder.size() != nodeCount)
return timing;
Time dcpl = 0;
for (size_t nodeIndex : timing.topologicalOrder) {
Time maxParentAest = 0;
for (auto [parent, transferCost] : parents[nodeIndex]) {
maxParentAest =
std::max(maxParentAest, addOrMax(addOrMax(timing.aest[parent], graph.nodes[parent].weight), transferCost));
}
timing.aest[nodeIndex] = maxParentAest;
dcpl = std::max(dcpl, addOrMax(maxParentAest, graph.nodes[nodeIndex].weight));
}
for (size_t nodeIndex : llvm::reverse(timing.topologicalOrder)) {
Time minAlst = std::numeric_limits<Time>::max();
if (children[nodeIndex].empty())
minAlst = subtractOrZero(dcpl, graph.nodes[nodeIndex].weight);
for (auto [child, transferCost] : children[nodeIndex]) {
minAlst =
std::min(minAlst, subtractOrZero(timing.alst[child], addOrMax(graph.nodes[nodeIndex].weight, transferCost)));
}
timing.alst[nodeIndex] = minAlst;
}
timing.valid = true;
return timing;
}
std::vector<std::vector<size_t>> buildUndirectedAdjacency(const VirtualGraph &graph) {
std::vector<std::vector<size_t>> adjacency(graph.nodes.size());
for (auto [start, end, weight] : graph.edges) {
(void) weight;
size_t startIndex = static_cast<size_t>(start);
size_t endIndex = static_cast<size_t>(end);
assert(startIndex < graph.nodes.size() && endIndex < graph.nodes.size() && "virtual edge endpoint out of range");
adjacency[startIndex].push_back(endIndex);
adjacency[endIndex].push_back(startIndex);
}
for (auto &neighbours : adjacency) {
llvm::sort(neighbours);
neighbours.erase(std::unique(neighbours.begin(), neighbours.end()), neighbours.end());
}
return adjacency;
}
std::vector<size_t> selectCriticalWindow(const VirtualGraph &graph, const TimingInfo &timing, size_t windowSize) {
std::vector<size_t> ranked(timing.aest.size());
std::iota(ranked.begin(), ranked.end(), 0);
auto isHigherPriority = [&](size_t lhs, size_t rhs) {
Time lhsSlack = slackOrZero(timing.aest[lhs], timing.alst[lhs]);
Time rhsSlack = slackOrZero(timing.aest[rhs], timing.alst[rhs]);
if (lhsSlack != rhsSlack)
return lhsSlack < rhsSlack;
if (timing.aest[lhs] != timing.aest[rhs])
return timing.aest[lhs] < timing.aest[rhs];
return lhs < rhs;
};
windowSize = std::min(windowSize, ranked.size());
if (windowSize == 0)
return {};
if (windowSize == ranked.size()) {
llvm::sort(ranked, isHigherPriority);
return ranked;
}
size_t criticalPoolSize = std::min(ranked.size(), std::max(windowSize, windowSize * 2));
if (criticalPoolSize < ranked.size())
std::nth_element(
ranked.begin(), ranked.begin() + static_cast<std::ptrdiff_t>(criticalPoolSize), ranked.end(), isHigherPriority);
std::vector<char> inCriticalPool(ranked.size(), false);
for (size_t i = 0; i < criticalPoolSize; ++i)
inCriticalPool[ranked[i]] = true;
size_t seed = *std::min_element(ranked.begin(), ranked.end(), isHigherPriority);
std::vector<std::vector<size_t>> adjacency = buildUndirectedAdjacency(graph);
std::vector<size_t> selected;
std::vector<char> inWindow(ranked.size(), false);
selected.reserve(windowSize);
struct FrontierEntry {
size_t node;
};
auto frontierCompare = [&](FrontierEntry lhs, FrontierEntry rhs) { return isHigherPriority(rhs.node, lhs.node); };
std::priority_queue<FrontierEntry, std::vector<FrontierEntry>, decltype(frontierCompare)> frontier(frontierCompare);
auto addToWindow = [&](size_t node, const std::vector<char> &eligible) {
if (inWindow[node])
return;
inWindow[node] = true;
selected.push_back(node);
for (size_t neighbour : adjacency[node])
if (!inWindow[neighbour] && eligible[neighbour])
frontier.push({neighbour});
};
addToWindow(seed, inCriticalPool);
while (!frontier.empty() && selected.size() < windowSize) {
size_t node = frontier.top().node;
frontier.pop();
if (!inWindow[node])
addToWindow(node, inCriticalPool);
}
if (selected.size() < windowSize) {
std::vector<char> anyNode(ranked.size(), true);
for (size_t node : selected)
for (size_t neighbour : adjacency[node])
if (!inWindow[neighbour])
frontier.push({neighbour});
while (!frontier.empty() && selected.size() < windowSize) {
size_t node = frontier.top().node;
frontier.pop();
if (!inWindow[node])
addToWindow(node, anyNode);
}
}
if (selected.size() < windowSize) {
llvm::sort(ranked, isHigherPriority);
for (size_t node : ranked) {
if (selected.size() == windowSize)
break;
if (!inWindow[node]) {
inWindow[node] = true;
selected.push_back(node);
}
}
}
llvm::sort(selected, isHigherPriority);
return selected;
}
std::vector<IndexedEdge> buildWindowEdges(const VirtualGraph &graph, const std::vector<int64_t> &nodeToWindowIndex) {
std::vector<IndexedEdge> windowEdges;
windowEdges.reserve(graph.edges.size());
for (auto [start, end, weight] : graph.edges) {
int64_t mappedStart = nodeToWindowIndex[static_cast<size_t>(start)];
int64_t mappedEnd = nodeToWindowIndex[static_cast<size_t>(end)];
if (mappedStart == -1 || mappedEnd == -1)
continue;
windowEdges.push_back({mappedStart, mappedEnd, weight});
}
return aggregateEdges(windowEdges);
}
WindowScheduleResult scheduleWindow(const VirtualGraph &graph,
llvm::ArrayRef<size_t> selectedNodes,
const DcpScheduleOptions &options,
mlir::MLIRContext *context) {
std::vector<Weight> windowWeights;
std::vector<CrossbarUsage> windowCrossbarUsage;
std::vector<int64_t> windowNodeOrderKeys;
std::vector<int64_t> nodeToWindowIndex(graph.nodes.size(), -1);
windowWeights.reserve(selectedNodes.size());
windowCrossbarUsage.reserve(selectedNodes.size());
windowNodeOrderKeys.reserve(selectedNodes.size());
for (auto [windowIndex, nodeIndex] : llvm::enumerate(selectedNodes)) {
nodeToWindowIndex[nodeIndex] = static_cast<int64_t>(windowIndex);
windowWeights.push_back(graph.nodes[nodeIndex].weight);
windowCrossbarUsage.push_back(graph.nodes[nodeIndex].crossbarUsage);
windowNodeOrderKeys.push_back(static_cast<int64_t>(nodeIndex));
}
GraphDCP windowGraph(
windowWeights, buildWindowEdges(graph, nodeToWindowIndex), windowNodeOrderKeys, windowCrossbarUsage);
if (options.processorCount > 0)
windowGraph.setMaxCpuCount(static_cast<int>(options.processorCount));
windowGraph.setContext(context);
windowGraph.runDcp();
WindowScheduleResult result;
result.cpuCount = windowGraph.cpuCount();
for (CPU cpu = 0; cpu < windowGraph.cpuCount(); ++cpu) {
auto scheduledTasks = windowGraph.getScheduledTasks(cpu);
if (scheduledTasks.size() < 2)
continue;
result.mergedNodeCount += scheduledTasks.size();
result.maxMergeGroupSize = std::max(result.maxMergeGroupSize, scheduledTasks.size());
std::vector<size_t> mergeGroup;
mergeGroup.reserve(scheduledTasks.size());
for (const auto &task : scheduledTasks)
mergeGroup.push_back(selectedNodes[task.nodeIndex]);
result.mergeGroups.push_back(std::move(mergeGroup));
}
return result;
}
bool coarsenGraph(const VirtualGraph &graph,
llvm::ArrayRef<std::vector<size_t>> mergeGroups,
VirtualGraph &coarsenedGraph,
std::vector<size_t> &oldToNewNode) {
TimingInfo timing = computeTiming(graph);
std::vector<size_t> topologicalRank(graph.nodes.size());
std::iota(topologicalRank.begin(), topologicalRank.end(), 0);
if (timing.valid)
for (auto [rank, nodeIndex] : llvm::enumerate(timing.topologicalOrder))
topologicalRank[nodeIndex] = rank;
std::vector<std::vector<size_t>> orderedMergeGroups;
orderedMergeGroups.reserve(mergeGroups.size());
for (const auto &mergeGroup : mergeGroups) {
orderedMergeGroups.emplace_back(mergeGroup.begin(), mergeGroup.end());
std::stable_sort(orderedMergeGroups.back().begin(), orderedMergeGroups.back().end(), [&](size_t lhs, size_t rhs) {
if (topologicalRank[lhs] != topologicalRank[rhs])
return topologicalRank[lhs] < topologicalRank[rhs];
return lhs < rhs;
});
}
std::vector<int64_t> nodeToMergeGroup(graph.nodes.size(), -1);
for (auto [groupIndex, mergeGroup] : llvm::enumerate(orderedMergeGroups)) {
if (mergeGroup.size() < 2)
continue;
for (size_t nodeIndex : mergeGroup) {
assert(nodeIndex < graph.nodes.size() && "merge group node out of range");
nodeToMergeGroup[nodeIndex] = static_cast<int64_t>(groupIndex);
}
}
std::vector<std::optional<size_t>> mergeGroupToNewNode(orderedMergeGroups.size());
std::vector<size_t> newNodeRank;
oldToNewNode.assign(graph.nodes.size(), 0);
bool mergedAny = false;
coarsenedGraph.nodes.clear();
coarsenedGraph.edges.clear();
coarsenedGraph.nodes.reserve(graph.nodes.size());
newNodeRank.reserve(graph.nodes.size());
for (size_t nodeIndex = 0; nodeIndex < graph.nodes.size(); ++nodeIndex) {
int64_t mergeGroupIndex = nodeToMergeGroup[nodeIndex];
if (mergeGroupIndex == -1) {
oldToNewNode[nodeIndex] = coarsenedGraph.nodes.size();
coarsenedGraph.nodes.push_back(graph.nodes[nodeIndex]);
newNodeRank.push_back(topologicalRank[nodeIndex]);
continue;
}
auto &newNodeIndex = mergeGroupToNewNode[static_cast<size_t>(mergeGroupIndex)];
if (newNodeIndex.has_value()) {
oldToNewNode[nodeIndex] = *newNodeIndex;
continue;
}
VirtualNode mergedNode;
for (size_t memberIndex : orderedMergeGroups[static_cast<size_t>(mergeGroupIndex)]) {
const VirtualNode &memberNode = graph.nodes[memberIndex];
mergedNode.originalNodeIndices.append(memberNode.originalNodeIndices.begin(), memberNode.originalNodeIndices.end());
mergedNode.weight = addOrMax(mergedNode.weight, memberNode.weight);
mergedNode.crossbarUsage = addOrMax(mergedNode.crossbarUsage, memberNode.crossbarUsage);
}
std::sort(mergedNode.originalNodeIndices.begin(), mergedNode.originalNodeIndices.end());
mergedAny = true;
newNodeIndex = coarsenedGraph.nodes.size();
for (size_t memberIndex : orderedMergeGroups[static_cast<size_t>(mergeGroupIndex)])
oldToNewNode[memberIndex] = *newNodeIndex;
newNodeRank.push_back(topologicalRank[orderedMergeGroups[static_cast<size_t>(mergeGroupIndex)].front()]);
coarsenedGraph.nodes.push_back(std::move(mergedNode));
}
if (!mergedAny)
return false;
std::vector<IndexedEdge> remappedEdges;
remappedEdges.reserve(graph.edges.size());
for (auto [start, end, weight] : graph.edges) {
size_t newStart = oldToNewNode[static_cast<size_t>(start)];
size_t newEnd = oldToNewNode[static_cast<size_t>(end)];
if (newStart == newEnd)
continue;
if (newNodeRank[newStart] >= newNodeRank[newEnd])
continue;
remappedEdges.push_back({static_cast<int64_t>(newStart), static_cast<int64_t>(newEnd), weight});
}
coarsenedGraph.edges = aggregateEdges(remappedEdges);
return true;
}
size_t getDcpCoarseningWindowSize(size_t nodeCount, const DcpScheduleOptions &options) {
size_t windowSize = std::min(options.criticalWindowSize, nodeCount);
CPU maxCpuCount = std::max<CPU>(1, static_cast<CPU>(getSchedulingCpuBudget(options)));
if (nodeCount > static_cast<size_t>(maxCpuCount))
windowSize = std::max(windowSize, std::min(nodeCount, static_cast<size_t>(maxCpuCount) + 1));
return windowSize;
}
void assignFeasibleAest(const ComputeGraph &graph, MergeScheduleResult &result) {
llvm::DenseMap<ComputeInstance, size_t> nodeIndexByInstance;
nodeIndexByInstance.reserve(graph.nodes.size());
for (auto [nodeIndex, node] : llvm::enumerate(graph.nodes))
nodeIndexByInstance[node.instance] = nodeIndex;
struct ScheduledEdge {
size_t target = 0;
Time delay = 0;
};
std::vector<std::vector<ScheduledEdge>> scheduledChildren(graph.nodes.size());
std::vector<size_t> incomingEdgeCount(graph.nodes.size(), 0);
for (const ComputeGraphEdge &edge : graph.edges) {
const ComputeInstance sourceInstance = graph.nodes[edge.source].instance;
const ComputeInstance targetInstance = graph.nodes[edge.target].instance;
const size_t sourceCpu = result.computeToCpuMap.lookup(sourceInstance);
const size_t targetCpu = result.computeToCpuMap.lookup(targetInstance);
Time delay = graph.nodes[edge.source].weight;
if (sourceCpu != targetCpu)
delay = addOrMax(delay, edge.transferCost);
scheduledChildren[edge.source].push_back({edge.target, delay});
incomingEdgeCount[edge.target]++;
}
llvm::DenseMap<size_t, std::vector<std::pair<size_t, size_t>>> tasksByCpu;
for (const ComputeGraphNode &node : graph.nodes) {
size_t cpu = result.computeToCpuMap.lookup(node.instance);
size_t slot = result.computeToCpuSlotMap.lookup(node.instance);
tasksByCpu[cpu].push_back({slot, nodeIndexByInstance.lookup(node.instance)});
}
for (auto &entry : tasksByCpu) {
auto &scheduledTasks = entry.second;
llvm::sort(scheduledTasks, [](const auto &lhs, const auto &rhs) {
if (lhs.first != rhs.first)
return lhs.first < rhs.first;
return lhs.second < rhs.second;
});
for (size_t i = 1; i < scheduledTasks.size(); ++i) {
size_t sourceIndex = scheduledTasks[i - 1].second;
size_t targetIndex = scheduledTasks[i].second;
scheduledChildren[sourceIndex].push_back({targetIndex, graph.nodes[sourceIndex].weight});
incomingEdgeCount[targetIndex]++;
}
}
auto readyNodeGreater = [&](size_t lhs, size_t rhs) {
if (graph.nodes[lhs].originalOrder != graph.nodes[rhs].originalOrder)
return graph.nodes[lhs].originalOrder > graph.nodes[rhs].originalOrder;
return lhs > rhs;
};
std::priority_queue<size_t, std::vector<size_t>, decltype(readyNodeGreater)> readyNodes(readyNodeGreater);
for (size_t nodeIndex = 0; nodeIndex < graph.nodes.size(); ++nodeIndex)
if (incomingEdgeCount[nodeIndex] == 0)
readyNodes.push(nodeIndex);
std::vector<Time> startTimes(graph.nodes.size(), 0);
size_t processedNodeCount = 0;
while (!readyNodes.empty()) {
size_t sourceIndex = readyNodes.top();
readyNodes.pop();
processedNodeCount++;
for (const ScheduledEdge &edge : scheduledChildren[sourceIndex]) {
startTimes[edge.target] = std::max(startTimes[edge.target], addOrMax(startTimes[sourceIndex], edge.delay));
assert(incomingEdgeCount[edge.target] > 0 && "scheduled incoming edge count underflow");
incomingEdgeCount[edge.target]--;
if (incomingEdgeCount[edge.target] == 0)
readyNodes.push(edge.target);
}
}
if (processedNodeCount != graph.nodes.size())
llvm::report_fatal_error("merge scheduling: coarsened DCP schedule is cyclic");
for (auto [nodeIndex, node] : llvm::enumerate(graph.nodes))
result.computeToAestMap[node.instance] = startTimes[nodeIndex];
}
MergeScheduleResult buildResultFromVirtualGraph(const VirtualGraph &graph, const ComputeGraph &originalGraph) {
MergeScheduleResult result;
TimingInfo timing = computeTiming(graph);
std::vector<size_t> virtualNodeOrder;
if (timing.valid)
virtualNodeOrder = std::move(timing.topologicalOrder);
else {
virtualNodeOrder.resize(graph.nodes.size());
std::iota(virtualNodeOrder.begin(), virtualNodeOrder.end(), 0);
}
std::vector<size_t> originalNodeToCpu(originalGraph.nodes.size(), 0);
for (auto [cpu, virtualNodeIndex] : llvm::enumerate(virtualNodeOrder)) {
const VirtualNode &virtualNode = graph.nodes[virtualNodeIndex];
for (size_t originalIndex : virtualNode.originalNodeIndices)
originalNodeToCpu[originalIndex] = cpu;
}
result.dominanceOrderCompute.reserve(originalGraph.nodes.size());
llvm::DenseMap<size_t, size_t> nextCpuSlot;
for (auto [originalIndex, node] : llvm::enumerate(originalGraph.nodes)) {
size_t cpu = originalNodeToCpu[originalIndex];
result.dominanceOrderCompute.push_back(node.instance);
result.computeToCpuMap[node.instance] = cpu;
result.computeToCpuSlotMap[node.instance] = nextCpuSlot[cpu]++;
result.cpuToLastComputeMap[cpu] = node.instance;
}
for (const auto &[cpu, lastCompute] : result.cpuToLastComputeMap)
result.isLastComputeOfCpu.insert(lastCompute);
assignFeasibleAest(originalGraph, result);
return result;
}
MergeScheduleResult buildResultFromScheduledGraph(GraphDCP &graphDCP, const ComputeGraph &graph) {
MergeScheduleResult result;
result.dominanceOrderCompute.reserve(graph.nodes.size());
for (const ComputeGraphNode &node : graph.nodes)
result.dominanceOrderCompute.push_back(node.instance);
for (CPU cpu = 0; cpu < graphDCP.cpuCount(); ++cpu) {
auto scheduledTasks = graphDCP.getScheduledTasks(cpu);
if (scheduledTasks.empty())
continue;
for (auto [slot, task] : llvm::enumerate(scheduledTasks)) {
const ComputeInstance instance = graph.nodes[task.nodeIndex].instance;
result.computeToCpuMap[instance] = cpu;
result.computeToCpuSlotMap[instance] = slot;
result.computeToAestMap[instance] = static_cast<uint64_t>(task.aest);
}
const ComputeInstance lastInstance = graph.nodes[scheduledTasks.back().nodeIndex].instance;
result.cpuToLastComputeMap[cpu] = lastInstance;
result.isLastComputeOfCpu.insert(lastInstance);
}
return result;
}
MergeScheduleResult runLegacyDcp(const ComputeGraph &graph, const DcpScheduleOptions &options, mlir::MLIRContext *context) {
llvm::SmallVector<Weight> nodeWeights;
llvm::SmallVector<CrossbarUsage> nodeCrossbarUsage;
llvm::SmallVector<int64_t> nodeOrderKeys;
llvm::SmallVector<IndexedEdge> edges;
nodeWeights.reserve(graph.nodes.size());
nodeCrossbarUsage.reserve(graph.nodes.size());
nodeOrderKeys.reserve(graph.nodes.size());
edges.reserve(graph.edges.size());
for (const ComputeGraphNode &node : graph.nodes) {
nodeWeights.push_back(node.weight);
nodeCrossbarUsage.push_back(node.crossbarUsage);
nodeOrderKeys.push_back(static_cast<int64_t>(node.originalOrder));
}
for (const ComputeGraphEdge &edge : graph.edges) {
edges.push_back(
{static_cast<int64_t>(edge.source), static_cast<int64_t>(edge.target), static_cast<int64_t>(edge.transferCost)});
}
GraphDCP graphDCP(nodeWeights, edges, nodeOrderKeys, nodeCrossbarUsage);
if (options.processorCount > 0)
graphDCP.setMaxCpuCount(static_cast<int>(options.processorCount));
graphDCP.setContext(context);
graphDCP.runDcp();
return buildResultFromScheduledGraph(graphDCP, graph);
}
bool needsExactScheduledBatches(const ComputeGraph &graph, const DcpScheduleOptions &options) {
if (options.processorCount == 0 || !options.allowFallbackForAutoCoreCount)
return false;
size_t schedulingCpuBudget = getSchedulingCpuBudget(options);
return llvm::any_of(graph.nodes, [&](const ComputeGraphNode &node) {
auto batch = dyn_cast<SpatComputeBatch>(node.instance.op);
return batch && static_cast<size_t>(batch.getLaneCount()) > schedulingCpuBudget;
});
}
} // namespace
MergeScheduleResult
runDcpScheduler(const ComputeGraph &graph, const DcpScheduleOptions &options, mlir::MLIRContext *context) {
if (needsExactScheduledBatches(graph, options))
return runLegacyDcp(graph, options, context);
if (options.criticalWindowSize == 0)
return runLegacyDcp(graph, options, context);
VirtualGraph virtualGraph = buildInitialVirtualGraph(graph);
size_t iteration = 0;
bool debugCoarsening = isDcpCoarsenDebugEnabled();
auto tryCoarsenSelectedNodes = [&](llvm::ArrayRef<size_t> selectedNodes) {
size_t oldNodeCount = virtualGraph.nodes.size();
WindowScheduleResult windowSchedule = scheduleWindow(virtualGraph, selectedNodes, options, context);
if (windowSchedule.mergeGroups.empty()) {
if (debugCoarsening && oldNodeCount >= 200)
llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} windowCpus={3} "
"groups=0 mergedNodes=0 maxGroup=0 new={1} changed=0\n",
iteration,
oldNodeCount,
selectedNodes.size(),
windowSchedule.cpuCount);
return false;
}
VirtualGraph coarsenedGraph;
std::vector<size_t> oldToNewNode;
if (!coarsenGraph(virtualGraph, windowSchedule.mergeGroups, coarsenedGraph, oldToNewNode))
return false;
if (debugCoarsening && (oldNodeCount >= 200 || coarsenedGraph.nodes.size() >= 200))
llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} windowCpus={3} "
"groups={4} mergedNodes={5} maxGroup={6} new={7} changed={8}\n",
iteration,
oldNodeCount,
selectedNodes.size(),
windowSchedule.cpuCount,
windowSchedule.mergeGroups.size(),
windowSchedule.mergedNodeCount,
windowSchedule.maxMergeGroupSize,
coarsenedGraph.nodes.size(),
oldNodeCount - coarsenedGraph.nodes.size());
virtualGraph = std::move(coarsenedGraph);
return true;
};
while (virtualGraph.nodes.size() > 1) {
if (virtualGraph.nodes.size() <= getSchedulingCpuBudget(options)) {
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
llvm::errs() << llvm::formatv(
"[DCP-COARSEN] iter={0} old={1} stop=cpu-budget\n", iteration, virtualGraph.nodes.size());
break;
}
iteration++;
TimingInfo timing = computeTiming(virtualGraph);
if (!timing.valid) {
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
llvm::errs() << llvm::formatv(
"[DCP-COARSEN] iter={0} old={1} invalid-timing\n", iteration, virtualGraph.nodes.size());
break;
}
llvm::SmallVector<size_t> selectedNodes;
auto criticalWindow =
selectCriticalWindow(virtualGraph, timing, getDcpCoarseningWindowSize(virtualGraph.nodes.size(), options));
selectedNodes.append(criticalWindow.begin(), criticalWindow.end());
if (selectedNodes.size() < 2) {
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} stop=small-window\n",
iteration,
virtualGraph.nodes.size(),
selectedNodes.size());
break;
}
if (tryCoarsenSelectedNodes(selectedNodes))
continue;
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
llvm::errs() << llvm::formatv(
"[DCP-COARSEN] iter={0} old={1} stop=no-merge\n", iteration, virtualGraph.nodes.size());
break;
}
return buildResultFromVirtualGraph(virtualGraph, graph);
}
} // namespace spatial
} // namespace onnx_mlir
@@ -0,0 +1,21 @@
#pragma once
#include "mlir/IR/MLIRContext.h"
#include "ComputeGraph.hpp"
#include "MergeSchedule.hpp"
namespace onnx_mlir {
namespace spatial {
struct DcpScheduleOptions {
size_t processorCount = 0;
size_t criticalWindowSize = 0;
bool allowFallbackForAutoCoreCount = true;
};
MergeScheduleResult
runDcpScheduler(const ComputeGraph &graph, const DcpScheduleOptions &options, mlir::MLIRContext *context);
} // namespace spatial
} // namespace onnx_mlir
@@ -0,0 +1,26 @@
#pragma once
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h"
#include <cstddef>
#include <cstdint>
#include <vector>
#include "ComputeInstance.hpp"
namespace onnx_mlir {
namespace spatial {
struct MergeScheduleResult {
std::vector<ComputeInstance> dominanceOrderCompute;
llvm::DenseMap<ComputeInstance, size_t> computeToCpuMap;
llvm::DenseMap<ComputeInstance, size_t> computeToCpuSlotMap;
llvm::DenseMap<ComputeInstance, uint64_t> computeToAestMap;
llvm::DenseSet<ComputeInstance> isLastComputeOfCpu;
llvm::DenseMap<size_t, ComputeInstance> cpuToLastComputeMap;
llvm::DenseMap<size_t, mlir::SmallVector<size_t, 5>> equivalentClass;
};
} // namespace spatial
} // namespace onnx_mlir
@@ -0,0 +1,139 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/FormatVariadic.h"
#include <limits>
#include <vector>
#include "ComputeGraph.hpp"
#include "../DCPGraph/DCPAnalysis.hpp"
#include "DcpScheduler.hpp"
#include "MergeSchedulingAnalysis.hpp"
#include "PeftScheduler.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
namespace onnx_mlir {
namespace spatial {
namespace {
MergeSchedulerKind getSchedulerKind() {
switch (pimMergeScheduler.getValue()) {
case MergeSchedulerPeft:
return MergeSchedulerKind::Peft;
case MergeSchedulerDcp:
return MergeSchedulerKind::Dcp;
}
llvm_unreachable("unknown merge scheduler kind");
}
void verifySchedule(const ComputeGraph &graph, const MergeScheduleResult &result, CrossbarUsage crossbarCapacity) {
llvm::DenseMap<size_t, std::vector<std::pair<size_t, size_t>>> tasksByCpu;
tasksByCpu.reserve(result.cpuToLastComputeMap.size());
for (size_t nodeIndex = 0; nodeIndex < graph.nodes.size(); ++nodeIndex) {
const ComputeInstance instance = graph.nodes[nodeIndex].instance;
if (!result.computeToCpuMap.count(instance))
llvm::report_fatal_error("merge scheduling: missing CPU assignment");
if (!result.computeToCpuSlotMap.count(instance))
llvm::report_fatal_error("merge scheduling: missing CPU slot assignment");
if (!result.computeToAestMap.count(instance))
llvm::report_fatal_error("merge scheduling: missing start time");
tasksByCpu[result.computeToCpuMap.lookup(instance)].push_back(
{result.computeToCpuSlotMap.lookup(instance), nodeIndex});
}
for (auto &entry : tasksByCpu) {
auto &scheduledTasks = entry.second;
llvm::sort(scheduledTasks, [](const auto &lhs, const auto &rhs) {
if (lhs.first != rhs.first)
return lhs.first < rhs.first;
return lhs.second < rhs.second;
});
CrossbarUsage usedCrossbars = 0;
for (size_t slot = 0; slot < scheduledTasks.size(); ++slot) {
if (scheduledTasks[slot].first != slot)
llvm::report_fatal_error("merge scheduling: CPU slots are not contiguous");
usedCrossbars = addOrMax(usedCrossbars, graph.nodes[scheduledTasks[slot].second].crossbarUsage);
if (usedCrossbars > crossbarCapacity)
llvm::report_fatal_error("merge scheduling: CPU crossbar capacity exceeded");
}
const ComputeInstance expectedLast = graph.nodes[scheduledTasks.back().second].instance;
auto lastIt = result.cpuToLastComputeMap.find(entry.first);
if (lastIt == result.cpuToLastComputeMap.end() || !(lastIt->second == expectedLast))
llvm::report_fatal_error("merge scheduling: cpuToLastComputeMap does not match slot order");
if (!result.isLastComputeOfCpu.count(expectedLast))
llvm::report_fatal_error("merge scheduling: missing last-compute marker");
}
for (const ComputeGraphEdge &edge : graph.edges) {
const ComputeInstance source = graph.nodes[edge.source].instance;
const ComputeInstance target = graph.nodes[edge.target].instance;
const size_t sourceCpu = result.computeToCpuMap.lookup(source);
const size_t targetCpu = result.computeToCpuMap.lookup(target);
const size_t sourceSlot = result.computeToCpuSlotMap.lookup(source);
const size_t targetSlot = result.computeToCpuSlotMap.lookup(target);
const Time sourceStart = static_cast<Time>(result.computeToAestMap.lookup(source));
const Time targetStart = static_cast<Time>(result.computeToAestMap.lookup(target));
if (sourceCpu == targetCpu && sourceSlot >= targetSlot)
llvm::report_fatal_error("merge scheduling: same-CPU dependency order is invalid");
Time earliestTargetStart = addOrMax(sourceStart, graph.nodes[edge.source].weight);
if (sourceCpu != targetCpu)
earliestTargetStart = addOrMax(earliestTargetStart, edge.transferCost);
if (targetStart < earliestTargetStart) {
std::string message = llvm::formatv("merge scheduling: dependency legality failed between tasks {0} and {1}",
graph.nodes[edge.source].originalOrder,
graph.nodes[edge.target].originalOrder)
.str();
llvm::report_fatal_error(llvm::StringRef(message));
}
}
}
} // namespace
MergeSchedulingAnalysis::MergeSchedulingAnalysis(mlir::Operation *op)
: entryOp(op) {
result = run();
}
MergeScheduleResult MergeSchedulingAnalysis::run() {
verifyExplicitPimCoreCount();
ComputeGraph graph = buildComputeGraph(entryOp);
if (!verifyAcyclic(graph))
llvm::report_fatal_error("merge scheduling: compute graph is cyclic");
MergeSchedulingOptions options;
options.kind = getSchedulerKind();
if (coresCount.getValue() > 0)
options.processorCount = static_cast<size_t>(coresCount.getValue());
MergeScheduleResult schedule;
if (options.kind == MergeSchedulerKind::Peft) {
schedule = runPeftScheduler(
graph,
PeftScheduleOptions {options.processorCount, static_cast<CrossbarUsage>(crossbarCountInCore.getValue()),
entryOp->getContext()});
}
else {
schedule = runDcpScheduler(
graph,
DcpScheduleOptions {
options.processorCount,
dcpCriticalWindowSize.getValue(),
options.allowDcpFallbackForAutoCoreCount
},
entryOp->getContext());
}
verifySchedule(graph, schedule, static_cast<CrossbarUsage>(crossbarCountInCore.getValue()));
return schedule;
}
} // namespace spatial
} // namespace onnx_mlir

Some files were not shown because too many files have changed in this diff Show More