Compare commits
34 Commits
628dc630a4
..
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 5637c861b4 | |||
| 94157a8404 | |||
| 68a3521978 | |||
| e263e05f56 | |||
| 34c29fdec4 | |||
| aa088e2ba5 | |||
| 2836e759ab | |||
| 8071ebab0b | |||
| f1602c0550 | |||
| de0a2f4561 | |||
| 1c4a5bde76 | |||
| 78242e2887 | |||
| fe244d5aa1 | |||
| d09e76c8f9 | |||
| c5e608fa5b | |||
| 43f3ccdd21 | |||
| 8d95c604a6 | |||
| 55eda487dc | |||
| 061139aefb | |||
| ea61540e08 | |||
| 324178cba8 | |||
| e71ba07cd5 | |||
| 64a3805619 | |||
| 9f9e7c0892 | |||
| 03eab42971 | |||
| c15aba5d96 | |||
| 4821e8a55e | |||
| 88bb223bb1 | |||
| 623ee62a04 | |||
| ad56888b0b | |||
| f993840641 | |||
| 0c7db55a24 | |||
| 41de3cb150 | |||
| 4f3570520c |
+1
-1
@@ -3,4 +3,4 @@
|
||||
url = https://github.com/onnx/onnx-mlir.git
|
||||
[submodule "backend-simulators/pim/pimsim-nn"]
|
||||
path = backend-simulators/pim/pimsim-nn
|
||||
url = https://github.com/wangxy-2000/pimsim-nn.git
|
||||
url = https://github.com/HEAPLab/pimsim-nn.git
|
||||
|
||||
@@ -81,7 +81,11 @@ ONNX-MLIR ──► Spatial ──► Pim (tensor) ──► Pim (bufferized)
|
||||
standard MLIR `BufferizableOpInterface` machinery
|
||||
(`OpBufferizationInterfaces.*`, `PimBufferization.td`).
|
||||
|
||||
5. **PIM code generation** (`src/PIM/Pass/PimCodegen`):
|
||||
5. **Static memory coalescing** (`src/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing`).
|
||||
Conservatively reuses same-typed local memref allocations inside PIM cores
|
||||
after bufferization and before code generation.
|
||||
|
||||
6. **PIM code generation** (`src/PIM/Pass/PimCodegen`):
|
||||
- `HostConstantFolding` — folds host-side constants.
|
||||
- `MaterializeHostConstantsPass` — materializes the remaining host
|
||||
constants for emission.
|
||||
@@ -110,7 +114,9 @@ Pass these on the `onnx-mlir` command line when compiling for PIM:
|
||||
run only the codegen tail.
|
||||
- `--crossbar-size=<N>` / `--crossbar-count=<N>` — crossbar dimensions and
|
||||
per-core count.
|
||||
- `--core-count=<N>` — number of cores (`-1` picks the minimum).
|
||||
- `--core-count=<N>` — number of cores. Required for PIM compilation.
|
||||
- `--pim-merge-scheduler={peft,dcp}` — scheduler used by the Spatial
|
||||
merge-compute-nodes pass (default: `peft`).
|
||||
- `--dcp-critical-window-size=<N>` — DCP coarsening window (0 = legacy).
|
||||
- `--use-experimental-conv-impl` — alternative convolution lowering.
|
||||
- `--ignore-concat-error` — soft-fail corner case in `ConcatOp`.
|
||||
@@ -125,7 +131,8 @@ Per-operation validation (from `validation/`):
|
||||
```
|
||||
validate.py \
|
||||
--raptor-path ../cmake-build-release/Release/bin/onnx-mlir \
|
||||
--onnx-include-dir ../onnx-mlir/include
|
||||
--onnx-include-dir ../onnx-mlir/include \
|
||||
--core-count 1000
|
||||
```
|
||||
|
||||
End-to-end network validation (example: first 4 layers of YOLOv11n):
|
||||
|
||||
+19
@@ -1030,6 +1030,15 @@ version = "0.2.16"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b6d2cec3eae94f9f509c767b45932f1ada8350c4bdb85af2fcab4a3c14807981"
|
||||
|
||||
[[package]]
|
||||
name = "libmimalloc-sys"
|
||||
version = "0.1.47"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2d1eacfa31c33ec25e873c136ba5669f00f9866d0688bea7be4d3f7e43067df6"
|
||||
dependencies = [
|
||||
"cc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "linux-raw-sys"
|
||||
version = "0.12.1"
|
||||
@@ -1095,6 +1104,15 @@ version = "2.8.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79"
|
||||
|
||||
[[package]]
|
||||
name = "mimalloc"
|
||||
version = "0.1.50"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b3627c4272df786b9260cabaa46aec1d59c93ede723d4c3ef646c503816b0640"
|
||||
dependencies = [
|
||||
"libmimalloc-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mime"
|
||||
version = "0.3.17"
|
||||
@@ -1414,6 +1432,7 @@ dependencies = [
|
||||
"faer-traits",
|
||||
"glob",
|
||||
"hex",
|
||||
"mimalloc",
|
||||
"paste",
|
||||
"plotly",
|
||||
"rayon",
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
[package]
|
||||
name = "pim-simulator"
|
||||
version = "0.1.0"
|
||||
@@ -34,3 +33,4 @@ plotly = {version="0.8", optional=true}
|
||||
rayon = "1.12.0"
|
||||
faer = "0.24.0"
|
||||
faer-traits = "0.24.0"
|
||||
mimalloc = "0.1.50"
|
||||
|
||||
@@ -1,14 +1,20 @@
|
||||
use mimalloc::MiMalloc;
|
||||
|
||||
#[global_allocator]
|
||||
static GLOBAL: MiMalloc = MiMalloc;
|
||||
|
||||
use anyhow::{Context, Result, bail};
|
||||
use clap::Parser;
|
||||
use glob::glob;
|
||||
use pimcore::binary_to_instruction::binary_to_executor;
|
||||
use pimcore::cpu::crossbar::Crossbar;
|
||||
use pimcore::json_to_instruction::json_to_executor;
|
||||
use pimcore::memory_manager::CoreMemory;
|
||||
use pimcore::tracing::TRACER;
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use std::fs::{self, read_link};
|
||||
use std::io::Write;
|
||||
use std::fs::{self, File, read_link};
|
||||
use std::io::{BufReader, Write};
|
||||
use std::path::PathBuf;
|
||||
|
||||
/// Program to simulate core execution configuration
|
||||
@@ -44,18 +50,24 @@ fn main() -> Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
let config_json = retrive_config(&args)?;
|
||||
let core_jsons = retrive_cores(&args)?;
|
||||
let mut core_inputs = retrive_cores(&args)?;
|
||||
let memory = retrive_memory(&args)?;
|
||||
let global_crossbars = get_crossbars(&config_json, &args).unwrap();
|
||||
let crossbars = map_crossbars_to_cores(&config_json, &args, &global_crossbars);
|
||||
let mut executor =
|
||||
json_to_executor::json_to_executor(config_json, core_jsons.iter(), crossbars);
|
||||
let mut executor = match &mut core_inputs {
|
||||
CoreInputs::Json(core_jsons) => {
|
||||
json_to_executor::json_to_executor(config_json, core_jsons, crossbars)
|
||||
}
|
||||
CoreInputs::Binary(core_bins) => {
|
||||
binary_to_executor(config_json, core_bins.iter(), crossbars)?
|
||||
}
|
||||
};
|
||||
set_memory(&mut executor, memory);
|
||||
TRACER
|
||||
.lock()
|
||||
.unwrap()
|
||||
.init(executor.cpu().num_core(), args.output.clone());
|
||||
executor.execute();
|
||||
executor.execute()?;
|
||||
dump_memory(executor, &args)?;
|
||||
Ok(())
|
||||
}
|
||||
@@ -65,7 +77,7 @@ fn map_crossbars_to_cores<'c>(
|
||||
args: &Args,
|
||||
global_crossbars: &'c HashMap<String, Crossbar>,
|
||||
) -> Vec<Vec<&'c Crossbar>> {
|
||||
let mut res = Vec::new();
|
||||
let mut res = vec![Vec::new()];
|
||||
let num_cores = config.get("core_cnt").unwrap().as_i64().unwrap() as i32;
|
||||
|
||||
if let Some(folder) = args.folder.as_ref() {
|
||||
@@ -140,8 +152,7 @@ fn get_crossbars(config: &Value, args: &Args) -> anyhow::Result<HashMap<String,
|
||||
}
|
||||
|
||||
let bytes = std::fs::read(weight_file.path()).expect("Failed to read binary file");
|
||||
let mut crossbar =
|
||||
Crossbar::new(column_corssbar * 4, rows_crossbar, CoreMemory::new());
|
||||
let mut crossbar = Crossbar::new(column_corssbar * 4, rows_crossbar, CoreMemory::new());
|
||||
crossbar.execute_store(&bytes).unwrap();
|
||||
res.insert(
|
||||
weight_file
|
||||
@@ -214,45 +225,82 @@ fn retrive_memory(args: &Args) -> Result<Vec<u8>> {
|
||||
Ok(memory_vector)
|
||||
}
|
||||
|
||||
fn retrive_cores(args: &Args) -> Result<Vec<Value>, anyhow::Error> {
|
||||
let mut core_jsons: Vec<Value> = Vec::new();
|
||||
enum CoreInputs {
|
||||
Json(Vec<BufReader<File>>),
|
||||
Binary(Vec<Vec<u8>>),
|
||||
}
|
||||
|
||||
fn retrive_cores(args: &Args) -> Result<CoreInputs, anyhow::Error> {
|
||||
if let Some(cores_override) = &args.cores {
|
||||
let first_extension = cores_override
|
||||
.first()
|
||||
.and_then(|path| path.extension())
|
||||
.and_then(|ext| ext.to_str())
|
||||
.unwrap_or_default();
|
||||
if first_extension == "pim" {
|
||||
let mut core_bins = Vec::with_capacity(cores_override.len());
|
||||
for core in cores_override {
|
||||
let content = fs::read_to_string(core)
|
||||
.with_context(|| format!("Failed to read core file: {:?}", cores_override))?;
|
||||
let json: Value =
|
||||
serde_json::from_str(&content).context("Failed to parse core json override")?;
|
||||
core_jsons.push(json);
|
||||
core_bins.push(
|
||||
fs::read(core)
|
||||
.with_context(|| format!("Failed to read binary core file: {:?}", core))?,
|
||||
);
|
||||
}
|
||||
} else if let Some(folder) = args.folder.as_ref() {
|
||||
let pattern = folder.join("core*.json");
|
||||
let pattern_str = pattern.to_str().context("Invalid path encoding")?;
|
||||
let mut paths: Vec<_> = glob(pattern_str)?.map(|x| x.unwrap()).collect();
|
||||
paths.sort_by_cached_key(|x| {
|
||||
let mut x = x
|
||||
return Ok(CoreInputs::Binary(core_bins));
|
||||
}
|
||||
let mut core_jsons_reader: Vec<BufReader<File>> = Vec::with_capacity(cores_override.len());
|
||||
for core in cores_override {
|
||||
let file = File::open(core)?;
|
||||
let reader = BufReader::new(file);
|
||||
core_jsons_reader.push(reader);
|
||||
}
|
||||
return Ok(CoreInputs::Json(core_jsons_reader));
|
||||
}
|
||||
|
||||
if let Some(folder) = args.folder.as_ref() {
|
||||
let binary_pattern = folder.join("core*.pim");
|
||||
let binary_pattern_str = binary_pattern.to_str().context("Invalid path encoding")?;
|
||||
let mut binary_paths: Vec<_> = glob(binary_pattern_str)?.map(|x| x.unwrap()).collect();
|
||||
binary_paths.sort_by_cached_key(core_sort_key);
|
||||
if !binary_paths.is_empty() {
|
||||
let mut core_bins = Vec::with_capacity(binary_paths.len());
|
||||
for path in binary_paths {
|
||||
core_bins.push(
|
||||
fs::read(&path)
|
||||
.with_context(|| format!("Failed to read core file: {:?}", path))?,
|
||||
);
|
||||
}
|
||||
return Ok(CoreInputs::Binary(core_bins));
|
||||
}
|
||||
|
||||
let json_pattern = folder.join("core*.json");
|
||||
let json_pattern_str = json_pattern.to_str().context("Invalid path encoding")?;
|
||||
let mut json_paths: Vec<_> = glob(json_pattern_str)?.map(|x| x.unwrap()).collect();
|
||||
json_paths.sort_by_cached_key(core_sort_key);
|
||||
|
||||
if json_paths.is_empty() {
|
||||
bail!("No core*.pim or core*.json files found in {:?}", folder);
|
||||
}
|
||||
|
||||
let mut core_json_reader: Vec<BufReader<File>> = Vec::with_capacity(json_paths.len());
|
||||
for path in json_paths {
|
||||
let file = File::open(path)?;
|
||||
let reader = BufReader::new(file);
|
||||
core_json_reader.push(reader);
|
||||
}
|
||||
return Ok(CoreInputs::Json(core_json_reader));
|
||||
}
|
||||
|
||||
bail!("Either --core or --folder must be provided to find core definitions.");
|
||||
}
|
||||
|
||||
fn core_sort_key(path: &PathBuf) -> i32 {
|
||||
let mut stem = path
|
||||
.file_stem()
|
||||
.expect("Extracting the stem")
|
||||
.to_str()
|
||||
.expect("File not utf-8");
|
||||
x = &x[5..];
|
||||
x.parse::<i32>().unwrap()
|
||||
});
|
||||
|
||||
if paths.is_empty() {
|
||||
bail!("No core*.json files found in {:?}", folder);
|
||||
}
|
||||
for entry in paths {
|
||||
let path = entry;
|
||||
let content = fs::read_to_string(&path)
|
||||
.with_context(|| format!("Failed to read core file: {:?}", path))?;
|
||||
let json: Value = serde_json::from_str(&content)
|
||||
.with_context(|| format!("Failed to parse JSON in {:?}", path))?;
|
||||
core_jsons.push(json);
|
||||
}
|
||||
} else {
|
||||
bail!("Either --core or --folder must be provided to find core definitions.");
|
||||
}
|
||||
Ok(core_jsons)
|
||||
stem = &stem[5..];
|
||||
stem.parse::<i32>().unwrap()
|
||||
}
|
||||
|
||||
fn retrive_config(args: &Args) -> Result<Value, anyhow::Error> {
|
||||
|
||||
@@ -0,0 +1,497 @@
|
||||
use crate::{
|
||||
CoreInstructionsBuilder, Executable,
|
||||
cpu::{CPU, crossbar::Crossbar},
|
||||
instruction_set::{InstructionsBuilder, instruction_data::InstructionDataBuilder, isa::*},
|
||||
};
|
||||
use anyhow::{Context, Result, bail, ensure};
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use std::convert::TryFrom;
|
||||
use std::sync::LazyLock;
|
||||
|
||||
const MAGIC: &[u8; 4] = b"PIMB";
|
||||
const VERSION: u32 = 1;
|
||||
const HEADER_SIZE: usize = 12;
|
||||
const RECORD_SIZE: usize = 20;
|
||||
|
||||
macro_rules! add_name {
|
||||
($storage:ident, $opcode:literal, $name:literal) => {
|
||||
$storage.insert($opcode, $name);
|
||||
};
|
||||
}
|
||||
|
||||
static INSTRUCTIONS: LazyLock<HashMap<usize, &'static str>> = LazyLock::new(|| {
|
||||
let mut hash = HashMap::new();
|
||||
add_name!(hash, 0, "nop");
|
||||
add_name!(hash, 1, "sldi");
|
||||
add_name!(hash, 2, "sld");
|
||||
add_name!(hash, 3, "sadd");
|
||||
add_name!(hash, 4, "ssub");
|
||||
add_name!(hash, 5, "smul");
|
||||
add_name!(hash, 6, "saddi");
|
||||
add_name!(hash, 7, "smuli");
|
||||
add_name!(hash, 8, "setbw");
|
||||
add_name!(hash, 9, "mvmul");
|
||||
add_name!(hash, 10, "vvadd");
|
||||
add_name!(hash, 11, "vvsub");
|
||||
add_name!(hash, 12, "vvmul");
|
||||
add_name!(hash, 13, "vvdmul");
|
||||
add_name!(hash, 14, "vvmax");
|
||||
add_name!(hash, 15, "vvsll");
|
||||
add_name!(hash, 16, "vvsra");
|
||||
add_name!(hash, 17, "vavg");
|
||||
add_name!(hash, 18, "vrelu");
|
||||
add_name!(hash, 19, "vtanh");
|
||||
add_name!(hash, 20, "vsigm");
|
||||
add_name!(hash, 21, "vsoftmax");
|
||||
add_name!(hash, 22, "vmv");
|
||||
add_name!(hash, 23, "vrsu");
|
||||
add_name!(hash, 24, "vrsl");
|
||||
add_name!(hash, 25, "ld");
|
||||
add_name!(hash, 26, "st");
|
||||
add_name!(hash, 27, "lldi");
|
||||
add_name!(hash, 28, "lmv");
|
||||
add_name!(hash, 29, "send");
|
||||
add_name!(hash, 30, "recv");
|
||||
add_name!(hash, 31, "wait");
|
||||
add_name!(hash, 32, "sync");
|
||||
hash
|
||||
});
|
||||
|
||||
#[derive(Clone, Copy, Debug, Default)]
|
||||
struct InstructionRecord {
|
||||
opcode: u8,
|
||||
rd: u8,
|
||||
r1: u8,
|
||||
r2_or_imm: i32,
|
||||
generic1: i32,
|
||||
generic2: i32,
|
||||
generic3: i32,
|
||||
flags: u8,
|
||||
}
|
||||
|
||||
fn read_u32_le(bytes: &[u8], offset: usize) -> u32 {
|
||||
u32::from_le_bytes(bytes[offset..offset + 4].try_into().unwrap())
|
||||
}
|
||||
|
||||
fn read_i32_le(bytes: &[u8], offset: usize) -> i32 {
|
||||
i32::from_le_bytes(bytes[offset..offset + 4].try_into().unwrap())
|
||||
}
|
||||
|
||||
fn parse_binary_records(bytes: &[u8]) -> Result<Vec<InstructionRecord>> {
|
||||
ensure!(bytes.len() >= HEADER_SIZE, "binary core file too small");
|
||||
ensure!(&bytes[0..4] == MAGIC, "invalid PIM binary magic");
|
||||
|
||||
let version = read_u32_le(bytes, 4);
|
||||
ensure!(
|
||||
version == VERSION,
|
||||
"unsupported PIM binary version {version}"
|
||||
);
|
||||
|
||||
let instruction_count = read_u32_le(bytes, 8) as usize;
|
||||
let expected_len = HEADER_SIZE + instruction_count * RECORD_SIZE;
|
||||
ensure!(
|
||||
bytes.len() == expected_len,
|
||||
"PIM binary size mismatch: expected {expected_len} bytes, got {}",
|
||||
bytes.len()
|
||||
);
|
||||
|
||||
let mut records = Vec::with_capacity(instruction_count);
|
||||
for index in 0..instruction_count {
|
||||
let base = HEADER_SIZE + index * RECORD_SIZE;
|
||||
records.push(InstructionRecord {
|
||||
opcode: bytes[base],
|
||||
rd: bytes[base + 1],
|
||||
r1: bytes[base + 2],
|
||||
flags: bytes[base + 3],
|
||||
r2_or_imm: read_i32_le(bytes, base + 4),
|
||||
generic1: read_i32_le(bytes, base + 8),
|
||||
generic2: read_i32_le(bytes, base + 12),
|
||||
generic3: read_i32_le(bytes, base + 16),
|
||||
});
|
||||
}
|
||||
|
||||
Ok(records)
|
||||
}
|
||||
|
||||
fn append_record(
|
||||
inst_builder: &mut InstructionsBuilder,
|
||||
inst_data_builder: &mut InstructionDataBuilder,
|
||||
record: InstructionRecord,
|
||||
) -> Result<()> {
|
||||
let InstructionRecord {
|
||||
opcode,
|
||||
rd,
|
||||
r1,
|
||||
r2_or_imm,
|
||||
generic1,
|
||||
generic2,
|
||||
generic3,
|
||||
flags: _,
|
||||
} = record;
|
||||
|
||||
match opcode {
|
||||
0 => {}
|
||||
1 => {
|
||||
inst_data_builder.set_rd_u8(rd).set_imm(r2_or_imm);
|
||||
inst_builder.make_inst(sldi, inst_data_builder.build());
|
||||
}
|
||||
2 => {
|
||||
inst_data_builder
|
||||
.set_rd_u8(rd)
|
||||
.set_r1_u8(r1)
|
||||
.set_offset_select(generic1)
|
||||
.set_offset_value(generic2);
|
||||
inst_builder.make_inst(sld, inst_data_builder.build());
|
||||
}
|
||||
3 => {
|
||||
inst_data_builder.set_rdr1r2_u8(rd, r1, r2_or_imm);
|
||||
inst_builder.make_inst(sadd, inst_data_builder.build());
|
||||
}
|
||||
4 => {
|
||||
inst_data_builder.set_rdr1r2_u8(rd, r1, r2_or_imm);
|
||||
inst_builder.make_inst(ssub, inst_data_builder.build());
|
||||
}
|
||||
5 => {
|
||||
inst_data_builder.set_rdr1r2_u8(rd, r1, r2_or_imm);
|
||||
inst_builder.make_inst(smul, inst_data_builder.build());
|
||||
}
|
||||
6 => {
|
||||
inst_data_builder.set_rdr1imm_u8(rd, r1, r2_or_imm);
|
||||
inst_builder.make_inst(saddi, inst_data_builder.build());
|
||||
}
|
||||
7 => {
|
||||
inst_data_builder.set_rdr1imm_u8(rd, r1, r2_or_imm);
|
||||
inst_builder.make_inst(smuli, inst_data_builder.build());
|
||||
}
|
||||
8 => {
|
||||
inst_data_builder.set_ibiw_obiw(generic1, generic2);
|
||||
inst_builder.make_inst(setbw, inst_data_builder.build());
|
||||
}
|
||||
9 => {
|
||||
inst_data_builder
|
||||
.set_rd_u8(rd)
|
||||
.set_r1_u8(r1)
|
||||
.set_mbiw_immrelu_immgroup(r2_or_imm, generic1, generic2);
|
||||
inst_builder.make_inst(mvmul, inst_data_builder.build());
|
||||
}
|
||||
10 => {
|
||||
inst_data_builder
|
||||
.set_rdr1r2_u8(rd, r1, r2_or_imm)
|
||||
.set_imm_len(generic3)
|
||||
.set_offset_select_value(generic1, generic2);
|
||||
inst_builder.make_inst(vvadd, inst_data_builder.build());
|
||||
}
|
||||
11 => {
|
||||
inst_data_builder
|
||||
.set_rdr1r2_u8(rd, r1, r2_or_imm)
|
||||
.set_imm_len(generic3)
|
||||
.set_offset_select_value(generic1, generic2);
|
||||
inst_builder.make_inst(vvsub, inst_data_builder.build());
|
||||
}
|
||||
12 => {
|
||||
inst_data_builder
|
||||
.set_rdr1r2_u8(rd, r1, r2_or_imm)
|
||||
.set_imm_len(generic3)
|
||||
.set_offset_select_value(generic1, generic2);
|
||||
inst_builder.make_inst(vvmul, inst_data_builder.build());
|
||||
}
|
||||
13 => {
|
||||
inst_data_builder
|
||||
.set_rdr1r2_u8(rd, r1, r2_or_imm)
|
||||
.set_imm_len(generic3)
|
||||
.set_offset_select_value(generic1, generic2);
|
||||
inst_builder.make_inst(vvdmul, inst_data_builder.build());
|
||||
}
|
||||
14 => {
|
||||
inst_data_builder
|
||||
.set_rdr1r2_u8(rd, r1, r2_or_imm)
|
||||
.set_imm_len(generic3)
|
||||
.set_offset_select_value(generic1, generic2);
|
||||
inst_builder.make_inst(vvmax, inst_data_builder.build());
|
||||
}
|
||||
15 => {
|
||||
inst_data_builder
|
||||
.set_rdr1r2_u8(rd, r1, r2_or_imm)
|
||||
.set_imm_len(generic3)
|
||||
.set_offset_select_value(generic1, generic2);
|
||||
inst_builder.make_inst(vvsll, inst_data_builder.build());
|
||||
}
|
||||
16 => {
|
||||
inst_data_builder
|
||||
.set_rdr1r2_u8(rd, r1, r2_or_imm)
|
||||
.set_imm_len(generic3)
|
||||
.set_offset_select_value(generic1, generic2);
|
||||
inst_builder.make_inst(vvsra, inst_data_builder.build());
|
||||
}
|
||||
17 => {
|
||||
inst_data_builder
|
||||
.set_rdr1r2_u8(rd, r1, r2_or_imm)
|
||||
.set_imm_len(generic3)
|
||||
.set_offset_select_value(generic1, generic2);
|
||||
inst_builder.make_inst(vavg, inst_data_builder.build());
|
||||
}
|
||||
18 => {
|
||||
inst_data_builder
|
||||
.set_rdr1_u8(rd, r1)
|
||||
.set_imm_len(generic3)
|
||||
.set_offset_select_value(generic1, generic2);
|
||||
inst_builder.make_inst(vrelu, inst_data_builder.build());
|
||||
}
|
||||
19 => {
|
||||
inst_data_builder
|
||||
.set_rdr1_u8(rd, r1)
|
||||
.set_imm_len(generic3)
|
||||
.set_offset_select_value(generic1, generic2);
|
||||
inst_builder.make_inst(vtanh, inst_data_builder.build());
|
||||
}
|
||||
20 => {
|
||||
inst_data_builder
|
||||
.set_rdr1_u8(rd, r1)
|
||||
.set_imm_len(generic3)
|
||||
.set_offset_select_value(generic1, generic2);
|
||||
inst_builder.make_inst(vsigm, inst_data_builder.build());
|
||||
}
|
||||
21 => {
|
||||
inst_data_builder
|
||||
.set_rdr1_u8(rd, r1)
|
||||
.set_imm_len(generic3)
|
||||
.set_offset_select_value(generic1, generic2);
|
||||
inst_builder.make_inst(vsoftmax, inst_data_builder.build());
|
||||
}
|
||||
22 => {
|
||||
inst_data_builder
|
||||
.set_rdr1r2_u8(rd, r1, r2_or_imm)
|
||||
.set_imm_len(generic3)
|
||||
.set_offset_select_value(generic1, generic2);
|
||||
inst_builder.make_inst(vmv, inst_data_builder.build());
|
||||
}
|
||||
23 => {
|
||||
inst_data_builder
|
||||
.set_rdr1r2_u8(rd, r1, r2_or_imm)
|
||||
.set_imm_len(generic3)
|
||||
.set_offset_select_value(generic1, generic2);
|
||||
inst_builder.make_inst(vrsu, inst_data_builder.build());
|
||||
}
|
||||
24 => {
|
||||
inst_data_builder
|
||||
.set_rdr1r2_u8(rd, r1, r2_or_imm)
|
||||
.set_imm_len(generic3)
|
||||
.set_offset_select_value(generic1, generic2);
|
||||
inst_builder.make_inst(vrsl, inst_data_builder.build());
|
||||
}
|
||||
25 => {
|
||||
inst_data_builder
|
||||
.set_rdr1_u8(rd, r1)
|
||||
.set_imm_len(generic3)
|
||||
.set_offset_select_value(generic1, generic2);
|
||||
inst_builder.make_inst(ld, inst_data_builder.build());
|
||||
}
|
||||
26 => {
|
||||
inst_data_builder
|
||||
.set_rdr1_u8(rd, r1)
|
||||
.set_imm_len(generic3)
|
||||
.set_offset_select_value(generic1, generic2);
|
||||
inst_builder.make_inst(st, inst_data_builder.build());
|
||||
}
|
||||
27 => {
|
||||
inst_data_builder
|
||||
.set_rd_u8(rd)
|
||||
.set_imm(r2_or_imm)
|
||||
.set_imm_len(generic3)
|
||||
.set_offset_select_value(generic1, generic2);
|
||||
inst_builder.make_inst(lldi, inst_data_builder.build());
|
||||
}
|
||||
28 => {
|
||||
inst_data_builder
|
||||
.set_rdr1_u8(rd, r1)
|
||||
.set_imm_len(generic3)
|
||||
.set_offset_select_value(generic1, generic2);
|
||||
inst_builder.make_inst(lmv, inst_data_builder.build());
|
||||
}
|
||||
29 => {
|
||||
inst_data_builder
|
||||
.set_rd_u8(rd)
|
||||
.set_imm_core(r2_or_imm + 1)
|
||||
.set_imm_len(generic3)
|
||||
.set_offset_select_value(generic1, generic2);
|
||||
inst_builder.make_inst(send, inst_data_builder.build());
|
||||
}
|
||||
30 => {
|
||||
inst_data_builder
|
||||
.set_rd_u8(rd)
|
||||
.set_imm_core(r2_or_imm + 1)
|
||||
.set_imm_len(generic3)
|
||||
.set_offset_select_value(generic1, generic2);
|
||||
inst_builder.make_inst(recv, inst_data_builder.build());
|
||||
}
|
||||
31 => {
|
||||
inst_builder.make_inst(wait, inst_data_builder.build());
|
||||
}
|
||||
32 => {
|
||||
inst_builder.make_inst(sync, inst_data_builder.build());
|
||||
}
|
||||
_ => bail!("unsupported PIM binary opcode {opcode}"),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn binary_to_instructions(
|
||||
core_bytes: &[u8],
|
||||
core_index: i32,
|
||||
) -> Result<Vec<crate::instruction_set::Instruction>> {
|
||||
let records = parse_binary_records(core_bytes)?;
|
||||
let mut insts_builder = InstructionsBuilder::new();
|
||||
let mut inst_data_builder = InstructionDataBuilder::new();
|
||||
inst_data_builder
|
||||
.set_core_indx_u16(u16::try_from(core_index).expect("core index does not fit in u16"))
|
||||
.fix_core_indx();
|
||||
|
||||
for record in records {
|
||||
let opcode = record.opcode;
|
||||
let name = INSTRUCTIONS
|
||||
.get(&(opcode as usize))
|
||||
.copied()
|
||||
.unwrap_or("<unknown>");
|
||||
|
||||
append_record(&mut insts_builder, &mut inst_data_builder, record).with_context(|| {
|
||||
format!(
|
||||
"while decoding binary instruction for core {core_index}: opcode {opcode} ({name})"
|
||||
)
|
||||
})?;
|
||||
}
|
||||
|
||||
Ok(insts_builder.build())
|
||||
}
|
||||
|
||||
pub fn binary_to_executor<'a, 'b>(
|
||||
config: Value,
|
||||
cores: impl Iterator<Item = &'b Vec<u8>>,
|
||||
crossbars: Vec<Vec<&'a Crossbar>>,
|
||||
) -> Result<Executable<'a>> {
|
||||
let core_cnt = config
|
||||
.get("core_cnt")
|
||||
.context("missing core_cnt in config")?
|
||||
.as_i64()
|
||||
.context("core_cnt is not an integer")? as i32;
|
||||
|
||||
let cpu = CPU::new(core_cnt, crossbars);
|
||||
let mut core_insts_builder = CoreInstructionsBuilder::new(core_cnt as usize);
|
||||
for (external_core_indx, core_bytes) in cores.enumerate() {
|
||||
let core_indx = external_core_indx as i32 + 1;
|
||||
let instructions = binary_to_instructions(core_bytes, core_indx)?;
|
||||
core_insts_builder.set_core(core_indx, instructions);
|
||||
}
|
||||
|
||||
Ok(Executable::new(cpu, core_insts_builder.build()))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{
|
||||
HEADER_SIZE, InstructionRecord, MAGIC, RECORD_SIZE, VERSION, binary_to_instructions,
|
||||
};
|
||||
use crate::{
|
||||
functor_to_name,
|
||||
instruction_set::{InstructionsBuilder, instruction_data::InstructionDataBuilder},
|
||||
json_to_instruction::json_isa::json_to_instruction,
|
||||
};
|
||||
|
||||
fn encode_record(record: InstructionRecord, dst: &mut Vec<u8>) {
|
||||
dst.push(record.opcode);
|
||||
dst.push(record.rd);
|
||||
dst.push(record.r1);
|
||||
dst.push(record.flags);
|
||||
dst.extend_from_slice(&record.r2_or_imm.to_le_bytes());
|
||||
dst.extend_from_slice(&record.generic1.to_le_bytes());
|
||||
dst.extend_from_slice(&record.generic2.to_le_bytes());
|
||||
dst.extend_from_slice(&record.generic3.to_le_bytes());
|
||||
}
|
||||
|
||||
fn binary_blob(records: &[InstructionRecord]) -> Vec<u8> {
|
||||
let mut blob = Vec::with_capacity(HEADER_SIZE + records.len() * RECORD_SIZE);
|
||||
blob.extend_from_slice(MAGIC);
|
||||
blob.extend_from_slice(&VERSION.to_le_bytes());
|
||||
blob.extend_from_slice(&(records.len() as u32).to_le_bytes());
|
||||
for &record in records {
|
||||
encode_record(record, &mut blob);
|
||||
}
|
||||
blob
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn json_and_binary_decoders_match_for_representative_ops() {
|
||||
let json_program = [
|
||||
r#"{"imm":64,"op":"sldi","rd":0}"#,
|
||||
r#"{"imm":128,"op":"sldi","rd":1}"#,
|
||||
r#"{"len":16,"offset":{"offset_select":0,"offset_value":0},"op":"lmv","rd":0,"rs1":1}"#,
|
||||
r#"{"group":3,"mbiw":8,"op":"mvmul","rd":0,"relu":0,"rs1":1}"#,
|
||||
r#"{"len":16,"offset":{"offset_select":0,"offset_value":0},"op":"vvadd","rd":0,"rs1":1,"rs2":2}"#,
|
||||
r#"{"core":2,"offset":{"offset_select":0,"offset_value":0},"op":"send","rd":0,"size":16}"#,
|
||||
];
|
||||
|
||||
let binary_program = binary_blob(&[
|
||||
InstructionRecord {
|
||||
opcode: 1,
|
||||
rd: 0,
|
||||
r2_or_imm: 64,
|
||||
..Default::default()
|
||||
},
|
||||
InstructionRecord {
|
||||
opcode: 1,
|
||||
rd: 1,
|
||||
r2_or_imm: 128,
|
||||
..Default::default()
|
||||
},
|
||||
InstructionRecord {
|
||||
opcode: 28,
|
||||
rd: 0,
|
||||
r1: 1,
|
||||
generic3: 16,
|
||||
..Default::default()
|
||||
},
|
||||
InstructionRecord {
|
||||
opcode: 9,
|
||||
rd: 0,
|
||||
r1: 1,
|
||||
r2_or_imm: 8,
|
||||
generic2: 3,
|
||||
..Default::default()
|
||||
},
|
||||
InstructionRecord {
|
||||
opcode: 10,
|
||||
rd: 0,
|
||||
r1: 1,
|
||||
r2_or_imm: 2,
|
||||
generic3: 16,
|
||||
..Default::default()
|
||||
},
|
||||
InstructionRecord {
|
||||
opcode: 29,
|
||||
rd: 0,
|
||||
r2_or_imm: 2,
|
||||
generic3: 16,
|
||||
..Default::default()
|
||||
},
|
||||
]);
|
||||
|
||||
let mut json_builder = InstructionsBuilder::new();
|
||||
let mut json_data_builder = InstructionDataBuilder::new();
|
||||
json_data_builder.set_core_indx(1).fix_core_indx();
|
||||
for inst in json_program {
|
||||
let value = serde_json::from_str(inst).unwrap();
|
||||
json_to_instruction(&mut json_builder, &mut json_data_builder, &value);
|
||||
}
|
||||
let json_instructions = json_builder.build();
|
||||
let binary_instructions = binary_to_instructions(&binary_program, 1).unwrap();
|
||||
|
||||
assert_eq!(json_instructions.len(), binary_instructions.len());
|
||||
for (json_inst, binary_inst) in json_instructions.iter().zip(binary_instructions.iter()) {
|
||||
assert_eq!(
|
||||
functor_to_name(json_inst.functor as usize),
|
||||
functor_to_name(binary_inst.functor as usize)
|
||||
);
|
||||
assert_eq!(json_inst.data, binary_inst.data);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,10 +1,11 @@
|
||||
use paste::paste;
|
||||
use std::convert::TryFrom;
|
||||
|
||||
#[derive(Clone, Copy, Debug, Default)]
|
||||
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
|
||||
pub struct InstructionData {
|
||||
core_indx: i32,
|
||||
rd: i32,
|
||||
r1: i32,
|
||||
core_indx: u16,
|
||||
rd: u8,
|
||||
r1: u8,
|
||||
//r2 imm mbiw imm_core
|
||||
r2_or_imm: i32,
|
||||
//offset_select imm_relu ibiw
|
||||
@@ -16,18 +17,30 @@ pub struct InstructionData {
|
||||
}
|
||||
|
||||
impl InstructionData {
|
||||
pub fn core_indx(&self) -> i32 {
|
||||
pub fn core_indx_u16(&self) -> u16 {
|
||||
self.core_indx
|
||||
}
|
||||
|
||||
pub fn rd(&self) -> i32 {
|
||||
pub fn core_indx(&self) -> i32 {
|
||||
i32::from(self.core_indx)
|
||||
}
|
||||
|
||||
pub fn rd_u8(&self) -> u8 {
|
||||
self.rd
|
||||
}
|
||||
|
||||
pub fn r1(&self) -> i32 {
|
||||
pub fn rd(&self) -> i32 {
|
||||
i32::from(self.rd)
|
||||
}
|
||||
|
||||
pub fn r1_u8(&self) -> u8 {
|
||||
self.r1
|
||||
}
|
||||
|
||||
pub fn r1(&self) -> i32 {
|
||||
i32::from(self.r1)
|
||||
}
|
||||
|
||||
pub fn r2(&self) -> i32 {
|
||||
self.r2_or_imm
|
||||
}
|
||||
@@ -49,26 +62,26 @@ impl InstructionData {
|
||||
}
|
||||
|
||||
pub fn get_core_rd_r1(&self) -> (i32, i32, i32) {
|
||||
(self.core_indx, self.rd, self.r1)
|
||||
(self.core_indx(), self.rd(), self.r1())
|
||||
}
|
||||
|
||||
pub fn get_core_rd_r1_r2(&self) -> (i32, i32, i32, i32) {
|
||||
(self.core_indx, self.rd, self.r1, self.r2_or_imm)
|
||||
(self.core_indx(), self.rd(), self.r1(), self.r2_or_imm)
|
||||
}
|
||||
|
||||
pub fn get_core_rd_imm(&self) -> (i32, i32, i32) {
|
||||
(self.core_indx, self.rd, self.r2_or_imm)
|
||||
(self.core_indx(), self.rd(), self.r2_or_imm)
|
||||
}
|
||||
|
||||
pub fn get_core_rd_r1_imm(&self) -> (i32, i32, i32, i32) {
|
||||
(self.core_indx, self.rd, self.r1, self.r2_or_imm)
|
||||
(self.core_indx(), self.rd(), self.r1(), self.r2_or_imm)
|
||||
}
|
||||
|
||||
pub fn get_core_rd_r1_r2_immlen_offset(&self) -> (i32, i32, i32, i32, i32, i32, i32) {
|
||||
(
|
||||
self.core_indx,
|
||||
self.rd,
|
||||
self.r1,
|
||||
self.core_indx(),
|
||||
self.rd(),
|
||||
self.r1(),
|
||||
self.r2_or_imm,
|
||||
self.generic3,
|
||||
self.generic1,
|
||||
@@ -78,9 +91,9 @@ impl InstructionData {
|
||||
|
||||
pub fn get_core_rd_r1_mbiw_immrelu_immgroup(&self) -> (i32, i32, i32, i32, i32, i32) {
|
||||
(
|
||||
self.core_indx,
|
||||
self.rd,
|
||||
self.r1,
|
||||
self.core_indx(),
|
||||
self.rd(),
|
||||
self.r1(),
|
||||
self.r2_or_imm,
|
||||
self.generic1,
|
||||
self.generic2,
|
||||
@@ -100,7 +113,7 @@ impl InstructionData {
|
||||
}
|
||||
|
||||
pub(crate) fn get_core_immcore(&self) -> (i32, i32) {
|
||||
(self.core_indx, self.r2_or_imm)
|
||||
(self.core_indx(), self.r2_or_imm)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -216,6 +229,18 @@ impl InstructionDataBuilder {
|
||||
common_getter_setter![imm_group];
|
||||
common_getter_setter![imm_core];
|
||||
|
||||
pub fn set_core_indx_u16(&mut self, val: u16) -> &mut Self {
|
||||
self.set_core_indx(i32::from(val))
|
||||
}
|
||||
|
||||
pub fn set_rd_u8(&mut self, val: u8) -> &mut Self {
|
||||
self.set_rd(i32::from(val))
|
||||
}
|
||||
|
||||
pub fn set_r1_u8(&mut self, val: u8) -> &mut Self {
|
||||
self.set_r1(i32::from(val))
|
||||
}
|
||||
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
core_indx: Fixer::Edit(0),
|
||||
@@ -254,20 +279,16 @@ impl InstructionDataBuilder {
|
||||
|
||||
fn check_sanity(&self) {
|
||||
assert!(!(self.get_r2() != 0 && self.get_imm() != 0 && self.get_mbiw() != 0 && self.get_imm_core() != 0));
|
||||
assert!(
|
||||
!(self.get_ibiw() != 0 && self.get_offset_select() != 0 && self.get_imm_relu() != 0)
|
||||
);
|
||||
assert!(
|
||||
!(self.get_obiw() != 0 && self.get_offset_value() != 0 && self.get_imm_group() != 0)
|
||||
);
|
||||
assert!(!(self.get_ibiw() != 0 && self.get_offset_select() != 0 && self.get_imm_relu() != 0));
|
||||
assert!(!(self.get_obiw() != 0 && self.get_offset_value() != 0 && self.get_imm_group() != 0));
|
||||
}
|
||||
|
||||
pub fn build(&mut self) -> InstructionData {
|
||||
self.check_sanity();
|
||||
let inst_data = InstructionData {
|
||||
core_indx: self.get_core_indx(),
|
||||
rd: self.get_rd(),
|
||||
r1: self.get_r1(),
|
||||
core_indx: u16::try_from(self.get_core_indx()).expect("core index does not fit in u16"),
|
||||
rd: u8::try_from(self.get_rd()).expect("rd does not fit in u8"),
|
||||
r1: u8::try_from(self.get_r1()).expect("r1 does not fit in u8"),
|
||||
r2_or_imm: self.get_r2() + self.get_imm() + self.get_mbiw() + self.get_imm_core(),
|
||||
generic1: self.get_offset_select() + self.get_ibiw() + self.get_imm_relu(),
|
||||
generic2: self.get_offset_value() + self.get_obiw() + self.get_imm_group(),
|
||||
@@ -281,6 +302,10 @@ impl InstructionDataBuilder {
|
||||
self.set_rd(rd).set_r1(r1).set_r2(r2)
|
||||
}
|
||||
|
||||
pub fn set_rdr1r2_u8(&mut self, rd: u8, r1: u8, r2: i32) -> &mut Self {
|
||||
self.set_rd_u8(rd).set_r1_u8(r1).set_r2(r2)
|
||||
}
|
||||
|
||||
pub fn set_offset_select_value(&mut self, offset_select: i32, offset_value: i32) -> &mut Self {
|
||||
self.set_offset_select(offset_select)
|
||||
.set_offset_value(offset_value)
|
||||
@@ -290,14 +315,26 @@ impl InstructionDataBuilder {
|
||||
self.set_rd(rd).set_r1(r1).set_imm(imm)
|
||||
}
|
||||
|
||||
pub fn set_rdr1imm_u8(&mut self, rd: u8, r1: u8, imm: i32) -> &mut Self {
|
||||
self.set_rd_u8(rd).set_r1_u8(r1).set_imm(imm)
|
||||
}
|
||||
|
||||
pub fn set_rdr1(&mut self, rd: i32, r1: i32) -> &mut Self {
|
||||
self.set_rd(rd).set_r1(r1)
|
||||
}
|
||||
|
||||
pub fn set_rdr1_u8(&mut self, rd: u8, r1: u8) -> &mut Self {
|
||||
self.set_rd_u8(rd).set_r1_u8(r1)
|
||||
}
|
||||
|
||||
pub fn set_rdimm(&mut self, rd: i32, imm: i32) -> &mut Self {
|
||||
self.set_rd(rd).set_imm(imm)
|
||||
}
|
||||
|
||||
pub fn set_rdimm_u8(&mut self, rd: u8, imm: i32) -> &mut Self {
|
||||
self.set_rd_u8(rd).set_imm(imm)
|
||||
}
|
||||
|
||||
pub fn set_ibiw_obiw(&mut self, ibiw: i32, obiw: i32) -> &mut Self {
|
||||
self.set_ibiw(ibiw).set_obiw(obiw)
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@ use anyhow::{Context, Result, ensure};
|
||||
use rayon::prelude::*;
|
||||
|
||||
use paste::paste;
|
||||
use std::{borrow::Cow, cell::OnceCell, collections::HashMap};
|
||||
use std::{borrow::Cow, cell::OnceCell, collections::HashMap };
|
||||
use std::{collections::HashSet, sync::LazyLock};
|
||||
|
||||
macro_rules! add_name {
|
||||
@@ -35,7 +35,7 @@ macro_rules! add_name_simd {
|
||||
};
|
||||
}
|
||||
|
||||
static NAMES: LazyLock<HashMap<usize, &'static str>> = LazyLock::new(|| {
|
||||
pub static NAMES: LazyLock<HashMap<usize, &'static str>> = LazyLock::new(|| {
|
||||
let mut hash = HashMap::new();
|
||||
add_name!(hash, sldi);
|
||||
add_name!(hash, sld);
|
||||
@@ -81,6 +81,7 @@ pub fn functor_to_name(functor: usize) -> &'static str {
|
||||
///////////////////////////////////////////////////////////////
|
||||
/////////////////Scalar/register Instructions//////////////////
|
||||
///////////////////////////////////////////////////////////////
|
||||
#[inline(never)]
|
||||
pub fn sldi(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
TRACER.lock().unwrap().pre_sldi(cores, data);
|
||||
let (core_indx, rd, imm) = data.get_core_rd_imm();
|
||||
@@ -90,6 +91,7 @@ pub fn sldi(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||
Ok(InstructionStatus::Completed)
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn sld(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
TRACER.lock().unwrap().pre_sld(cores, data);
|
||||
let (core_indx, rd, r1) = data.get_core_rd_r1();
|
||||
@@ -104,6 +106,7 @@ pub fn sld(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||
Ok(InstructionStatus::Completed)
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn sadd(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
TRACER.lock().unwrap().pre_sadd(cores, data);
|
||||
let (core_indx, rd, r1, r2) = data.get_core_rd_r1_r2();
|
||||
@@ -114,6 +117,7 @@ pub fn sadd(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||
Ok(InstructionStatus::Completed)
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn ssub(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
TRACER.lock().unwrap().pre_ssub(cores, data);
|
||||
let (core_indx, rd, r1, r2) = data.get_core_rd_r1_r2();
|
||||
@@ -124,6 +128,7 @@ pub fn ssub(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||
Ok(InstructionStatus::Completed)
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn smul(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
TRACER.lock().unwrap().pre_smul(cores, data);
|
||||
let (core_indx, rd, r1, r2) = data.get_core_rd_r1_r2();
|
||||
@@ -134,6 +139,7 @@ pub fn smul(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||
Ok(InstructionStatus::Completed)
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn saddi(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
TRACER.lock().unwrap().pre_saddi(cores, data);
|
||||
let (core_indx, rd, r1, imm) = data.get_core_rd_r1_imm();
|
||||
@@ -143,6 +149,7 @@ pub fn saddi(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus
|
||||
Ok(InstructionStatus::Completed)
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn smuli(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
TRACER.lock().unwrap().pre_smuli(cores, data);
|
||||
let (core_indx, rd, r1, imm) = data.get_core_rd_r1_imm();
|
||||
@@ -217,14 +224,17 @@ pub fn is_setbw(functor: InstructionType) -> bool {
|
||||
functor as usize == setbw as *const () as usize
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn setbw(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
panic!("You are calling a placeholder, this instruction is resolved in the construction phase");
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn mvmul(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
panic!("You are calling a placeholder, the real call is the generic version");
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub(super) fn mvm_impl_internal<F, M, T>(
|
||||
cores: &mut CPU,
|
||||
data: InstructionData,
|
||||
@@ -309,6 +319,7 @@ where
|
||||
Ok(InstructionStatus::Completed)
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub(super) fn mvmul_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||
where
|
||||
[F]: UpcastSlice<T> + UpcastSlice<f32> + UpcastSlice<f64>,
|
||||
@@ -329,10 +340,12 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn vvadd(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
panic!("You are calling a placeholder, the real call is the generic version");
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub(super) fn vvadd_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||
where
|
||||
[F]: UpcastSlice<T>,
|
||||
@@ -371,10 +384,12 @@ where
|
||||
Ok(InstructionStatus::Completed)
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn vvsub(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
panic!("You are calling a placeholder, the real call is the generic version");
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub(super) fn vvsub_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||
where
|
||||
[F]: UpcastSlice<T>,
|
||||
@@ -416,6 +431,7 @@ pub fn vvmul(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus
|
||||
panic!("You are calling a placeholder, the real call is the generic version");
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub(super) fn vvmul_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||
where
|
||||
[F]: UpcastSlice<T>,
|
||||
@@ -452,10 +468,12 @@ where
|
||||
Ok(InstructionStatus::Completed)
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn vvdmul(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
panic!("You are calling a placeholder, the real call is the generic version");
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub(super) fn vvdmul_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||
where
|
||||
[F]: UpcastSlice<T>,
|
||||
@@ -488,10 +506,12 @@ where
|
||||
Ok(InstructionStatus::Completed)
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn vvmax(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
panic!("You are calling a placeholder, the real call is the generic version");
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub(super) fn vvmax_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||
where
|
||||
[F]: UpcastSlice<T>,
|
||||
@@ -525,22 +545,26 @@ where
|
||||
Ok(InstructionStatus::Completed)
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn vvsll(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
panic!(
|
||||
"Shift left on floating point what does it means? who has generated this instruction???"
|
||||
);
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn vvsra(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
panic!(
|
||||
"Shift right on floating point what does it means? who has generated this instruction???"
|
||||
);
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn vavg(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
panic!("You are calling a placeholder, the real call is the generic version");
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub(super) fn vavg_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||
where
|
||||
[F]: UpcastSlice<T>,
|
||||
@@ -570,10 +594,12 @@ where
|
||||
Ok(InstructionStatus::Completed)
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn vrelu(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
panic!("You are calling a placeholder, the real call is the generic version");
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub(super) fn vrelu_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||
where
|
||||
[F]: UpcastSlice<T>,
|
||||
@@ -600,10 +626,12 @@ where
|
||||
Ok(InstructionStatus::Completed)
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn vtanh(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
panic!("You are calling a placeholder, the real call is the generic version");
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub(super) fn vtanh_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||
where
|
||||
[F]: UpcastSlice<T>,
|
||||
@@ -628,10 +656,12 @@ where
|
||||
Ok(InstructionStatus::Completed)
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn vsigm(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
panic!("You are calling a placeholder, the real call is the generic version");
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub(super) fn vsigm_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||
where
|
||||
[F]: UpcastSlice<T>,
|
||||
@@ -654,10 +684,12 @@ where
|
||||
Ok(InstructionStatus::Completed)
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn vsoftmax(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
panic!("You are calling a placeholder, the real call is the generic version");
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub(super) fn vsoftmax_impl<F, T>(
|
||||
cores: &mut CPU,
|
||||
data: InstructionData,
|
||||
@@ -696,14 +728,17 @@ where
|
||||
Ok(InstructionStatus::Completed)
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn vmv(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn vrsu(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn vrsl(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
todo!()
|
||||
}
|
||||
@@ -711,6 +746,7 @@ pub fn vrsl(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||
///////////////////////////////////////////////////////////////
|
||||
///Communication/synchronization Instructions/////////////////
|
||||
///////////////////////////////////////////////////////////////
|
||||
#[inline(never)]
|
||||
pub fn ld(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
TRACER.lock().unwrap().pre_ld(cores, data);
|
||||
let (core, rd, r1, _, imm_len, offset_select, offset_value) =
|
||||
@@ -727,6 +763,7 @@ pub fn ld(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
Ok(InstructionStatus::Completed)
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn st(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
TRACER.lock().unwrap().pre_st(cores, data);
|
||||
let (core, rd, r1, _, imm_len, offset_select, offset_value) =
|
||||
@@ -743,6 +780,7 @@ pub fn st(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
Ok(InstructionStatus::Completed)
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn lldi(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
TRACER.lock().unwrap().pre_lldi(cores, data);
|
||||
let (core, rd, imm) = data.get_core_rd_imm();
|
||||
@@ -759,6 +797,7 @@ pub fn lldi(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||
Ok(InstructionStatus::Completed)
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn lmv(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
TRACER.lock().unwrap().pre_lmv(cores, data);
|
||||
let (core, rd, r1, _, imm_len, offset_select, offset_value) =
|
||||
@@ -775,18 +814,32 @@ pub fn lmv(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||
Ok(InstructionStatus::Completed)
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn isa_send(functor : usize) -> bool{
|
||||
(send as *const () as usize) == functor
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn send(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
Ok(InstructionStatus::Sending(data))
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn isa_recv(functor : usize) -> bool{
|
||||
(recv as *const () as usize) == functor
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn recv(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
Ok(InstructionStatus::Reciving(data))
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn wait(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
Ok(InstructionStatus::Waiting(data))
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn sync(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
Ok(InstructionStatus::Sync(data))
|
||||
}
|
||||
|
||||
@@ -14,7 +14,7 @@ pub mod helper;
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub struct Instruction {
|
||||
pub data: InstructionData,
|
||||
functor: InstructionType,
|
||||
pub functor: InstructionType,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Default)]
|
||||
|
||||
@@ -567,7 +567,7 @@ fn json_to_send(
|
||||
let (offset_select, offset_value) = json_to_offset(json.get("offset").unwrap());
|
||||
inst_data_builder
|
||||
.set_rd(rd)
|
||||
.set_imm_core(core)
|
||||
.set_imm_core(core + 1)
|
||||
.set_imm_len(size)
|
||||
.set_offset_select(offset_select)
|
||||
.set_offset_value(offset_value);
|
||||
@@ -588,7 +588,7 @@ fn json_to_recv(
|
||||
let (offset_select, offset_value) = json_to_offset(json.get("offset").unwrap());
|
||||
inst_data_builder
|
||||
.set_rd(rd)
|
||||
.set_imm_core(core)
|
||||
.set_imm_core(core + 1)
|
||||
.set_imm_len(size)
|
||||
.set_offset_select(offset_select)
|
||||
.set_offset_value(offset_value);
|
||||
|
||||
+15
-28
@@ -1,45 +1,32 @@
|
||||
use core::panic;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use serde_json::{Map, Value};
|
||||
use serde_json::Value;
|
||||
use std::{fs::File, io::BufReader};
|
||||
|
||||
use crate::{
|
||||
CoreInstructionsBuilder, Executable,
|
||||
cpu::{CPU, crossbar::{self, Crossbar}},
|
||||
instruction_set::{
|
||||
InstructionsBuilder,
|
||||
instruction_data::{self, InstructionData, InstructionDataBuilder},
|
||||
},
|
||||
json_to_instruction::{self, json_isa},
|
||||
memory_manager::type_traits::TryToUsize,
|
||||
cpu::{CPU, crossbar::Crossbar},
|
||||
instruction_set::{InstructionsBuilder, instruction_data::InstructionDataBuilder},
|
||||
json_to_instruction::json_isa,
|
||||
};
|
||||
|
||||
|
||||
pub fn json_to_executor<'a>(
|
||||
pub fn json_to_executor<'a, 'b>(
|
||||
config: Value,
|
||||
mut cores: impl Iterator<Item = &'a Value>,
|
||||
crossbars : Vec<Vec<&'a Crossbar>>
|
||||
cores: &'b mut Vec<BufReader<File>>,
|
||||
crossbars: Vec<Vec<&'a Crossbar>>,
|
||||
) -> Executable<'a> {
|
||||
let cell_precision = config.get("cell_precision").unwrap().as_i64().unwrap() as i32;
|
||||
let core_cnt = config.get("core_cnt").unwrap().as_i64().unwrap() as i32 - 1;
|
||||
let xbar_count = config.get("xbar_array_count").unwrap().as_i64().unwrap() as i32;
|
||||
let xbar_size = config.get("xbar_size").unwrap().as_array().unwrap();
|
||||
let rows_crossbar = xbar_size[0].as_i64().unwrap() as i32;
|
||||
let column_corssbar = xbar_size[1].as_i64().unwrap() as i32;
|
||||
let core_cnt = config.get("core_cnt").unwrap().as_i64().unwrap() as i32;
|
||||
|
||||
let mut cpu = CPU::new(core_cnt, crossbars);
|
||||
let cpu = CPU::new(core_cnt, crossbars);
|
||||
let mut core_insts_builder = CoreInstructionsBuilder::new(core_cnt as usize);
|
||||
cores.next();
|
||||
for core_indx in 1..=core_cnt {
|
||||
for (external_core_indx, json_core_reader) in cores.iter_mut().enumerate() {
|
||||
let core_indx = external_core_indx as i32 + 1;
|
||||
let mut insts_builder = InstructionsBuilder::new();
|
||||
let mut inst_data_builder = InstructionDataBuilder::new();
|
||||
inst_data_builder.set_core_indx(core_indx).fix_core_indx();
|
||||
let json_core = cores
|
||||
.next()
|
||||
.unwrap_or_else(|| panic!("cores files less than {}", core_indx ));
|
||||
let json_core: Value = serde_json::from_reader(json_core_reader)
|
||||
.unwrap_or_else(|err| panic!("failed to parse core{}: {}", external_core_indx, err));
|
||||
let json_core_insts = json_core
|
||||
.as_array()
|
||||
.unwrap_or_else(|| panic!("core{} has not a list of instruction", core_indx));
|
||||
.unwrap_or_else(|| panic!("core{} has not a list of instruction", external_core_indx));
|
||||
for json_inst in json_core_insts {
|
||||
json_isa::json_to_instruction(&mut insts_builder, &mut inst_data_builder, json_inst);
|
||||
}
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
mod json_isa;
|
||||
pub(crate) mod json_isa;
|
||||
pub mod json_to_executor;
|
||||
|
||||
@@ -1,14 +1,22 @@
|
||||
#![allow(unused)]
|
||||
|
||||
use std::time::{Duration, SystemTime};
|
||||
use anyhow::{Result, bail};
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
time::{Duration, SystemTime},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
cpu::CPU,
|
||||
instruction_set::{Instruction, InstructionStatus, Instructions, isa::functor_to_name},
|
||||
instruction_set::{
|
||||
Instruction, InstructionStatus, Instructions,
|
||||
isa::{NAMES, functor_to_name, isa_recv, isa_send},
|
||||
},
|
||||
memory_manager::type_traits::TryToUsize,
|
||||
send_recv::{SendRecv, handle_send_recv},
|
||||
tracing::TRACER,
|
||||
};
|
||||
pub mod binary_to_instruction;
|
||||
pub mod cpu;
|
||||
pub mod instruction_set;
|
||||
pub mod json_to_instruction;
|
||||
@@ -80,6 +88,11 @@ pub struct Executable<'a> {
|
||||
send_recv: SendRecv,
|
||||
}
|
||||
|
||||
struct DeadlockInfo {
|
||||
cycle: String,
|
||||
states: String,
|
||||
}
|
||||
|
||||
fn print_status(core_instructions: &[CoreInstructions]) {
|
||||
let mut tot_instructions = 0;
|
||||
let mut progress = 0;
|
||||
@@ -111,7 +124,7 @@ impl<'a> Executable<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn execute<'b>(&'b mut self)
|
||||
pub fn execute<'b>(&'b mut self) -> Result<()>
|
||||
where
|
||||
'a: 'b,
|
||||
{
|
||||
@@ -144,8 +157,15 @@ impl<'a> Executable<'a> {
|
||||
cpu_progressed = 0;
|
||||
*program_counter += 1;
|
||||
}
|
||||
if (now.elapsed().unwrap() > Duration::from_secs(1)) {
|
||||
print_status(&cores_instructions);
|
||||
if (now.elapsed().unwrap() > Duration::from_secs(5)) {
|
||||
print_status(cores_instructions);
|
||||
if let Some(deadlock) = detect_deadlock(cores_instructions) {
|
||||
bail!(
|
||||
"Deadlock cycle detected: {} [{}]",
|
||||
deadlock.cycle,
|
||||
deadlock.states
|
||||
);
|
||||
}
|
||||
now = SystemTime::now();
|
||||
}
|
||||
}
|
||||
@@ -170,8 +190,23 @@ impl<'a> Executable<'a> {
|
||||
}
|
||||
print_status(cores_instructions);
|
||||
|
||||
if let Some(deadlock) = detect_deadlock(cores_instructions) {
|
||||
bail!(
|
||||
"Deadlock cycle detected: {} [{}]",
|
||||
deadlock.cycle,
|
||||
deadlock.states
|
||||
);
|
||||
}
|
||||
if cores_instructions
|
||||
.iter()
|
||||
.any(|core_inst| core_inst.program_counter < core_inst.instructions.len())
|
||||
{
|
||||
bail!("Execution stalled with unfinished instructions");
|
||||
}
|
||||
|
||||
#[cfg(feature = "profile_time")]
|
||||
TRACER.lock().unwrap().report();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn cpu(&self) -> &CPU<'a> {
|
||||
@@ -193,6 +228,124 @@ impl<'a> Executable<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
fn detect_deadlock(cores_instructions: &[CoreInstructions]) -> Option<DeadlockInfo> {
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
enum CoreState {
|
||||
SendingTo(i32, i32),
|
||||
ReceivingFrom(i32, i32),
|
||||
Working,
|
||||
Halted,
|
||||
}
|
||||
|
||||
let mut states = HashMap::new();
|
||||
|
||||
for core_inst in cores_instructions.iter() {
|
||||
if core_inst.program_counter >= core_inst.instructions.len() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let Instruction { data, functor } = core_inst.instructions[core_inst.program_counter];
|
||||
let functor_address = functor as usize;
|
||||
|
||||
let (this_core, target_core) = data.get_core_immcore();
|
||||
|
||||
if isa_recv(functor_address) {
|
||||
states.insert(this_core, CoreState::ReceivingFrom(target_core, data.imm_len()));
|
||||
} else if isa_send(functor_address) {
|
||||
states.insert(this_core, CoreState::SendingTo(target_core, data.imm_len()));
|
||||
} else {
|
||||
states.insert(this_core, CoreState::Working);
|
||||
}
|
||||
}
|
||||
|
||||
let mut wait_for = HashMap::new();
|
||||
|
||||
for (&core_id, state) in states.iter() {
|
||||
match state {
|
||||
CoreState::SendingTo(target_core, size) => {
|
||||
let target_state = states.get(target_core).unwrap_or(&CoreState::Halted);
|
||||
if target_state != &CoreState::ReceivingFrom(core_id, *size) {
|
||||
wait_for.insert(core_id, *target_core);
|
||||
}
|
||||
}
|
||||
CoreState::ReceivingFrom(target_core, size) => {
|
||||
let target_state = states.get(target_core).unwrap_or(&CoreState::Halted);
|
||||
if target_state != &CoreState::SendingTo(core_id, *size) {
|
||||
wait_for.insert(core_id, *target_core);
|
||||
}
|
||||
}
|
||||
CoreState::Working | CoreState::Halted => {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut visited = HashSet::new();
|
||||
|
||||
for &start_core in wait_for.keys() {
|
||||
if visited.contains(&start_core) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut path = Vec::new();
|
||||
let mut current_core = start_core;
|
||||
let mut in_path = HashSet::new();
|
||||
|
||||
while let Some(&waiting_for) = wait_for.get(¤t_core) {
|
||||
path.push(current_core);
|
||||
in_path.insert(current_core);
|
||||
visited.insert(current_core);
|
||||
|
||||
// Found a closed loop!
|
||||
if in_path.contains(&waiting_for) {
|
||||
let cycle_start = path.iter().position(|&c| c == waiting_for).unwrap();
|
||||
let cycle = &path[cycle_start..];
|
||||
|
||||
let cycle_str = cycle
|
||||
.iter()
|
||||
.map(|c| c.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(" -> ");
|
||||
|
||||
let cycle = cycle
|
||||
.iter()
|
||||
.copied()
|
||||
.chain(std::iter::once(waiting_for))
|
||||
.collect::<Vec<_>>();
|
||||
let cycle_msg = format!("{} -> {}", cycle_str, waiting_for);
|
||||
let states_msg = cycle
|
||||
.iter()
|
||||
.filter_map(|core| {
|
||||
states.get(core).map(|state| match state {
|
||||
CoreState::SendingTo(target, size) => {
|
||||
format!("core {} send {}B -> {}", core, size, target)
|
||||
}
|
||||
CoreState::ReceivingFrom(source, size) => {
|
||||
format!("core {} recv {}B <- {}", core, size, source)
|
||||
}
|
||||
CoreState::Working => format!("core {} working", core),
|
||||
CoreState::Halted => format!("core {} halted", core),
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ");
|
||||
|
||||
return Some(DeadlockInfo {
|
||||
cycle: cycle_msg,
|
||||
states: states_msg,
|
||||
});
|
||||
}
|
||||
|
||||
// Hit a known branch that didn't result in a cycle
|
||||
if visited.contains(&waiting_for) {
|
||||
break;
|
||||
}
|
||||
|
||||
current_core = waiting_for;
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn handle_wait_sync<'a, 'b, 'c>(
|
||||
cpu: &'b mut CPU<'a>,
|
||||
core_instructions: &'c mut [CoreInstructions],
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
use std::path::Path;
|
||||
|
||||
use pimcore::{Executable, cpu::CPU, instruction_set::{InstructionsBuilder, instruction_data::InstructionDataBuilder, isa::*}};
|
||||
use pimcore::{
|
||||
Executable,
|
||||
cpu::crossbar::Crossbar,
|
||||
instruction_set::{InstructionsBuilder, instruction_data::InstructionDataBuilder, isa::*},
|
||||
memory_manager::CoreMemory,
|
||||
};
|
||||
|
||||
fn simple_read(path: &Path) -> Vec<f32> {
|
||||
if !path.exists() {
|
||||
@@ -17,14 +22,12 @@ fn simple_read(path: &Path) -> Vec<f32> {
|
||||
fn mvmul_f32(err: &str)
|
||||
where
|
||||
{
|
||||
let mut cpu = CPU::new(0);
|
||||
cpu.reserve_crossbar(1, 1024 * size_of::<f32>(), 1024);
|
||||
let (memory, crossbars) = cpu.host().get_memory_crossbar();
|
||||
let matrix = simple_read(Path::new("B.txt")) ;
|
||||
|
||||
|
||||
crossbars.get_mut(0).unwrap().execute_store( &matrix).unwrap();
|
||||
let vector = simple_read(Path::new("A.txt"));
|
||||
let matrix = simple_read(Path::new("tests/B.txt"));
|
||||
let mut crossbar = Crossbar::new(1024 * size_of::<f32>(), 1024, CoreMemory::new());
|
||||
crossbar.execute_store(&matrix).unwrap();
|
||||
let mut cpu = pimcore::cpu::CPU::new(0, vec![vec![&crossbar]]);
|
||||
let (memory, _) = cpu.host().get_memory_crossbar();
|
||||
let vector = simple_read(Path::new("tests/A.txt"));
|
||||
memory.execute_store(0, &vector).unwrap();
|
||||
|
||||
let mut inst_builder = InstructionsBuilder::new();
|
||||
@@ -57,7 +60,7 @@ where
|
||||
.cpu_mut()
|
||||
.host()
|
||||
.load::<f32>(1024 * size_of::<f32>(), 1024*size_of::<f32>()).unwrap()[0].iter().zip(
|
||||
simple_read(Path::new("X.txt")) ).all(|(&a,b) : (&f32, f32)| {a-b < 0.001}),
|
||||
simple_read(Path::new("tests/X.txt")) ).all(|(&a,b) : (&f32, f32)| {a-b < 0.001}),
|
||||
"Wrong result for {}",
|
||||
err
|
||||
);
|
||||
@@ -69,5 +72,3 @@ fn mvmul_big_test() {
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
use pimcore::cpu::CPU;
|
||||
|
||||
pub fn empty_cpu(num_cores: usize) -> CPU<'static> {
|
||||
CPU::new(num_cores, vec![Vec::new(); num_cores + 1])
|
||||
}
|
||||
@@ -1,51 +1,103 @@
|
||||
use std::{fs, io::BufReader, path::Path};
|
||||
use std::{
|
||||
fs::{self, File},
|
||||
io::BufReader,
|
||||
path::{Path, PathBuf},
|
||||
};
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use pimcore::json_to_instruction::json_to_executor;
|
||||
use pimcore::{
|
||||
cpu::crossbar::Crossbar,
|
||||
json_to_instruction::json_to_executor,
|
||||
memory_manager::CoreMemory,
|
||||
};
|
||||
use serde_json::Value;
|
||||
|
||||
fn collect_json_from_subfolders<P: AsRef<Path>>(root: P) -> Result<Vec<(Value, Vec<Value>)>> {
|
||||
fn collect_examples<P: AsRef<Path>>(root: P) -> Result<Vec<PathBuf>> {
|
||||
let mut result = Vec::new();
|
||||
for entry in fs::read_dir(root)? {
|
||||
let entry = entry.context("Root not found")?;
|
||||
let path = entry.path();
|
||||
if path.is_dir() {
|
||||
let mut cores = Vec::new();
|
||||
let mut config: Option<Value> = None;
|
||||
for sub_entry in fs::read_dir(&path)
|
||||
.with_context(|| format!("File {} not readable", path.display()))?
|
||||
{
|
||||
let sub_entry =
|
||||
sub_entry.with_context(|| format!("File {} not readable", path.display()))?;
|
||||
let sub_path = sub_entry.path();
|
||||
if sub_path.is_file()
|
||||
&& sub_path.extension().and_then(|s| s.to_str()) == Some("json")
|
||||
{
|
||||
let file = fs::File::open(&sub_path)
|
||||
.with_context(|| format!("Subpath {} not opened", sub_path.display()))?;
|
||||
let reader = BufReader::new(file);
|
||||
let val: Value = serde_json::from_reader(reader).with_context(|| format!(
|
||||
"Serde reader fail for subpath {}",
|
||||
sub_path.display()
|
||||
))?;
|
||||
if sub_path.file_name().unwrap() == "config.json" {
|
||||
config = Some(val);
|
||||
} else {
|
||||
cores.push(val);
|
||||
}
|
||||
}
|
||||
}
|
||||
result.push((config.unwrap(), cores));
|
||||
result.push(path);
|
||||
}
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn core_sort_key(path: &Path) -> i32 {
|
||||
let stem = path.file_stem().unwrap().to_str().unwrap();
|
||||
stem[5..].parse::<i32>().unwrap()
|
||||
}
|
||||
|
||||
fn crossbar_sort_key(path: &Path) -> i32 {
|
||||
let stem = path.file_stem().unwrap().to_str().unwrap();
|
||||
stem[9..].parse::<i32>().unwrap()
|
||||
}
|
||||
|
||||
fn load_crossbars(folder: &Path, config: &Value) -> Result<Vec<Vec<Crossbar>>> {
|
||||
let xbar_size = config.get("xbar_size").unwrap().as_array().unwrap();
|
||||
let rows = xbar_size[0].as_i64().unwrap() as usize;
|
||||
let cols = xbar_size[1].as_i64().unwrap() as usize;
|
||||
let core_cnt = config.get("core_cnt").unwrap().as_i64().unwrap() as usize;
|
||||
let mut owned_crossbars = Vec::with_capacity(core_cnt + 1);
|
||||
owned_crossbars.push(Vec::new());
|
||||
|
||||
for core_idx in 0..core_cnt {
|
||||
let core_folder = folder.join(format!("core_{core_idx}"));
|
||||
let mut core_crossbars = Vec::new();
|
||||
if core_folder.is_dir() {
|
||||
let mut paths: Vec<_> = fs::read_dir(&core_folder)?
|
||||
.map(|entry| entry.map(|entry| entry.path()))
|
||||
.collect::<std::io::Result<Vec<_>>>()?;
|
||||
paths.sort_by_cached_key(|path| crossbar_sort_key(path));
|
||||
for path in paths {
|
||||
if path.extension().and_then(|ext| ext.to_str()) != Some("bin") {
|
||||
continue;
|
||||
}
|
||||
let bytes = fs::read(&path)
|
||||
.with_context(|| format!("failed to read crossbar {}", path.display()))?;
|
||||
let mut crossbar = Crossbar::new(cols * 4, rows, CoreMemory::new());
|
||||
crossbar.execute_store(&bytes)?;
|
||||
core_crossbars.push(crossbar);
|
||||
}
|
||||
}
|
||||
owned_crossbars.push(core_crossbars);
|
||||
}
|
||||
Ok(owned_crossbars)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn json_folder_tester() {
|
||||
let examples = collect_json_from_subfolders("data").unwrap();
|
||||
for example in examples {
|
||||
let (config, cores) = example;
|
||||
json_to_executor::json_to_executor(config, cores.iter()).execute();
|
||||
let examples = collect_examples("tests/data").unwrap();
|
||||
for folder in examples {
|
||||
let config_path = folder.join("config.json");
|
||||
let config_file = File::open(&config_path).unwrap();
|
||||
let config: Value = serde_json::from_reader(BufReader::new(config_file)).unwrap();
|
||||
|
||||
let core_cnt = config.get("core_cnt").unwrap().as_i64().unwrap() as usize;
|
||||
let mut core_paths: Vec<_> = fs::read_dir(&folder)
|
||||
.unwrap()
|
||||
.map(|entry| entry.unwrap().path())
|
||||
.filter(|path| path.extension().and_then(|ext| ext.to_str()) == Some("json"))
|
||||
.filter(|path| path.file_name().unwrap() != "config.json")
|
||||
.collect();
|
||||
core_paths.sort_by_cached_key(|path| core_sort_key(path));
|
||||
assert_eq!(core_paths.len(), core_cnt);
|
||||
|
||||
let mut core_readers: Vec<_> = core_paths
|
||||
.into_iter()
|
||||
.map(|path| BufReader::new(File::open(path).unwrap()))
|
||||
.collect();
|
||||
|
||||
let owned_crossbars = load_crossbars(&folder, &config).unwrap();
|
||||
let crossbars = owned_crossbars
|
||||
.iter()
|
||||
.map(|core_crossbars| core_crossbars.iter().collect())
|
||||
.collect();
|
||||
|
||||
let mut executable = json_to_executor::json_to_executor(config, &mut core_readers, crossbars);
|
||||
let memory = fs::read(folder.join("memory.bin")).unwrap();
|
||||
executable.cpu_mut().host().execute_store(0, &memory).unwrap();
|
||||
executable.execute();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,11 +1,17 @@
|
||||
mod common;
|
||||
|
||||
use pimcore::{Executable, cpu::CPU, instruction_set::{InstructionType, InstructionsBuilder, instruction_data::InstructionDataBuilder, isa::*}};
|
||||
use pimcore::{
|
||||
Executable,
|
||||
instruction_set::{
|
||||
InstructionType, InstructionsBuilder, instruction_data::InstructionDataBuilder, isa::*,
|
||||
},
|
||||
};
|
||||
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Function not found for the requested size") ]
|
||||
fn wrong_size_place_holder() {
|
||||
let cpu = CPU::new(0);
|
||||
let cpu = common::empty_cpu(0);
|
||||
let mut inst_builder = InstructionsBuilder::new();
|
||||
let mut idata_build = InstructionDataBuilder::new();
|
||||
idata_build.set_core_indx(0).fix_core_indx();
|
||||
@@ -30,7 +36,7 @@ fn wrong_size_place_holder() {
|
||||
|
||||
|
||||
fn place_holder(inst : InstructionType) {
|
||||
let mut cpu = CPU::new(0);
|
||||
let mut cpu = common::empty_cpu(0);
|
||||
let mut idata_build = InstructionDataBuilder::new();
|
||||
idata_build.set_core_indx(0).fix_core_indx();
|
||||
inst(&mut cpu, idata_build.build()).unwrap();
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
mod common;
|
||||
|
||||
use pimcore::{
|
||||
Executable,
|
||||
cpu::CPU,
|
||||
cpu::crossbar::Crossbar,
|
||||
instruction_set::{InstructionsBuilder, instruction_data::InstructionDataBuilder, isa::*},
|
||||
memory_manager::{MemoryStorable, type_traits::UpcastDestTraits},
|
||||
memory_manager::{CoreMemory, MemoryStorable, type_traits::UpcastDestTraits},
|
||||
};
|
||||
|
||||
/// VVADD Test
|
||||
@@ -11,7 +13,7 @@ where
|
||||
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
||||
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
|
||||
{
|
||||
let mut cpu = CPU::new(0);
|
||||
let mut cpu = common::empty_cpu(0);
|
||||
let buff: [F; _] = [
|
||||
1.0.into(),
|
||||
2.0.into(),
|
||||
@@ -115,7 +117,7 @@ where
|
||||
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
||||
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
|
||||
{
|
||||
let mut cpu = CPU::new(0);
|
||||
let mut cpu = common::empty_cpu(0);
|
||||
let buff: [F; _] = [
|
||||
1.0.into(),
|
||||
2.0.into(),
|
||||
@@ -219,7 +221,7 @@ where
|
||||
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
||||
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
|
||||
{
|
||||
let mut cpu = CPU::new(0);
|
||||
let mut cpu = common::empty_cpu(0);
|
||||
let buff: [F; _] = [
|
||||
1.0.into(),
|
||||
2.0.into(),
|
||||
@@ -323,7 +325,7 @@ where
|
||||
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
||||
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
|
||||
{
|
||||
let mut cpu = CPU::new(0);
|
||||
let mut cpu = common::empty_cpu(0);
|
||||
let buff: [F; _] = [
|
||||
1.0.into(),
|
||||
2.0.into(),
|
||||
@@ -420,7 +422,7 @@ where
|
||||
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
||||
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
|
||||
{
|
||||
let mut cpu = CPU::new(0);
|
||||
let mut cpu = common::empty_cpu(0);
|
||||
let buff: [F; _] = [
|
||||
9.0.into(),
|
||||
2.0.into(),
|
||||
@@ -524,7 +526,7 @@ where
|
||||
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
||||
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
|
||||
{
|
||||
let mut cpu = CPU::new(0);
|
||||
let mut cpu = common::empty_cpu(0);
|
||||
let buff: [F; _] = [
|
||||
9.0.into(),
|
||||
2.0.into(),
|
||||
@@ -562,6 +564,7 @@ where
|
||||
vavg,
|
||||
idata_build
|
||||
.set_rdr1r2(3, 1, 1)
|
||||
.set_offset_select(1)
|
||||
.set_imm_len(8 * size_of::<F>() as i32)
|
||||
.build(),
|
||||
);
|
||||
@@ -617,7 +620,7 @@ where
|
||||
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
||||
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
|
||||
{
|
||||
let mut cpu = CPU::new(0);
|
||||
let mut cpu = common::empty_cpu(0);
|
||||
let buff: [F; _] = [
|
||||
(-9.0).into(),
|
||||
2.0.into(),
|
||||
@@ -717,7 +720,7 @@ where
|
||||
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
||||
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable + UpcastDestTraits<T>,
|
||||
{
|
||||
let mut cpu = CPU::new(0);
|
||||
let mut cpu = common::empty_cpu(0);
|
||||
let buff: [F; _] = [
|
||||
0.1.into(),
|
||||
0.2.into(),
|
||||
@@ -819,7 +822,7 @@ where
|
||||
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
||||
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable + UpcastDestTraits<T>,
|
||||
{
|
||||
let mut cpu = CPU::new(0);
|
||||
let mut cpu = common::empty_cpu(0);
|
||||
let buff: [F; _] = [
|
||||
0.1.into(),
|
||||
0.2.into(),
|
||||
@@ -923,9 +926,6 @@ where
|
||||
M: From<f32> + std::fmt::Debug + PartialEq<M> + MemoryStorable,
|
||||
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable + UpcastDestTraits<T>,
|
||||
{
|
||||
let mut cpu = CPU::new(0);
|
||||
cpu.reserve_crossbar(1, 4 * size_of::<M>(), 4);
|
||||
let (memory, crossbars) = cpu.host().get_memory_crossbar();
|
||||
let matrix: [M; _] = [
|
||||
1.0.into(),
|
||||
2.0.into(),
|
||||
@@ -944,7 +944,10 @@ where
|
||||
15.0.into(),
|
||||
16.0.into(),
|
||||
];
|
||||
crossbars.get_mut(0).unwrap().execute_store( &matrix).unwrap();
|
||||
let mut crossbar = Crossbar::new(4 * size_of::<M>(), 4, CoreMemory::new());
|
||||
crossbar.execute_store(&matrix).unwrap();
|
||||
let mut cpu = pimcore::cpu::CPU::new(0, vec![vec![&crossbar]]);
|
||||
let (memory, _) = cpu.host().get_memory_crossbar();
|
||||
let vector: [F; _] = [
|
||||
1.0.into(),
|
||||
2.0.into(),
|
||||
@@ -1054,5 +1057,3 @@ fn mvmul_test() {
|
||||
mvmul_test_generic::<f64,f64,f64>("mvmul<f64,f64,f64>",1);
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
mod common;
|
||||
|
||||
use pimcore::{
|
||||
Executable, CoreInstructionsBuilder,
|
||||
cpu::CPU,
|
||||
instruction_set::{InstructionsBuilder, instruction_data::InstructionDataBuilder, isa::*},
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn ld_test() {
|
||||
let mut cpu = CPU::new(1);
|
||||
let mut cpu = common::empty_cpu(1);
|
||||
let mut core_instruction_builder = CoreInstructionsBuilder::new(1);
|
||||
let buff: [f32; _] = [
|
||||
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
|
||||
@@ -41,7 +42,7 @@ fn ld_test() {
|
||||
|
||||
#[test]
|
||||
fn st_test() {
|
||||
let mut cpu = CPU::new(1);
|
||||
let mut cpu = common::empty_cpu(1);
|
||||
let mut core_instruction_builder = CoreInstructionsBuilder::new(1);
|
||||
let buff: [f32; _] = [
|
||||
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
|
||||
@@ -76,7 +77,7 @@ fn st_test() {
|
||||
|
||||
#[test]
|
||||
fn lldi_test() {
|
||||
let cpu = CPU::new(1);
|
||||
let cpu = common::empty_cpu(1);
|
||||
let mut core_instruction_builder = CoreInstructionsBuilder::new(1);
|
||||
let mut inst_builder = InstructionsBuilder::new();
|
||||
let mut idata_build = InstructionDataBuilder::new();
|
||||
@@ -106,7 +107,7 @@ fn lldi_test() {
|
||||
|
||||
#[test]
|
||||
fn lmv_test() {
|
||||
let mut cpu = CPU::new(1);
|
||||
let mut cpu = common::empty_cpu(1);
|
||||
let mut core_instruction_builder = CoreInstructionsBuilder::new(1);
|
||||
let buff: [f32; _] = [
|
||||
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
|
||||
@@ -148,7 +149,7 @@ fn lmv_test() {
|
||||
|
||||
#[test]
|
||||
fn simple_send_recv_test() {
|
||||
let mut cpu = CPU::new(2);
|
||||
let mut cpu = common::empty_cpu(2);
|
||||
let mut core_instruction_builder = CoreInstructionsBuilder::new(2);
|
||||
let buff: [f32; _] = [
|
||||
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
|
||||
@@ -207,7 +208,7 @@ fn simple_send_recv_test() {
|
||||
|
||||
#[test]
|
||||
fn multiple_send_recv_test() {
|
||||
let mut cpu = CPU::new(4);
|
||||
let mut cpu = common::empty_cpu(4);
|
||||
let mut core_instruction_builder = CoreInstructionsBuilder::new(4);
|
||||
let buff: [f32; _] = [
|
||||
1.0, 1.0, 1.0, 1.0, 1.0
|
||||
@@ -226,7 +227,7 @@ fn multiple_send_recv_test() {
|
||||
];
|
||||
cpu.core(4).execute_store(0, &buff).unwrap();
|
||||
|
||||
let send_inst = |cpu :&mut CPU, core_instruction_builder: &mut CoreInstructionsBuilder, inst_builder: &mut InstructionsBuilder, from : i32, to : i32| {
|
||||
let send_inst = |inst_builder: &mut InstructionsBuilder, from: i32, to: i32| {
|
||||
let mut idata_build = InstructionDataBuilder::new();
|
||||
idata_build.set_core_indx(from).fix_core_indx();
|
||||
inst_builder.make_inst(sldi, idata_build.set_rdimm(1, from*size_of::<f32>() as i32).build());
|
||||
@@ -240,7 +241,7 @@ fn multiple_send_recv_test() {
|
||||
);
|
||||
};
|
||||
|
||||
let recv_inst = |cpu :&mut CPU, core_instruction_builder: &mut CoreInstructionsBuilder, mut inst_builder: &mut InstructionsBuilder, to : i32, from : i32| {
|
||||
let recv_inst = |inst_builder: &mut InstructionsBuilder, to: i32, from: i32| {
|
||||
let mut idata_build = InstructionDataBuilder::new();
|
||||
idata_build.set_core_indx(to).fix_core_indx();
|
||||
inst_builder.make_inst(sldi, idata_build.set_rdimm(1, from*size_of::<f32>() as i32).build());
|
||||
@@ -257,26 +258,26 @@ fn multiple_send_recv_test() {
|
||||
|
||||
|
||||
// 1 -> 3
|
||||
send_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,1, 3);
|
||||
send_inst(&mut inst_builder, 1, 3);
|
||||
core_instruction_builder.set_core(1, inst_builder.build());
|
||||
|
||||
// 2 -> 3
|
||||
// 2 <- 4
|
||||
send_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,2, 3);
|
||||
recv_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,2, 4);
|
||||
send_inst(&mut inst_builder, 2, 3);
|
||||
recv_inst(&mut inst_builder, 2, 4);
|
||||
core_instruction_builder.set_core(2, inst_builder.build());
|
||||
|
||||
// 3 <- 2
|
||||
// 3 <- 4
|
||||
// 3 <- 1
|
||||
recv_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,3, 2);
|
||||
recv_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,3, 4);
|
||||
recv_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,3, 1);
|
||||
recv_inst(&mut inst_builder, 3, 2);
|
||||
recv_inst(&mut inst_builder, 3, 4);
|
||||
recv_inst(&mut inst_builder, 3, 1);
|
||||
core_instruction_builder.set_core(3, inst_builder.build());
|
||||
// 4 -> 2
|
||||
// 4 -> 3
|
||||
send_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,4, 2);
|
||||
send_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,4, 3);
|
||||
send_inst(&mut inst_builder, 4, 2);
|
||||
send_inst(&mut inst_builder, 4, 3);
|
||||
core_instruction_builder.set_core(4, inst_builder.build());
|
||||
|
||||
let mut executable = Executable::new(cpu, core_instruction_builder.build());
|
||||
|
||||
Submodule backend-simulators/pim/pimsim-nn updated: 3e3442b663...6d3b898e6b
@@ -68,5 +68,6 @@ add_pim_library(OMPIMAccel
|
||||
OMSpatialToPim
|
||||
OMPimCommon
|
||||
OMPimBufferization
|
||||
OMPimStaticMemoryCoalescing
|
||||
MLIRTensorInferTypeOpInterfaceImpl
|
||||
)
|
||||
|
||||
@@ -3,10 +3,12 @@ add_pim_library(OMPimCommon
|
||||
IR/CoreBlockUtils.cpp
|
||||
IR/EntryPointUtils.cpp
|
||||
IR/ShapeUtils.cpp
|
||||
IR/SubviewUtils.cpp
|
||||
IR/WeightUtils.cpp
|
||||
Support/DebugDump.cpp
|
||||
Support/Diagnostics.cpp
|
||||
Support/FileSystemUtils.cpp
|
||||
Support/ReportUtils.cpp
|
||||
|
||||
EXCLUDE_FROM_OM_LIBS
|
||||
|
||||
|
||||
@@ -110,6 +110,14 @@ llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticVa
|
||||
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) / static_cast<uint64_t>(*rhs));
|
||||
}
|
||||
|
||||
if (auto minOp = mlir::dyn_cast<mlir::arith::MinUIOp>(definingOp)) {
|
||||
auto lhs = resolveIndexValueImpl(minOp.getLhs(), knowledge);
|
||||
auto rhs = resolveIndexValueImpl(minOp.getRhs(), knowledge);
|
||||
if (failed(lhs) || failed(rhs))
|
||||
return mlir::failure();
|
||||
return static_cast<int64_t>(std::min(static_cast<uint64_t>(*lhs), static_cast<uint64_t>(*rhs)));
|
||||
}
|
||||
|
||||
if (auto remOp = mlir::dyn_cast<mlir::arith::RemUIOp>(definingOp)) {
|
||||
auto lhs = resolveIndexValueImpl(remOp.getLhs(), knowledge);
|
||||
auto rhs = resolveIndexValueImpl(remOp.getRhs(), knowledge);
|
||||
|
||||
@@ -12,6 +12,7 @@ bool isCoreStaticAddressOp(mlir::Operation* op) {
|
||||
mlir::arith::SubIOp,
|
||||
mlir::arith::MulIOp,
|
||||
mlir::arith::DivUIOp,
|
||||
mlir::arith::MinUIOp,
|
||||
mlir::arith::RemUIOp,
|
||||
mlir::arith::IndexCastOp,
|
||||
mlir::memref::AllocOp,
|
||||
|
||||
@@ -0,0 +1,84 @@
|
||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
Value stripMemRefCasts(Value value) {
|
||||
while (auto castOp = value.getDefiningOp<memref::CastOp>())
|
||||
value = castOp.getSource();
|
||||
return value;
|
||||
}
|
||||
|
||||
Value stripMemRefViewOps(Value value) {
|
||||
while (true) {
|
||||
if (auto castOp = value.getDefiningOp<memref::CastOp>()) {
|
||||
value = castOp.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto collapseOp = value.getDefiningOp<memref::CollapseShapeOp>()) {
|
||||
value = collapseOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
if (auto expandOp = value.getDefiningOp<memref::ExpandShapeOp>()) {
|
||||
value = expandOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
return value;
|
||||
}
|
||||
}
|
||||
|
||||
bool hasAllStaticSubviewParts(memref::SubViewOp subview) {
|
||||
return llvm::all_of(subview.getStaticOffsets(), [](int64_t value) { return !ShapedType::isDynamic(value); })
|
||||
&& llvm::all_of(subview.getStaticSizes(), [](int64_t value) { return !ShapedType::isDynamic(value); })
|
||||
&& llvm::all_of(subview.getStaticStrides(), [](int64_t value) { return !ShapedType::isDynamic(value); });
|
||||
}
|
||||
|
||||
FailureOr<StaticSubviewInfo> getStaticSubviewInfo(Value value) {
|
||||
value = stripMemRefViewOps(value);
|
||||
auto subviewOp = value.getDefiningOp<memref::SubViewOp>();
|
||||
if (!subviewOp)
|
||||
return failure();
|
||||
|
||||
auto source = stripMemRefCasts(subviewOp.getSource());
|
||||
auto sourceType = dyn_cast<MemRefType>(source.getType());
|
||||
auto subviewType = dyn_cast<MemRefType>(subviewOp.getType());
|
||||
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
StaticSubviewInfo info;
|
||||
info.source = source;
|
||||
info.sourceShape.assign(sourceType.getShape().begin(), sourceType.getShape().end());
|
||||
SmallVector<OpFoldResult> mixedOffsets = subviewOp.getMixedOffsets();
|
||||
info.offsets.assign(mixedOffsets.begin(), mixedOffsets.end());
|
||||
for (OpFoldResult size : subviewOp.getMixedSizes()) {
|
||||
auto staticSize = getConstantIntValue(size);
|
||||
if (!staticSize)
|
||||
return failure();
|
||||
info.sizes.push_back(*staticSize);
|
||||
}
|
||||
for (OpFoldResult stride : subviewOp.getMixedStrides()) {
|
||||
auto staticStride = getConstantIntValue(stride);
|
||||
if (!staticStride)
|
||||
return failure();
|
||||
info.strides.push_back(*staticStride);
|
||||
}
|
||||
return info;
|
||||
}
|
||||
|
||||
FailureOr<SmallVector<int64_t>> getStaticSubviewOffsets(const StaticSubviewInfo& info) {
|
||||
SmallVector<int64_t> staticOffsets;
|
||||
staticOffsets.reserve(info.offsets.size());
|
||||
for (OpFoldResult offset : info.offsets) {
|
||||
auto staticOffset = getConstantIntValue(offset);
|
||||
if (!staticOffset)
|
||||
return failure();
|
||||
staticOffsets.push_back(*staticOffset);
|
||||
}
|
||||
return staticOffsets;
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,30 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/LogicalResult.h"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
struct StaticSubviewInfo {
|
||||
mlir::Value source;
|
||||
llvm::SmallVector<int64_t> sourceShape;
|
||||
llvm::SmallVector<mlir::OpFoldResult> offsets;
|
||||
llvm::SmallVector<int64_t> sizes;
|
||||
llvm::SmallVector<int64_t> strides;
|
||||
};
|
||||
|
||||
mlir::Value stripMemRefCasts(mlir::Value value);
|
||||
|
||||
mlir::Value stripMemRefViewOps(mlir::Value value);
|
||||
|
||||
bool hasAllStaticSubviewParts(mlir::memref::SubViewOp subview);
|
||||
|
||||
llvm::FailureOr<StaticSubviewInfo> getStaticSubviewInfo(mlir::Value value);
|
||||
|
||||
/// Returns the offsets in `info` as int64_t, failing if any offset is dynamic.
|
||||
llvm::FailureOr<llvm::SmallVector<int64_t>> getStaticSubviewOffsets(const StaticSubviewInfo& info);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -7,10 +7,34 @@
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <system_error>
|
||||
|
||||
namespace onnx_mlir::pim {
|
||||
|
||||
struct CappedDiagnosticReporter {
|
||||
explicit CappedDiagnosticReporter(int64_t maxReportedFailures = 8) : maxReportedFailures(maxReportedFailures) {}
|
||||
|
||||
template <typename EmitFn>
|
||||
void report(mlir::Operation* op, EmitFn&& emit) {
|
||||
numFailures++;
|
||||
if (numFailures <= maxReportedFailures)
|
||||
emit(op);
|
||||
}
|
||||
|
||||
void emitSuppressedSummary(mlir::Operation* op, llvm::StringRef failureDescription) const {
|
||||
if (numFailures > maxReportedFailures)
|
||||
op->emitError() << "suppressed " << (numFailures - maxReportedFailures) << " additional "
|
||||
<< failureDescription;
|
||||
}
|
||||
|
||||
bool hasFailure() const { return numFailures != 0; }
|
||||
|
||||
private:
|
||||
int64_t maxReportedFailures;
|
||||
int64_t numFailures = 0;
|
||||
};
|
||||
|
||||
/// Emits a consistent diagnostic for target paths that require static shapes.
|
||||
mlir::InFlightDiagnostic emitUnsupportedStaticShapeDiagnostic(mlir::Operation* op, llvm::StringRef valueDescription);
|
||||
|
||||
|
||||
@@ -0,0 +1,62 @@
|
||||
#include "llvm/Support/Format.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/Support/FileSystemUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
std::fstream openReportFile(const std::string& name) {
|
||||
std::string outputDir = getOutputDir();
|
||||
if (outputDir.empty())
|
||||
return {};
|
||||
|
||||
std::string reportsDir = outputDir + "/reports";
|
||||
createDirectory(reportsDir);
|
||||
return std::fstream(reportsDir + "/" + name + ".txt", std::ios::out);
|
||||
}
|
||||
|
||||
std::string formatReportMemory(uint64_t bytes) {
|
||||
const char* units[] = {"B", "KB", "MB", "GB", "TB", "PB", "EB"};
|
||||
int i = 0;
|
||||
double size = static_cast<double>(bytes);
|
||||
while (size >= 1024 && i < 6) {
|
||||
size /= 1024;
|
||||
i++;
|
||||
}
|
||||
|
||||
std::string out;
|
||||
llvm::raw_string_ostream rss(out);
|
||||
rss << llvm::format("%.2f ", size) << units[i];
|
||||
return rss.str();
|
||||
}
|
||||
|
||||
void printReportFlatFields(llvm::raw_ostream& os, llvm::ArrayRef<ReportField> fields) {
|
||||
for (const ReportField& field : fields)
|
||||
os << "\t" << field.label << ": " << field.value << "\n";
|
||||
}
|
||||
|
||||
void printReportFieldBlock(llvm::raw_ostream& os, llvm::StringRef title, llvm::ArrayRef<ReportField> fields) {
|
||||
os << "\t" << title << ":\n";
|
||||
for (const ReportField& field : fields)
|
||||
os << "\t " << field.label << ": " << field.value << "\n";
|
||||
}
|
||||
|
||||
void printReportTotalsBlock(llvm::raw_ostream& os, llvm::ArrayRef<ReportField> fields) {
|
||||
os << "Totals:\n";
|
||||
for (const ReportField& field : fields)
|
||||
os << "\t" << field.label << ": " << field.value << "\n";
|
||||
}
|
||||
|
||||
void printReportPerCoreAndTotalFields(llvm::raw_ostream& os,
|
||||
llvm::ArrayRef<ReportField> perCoreFields,
|
||||
llvm::ArrayRef<ReportField> totalFields) {
|
||||
printReportFieldBlock(os, "Per core", perCoreFields);
|
||||
printReportFieldBlock(os, "Total", totalFields);
|
||||
}
|
||||
|
||||
void printReportEntrySeparator(llvm::raw_ostream& os, bool hasNextEntry) {
|
||||
if (hasNextEntry)
|
||||
os << "\n";
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,47 @@
|
||||
#pragma once
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include <fstream>
|
||||
#include <limits>
|
||||
#include <string>
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
std::fstream openReportFile(const std::string& name);
|
||||
std::string formatReportMemory(uint64_t bytes);
|
||||
|
||||
struct ReportField {
|
||||
std::string label;
|
||||
std::string value;
|
||||
};
|
||||
|
||||
void printReportFlatFields(llvm::raw_ostream& os, llvm::ArrayRef<ReportField> fields);
|
||||
void printReportFieldBlock(llvm::raw_ostream& os, llvm::StringRef title, llvm::ArrayRef<ReportField> fields);
|
||||
void printReportTotalsBlock(llvm::raw_ostream& os, llvm::ArrayRef<ReportField> fields);
|
||||
void printReportPerCoreAndTotalFields(llvm::raw_ostream& os,
|
||||
llvm::ArrayRef<ReportField> perCoreFields,
|
||||
llvm::ArrayRef<ReportField> totalFields);
|
||||
void printReportEntrySeparator(llvm::raw_ostream& os, bool hasNextEntry);
|
||||
|
||||
template <typename EntryTy>
|
||||
int32_t getFirstReportCoreId(const EntryTy& entry) {
|
||||
if (entry.coreIds.empty())
|
||||
return std::numeric_limits<int32_t>::max();
|
||||
return entry.coreIds.front();
|
||||
}
|
||||
|
||||
template <typename EntryRange>
|
||||
void sortReportEntriesByFirstCore(EntryRange& entries) {
|
||||
llvm::stable_sort(entries, [](const auto& lhs, const auto& rhs) {
|
||||
int32_t lhsFirstCore = getFirstReportCoreId(lhs);
|
||||
int32_t rhsFirstCore = getFirstReportCoreId(rhs);
|
||||
if (lhsFirstCore != rhsFirstCore)
|
||||
return lhsFirstCore < rhsFirstCore;
|
||||
return lhs.id < rhs.id;
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -29,6 +29,7 @@ add_pim_library(OMPimCompilerUtils
|
||||
OMPimCompilerOptions
|
||||
OMPimCommon
|
||||
OMPimBufferization
|
||||
OMPimStaticMemoryCoalescing
|
||||
OMPimPasses
|
||||
OMONNXToSpatial
|
||||
OMSpatialToPim
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimArtifactWriter.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimBinaryFormat.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
|
||||
@@ -19,21 +20,6 @@ using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
OnnxMlirCompilerErrorCodes writeHostCoreJson(StringRef outputDirPath) {
|
||||
std::error_code errorCode;
|
||||
std::string outputHostCorePath = outputDirPath.str() + "/core_0.json";
|
||||
raw_fd_ostream hostFileStream(outputHostCorePath, errorCode);
|
||||
if (errorCode) {
|
||||
errs() << "Error while opening host core file `" << outputHostCorePath << "`: " << errorCode.message() << '\n';
|
||||
return InvalidOutputFileAccess;
|
||||
}
|
||||
|
||||
// The host core json contains two no-op-like instructions to satisfy pimsim-nn.
|
||||
hostFileStream << "[{\"imm\":0,\"op\":\"sldi\",\"rd\":0},{\"imm\":0,\"op\":\"sldi\",\"rd\":0}]";
|
||||
hostFileStream.close();
|
||||
return CompilerSuccess;
|
||||
}
|
||||
|
||||
OnnxMlirCompilerErrorCodes
|
||||
writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory& memory, StringRef outputDirPath) {
|
||||
auto memoryFilePath = (outputDirPath + "/memory.bin").str();
|
||||
@@ -91,9 +77,6 @@ OnnxMlirCompilerErrorCodes writeConfigJson(func::FuncOp funcOp,
|
||||
json::Object configJson;
|
||||
|
||||
configJson["core_cnt"] = maxCoreId + 1;
|
||||
configJson["adc_count"] = 16;
|
||||
configJson["cell_precision"] = 2;
|
||||
configJson["xbar_array_count"] = crossbarCountInCore.getValue();
|
||||
configJson["xbar_size"] = {crossbarSize.getValue(), crossbarSize.getValue()};
|
||||
configJson["array_group_map"] = std::move(xbarsPerArrayGroup);
|
||||
|
||||
|
||||
@@ -12,7 +12,6 @@ namespace onnx_mlir {
|
||||
|
||||
class PimAcceleratorMemory;
|
||||
|
||||
OnnxMlirCompilerErrorCodes writeHostCoreJson(llvm::StringRef outputDirPath);
|
||||
OnnxMlirCompilerErrorCodes writeMemoryBinary(mlir::ModuleOp moduleOp,
|
||||
mlir::func::FuncOp funcOp,
|
||||
PimAcceleratorMemory& memory,
|
||||
|
||||
@@ -24,6 +24,78 @@ static SmallVector<int32_t> getLaneChunkCoreIds(ArrayRef<int32_t> coreIds, size_
|
||||
return laneCoreIds;
|
||||
}
|
||||
|
||||
static void scalarizeBatchOpsInCore(pim::PimCoreOp scalarCore, size_t laneCount, unsigned lane) {
|
||||
IRRewriter rewriter(scalarCore.getContext());
|
||||
SmallVector<Operation*> batchOps;
|
||||
scalarCore.walk([&](Operation* op) {
|
||||
if (isa<pim::PimSendBatchOp,
|
||||
pim::PimSendTensorBatchOp,
|
||||
pim::PimReceiveBatchOp,
|
||||
pim::PimReceiveTensorBatchOp,
|
||||
pim::PimMemCopyHostToDevBatchOp>(op)) {
|
||||
batchOps.push_back(op);
|
||||
}
|
||||
});
|
||||
|
||||
for (Operation* op : batchOps) {
|
||||
rewriter.setInsertionPoint(op);
|
||||
|
||||
if (auto sendBatchOp = dyn_cast<pim::PimSendBatchOp>(op)) {
|
||||
pim::PimSendOp::create(rewriter,
|
||||
sendBatchOp.getLoc(),
|
||||
sendBatchOp.getInput(),
|
||||
sendBatchOp.getSizeAttr(),
|
||||
rewriter.getI32IntegerAttr(sendBatchOp.getTargetCoreIds()[lane]));
|
||||
rewriter.eraseOp(op);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto sendTensorBatchOp = dyn_cast<pim::PimSendTensorBatchOp>(op)) {
|
||||
pim::PimSendTensorOp::create(
|
||||
rewriter,
|
||||
sendTensorBatchOp.getLoc(),
|
||||
sendTensorBatchOp.getInput(),
|
||||
rewriter.getDenseI32ArrayAttr(getLaneChunkCoreIds(sendTensorBatchOp.getTargetCoreIds(), laneCount, lane)));
|
||||
rewriter.eraseOp(op);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto receiveBatchOp = dyn_cast<pim::PimReceiveBatchOp>(op)) {
|
||||
auto scalarReceive =
|
||||
pim::PimReceiveOp::create(rewriter,
|
||||
receiveBatchOp.getLoc(),
|
||||
receiveBatchOp.getOutput().getType(),
|
||||
receiveBatchOp.getOutputBuffer(),
|
||||
receiveBatchOp.getSizeAttr(),
|
||||
rewriter.getI32IntegerAttr(receiveBatchOp.getSourceCoreIds()[lane]));
|
||||
rewriter.replaceOp(op, scalarReceive->getResults());
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto receiveTensorBatchOp = dyn_cast<pim::PimReceiveTensorBatchOp>(op)) {
|
||||
auto scalarReceive = pim::PimReceiveTensorOp::create(
|
||||
rewriter,
|
||||
receiveTensorBatchOp.getLoc(),
|
||||
receiveTensorBatchOp.getOutput().getType(),
|
||||
receiveTensorBatchOp.getOutputBuffer(),
|
||||
rewriter.getDenseI32ArrayAttr(getLaneChunkCoreIds(receiveTensorBatchOp.getSourceCoreIds(), laneCount, lane)));
|
||||
rewriter.replaceOp(op, scalarReceive->getResults());
|
||||
continue;
|
||||
}
|
||||
|
||||
auto memcpBatchOp = cast<pim::PimMemCopyHostToDevBatchOp>(op);
|
||||
auto scalarCopy = pim::PimMemCopyHostToDevOp::create(rewriter,
|
||||
memcpBatchOp.getLoc(),
|
||||
memcpBatchOp.getOutput().getType(),
|
||||
memcpBatchOp.getDeviceTarget(),
|
||||
memcpBatchOp.getHostSource(),
|
||||
memcpBatchOp.getDeviceTargetOffsetAttr(),
|
||||
memcpBatchOp.getHostSourceOffsetAttr(),
|
||||
memcpBatchOp.getSizeAttr());
|
||||
rewriter.replaceOp(op, scalarCopy->getResults());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult withScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp,
|
||||
@@ -50,69 +122,6 @@ LogicalResult withScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp,
|
||||
|
||||
builder.setInsertionPointToEnd(block);
|
||||
for (Operation& op : coreBatchOp.getBody().front()) {
|
||||
if (isa<pim::PimHaltOp>(op)) {
|
||||
pim::PimHaltOp::create(builder, op.getLoc());
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto sendBatchOp = dyn_cast<pim::PimSendBatchOp>(op)) {
|
||||
pim::PimSendOp::create(builder,
|
||||
sendBatchOp.getLoc(),
|
||||
mapper.lookup(sendBatchOp.getInput()),
|
||||
sendBatchOp.getSizeAttr(),
|
||||
builder.getI32IntegerAttr(sendBatchOp.getTargetCoreIds()[lane]));
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto sendTensorBatchOp = dyn_cast<pim::PimSendTensorBatchOp>(op)) {
|
||||
pim::PimSendTensorOp::create(
|
||||
builder,
|
||||
sendTensorBatchOp.getLoc(),
|
||||
mapper.lookup(sendTensorBatchOp.getInput()),
|
||||
builder.getDenseI32ArrayAttr(getLaneChunkCoreIds(sendTensorBatchOp.getTargetCoreIds(), laneCount, lane)));
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto receiveBatchOp = dyn_cast<pim::PimReceiveBatchOp>(op)) {
|
||||
auto scalarReceive =
|
||||
pim::PimReceiveOp::create(builder,
|
||||
receiveBatchOp.getLoc(),
|
||||
receiveBatchOp.getOutput().getType(),
|
||||
mapper.lookup(receiveBatchOp.getOutputBuffer()),
|
||||
receiveBatchOp.getSizeAttr(),
|
||||
builder.getI32IntegerAttr(receiveBatchOp.getSourceCoreIds()[lane]));
|
||||
mapper.map(receiveBatchOp.getOutput(), scalarReceive.getOutput());
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto receiveTensorBatchOp = dyn_cast<pim::PimReceiveTensorBatchOp>(op)) {
|
||||
auto scalarReceive = pim::PimReceiveTensorOp::create(
|
||||
builder,
|
||||
receiveTensorBatchOp.getLoc(),
|
||||
receiveTensorBatchOp.getOutput().getType(),
|
||||
mapper.lookup(receiveTensorBatchOp.getOutputBuffer()),
|
||||
builder.getDenseI32ArrayAttr(getLaneChunkCoreIds(receiveTensorBatchOp.getSourceCoreIds(), laneCount, lane)));
|
||||
mapper.map(receiveTensorBatchOp.getOutput(), scalarReceive.getOutput());
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto memcpBatchOp = dyn_cast<pim::PimMemCopyHostToDevBatchOp>(op)) {
|
||||
Value hostSource = mapper.lookupOrNull(memcpBatchOp.getHostSource());
|
||||
if (!hostSource)
|
||||
hostSource = memcpBatchOp.getHostSource();
|
||||
|
||||
auto scalarCopy = pim::PimMemCopyHostToDevOp::create(builder,
|
||||
memcpBatchOp.getLoc(),
|
||||
memcpBatchOp.getOutput().getType(),
|
||||
mapper.lookup(memcpBatchOp.getDeviceTarget()),
|
||||
hostSource,
|
||||
memcpBatchOp.getDeviceTargetOffsetAttr(),
|
||||
memcpBatchOp.getHostSourceOffsetAttr(),
|
||||
memcpBatchOp.getSizeAttr());
|
||||
mapper.map(memcpBatchOp.getOutput(), scalarCopy.getOutput());
|
||||
continue;
|
||||
}
|
||||
|
||||
Operation* cloned = builder.clone(op, mapper);
|
||||
for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults()))
|
||||
mapper.map(originalResult, clonedResult);
|
||||
@@ -120,6 +129,7 @@ LogicalResult withScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp,
|
||||
|
||||
if (block->empty() || !isa<pim::PimHaltOp>(block->back()))
|
||||
pim::PimHaltOp::create(builder, coreBatchOp.getLoc());
|
||||
scalarizeBatchOpsInCore(scalarCore, laneCount, lane);
|
||||
return callback(scalarCore);
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,374 @@
|
||||
#pragma once
|
||||
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/Endian.h"
|
||||
#include "llvm/Support/JSON.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include <array>
|
||||
#include <cassert>
|
||||
#include <limits>
|
||||
|
||||
namespace onnx_mlir::pim_binary {
|
||||
|
||||
inline constexpr char kMagic[4] = {'P', 'I', 'M', 'B'};
|
||||
inline constexpr uint32_t kVersion = 1;
|
||||
inline constexpr uint64_t kCountOffset = 8;
|
||||
inline constexpr size_t kHeaderSize = 12;
|
||||
inline constexpr size_t kRecordSize = 20;
|
||||
|
||||
enum class Opcode : uint32_t {
|
||||
nop = 0,
|
||||
sldi = 1,
|
||||
sld = 2,
|
||||
sadd = 3,
|
||||
ssub = 4,
|
||||
smul = 5,
|
||||
saddi = 6,
|
||||
smuli = 7,
|
||||
setbw = 8,
|
||||
mvmul = 9,
|
||||
vvadd = 10,
|
||||
vvsub = 11,
|
||||
vvmul = 12,
|
||||
vvdmul = 13,
|
||||
vvmax = 14,
|
||||
vvsll = 15,
|
||||
vvsra = 16,
|
||||
vavg = 17,
|
||||
vrelu = 18,
|
||||
vtanh = 19,
|
||||
vsigm = 20,
|
||||
vsoftmax = 21,
|
||||
vmv = 22,
|
||||
vrsu = 23,
|
||||
vrsl = 24,
|
||||
ld = 25,
|
||||
st = 26,
|
||||
lldi = 27,
|
||||
lmv = 28,
|
||||
send = 29,
|
||||
recv = 30,
|
||||
wait = 31,
|
||||
sync = 32,
|
||||
};
|
||||
|
||||
struct InstructionRecord {
|
||||
Opcode opcode = Opcode::nop;
|
||||
uint8_t rd = 0;
|
||||
uint8_t r1 = 0;
|
||||
int32_t r2OrImm = 0;
|
||||
int32_t generic1 = 0;
|
||||
int32_t generic2 = 0;
|
||||
int32_t generic3 = 0;
|
||||
uint8_t flags = 0;
|
||||
};
|
||||
|
||||
inline void writeUint32LE(llvm::raw_ostream& os, uint32_t value) {
|
||||
std::array<char, sizeof(uint32_t)> bytes;
|
||||
llvm::support::endian::write32le(bytes.data(), value);
|
||||
os.write(bytes.data(), bytes.size());
|
||||
}
|
||||
|
||||
inline void writeInt32LE(llvm::raw_ostream& os, int32_t value) { writeUint32LE(os, static_cast<uint32_t>(value)); }
|
||||
|
||||
inline void writeHeader(llvm::raw_ostream& os) {
|
||||
os.write(kMagic, sizeof(kMagic));
|
||||
writeUint32LE(os, kVersion);
|
||||
writeUint32LE(os, 0);
|
||||
}
|
||||
|
||||
inline void patchInstructionCount(llvm::raw_pwrite_stream& os, uint32_t instructionCount) {
|
||||
std::array<char, sizeof(uint32_t)> bytes;
|
||||
llvm::support::endian::write32le(bytes.data(), instructionCount);
|
||||
os.pwrite(bytes.data(), bytes.size(), kCountOffset);
|
||||
}
|
||||
|
||||
inline void writeInstructionRecord(llvm::raw_ostream& os, const InstructionRecord& record) {
|
||||
os << static_cast<char>(static_cast<uint8_t>(record.opcode));
|
||||
os << static_cast<char>(record.rd);
|
||||
os << static_cast<char>(record.r1);
|
||||
os << static_cast<char>(record.flags);
|
||||
writeInt32LE(os, record.r2OrImm);
|
||||
writeInt32LE(os, record.generic1);
|
||||
writeInt32LE(os, record.generic2);
|
||||
writeInt32LE(os, record.generic3);
|
||||
}
|
||||
|
||||
inline int32_t toI32(int64_t value) {
|
||||
assert(value >= std::numeric_limits<int32_t>::min() && value <= std::numeric_limits<int32_t>::max()
|
||||
&& "PIM binary field out of int32 range");
|
||||
return static_cast<int32_t>(value);
|
||||
}
|
||||
|
||||
inline uint8_t toU8(int64_t value) {
|
||||
assert(value >= 0 && value <= std::numeric_limits<uint8_t>::max() && "PIM binary field out of uint8 range");
|
||||
return static_cast<uint8_t>(value);
|
||||
}
|
||||
|
||||
inline int32_t getOptionalInt(const llvm::json::Object& object, llvm::StringRef key, int32_t defaultValue = 0) {
|
||||
if (std::optional<int64_t> value = object.getInteger(key))
|
||||
return toI32(*value);
|
||||
return defaultValue;
|
||||
}
|
||||
|
||||
inline Opcode opcodeFromString(llvm::StringRef opName) {
|
||||
if (opName == "nop")
|
||||
return Opcode::nop;
|
||||
if (opName == "sldi")
|
||||
return Opcode::sldi;
|
||||
if (opName == "sld")
|
||||
return Opcode::sld;
|
||||
if (opName == "sadd")
|
||||
return Opcode::sadd;
|
||||
if (opName == "ssub")
|
||||
return Opcode::ssub;
|
||||
if (opName == "smul")
|
||||
return Opcode::smul;
|
||||
if (opName == "saddi")
|
||||
return Opcode::saddi;
|
||||
if (opName == "smuli")
|
||||
return Opcode::smuli;
|
||||
if (opName == "setbw")
|
||||
return Opcode::setbw;
|
||||
if (opName == "mvmul")
|
||||
return Opcode::mvmul;
|
||||
if (opName == "vvadd")
|
||||
return Opcode::vvadd;
|
||||
if (opName == "vvsub")
|
||||
return Opcode::vvsub;
|
||||
if (opName == "vvmul")
|
||||
return Opcode::vvmul;
|
||||
if (opName == "vvdmul")
|
||||
return Opcode::vvdmul;
|
||||
if (opName == "vvmax")
|
||||
return Opcode::vvmax;
|
||||
if (opName == "vvsll")
|
||||
return Opcode::vvsll;
|
||||
if (opName == "vvsra")
|
||||
return Opcode::vvsra;
|
||||
if (opName == "vavg")
|
||||
return Opcode::vavg;
|
||||
if (opName == "vrelu")
|
||||
return Opcode::vrelu;
|
||||
if (opName == "vtanh")
|
||||
return Opcode::vtanh;
|
||||
if (opName == "vsigm")
|
||||
return Opcode::vsigm;
|
||||
if (opName == "vsoftmax")
|
||||
return Opcode::vsoftmax;
|
||||
if (opName == "vmv")
|
||||
return Opcode::vmv;
|
||||
if (opName == "vrsu")
|
||||
return Opcode::vrsu;
|
||||
if (opName == "vrsl")
|
||||
return Opcode::vrsl;
|
||||
if (opName == "ld")
|
||||
return Opcode::ld;
|
||||
if (opName == "st")
|
||||
return Opcode::st;
|
||||
if (opName == "lldi")
|
||||
return Opcode::lldi;
|
||||
if (opName == "lmv")
|
||||
return Opcode::lmv;
|
||||
if (opName == "send")
|
||||
return Opcode::send;
|
||||
if (opName == "recv")
|
||||
return Opcode::recv;
|
||||
if (opName == "wait")
|
||||
return Opcode::wait;
|
||||
if (opName == "sync")
|
||||
return Opcode::sync;
|
||||
llvm_unreachable("Unsupported PIM binary opcode");
|
||||
}
|
||||
|
||||
inline llvm::StringRef opcodeToString(Opcode opcode) {
|
||||
switch (opcode) {
|
||||
case Opcode::nop: return "nop";
|
||||
case Opcode::sldi: return "sldi";
|
||||
case Opcode::sld: return "sld";
|
||||
case Opcode::sadd: return "sadd";
|
||||
case Opcode::ssub: return "ssub";
|
||||
case Opcode::smul: return "smul";
|
||||
case Opcode::saddi: return "saddi";
|
||||
case Opcode::smuli: return "smuli";
|
||||
case Opcode::setbw: return "setbw";
|
||||
case Opcode::mvmul: return "mvmul";
|
||||
case Opcode::vvadd: return "vvadd";
|
||||
case Opcode::vvsub: return "vvsub";
|
||||
case Opcode::vvmul: return "vvmul";
|
||||
case Opcode::vvdmul: return "vvdmul";
|
||||
case Opcode::vvmax: return "vvmax";
|
||||
case Opcode::vvsll: return "vvsll";
|
||||
case Opcode::vvsra: return "vvsra";
|
||||
case Opcode::vavg: return "vavg";
|
||||
case Opcode::vrelu: return "vrelu";
|
||||
case Opcode::vtanh: return "vtanh";
|
||||
case Opcode::vsigm: return "vsigm";
|
||||
case Opcode::vsoftmax: return "vsoftmax";
|
||||
case Opcode::vmv: return "vmv";
|
||||
case Opcode::vrsu: return "vrsu";
|
||||
case Opcode::vrsl: return "vrsl";
|
||||
case Opcode::ld: return "ld";
|
||||
case Opcode::st: return "st";
|
||||
case Opcode::lldi: return "lldi";
|
||||
case Opcode::lmv: return "lmv";
|
||||
case Opcode::send: return "send";
|
||||
case Opcode::recv: return "recv";
|
||||
case Opcode::wait: return "wait";
|
||||
case Opcode::sync: return "sync";
|
||||
}
|
||||
llvm_unreachable("Unsupported PIM binary opcode");
|
||||
}
|
||||
|
||||
inline InstructionRecord makeInstructionRecord(const llvm::json::Object& instruction) {
|
||||
InstructionRecord record;
|
||||
std::optional<llvm::StringRef> opName = instruction.getString("op");
|
||||
assert(opName && "Missing op field in PIM instruction");
|
||||
record.opcode = opcodeFromString(*opName);
|
||||
record.rd = toU8(getOptionalInt(instruction, "rd"));
|
||||
record.r1 = toU8(getOptionalInt(instruction, "rs1"));
|
||||
|
||||
switch (record.opcode) {
|
||||
case Opcode::sldi:
|
||||
case Opcode::saddi:
|
||||
case Opcode::smuli:
|
||||
case Opcode::lldi: record.r2OrImm = getOptionalInt(instruction, "imm"); break;
|
||||
case Opcode::mvmul:
|
||||
record.r2OrImm = getOptionalInt(instruction, "mbiw");
|
||||
record.generic1 = getOptionalInt(instruction, "relu");
|
||||
record.generic2 = getOptionalInt(instruction, "group");
|
||||
break;
|
||||
case Opcode::setbw:
|
||||
record.generic1 = getOptionalInt(instruction, "ibiw");
|
||||
record.generic2 = getOptionalInt(instruction, "obiw");
|
||||
break;
|
||||
case Opcode::send:
|
||||
case Opcode::recv:
|
||||
record.r2OrImm = getOptionalInt(instruction, "core");
|
||||
record.generic3 = getOptionalInt(instruction, "size");
|
||||
break;
|
||||
default: record.r2OrImm = getOptionalInt(instruction, "rs2"); break;
|
||||
}
|
||||
|
||||
if (record.opcode != Opcode::mvmul && record.opcode != Opcode::setbw) {
|
||||
if (auto* offsetValue = instruction.getObject("offset")) {
|
||||
record.generic1 = getOptionalInt(*offsetValue, "offset_select");
|
||||
record.generic2 = getOptionalInt(*offsetValue, "offset_value");
|
||||
}
|
||||
}
|
||||
|
||||
if (instruction.get("len"))
|
||||
record.generic3 = getOptionalInt(instruction, "len");
|
||||
else if (instruction.get("size") && record.opcode != Opcode::send && record.opcode != Opcode::recv)
|
||||
record.generic3 = getOptionalInt(instruction, "size");
|
||||
|
||||
return record;
|
||||
}
|
||||
|
||||
inline llvm::json::Object makeInstructionJson(const InstructionRecord& record) {
|
||||
llvm::json::Object instruction;
|
||||
instruction["op"] = opcodeToString(record.opcode).str();
|
||||
|
||||
auto addOffset = [&](int32_t offsetSelect, int32_t offsetValue) {
|
||||
llvm::json::Object offset;
|
||||
offset["offset_select"] = offsetSelect;
|
||||
offset["offset_value"] = offsetValue;
|
||||
instruction["offset"] = std::move(offset);
|
||||
};
|
||||
|
||||
switch (record.opcode) {
|
||||
case Opcode::sldi:
|
||||
instruction["rd"] = static_cast<int64_t>(record.rd);
|
||||
instruction["imm"] = record.r2OrImm;
|
||||
break;
|
||||
case Opcode::sld:
|
||||
instruction["rd"] = static_cast<int64_t>(record.rd);
|
||||
instruction["rs1"] = static_cast<int64_t>(record.r1);
|
||||
addOffset(record.generic1, record.generic2);
|
||||
break;
|
||||
case Opcode::sadd:
|
||||
case Opcode::ssub:
|
||||
case Opcode::smul:
|
||||
instruction["rd"] = static_cast<int64_t>(record.rd);
|
||||
instruction["rs1"] = static_cast<int64_t>(record.r1);
|
||||
instruction["rs2"] = record.r2OrImm;
|
||||
break;
|
||||
case Opcode::saddi:
|
||||
case Opcode::smuli:
|
||||
instruction["rd"] = static_cast<int64_t>(record.rd);
|
||||
instruction["rs1"] = static_cast<int64_t>(record.r1);
|
||||
instruction["imm"] = record.r2OrImm;
|
||||
break;
|
||||
case Opcode::setbw:
|
||||
instruction["ibiw"] = record.generic1;
|
||||
instruction["obiw"] = record.generic2;
|
||||
break;
|
||||
case Opcode::mvmul:
|
||||
instruction["rd"] = static_cast<int64_t>(record.rd);
|
||||
instruction["rs1"] = static_cast<int64_t>(record.r1);
|
||||
instruction["mbiw"] = record.r2OrImm;
|
||||
instruction["relu"] = record.generic1;
|
||||
instruction["group"] = record.generic2;
|
||||
break;
|
||||
case Opcode::vvadd:
|
||||
case Opcode::vvsub:
|
||||
case Opcode::vvmul:
|
||||
case Opcode::vvdmul:
|
||||
case Opcode::vvmax:
|
||||
case Opcode::vvsll:
|
||||
case Opcode::vvsra:
|
||||
case Opcode::vavg:
|
||||
case Opcode::vmv:
|
||||
case Opcode::vrsu:
|
||||
case Opcode::vrsl:
|
||||
instruction["rd"] = static_cast<int64_t>(record.rd);
|
||||
instruction["rs1"] = static_cast<int64_t>(record.r1);
|
||||
instruction["rs2"] = record.r2OrImm;
|
||||
addOffset(record.generic1, record.generic2);
|
||||
instruction["len"] = record.generic3;
|
||||
break;
|
||||
case Opcode::vrelu:
|
||||
case Opcode::vtanh:
|
||||
case Opcode::vsigm:
|
||||
case Opcode::vsoftmax:
|
||||
instruction["rd"] = static_cast<int64_t>(record.rd);
|
||||
instruction["rs1"] = static_cast<int64_t>(record.r1);
|
||||
addOffset(record.generic1, record.generic2);
|
||||
instruction["len"] = record.generic3;
|
||||
break;
|
||||
case Opcode::ld:
|
||||
case Opcode::st:
|
||||
instruction["rd"] = static_cast<int64_t>(record.rd);
|
||||
instruction["rs1"] = static_cast<int64_t>(record.r1);
|
||||
addOffset(record.generic1, record.generic2);
|
||||
instruction["size"] = record.generic3;
|
||||
break;
|
||||
case Opcode::lldi:
|
||||
instruction["rd"] = static_cast<int64_t>(record.rd);
|
||||
instruction["imm"] = record.r2OrImm;
|
||||
addOffset(record.generic1, record.generic2);
|
||||
instruction["len"] = record.generic3;
|
||||
break;
|
||||
case Opcode::lmv:
|
||||
instruction["rd"] = static_cast<int64_t>(record.rd);
|
||||
instruction["rs1"] = static_cast<int64_t>(record.r1);
|
||||
addOffset(record.generic1, record.generic2);
|
||||
instruction["len"] = record.generic3;
|
||||
break;
|
||||
case Opcode::send:
|
||||
case Opcode::recv:
|
||||
instruction["rd"] = static_cast<int64_t>(record.rd);
|
||||
instruction["core"] = record.r2OrImm;
|
||||
addOffset(record.generic1, record.generic2);
|
||||
instruction["size"] = record.generic3;
|
||||
break;
|
||||
case Opcode::wait:
|
||||
case Opcode::sync:
|
||||
case Opcode::nop: break;
|
||||
}
|
||||
|
||||
return instruction;
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir::pim_binary
|
||||
+271
-205
@@ -24,11 +24,13 @@
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "Common/PimCommon.hpp"
|
||||
#include "Common/IR/CompactAsmUtils.hpp"
|
||||
#include "Common/PimCommon.hpp"
|
||||
#include "Common/Support/ReportUtils.hpp"
|
||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimArtifactWriter.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimBinaryFormat.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimWeightEmitter.hpp"
|
||||
@@ -65,6 +67,7 @@ void PimMemory::allocateMemoryForValue(mlir::Value value, MemEntry& memEntry) {
|
||||
if (size_t remainder = firstAvailableAddress % minAlignment)
|
||||
firstAvailableAddress += minAlignment - remainder;
|
||||
|
||||
ownedMemEntriesMap[value] = memEntry;
|
||||
globalMemEntriesMap[value] = memEntry;
|
||||
}
|
||||
|
||||
@@ -112,26 +115,32 @@ void PimMemory::allocateCore(Operation* op) {
|
||||
allocateGatheredMemory();
|
||||
}
|
||||
|
||||
std::string formatMemory(uint64_t bytes) {
|
||||
const char* units[] = {"B", "KB", "MB", "GB", "TB", "PB", "EB"};
|
||||
int i = 0;
|
||||
double size = static_cast<double>(bytes);
|
||||
while (size >= 1024 && i < 6) {
|
||||
size /= 1024;
|
||||
i++;
|
||||
}
|
||||
// Formats to 2 decimal places
|
||||
std::string out;
|
||||
llvm::raw_string_ostream rss(out);
|
||||
rss << llvm::format("%.2f ", size) << units[i];
|
||||
return rss.str();
|
||||
static void printHostMemoryReportRow(raw_ostream& os, const MemoryReportRow& row) {
|
||||
llvm::SmallVector<ReportField, 2> fields = {
|
||||
{"Number of globals", std::to_string(row.numGlobal) },
|
||||
{"Global memory", formatReportMemory(row.sizeGlobal)}
|
||||
};
|
||||
printReportFlatFields(os, fields);
|
||||
}
|
||||
|
||||
static void printMemoryReportRow(raw_ostream& os, const MemoryReportRow& row) {
|
||||
os << "\tNumber of allocas: " << row.numAlloca << "\n";
|
||||
os << "\tAllocated memory: " << formatMemory(row.sizeAlloca) << "\n";
|
||||
os << "\tNumber of globals: " << row.numGlobal << "\n";
|
||||
os << "\tGlobal memory: " << formatMemory(row.sizeGlobal) << "\n";
|
||||
static void printCoreMemoryReportRow(raw_ostream& os, const MemoryReportEntry& entry) {
|
||||
llvm::SmallVector<ReportField, 2> fields = {
|
||||
{"Number of allocas", std::to_string(entry.row.numAlloca) },
|
||||
{"Allocated memory", formatReportMemory(entry.row.sizeAlloca)}
|
||||
};
|
||||
printReportFlatFields(os, fields);
|
||||
}
|
||||
|
||||
static void printBatchMemoryReportRow(raw_ostream& os, const MemoryReportEntry& entry) {
|
||||
llvm::SmallVector<ReportField, 2> perCoreFields = {
|
||||
{"Number of allocas", std::to_string(entry.row.numAlloca) },
|
||||
{"Allocated memory", formatReportMemory(entry.row.sizeAlloca)}
|
||||
};
|
||||
llvm::SmallVector<ReportField, 2> totalFields = {
|
||||
{"Number of allocas", std::to_string(entry.totalAllocaCount) },
|
||||
{"Batch memory", formatReportMemory(entry.totalAllocaBytes)}
|
||||
};
|
||||
printReportPerCoreAndTotalFields(os, perCoreFields, totalFields);
|
||||
}
|
||||
|
||||
static MemoryReportRow addMemoryReportRows(const MemoryReportRow& lhs, const MemoryReportRow& rhs) {
|
||||
@@ -145,7 +154,7 @@ static MemoryReportRow addMemoryReportRows(const MemoryReportRow& lhs, const Mem
|
||||
|
||||
MemoryReportRow PimMemory::getReportRow() const {
|
||||
MemoryReportRow row;
|
||||
for (auto& [val, memEntry] : globalMemEntriesMap) {
|
||||
for (auto& [val, memEntry] : ownedMemEntriesMap) {
|
||||
if (auto op = val.getDefiningOp()) {
|
||||
if (isa<memref::AllocOp>(op)) {
|
||||
row.numAlloca++;
|
||||
@@ -162,6 +171,8 @@ MemoryReportRow PimMemory::getReportRow() const {
|
||||
}
|
||||
|
||||
void PimMemory::remove(mlir::Value val) {
|
||||
if (auto removeIter = ownedMemEntriesMap.find(val); removeIter != ownedMemEntriesMap.end())
|
||||
ownedMemEntriesMap.erase(removeIter);
|
||||
if (auto removeIter = globalMemEntriesMap.find(val); removeIter != globalMemEntriesMap.end())
|
||||
globalMemEntriesMap.erase(removeIter);
|
||||
}
|
||||
@@ -206,20 +217,25 @@ size_t PimAcceleratorMemory::getValueAddress(mlir::Value value, const StaticValu
|
||||
return iter->second.address + resolvedAddress->byteOffset;
|
||||
}
|
||||
|
||||
void PimAcceleratorMemory::reportHost() {
|
||||
hostReportRow = hostMem.getReportRow();
|
||||
}
|
||||
void PimAcceleratorMemory::reportHost() { hostReportRow = hostMem.getReportRow(); }
|
||||
|
||||
void PimAcceleratorMemory::recordCoreReport(size_t coreId, const MemoryReportRow& row) {
|
||||
reportEntries.push_back({MemoryReportEntry::Kind::Core, coreId, {static_cast<int32_t>(coreId)}, row});
|
||||
reportEntries.push_back(
|
||||
{MemoryReportEntry::Kind::Core, coreId, {static_cast<int32_t>(coreId)}, row, row.numAlloca, row.sizeAlloca});
|
||||
}
|
||||
|
||||
void PimAcceleratorMemory::recordBatchReport(uint64_t batchId, ArrayRef<int32_t> coreIds, const MemoryReportRow& row) {
|
||||
void PimAcceleratorMemory::recordBatchReport(uint64_t batchId,
|
||||
ArrayRef<int32_t> coreIds,
|
||||
const MemoryReportRow& perCoreRow,
|
||||
uint64_t totalAllocaCount,
|
||||
uint64_t totalAllocaBytes) {
|
||||
MemoryReportEntry entry;
|
||||
entry.kind = MemoryReportEntry::Kind::Batch;
|
||||
entry.id = batchId;
|
||||
llvm::append_range(entry.coreIds, coreIds);
|
||||
entry.row = row;
|
||||
entry.row = perCoreRow;
|
||||
entry.totalAllocaCount = totalAllocaCount;
|
||||
entry.totalAllocaBytes = totalAllocaBytes;
|
||||
reportEntries.push_back(std::move(entry));
|
||||
}
|
||||
|
||||
@@ -228,36 +244,33 @@ void PimAcceleratorMemory::flushReport() {
|
||||
return;
|
||||
|
||||
llvm::raw_os_ostream os(fileReport);
|
||||
uint64_t totalGlobalMemory = hostReportRow.has_value() ? hostReportRow->sizeGlobal : 0;
|
||||
uint64_t totalCoresMemory = 0;
|
||||
for (const MemoryReportEntry& entry : reportEntries)
|
||||
totalCoresMemory += entry.totalAllocaBytes;
|
||||
|
||||
llvm::SmallVector<ReportField, 2> totalFields = {
|
||||
{"Global memory", formatReportMemory(totalGlobalMemory)},
|
||||
{"Cores memory", formatReportMemory(totalCoresMemory) }
|
||||
};
|
||||
printReportTotalsBlock(os, totalFields);
|
||||
|
||||
if (hostReportRow.has_value()) {
|
||||
os << "Host:\n";
|
||||
printMemoryReportRow(os, *hostReportRow);
|
||||
os << "\nHost:\n";
|
||||
printHostMemoryReportRow(os, *hostReportRow);
|
||||
}
|
||||
|
||||
if (!reportEntries.empty()) {
|
||||
if (hostReportRow.has_value())
|
||||
os << "\n";
|
||||
|
||||
llvm::stable_sort(reportEntries, [](const MemoryReportEntry& lhs, const MemoryReportEntry& rhs) {
|
||||
if (lhs.kind != rhs.kind)
|
||||
return lhs.kind == MemoryReportEntry::Kind::Batch;
|
||||
|
||||
const MemoryReportRow& lhsRow = lhs.row;
|
||||
const MemoryReportRow& rhsRow = rhs.row;
|
||||
if (lhsRow.sizeAlloca != rhsRow.sizeAlloca)
|
||||
return lhsRow.sizeAlloca > rhsRow.sizeAlloca;
|
||||
if (lhsRow.numAlloca != rhsRow.numAlloca)
|
||||
return lhsRow.numAlloca > rhsRow.numAlloca;
|
||||
if (lhsRow.sizeGlobal != rhsRow.sizeGlobal)
|
||||
return lhsRow.sizeGlobal > rhsRow.sizeGlobal;
|
||||
if (lhsRow.numGlobal != rhsRow.numGlobal)
|
||||
return lhsRow.numGlobal > rhsRow.numGlobal;
|
||||
return lhs.id < rhs.id;
|
||||
});
|
||||
sortReportEntriesByFirstCore(reportEntries);
|
||||
|
||||
for (size_t index = 0; index < reportEntries.size();) {
|
||||
size_t runEnd = index + 1;
|
||||
while (runEnd < reportEntries.size() && reportEntries[runEnd].kind == reportEntries[index].kind
|
||||
&& reportEntries[runEnd].row == reportEntries[index].row) {
|
||||
&& reportEntries[runEnd].row == reportEntries[index].row
|
||||
&& reportEntries[runEnd].totalAllocaCount == reportEntries[index].totalAllocaCount
|
||||
&& reportEntries[runEnd].totalAllocaBytes == reportEntries[index].totalAllocaBytes) {
|
||||
++runEnd;
|
||||
}
|
||||
|
||||
@@ -279,9 +292,11 @@ void PimAcceleratorMemory::flushReport() {
|
||||
printCompressedIntegerEntries(os, ArrayRef<int32_t>(coreIds));
|
||||
}
|
||||
os << ":\n";
|
||||
printMemoryReportRow(os, reportEntries[index].row);
|
||||
if (runEnd < reportEntries.size())
|
||||
os << "\n";
|
||||
if (reportEntries[index].kind == MemoryReportEntry::Kind::Batch)
|
||||
printBatchMemoryReportRow(os, reportEntries[index]);
|
||||
else
|
||||
printCoreMemoryReportRow(os, reportEntries[index]);
|
||||
printReportEntrySeparator(os, runEnd < reportEntries.size());
|
||||
|
||||
index = runEnd;
|
||||
}
|
||||
@@ -299,36 +314,25 @@ void PimAcceleratorMemory::clean(mlir::Operation* op) {
|
||||
}
|
||||
}
|
||||
|
||||
json::Object PimCodeGen::createEmptyOffset() {
|
||||
json::Object offset;
|
||||
offset["offset_select"] = 0;
|
||||
offset["offset_value"] = 0;
|
||||
return offset;
|
||||
}
|
||||
|
||||
size_t PimCodeGen::remapCoreId(size_t coreId) const {
|
||||
auto it = emittedCoreIds.find(coreId);
|
||||
assert(it != emittedCoreIds.end() && "Missing emitted core id remapping");
|
||||
return it->second;
|
||||
}
|
||||
|
||||
static json::Object createRs1OnlyOffset() {
|
||||
json::Object offset;
|
||||
offset["offset_select"] = 1;
|
||||
offset["offset_value"] = 0;
|
||||
return offset;
|
||||
}
|
||||
|
||||
void PimCodeGen::emitInstruction(json::Object instruction) const {
|
||||
coreFileStream << json::Value(std::move(instruction)) << ',';
|
||||
void PimCodeGen::emitInstruction(const pim_binary::InstructionRecord& instruction) const {
|
||||
pim_binary::writeInstructionRecord(coreBinaryStream, instruction);
|
||||
++emittedInstructionCount;
|
||||
if (coreJsonStream)
|
||||
*coreJsonStream << json::Value(pim_binary::makeInstructionJson(instruction)) << ',';
|
||||
}
|
||||
|
||||
void PimCodeGen::genSetRegisterImmediateUnsigned(size_t registerNumber, size_t immediate) const {
|
||||
json::Object json;
|
||||
json["op"] = "sldi";
|
||||
json["rd"] = registerNumber;
|
||||
json["imm"] = immediate;
|
||||
emitInstruction(std::move(json));
|
||||
pim_binary::InstructionRecord instruction;
|
||||
instruction.opcode = pim_binary::Opcode::sldi;
|
||||
instruction.rd = static_cast<uint8_t>(registerNumber);
|
||||
instruction.r2OrImm = static_cast<int32_t>(immediate);
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
void PimCodeGen::setupRd(size_t rdAddress, size_t rdOffset) const {
|
||||
@@ -356,38 +360,41 @@ void PimCodeGen::emitMemCopyOp(StringRef opName,
|
||||
StringRef sizeFieldName) const {
|
||||
setupRdRs1(rdAddr, rdOffset, rs1Addr, rs1Offset);
|
||||
|
||||
json::Object json;
|
||||
json["op"] = opName;
|
||||
json["rd"] = 0;
|
||||
json["rs1"] = 1;
|
||||
json[sizeFieldName] = size;
|
||||
json["offset"] = createEmptyOffset();
|
||||
emitInstruction(std::move(json));
|
||||
pim_binary::InstructionRecord instruction;
|
||||
instruction.opcode = pim_binary::opcodeFromString(opName);
|
||||
instruction.rd = 0;
|
||||
instruction.r1 = 1;
|
||||
instruction.generic1 = 0;
|
||||
instruction.generic2 = 0;
|
||||
instruction.generic3 = static_cast<int32_t>(size);
|
||||
(void) sizeFieldName;
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
void PimCodeGen::emitCommunicationOp(StringRef opName, size_t bufferAddr, size_t coreId, size_t size) const {
|
||||
setupRd(bufferAddr, 0);
|
||||
|
||||
json::Object json;
|
||||
json["op"] = opName;
|
||||
json["rd"] = 0;
|
||||
json["core"] = remapCoreId(coreId);
|
||||
json["size"] = size;
|
||||
json["offset"] = createEmptyOffset();
|
||||
emitInstruction(std::move(json));
|
||||
pim_binary::InstructionRecord instruction;
|
||||
instruction.opcode = pim_binary::opcodeFromString(opName);
|
||||
instruction.rd = 0;
|
||||
instruction.r2OrImm = static_cast<int32_t>(remapCoreId(coreId));
|
||||
instruction.generic1 = 0;
|
||||
instruction.generic2 = 0;
|
||||
instruction.generic3 = static_cast<int32_t>(size);
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
void PimCodeGen::emitMvmOp(size_t groupId, size_t rdAddr, size_t rdOffset, size_t rs1Addr, size_t rs1Offset) const {
|
||||
setupRdRs1(rdAddr, rdOffset, rs1Addr, rs1Offset);
|
||||
|
||||
json::Object json;
|
||||
json["op"] = "mvmul";
|
||||
json["rd"] = 0;
|
||||
json["rs1"] = 1;
|
||||
json["group"] = groupId;
|
||||
json["relu"] = 0;
|
||||
json["mbiw"] = 8;
|
||||
emitInstruction(std::move(json));
|
||||
pim_binary::InstructionRecord instruction;
|
||||
instruction.opcode = pim_binary::Opcode::mvmul;
|
||||
instruction.rd = 0;
|
||||
instruction.r1 = 1;
|
||||
instruction.r2OrImm = 8;
|
||||
instruction.generic1 = 0;
|
||||
instruction.generic2 = static_cast<int32_t>(groupId);
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp, const StaticValueKnowledge& knowledge) const {
|
||||
@@ -495,14 +502,13 @@ void PimCodeGen::codeGenVVAddOp(pim::PimVVAddOp vvaddOp, const StaticValueKnowle
|
||||
auto rhsAddr = addressOf(vvaddOp.getRhs(), knowledge);
|
||||
setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0);
|
||||
|
||||
json::Object json;
|
||||
json["op"] = "vvadd";
|
||||
json["rd"] = 0;
|
||||
json["rs1"] = 1;
|
||||
json["rs2"] = 2;
|
||||
json["offset"] = createEmptyOffset();
|
||||
json["len"] = getValueSizeInBytes(vvaddOp.getLhs());
|
||||
emitInstruction(std::move(json));
|
||||
pim_binary::InstructionRecord instruction;
|
||||
instruction.opcode = pim_binary::Opcode::vvadd;
|
||||
instruction.rd = 0;
|
||||
instruction.r1 = 1;
|
||||
instruction.r2OrImm = 2;
|
||||
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vvaddOp.getLhs()));
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenVVSubOp(pim::PimVVSubOp vvsubOp, const StaticValueKnowledge& knowledge) const {
|
||||
@@ -511,14 +517,13 @@ void PimCodeGen::codeGenVVSubOp(pim::PimVVSubOp vvsubOp, const StaticValueKnowle
|
||||
auto rhsAddr = addressOf(vvsubOp.getRhs(), knowledge);
|
||||
setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0);
|
||||
|
||||
json::Object json;
|
||||
json["op"] = "vvsub";
|
||||
json["rd"] = 0;
|
||||
json["rs1"] = 1;
|
||||
json["rs2"] = 2;
|
||||
json["offset"] = createEmptyOffset();
|
||||
json["len"] = getValueSizeInBytes(vvsubOp.getLhs());
|
||||
emitInstruction(std::move(json));
|
||||
pim_binary::InstructionRecord instruction;
|
||||
instruction.opcode = pim_binary::Opcode::vvsub;
|
||||
instruction.rd = 0;
|
||||
instruction.r1 = 1;
|
||||
instruction.r2OrImm = 2;
|
||||
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vvsubOp.getLhs()));
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenVVMulOp(pim::PimVVMulOp vvmulOp, const StaticValueKnowledge& knowledge) const {
|
||||
@@ -527,14 +532,13 @@ void PimCodeGen::codeGenVVMulOp(pim::PimVVMulOp vvmulOp, const StaticValueKnowle
|
||||
auto rhsAddr = addressOf(vvmulOp.getRhs(), knowledge);
|
||||
setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0);
|
||||
|
||||
json::Object json;
|
||||
json["op"] = "vvmul";
|
||||
json["rd"] = 0;
|
||||
json["rs1"] = 1;
|
||||
json["rs2"] = 2;
|
||||
json["offset"] = createEmptyOffset();
|
||||
json["len"] = getValueSizeInBytes(vvmulOp.getLhs());
|
||||
emitInstruction(std::move(json));
|
||||
pim_binary::InstructionRecord instruction;
|
||||
instruction.opcode = pim_binary::Opcode::vvmul;
|
||||
instruction.rd = 0;
|
||||
instruction.r1 = 1;
|
||||
instruction.r2OrImm = 2;
|
||||
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vvmulOp.getLhs()));
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenVVMaxOp(pim::PimVVMaxOp vvmaxOp, const StaticValueKnowledge& knowledge) const {
|
||||
@@ -543,14 +547,13 @@ void PimCodeGen::codeGenVVMaxOp(pim::PimVVMaxOp vvmaxOp, const StaticValueKnowle
|
||||
auto rhsAddr = addressOf(vvmaxOp.getRhs(), knowledge);
|
||||
setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0);
|
||||
|
||||
json::Object json;
|
||||
json["op"] = "vvmax";
|
||||
json["rd"] = 0;
|
||||
json["rs1"] = 1;
|
||||
json["rs2"] = 2;
|
||||
json["offset"] = createEmptyOffset();
|
||||
json["len"] = getValueSizeInBytes(vvmaxOp.getLhs());
|
||||
emitInstruction(std::move(json));
|
||||
pim_binary::InstructionRecord instruction;
|
||||
instruction.opcode = pim_binary::Opcode::vvmax;
|
||||
instruction.rd = 0;
|
||||
instruction.r1 = 1;
|
||||
instruction.r2OrImm = 2;
|
||||
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vvmaxOp.getLhs()));
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenVVDMulOp(pim::PimVVDMulOp vvdmulOp, const StaticValueKnowledge& knowledge) const {
|
||||
@@ -559,14 +562,13 @@ void PimCodeGen::codeGenVVDMulOp(pim::PimVVDMulOp vvdmulOp, const StaticValueKno
|
||||
auto rhsAddr = addressOf(vvdmulOp.getRhs(), knowledge);
|
||||
setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0);
|
||||
|
||||
json::Object json;
|
||||
json["op"] = "vvdmul";
|
||||
json["rd"] = 0;
|
||||
json["rs1"] = 1;
|
||||
json["rs2"] = 2;
|
||||
json["offset"] = createEmptyOffset();
|
||||
json["len"] = getValueSizeInBytes(vvdmulOp.getLhs());
|
||||
emitInstruction(std::move(json));
|
||||
pim_binary::InstructionRecord instruction;
|
||||
instruction.opcode = pim_binary::Opcode::vvdmul;
|
||||
instruction.rd = 0;
|
||||
instruction.r1 = 1;
|
||||
instruction.r2OrImm = 2;
|
||||
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vvdmulOp.getLhs()));
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenVAvgOp(pim::PimVAvgOp vavgOp, const StaticValueKnowledge& knowledge) const {
|
||||
@@ -574,14 +576,14 @@ void PimCodeGen::codeGenVAvgOp(pim::PimVAvgOp vavgOp, const StaticValueKnowledge
|
||||
auto inputAddr = addressOf(vavgOp.getInput(), knowledge);
|
||||
setupRdRs1(outputBufferAddr, 0, inputAddr, 0);
|
||||
|
||||
json::Object json;
|
||||
json["op"] = "vavg";
|
||||
json["rd"] = 0;
|
||||
json["rs1"] = 1;
|
||||
json["rs2"] = 1;
|
||||
json["offset"] = createRs1OnlyOffset();
|
||||
json["len"] = getValueSizeInBytes(vavgOp.getInput());
|
||||
emitInstruction(std::move(json));
|
||||
pim_binary::InstructionRecord instruction;
|
||||
instruction.opcode = pim_binary::Opcode::vavg;
|
||||
instruction.rd = 0;
|
||||
instruction.r1 = 1;
|
||||
instruction.r2OrImm = 1;
|
||||
instruction.generic1 = 1;
|
||||
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vavgOp.getInput()));
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenVReluOp(pim::PimVReluOp vreluOp, const StaticValueKnowledge& knowledge) const {
|
||||
@@ -589,13 +591,12 @@ void PimCodeGen::codeGenVReluOp(pim::PimVReluOp vreluOp, const StaticValueKnowle
|
||||
auto inputAddr = addressOf(vreluOp.getInput(), knowledge);
|
||||
setupRdRs1(outputBufferAddr, 0, inputAddr, 0);
|
||||
|
||||
json::Object json;
|
||||
json["op"] = "vrelu";
|
||||
json["rd"] = 0;
|
||||
json["rs1"] = 1;
|
||||
json["offset"] = createEmptyOffset();
|
||||
json["len"] = getValueSizeInBytes(vreluOp.getInput());
|
||||
emitInstruction(std::move(json));
|
||||
pim_binary::InstructionRecord instruction;
|
||||
instruction.opcode = pim_binary::Opcode::vrelu;
|
||||
instruction.rd = 0;
|
||||
instruction.r1 = 1;
|
||||
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vreluOp.getInput()));
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenVTanhOp(pim::PimVTanhOp vtanhOp, const StaticValueKnowledge& knowledge) const {
|
||||
@@ -603,13 +604,12 @@ void PimCodeGen::codeGenVTanhOp(pim::PimVTanhOp vtanhOp, const StaticValueKnowle
|
||||
auto inputAddr = addressOf(vtanhOp.getInput(), knowledge);
|
||||
setupRdRs1(outputBufferAddr, 0, inputAddr, 0);
|
||||
|
||||
json::Object json;
|
||||
json["op"] = "vtanh";
|
||||
json["rd"] = 0;
|
||||
json["rs1"] = 1;
|
||||
json["offset"] = createEmptyOffset();
|
||||
json["len"] = getValueSizeInBytes(vtanhOp.getInput());
|
||||
emitInstruction(std::move(json));
|
||||
pim_binary::InstructionRecord instruction;
|
||||
instruction.opcode = pim_binary::Opcode::vtanh;
|
||||
instruction.rd = 0;
|
||||
instruction.r1 = 1;
|
||||
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vtanhOp.getInput()));
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenVSigmOp(pim::PimVSigmOp vsigmOp, const StaticValueKnowledge& knowledge) const {
|
||||
@@ -617,13 +617,12 @@ void PimCodeGen::codeGenVSigmOp(pim::PimVSigmOp vsigmOp, const StaticValueKnowle
|
||||
auto inputAddr = addressOf(vsigmOp.getInput(), knowledge);
|
||||
setupRdRs1(outputBufferAddr, 0, inputAddr, 0);
|
||||
|
||||
json::Object json;
|
||||
json["op"] = "vsigm";
|
||||
json["rd"] = 0;
|
||||
json["rs1"] = 1;
|
||||
json["offset"] = createEmptyOffset();
|
||||
json["len"] = getValueSizeInBytes(vsigmOp.getInput());
|
||||
emitInstruction(std::move(json));
|
||||
pim_binary::InstructionRecord instruction;
|
||||
instruction.opcode = pim_binary::Opcode::vsigm;
|
||||
instruction.rd = 0;
|
||||
instruction.r1 = 1;
|
||||
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vsigmOp.getInput()));
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp, const StaticValueKnowledge& knowledge) const {
|
||||
@@ -631,13 +630,12 @@ void PimCodeGen::codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp, const StaticVa
|
||||
auto inputAddr = addressOf(vsoftmaxOp.getInput(), knowledge);
|
||||
setupRdRs1(outputBufferAddr, 0, inputAddr, 0);
|
||||
|
||||
json::Object json;
|
||||
json["op"] = "vsoftmax";
|
||||
json["rd"] = 0;
|
||||
json["rs1"] = 1;
|
||||
json["offset"] = createEmptyOffset();
|
||||
json["len"] = getValueSizeInBytes(vsoftmaxOp.getInput());
|
||||
emitInstruction(std::move(json));
|
||||
pim_binary::InstructionRecord instruction;
|
||||
instruction.opcode = pim_binary::Opcode::vsoftmax;
|
||||
instruction.rd = 0;
|
||||
instruction.r1 = 1;
|
||||
instruction.generic3 = static_cast<int32_t>(getValueSizeInBytes(vsoftmaxOp.getInput()));
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGetGlobalOp(memref::GetGlobalOp getGlobalOp, const StaticValueKnowledge& knowledge) const {}
|
||||
@@ -669,6 +667,30 @@ void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp, const Stati
|
||||
dstStrides[i] = dstStrides[i + 1] * dstShape[i + 1];
|
||||
}
|
||||
|
||||
bool storagePreserving = true;
|
||||
for (size_t srcFlat = 0; srcFlat < totalElements; srcFlat++) {
|
||||
SmallVector<size_t> srcIdx(rank);
|
||||
size_t remaining = srcFlat;
|
||||
for (size_t d = 0; d < rank; d++) {
|
||||
srcIdx[d] = remaining / srcStrides[d];
|
||||
remaining %= srcStrides[d];
|
||||
}
|
||||
|
||||
size_t dstFlat = 0;
|
||||
for (size_t d = 0; d < rank; d++)
|
||||
dstFlat += srcIdx[perm[d]] * dstStrides[d];
|
||||
|
||||
if (dstFlat != srcFlat) {
|
||||
storagePreserving = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (storagePreserving) {
|
||||
emitMemCopyOp("lmv", dstAddr, 0, srcAddr, 0, totalElements * elementSize, "len");
|
||||
return;
|
||||
}
|
||||
|
||||
// Emit element-by-element copy with transposed addressing
|
||||
for (size_t srcFlat = 0; srcFlat < totalElements; srcFlat++) {
|
||||
// Decompose flat source index into multi-dimensional index
|
||||
@@ -734,9 +756,25 @@ static SmallVector<Operation*> collectTopLevelCoreLikeOps(func::FuncOp funcOp) {
|
||||
return coreLikeOps;
|
||||
}
|
||||
|
||||
static SmallDenseMap<memref::GlobalOp, MemEntry, 16>
|
||||
collectMaterializedHostGlobals(ModuleOp moduleOp, func::FuncOp funcOp, const PimAcceleratorMemory& memory) {
|
||||
SmallDenseMap<memref::GlobalOp, MemEntry, 16> materializedHostGlobals;
|
||||
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
|
||||
if (hasWeightAlways(getGlobalOp))
|
||||
return;
|
||||
auto targetGlobal = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||
if (!targetGlobal || materializedHostGlobals.contains(targetGlobal))
|
||||
return;
|
||||
auto it = memory.memEntriesMap.find(getGlobalOp.getResult());
|
||||
if (it != memory.memEntriesMap.end())
|
||||
materializedHostGlobals[targetGlobal] = it->second;
|
||||
});
|
||||
return materializedHostGlobals;
|
||||
}
|
||||
|
||||
static void aliasMaterializedHostGlobals(ModuleOp moduleOp,
|
||||
func::FuncOp funcOp,
|
||||
pim::PimCoreOp coreOp,
|
||||
const SmallDenseMap<memref::GlobalOp, MemEntry, 16>& materializedHostGlobals,
|
||||
PimAcceleratorMemory& memory) {
|
||||
coreOp.walk([&](memref::GetGlobalOp getGlobalOp) {
|
||||
if (hasWeightAlways(getGlobalOp) || memory.memEntriesMap.contains(getGlobalOp.getResult()))
|
||||
@@ -746,16 +784,9 @@ static void aliasMaterializedHostGlobals(ModuleOp moduleOp,
|
||||
if (!targetGlobal)
|
||||
return;
|
||||
|
||||
mlir::Value aliasedValue;
|
||||
funcOp.walk([&](memref::GetGlobalOp candidate) {
|
||||
if (aliasedValue || candidate == getGlobalOp || !memory.memEntriesMap.contains(candidate.getResult()))
|
||||
return;
|
||||
if (lookupGlobalForGetGlobal(moduleOp, candidate) == targetGlobal)
|
||||
aliasedValue = candidate.getResult();
|
||||
});
|
||||
|
||||
if (aliasedValue)
|
||||
memory.memEntriesMap[getGlobalOp.getResult()] = memory.memEntriesMap[aliasedValue];
|
||||
auto it = materializedHostGlobals.find(targetGlobal);
|
||||
if (it != materializedHostGlobals.end())
|
||||
memory.memEntriesMap[getGlobalOp.getResult()] = it->second;
|
||||
});
|
||||
}
|
||||
|
||||
@@ -810,7 +841,12 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
|
||||
else if (auto getGlobalOp = dyn_cast<memref::GetGlobalOp>(op))
|
||||
coreCodeGen.codeGetGlobalOp(getGlobalOp, knowledge);
|
||||
else {
|
||||
op.emitError("Unsupported codegen for this operation");
|
||||
InFlightDiagnostic diag = op.emitError()
|
||||
<< "unsupported codegen for op '" << op.getName().getStringRef() << "'";
|
||||
if (auto coreOp = op.getParentOfType<pim::PimCoreOp>())
|
||||
diag << " inside pim.core " << coreOp.getCoreId();
|
||||
else if (auto coreBatchOp = op.getParentOfType<pim::PimCoreBatchOp>())
|
||||
diag << " inside pim.core_batch with laneCount " << coreBatchOp.getLaneCount();
|
||||
return failure();
|
||||
}
|
||||
processedOperations++;
|
||||
@@ -819,7 +855,7 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
|
||||
return failed(result) ? -1 : static_cast<int64_t>(processedOperations);
|
||||
}
|
||||
|
||||
OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::string& outputDirPath) {
|
||||
OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::string& outputDirPath) {
|
||||
if (!outputDirPath.empty()) {
|
||||
if (auto error = sys::fs::create_directory(outputDirPath)) {
|
||||
errs() << "Error creating output directory: " << outputDirPath << ": " << error.message() << '\n';
|
||||
@@ -839,11 +875,6 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
|
||||
if (auto err = writeMemoryBinary(moduleOp, funcOp, memory, outputDirPath))
|
||||
return err;
|
||||
|
||||
if (auto err = writeHostCoreJson(outputDirPath))
|
||||
return err;
|
||||
|
||||
// For each core, specify the number of crossbar per array group.
|
||||
// This implementation always assigns one crossbar per group.
|
||||
json::Object xbarsPerArrayGroup;
|
||||
size_t maxCoreId = 0;
|
||||
uint64_t nextBatchReportId = 0;
|
||||
@@ -852,8 +883,10 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
|
||||
auto mapCoreWeightToFileName = createAndPopulateWeightFolder(funcOp, outputDirPath);
|
||||
|
||||
SmallVector<Operation*> coreLikeOps = collectTopLevelCoreLikeOps(funcOp);
|
||||
SmallDenseMap<memref::GlobalOp, MemEntry, 16> materializedHostGlobals =
|
||||
collectMaterializedHostGlobals(moduleOp, funcOp, memory);
|
||||
llvm::DenseMap<size_t, size_t> emittedCoreIds;
|
||||
size_t nextEmittedCoreId = 1;
|
||||
size_t nextEmittedCoreId = 0;
|
||||
|
||||
for (Operation* op : coreLikeOps) {
|
||||
if (auto coreOp = dyn_cast<pim::PimCoreOp>(op)) {
|
||||
@@ -873,32 +906,57 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
|
||||
}
|
||||
|
||||
for (Operation* op : coreLikeOps) {
|
||||
auto emitCore = [&](pim::PimCoreOp coreOp, bool temporaryCore) -> OnnxMlirCompilerErrorCodes {
|
||||
auto emitCore = [&](pim::PimCoreOp coreOp,
|
||||
bool temporaryCore,
|
||||
MemoryReportRow* reportRow = nullptr) -> OnnxMlirCompilerErrorCodes {
|
||||
size_t originalCoreId = static_cast<size_t>(coreOp.getCoreId());
|
||||
size_t coreId = emittedCoreIds.lookup(originalCoreId);
|
||||
maxCoreId = std::max(maxCoreId, coreId);
|
||||
|
||||
std::error_code errorCode;
|
||||
auto outputCorePath = outputDirPath + "/core_" + std::to_string(coreId) + ".json";
|
||||
raw_fd_ostream coreFileStream(outputCorePath, errorCode);
|
||||
auto outputCorePath = outputDirPath + "/core_" + std::to_string(coreId) + ".pim";
|
||||
raw_fd_ostream coreBinaryStream(outputCorePath, errorCode, sys::fs::OF_None);
|
||||
if (errorCode) {
|
||||
errs() << "Error while opening core file `" << outputCorePath << "`: " << errorCode.message() << '\n';
|
||||
return InvalidOutputFileAccess;
|
||||
}
|
||||
coreFileStream << '[';
|
||||
|
||||
PimCodeGen coreCodeGen(memory, coreFileStream, emittedCoreIds);
|
||||
aliasMaterializedHostGlobals(moduleOp, funcOp, coreOp, memory);
|
||||
memory.getOrCreateDeviceMem(coreId).allocateCore(coreOp);
|
||||
std::unique_ptr<raw_fd_ostream> coreJsonStream;
|
||||
if (pimEmitJson.getValue()) {
|
||||
std::string outputCoreJsonPath = outputDirPath + "/core_" + std::to_string(coreId) + ".json";
|
||||
errorCode = std::error_code();
|
||||
coreJsonStream = std::make_unique<raw_fd_ostream>(outputCoreJsonPath, errorCode);
|
||||
if (errorCode) {
|
||||
errs() << "Error while opening core json file `" << outputCoreJsonPath << "`: " << errorCode.message()
|
||||
<< '\n';
|
||||
return InvalidOutputFileAccess;
|
||||
}
|
||||
*coreJsonStream << '[';
|
||||
}
|
||||
|
||||
pim_binary::writeHeader(coreBinaryStream);
|
||||
|
||||
PimCodeGen coreCodeGen(memory, coreBinaryStream, coreJsonStream.get(), emittedCoreIds);
|
||||
aliasMaterializedHostGlobals(moduleOp, coreOp, materializedHostGlobals, memory);
|
||||
auto& deviceMemory = memory.getOrCreateDeviceMem(coreId);
|
||||
deviceMemory.allocateCore(coreOp);
|
||||
|
||||
int64_t processedOperations = codeGenCoreOps(coreOp.getBody().front(), coreCodeGen);
|
||||
if (processedOperations < 0)
|
||||
return CompilerFailure;
|
||||
assert(processedOperations > 0);
|
||||
|
||||
coreFileStream.seek(coreFileStream.tell() - 1);
|
||||
coreFileStream << ']';
|
||||
coreFileStream.close();
|
||||
if (reportRow)
|
||||
*reportRow = deviceMemory.getReportRow();
|
||||
|
||||
pim_binary::patchInstructionCount(coreBinaryStream, coreCodeGen.getEmittedInstructionCount());
|
||||
coreBinaryStream.close();
|
||||
|
||||
if (coreJsonStream) {
|
||||
coreJsonStream->seek(coreJsonStream->tell() - 1);
|
||||
*coreJsonStream << ']';
|
||||
coreJsonStream->close();
|
||||
}
|
||||
|
||||
auto coreWeightsDirPath = outputDirPath + "/core_" + std::to_string(coreId);
|
||||
if (auto error = sys::fs::create_directory(coreWeightsDirPath)) {
|
||||
@@ -933,11 +991,10 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
|
||||
};
|
||||
|
||||
if (auto coreOp = dyn_cast<pim::PimCoreOp>(op)) {
|
||||
if (auto err = emitCore(coreOp, false))
|
||||
MemoryReportRow coreRow;
|
||||
if (auto err = emitCore(coreOp, false, &coreRow))
|
||||
return err;
|
||||
memory.recordCoreReport(emittedCoreIds.lookup(static_cast<size_t>(coreOp.getCoreId())),
|
||||
memory.getOrCreateDeviceMem(emittedCoreIds.lookup(static_cast<size_t>(coreOp.getCoreId())))
|
||||
.getReportRow());
|
||||
memory.recordCoreReport(emittedCoreIds.lookup(static_cast<size_t>(coreOp.getCoreId())), coreRow);
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -946,20 +1003,29 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
|
||||
SmallVector<int32_t> reportedCoreIds;
|
||||
reportedCoreIds.reserve(batchCoreIds.size());
|
||||
MemoryReportRow batchRow;
|
||||
std::optional<MemoryReportRow> batchPerCoreRow;
|
||||
for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane) {
|
||||
OnnxMlirCompilerErrorCodes laneResult = CompilerSuccess;
|
||||
if (failed(withScalarCoreFromBatchLane(coreBatchOp, lane, [&](pim::PimCoreOp coreOp) {
|
||||
size_t originalCoreId = static_cast<size_t>(batchCoreIds[lane]);
|
||||
size_t coreId = emittedCoreIds.lookup(originalCoreId);
|
||||
reportedCoreIds.push_back(static_cast<int32_t>(coreId));
|
||||
laneResult = emitCore(coreOp, true);
|
||||
if (laneResult == CompilerSuccess)
|
||||
batchRow = addMemoryReportRows(batchRow, memory.getOrCreateDeviceMem(coreId).getReportRow());
|
||||
MemoryReportRow laneRow;
|
||||
laneResult = emitCore(coreOp, true, &laneRow);
|
||||
if (laneResult == CompilerSuccess) {
|
||||
if (!batchPerCoreRow.has_value())
|
||||
batchPerCoreRow = laneRow;
|
||||
batchRow = addMemoryReportRows(batchRow, laneRow);
|
||||
}
|
||||
return laneResult == CompilerSuccess ? success() : failure();
|
||||
})))
|
||||
return laneResult == CompilerSuccess ? CompilerFailure : laneResult;
|
||||
}
|
||||
memory.recordBatchReport(nextBatchReportId++, reportedCoreIds, batchRow);
|
||||
memory.recordBatchReport(nextBatchReportId++,
|
||||
reportedCoreIds,
|
||||
batchPerCoreRow.value_or(MemoryReportRow {}),
|
||||
batchRow.numAlloca,
|
||||
batchRow.sizeAlloca);
|
||||
}
|
||||
|
||||
memory.flushReport();
|
||||
|
||||
@@ -12,6 +12,8 @@
|
||||
|
||||
#include "onnx-mlir/Compiler/OMCompilerTypes.h"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimBinaryFormat.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
@@ -43,11 +45,14 @@ struct MemoryReportEntry {
|
||||
uint64_t id = 0;
|
||||
llvm::SmallVector<int32_t, 8> coreIds;
|
||||
MemoryReportRow row;
|
||||
uint64_t totalAllocaCount = 0;
|
||||
uint64_t totalAllocaBytes = 0;
|
||||
};
|
||||
|
||||
class PimMemory {
|
||||
llvm::SmallVector<std::pair<MemEntry, mlir::Value>, 32> memEntries;
|
||||
llvm::SmallDenseMap<mlir::Value, MemEntry, 32>& globalMemEntriesMap;
|
||||
llvm::SmallDenseMap<mlir::Value, MemEntry, 32> ownedMemEntriesMap;
|
||||
|
||||
size_t minAlignment = 4;
|
||||
size_t firstAvailableAddress = 0;
|
||||
@@ -82,40 +87,35 @@ private:
|
||||
|
||||
public:
|
||||
PimAcceleratorMemory()
|
||||
: hostMem(memEntriesMap) {
|
||||
|
||||
std::string outputDir = getOutputDir();
|
||||
if (outputDir.empty())
|
||||
return;
|
||||
|
||||
std::string dialectsDir = outputDir + "/reports/";
|
||||
createDirectory(dialectsDir);
|
||||
std::fstream file(dialectsDir + "/memory_report.txt", std::ios::out);
|
||||
fileReport = std::move(file);
|
||||
}
|
||||
: hostMem(memEntriesMap), fileReport(openReportFile("memory_report")) {}
|
||||
|
||||
PimMemory& getOrCreateDeviceMem(size_t id);
|
||||
|
||||
size_t getValueAddress(mlir::Value value, const StaticValueKnowledge& knowledge = {}) const;
|
||||
void reportHost();
|
||||
void recordCoreReport(size_t coreId, const MemoryReportRow& row);
|
||||
void recordBatchReport(uint64_t batchId, llvm::ArrayRef<int32_t> coreIds, const MemoryReportRow& row);
|
||||
void recordBatchReport(uint64_t batchId,
|
||||
llvm::ArrayRef<int32_t> coreIds,
|
||||
const MemoryReportRow& perCoreRow,
|
||||
uint64_t totalAllocaCount,
|
||||
uint64_t totalAllocaBytes);
|
||||
void flushReport();
|
||||
void clean(mlir::Operation* op);
|
||||
};
|
||||
|
||||
class PimCodeGen {
|
||||
PimAcceleratorMemory& memory;
|
||||
llvm::raw_fd_ostream& coreFileStream;
|
||||
llvm::raw_fd_ostream& coreBinaryStream;
|
||||
llvm::raw_fd_ostream* coreJsonStream;
|
||||
const llvm::DenseMap<size_t, size_t>& emittedCoreIds;
|
||||
mutable uint32_t emittedInstructionCount = 0;
|
||||
|
||||
size_t addressOf(mlir::Value value, const StaticValueKnowledge& knowledge) const {
|
||||
return memory.getValueAddress(value, knowledge);
|
||||
}
|
||||
size_t remapCoreId(size_t coreId) const;
|
||||
|
||||
static llvm::json::Object createEmptyOffset();
|
||||
void emitInstruction(llvm::json::Object instruction) const;
|
||||
void emitInstruction(const pim_binary::InstructionRecord& instruction) const;
|
||||
|
||||
void genSetRegisterImmediateUnsigned(size_t registerNumber, size_t immediate) const;
|
||||
void setupRd(size_t rdAddress, size_t rdOffset) const;
|
||||
@@ -135,9 +135,12 @@ class PimCodeGen {
|
||||
|
||||
public:
|
||||
PimCodeGen(PimAcceleratorMemory& memory,
|
||||
llvm::raw_fd_ostream& coreJson,
|
||||
llvm::raw_fd_ostream& coreBinary,
|
||||
llvm::raw_fd_ostream* coreJson,
|
||||
const llvm::DenseMap<size_t, size_t>& emittedCoreIds)
|
||||
: memory(memory), coreFileStream(coreJson), emittedCoreIds(emittedCoreIds) {}
|
||||
: memory(memory), coreBinaryStream(coreBinary), coreJsonStream(coreJson), emittedCoreIds(emittedCoreIds) {}
|
||||
|
||||
uint32_t getEmittedInstructionCount() const { return emittedInstructionCount; }
|
||||
|
||||
void codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp, const StaticValueKnowledge& knowledge) const;
|
||||
void codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp, const StaticValueKnowledge& knowledge) const;
|
||||
@@ -166,6 +169,6 @@ public:
|
||||
void codeGenTransposeOp(pim::PimTransposeOp transposeOp, const StaticValueKnowledge& knowledge) const;
|
||||
};
|
||||
|
||||
OnnxMlirCompilerErrorCodes compileToPimJson(mlir::ModuleOp& moduleOpRef, std::string& outputDirName);
|
||||
OnnxMlirCompilerErrorCodes compileToPimCode(mlir::ModuleOp& moduleOpRef, std::string& outputDirName);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
|
||||
#define DEBUG_TYPE "PimCompilerOptions"
|
||||
|
||||
namespace onnx_mlir {
|
||||
@@ -13,6 +15,14 @@ llvm::cl::opt<PimEmissionTargetType> pimEmissionTarget(
|
||||
llvm::cl::init(EmitPimCodegen),
|
||||
llvm::cl::cat(OnnxMlirOptions));
|
||||
|
||||
llvm::cl::opt<PimMergeSchedulerType> pimMergeScheduler(
|
||||
"pim-merge-scheduler",
|
||||
llvm::cl::desc("Scheduler used by the Spatial merge-compute-nodes pass"),
|
||||
llvm::cl::values(clEnumValN(MergeSchedulerPeft, "peft", "Use PEFT scheduling")),
|
||||
llvm::cl::values(clEnumValN(MergeSchedulerDcp, "dcp", "Use the legacy DCP-inspired scheduler")),
|
||||
llvm::cl::init(MergeSchedulerPeft),
|
||||
llvm::cl::cat(OnnxMlirOptions));
|
||||
|
||||
llvm::cl::opt<bool>
|
||||
pimOnlyCodegen("pim-only-codegen",
|
||||
llvm::cl::desc("Only generate code for PIM (assume input is already in bufferized PIM IR)"),
|
||||
@@ -24,20 +34,25 @@ llvm::cl::opt<bool> useExperimentalConvImpl("use-experimental-conv-impl",
|
||||
llvm::cl::init(false),
|
||||
llvm::cl::cat(OnnxMlirOptions));
|
||||
|
||||
llvm::cl::opt<bool> pimEmitJson("pim-emit-json",
|
||||
llvm::cl::desc("Also emit per-core JSON instruction files alongside binary .pim files"),
|
||||
llvm::cl::init(false),
|
||||
llvm::cl::cat(OnnxMlirOptions));
|
||||
|
||||
llvm::cl::opt<size_t>
|
||||
crossbarSize("crossbar-size", llvm::cl::desc("Width and heigth of a single crossbar"), llvm::cl::init(2));
|
||||
crossbarSize("crossbar-size", llvm::cl::desc("Width and height of a single crossbar"), llvm::cl::init(2));
|
||||
|
||||
llvm::cl::opt<size_t>
|
||||
crossbarCountInCore("crossbar-count", llvm::cl::desc("Number of crossbars in each core"), llvm::cl::init(256));
|
||||
|
||||
llvm::cl::opt<long> coresCount("core-count",
|
||||
llvm::cl::desc("Number of cores in the chip. `-1` to use the minimum amount of cores."),
|
||||
llvm::cl::desc("Number of cores in the chip. Required for PIM compilation."),
|
||||
llvm::cl::init(-1));
|
||||
|
||||
llvm::cl::opt<size_t> dcpCriticalWindowSize(
|
||||
"dcp-critical-window-size",
|
||||
llvm::cl::desc("Number of lowest-slack virtual nodes considered by each DCP coarsening iteration. "
|
||||
"Use 0 to run the legacy full-graph DCP analysis."),
|
||||
"Use 0 to run the legacy full-graph DCP analysis. Only used by the DCP scheduler."),
|
||||
llvm::cl::init(4000));
|
||||
|
||||
llvm::cl::opt<bool>
|
||||
@@ -45,4 +60,13 @@ llvm::cl::opt<bool>
|
||||
llvm::cl::desc("Ignore ConcatOp corner case: do not assert and do a simplification"),
|
||||
llvm::cl::init(false));
|
||||
|
||||
bool hasExplicitPimCoreCount() { return coresCount.getNumOccurrences() != 0; }
|
||||
|
||||
void verifyExplicitPimCoreCount() {
|
||||
if (!hasExplicitPimCoreCount())
|
||||
llvm::report_fatal_error("PIM compilation requires an explicit --core-count=<positive integer>");
|
||||
if (coresCount.getValue() <= 0)
|
||||
llvm::report_fatal_error("PIM compilation requires --core-count to be a positive integer");
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -20,17 +20,27 @@ typedef enum {
|
||||
EmitPimCodegen = 3
|
||||
} PimEmissionTargetType;
|
||||
|
||||
typedef enum {
|
||||
MergeSchedulerPeft = 0,
|
||||
MergeSchedulerDcp = 1,
|
||||
} PimMergeSchedulerType;
|
||||
|
||||
extern llvm::cl::OptionCategory OnnxMlirOptions;
|
||||
extern llvm::cl::opt<PimEmissionTargetType> pimEmissionTarget;
|
||||
extern llvm::cl::opt<PimMergeSchedulerType> pimMergeScheduler;
|
||||
|
||||
extern llvm::cl::opt<bool> pimOnlyCodegen;
|
||||
extern llvm::cl::opt<bool> useExperimentalConvImpl;
|
||||
extern llvm::cl::opt<bool> pimEmitJson;
|
||||
|
||||
extern llvm::cl::opt<size_t> crossbarSize;
|
||||
extern llvm::cl::opt<size_t> crossbarCountInCore;
|
||||
extern llvm::cl::opt<long> coresCount;
|
||||
extern llvm::cl::opt<size_t> dcpCriticalWindowSize;
|
||||
|
||||
bool hasExplicitPimCoreCount();
|
||||
void verifyExplicitPimCoreCount();
|
||||
|
||||
// This option, by default set to false, will ignore an error when resolving a
|
||||
// specific tiles of the operands of a concat. This specific case is when the
|
||||
// wanted tile is generated by two separate operands of the concat. If this is
|
||||
|
||||
@@ -17,6 +17,7 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
|
||||
PassManager& pm,
|
||||
EmissionTargetType& emissionTarget,
|
||||
std::string outputNameNoExt) {
|
||||
verifyExplicitPimCoreCount();
|
||||
|
||||
if (pimOnlyCodegen) {
|
||||
// Skip all the lowering passes and directly generate code for PIM.
|
||||
@@ -41,6 +42,7 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
|
||||
|
||||
if (pimEmissionTarget >= EmitPimBufferized) {
|
||||
pm.addPass(createPimBufferizationPass());
|
||||
pm.addPass(createPimStaticMemoryCoalescingPass());
|
||||
// pm.addPass(createCountInstructionPass());
|
||||
pm.addPass(createMessagePass("Pim bufferized"));
|
||||
}
|
||||
@@ -51,9 +53,9 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
|
||||
pm.addPass(createPimMaterializeHostConstantsPass());
|
||||
pm.addPass(createPimVerificationPass());
|
||||
pm.addPass(createMessagePass("Pim verified"));
|
||||
pm.addPass(createEmitPimJsonPass());
|
||||
pm.addPass(createEmitPimCodePass());
|
||||
// pm.addPass(createCountInstructionPass());
|
||||
pm.addPass(createMessagePass("Pim json code emitted"));
|
||||
pm.addPass(createMessagePass("Pim code emitted"));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -10,6 +10,8 @@
|
||||
#include <cassert>
|
||||
|
||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
|
||||
@@ -30,21 +32,8 @@ struct DenseWeightView {
|
||||
int64_t offset = 0;
|
||||
};
|
||||
|
||||
SmallVector<int64_t> computeRowMajorStridesForShape(ArrayRef<int64_t> shape) {
|
||||
SmallVector<int64_t> strides(shape.size(), 1);
|
||||
for (int64_t index = static_cast<int64_t>(shape.size()) - 2; index >= 0; --index)
|
||||
strides[index] = strides[index + 1] * shape[index + 1];
|
||||
return strides;
|
||||
}
|
||||
|
||||
bool allStaticSubviewParts(memref::SubViewOp subview) {
|
||||
return llvm::all_of(subview.getStaticOffsets(), [](int64_t value) { return !ShapedType::isDynamic(value); })
|
||||
&& llvm::all_of(subview.getStaticSizes(), [](int64_t value) { return !ShapedType::isDynamic(value); })
|
||||
&& llvm::all_of(subview.getStaticStrides(), [](int64_t value) { return !ShapedType::isDynamic(value); });
|
||||
}
|
||||
|
||||
FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value weight) {
|
||||
SmallVector<memref::SubViewOp> subviews;
|
||||
SmallVector<Operation*> viewOps;
|
||||
mlir::Value current = weight;
|
||||
memref::GetGlobalOp getGlobalOp;
|
||||
|
||||
@@ -55,9 +44,9 @@ FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value
|
||||
if ((getGlobalOp = dyn_cast<memref::GetGlobalOp>(defOp)))
|
||||
break;
|
||||
if (auto subview = dyn_cast<memref::SubViewOp>(defOp)) {
|
||||
if (!allStaticSubviewParts(subview))
|
||||
if (!hasAllStaticSubviewParts(subview))
|
||||
return failure();
|
||||
subviews.push_back(subview);
|
||||
viewOps.push_back(subview);
|
||||
current = subview.getSource();
|
||||
continue;
|
||||
}
|
||||
@@ -65,6 +54,24 @@ FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value
|
||||
current = cast.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto collapse = dyn_cast<memref::CollapseShapeOp>(defOp)) {
|
||||
auto srcType = dyn_cast<MemRefType>(collapse.getSrc().getType());
|
||||
auto resultType = dyn_cast<MemRefType>(collapse.getResult().getType());
|
||||
if (!srcType || !resultType || !srcType.hasStaticShape() || !resultType.hasStaticShape())
|
||||
return failure();
|
||||
viewOps.push_back(collapse);
|
||||
current = collapse.getSrc();
|
||||
continue;
|
||||
}
|
||||
if (auto expand = dyn_cast<memref::ExpandShapeOp>(defOp)) {
|
||||
auto srcType = dyn_cast<MemRefType>(expand.getSrc().getType());
|
||||
auto resultType = dyn_cast<MemRefType>(expand.getResult().getType());
|
||||
if (!srcType || !resultType || !srcType.hasStaticShape() || !resultType.hasStaticShape())
|
||||
return failure();
|
||||
viewOps.push_back(expand);
|
||||
current = expand.getSrc();
|
||||
continue;
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
@@ -79,9 +86,10 @@ FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value
|
||||
DenseWeightView view;
|
||||
view.denseAttr = denseAttr;
|
||||
view.shape.assign(denseAttr.getType().getShape().begin(), denseAttr.getType().getShape().end());
|
||||
view.strides = computeRowMajorStridesForShape(view.shape);
|
||||
view.strides = computeRowMajorStrides(view.shape);
|
||||
|
||||
for (memref::SubViewOp subview : llvm::reverse(subviews)) {
|
||||
for (Operation* viewOp : llvm::reverse(viewOps)) {
|
||||
if (auto subview = dyn_cast<memref::SubViewOp>(viewOp)) {
|
||||
SmallVector<int64_t> nextStrides;
|
||||
nextStrides.reserve(subview.getStaticStrides().size());
|
||||
for (auto [offset, stride, sourceStride] :
|
||||
@@ -91,6 +99,28 @@ FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value
|
||||
}
|
||||
view.shape.assign(subview.getStaticSizes().begin(), subview.getStaticSizes().end());
|
||||
view.strides = std::move(nextStrides);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Collapse/expand are accepted only as contiguous static reshapes of a
|
||||
// dense global view, so a row-major stride recomputation preserves layout.
|
||||
if (auto collapse = dyn_cast<memref::CollapseShapeOp>(viewOp)) {
|
||||
if (view.strides != computeRowMajorStrides(view.shape))
|
||||
return failure();
|
||||
auto resultType = cast<MemRefType>(collapse.getResult().getType());
|
||||
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
|
||||
view.strides = computeRowMajorStrides(view.shape);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto expand = dyn_cast<memref::ExpandShapeOp>(viewOp)) {
|
||||
if (view.strides != computeRowMajorStrides(view.shape))
|
||||
return failure();
|
||||
auto resultType = cast<MemRefType>(expand.getResult().getType());
|
||||
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
|
||||
view.strides = computeRowMajorStrides(view.shape);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
return view;
|
||||
|
||||
@@ -100,18 +100,27 @@ DenseMap<HSliceId, DenseMap<CoreId, SmallVector<Value>>> tileMatrix(
|
||||
return tiles;
|
||||
}
|
||||
|
||||
tensor::SplatOp
|
||||
broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
Value broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
auto oldType = cast<RankedTensorType>(scalarToBroadcast.getType());
|
||||
Type elementType = oldType.getElementType();
|
||||
int64_t shape[2] = {1, length};
|
||||
Type type = oldType.cloneWith(ArrayRef(shape), elementType);
|
||||
|
||||
auto buildBroadcast = [&](Value input) -> Value {
|
||||
auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0).getResult();
|
||||
SmallVector<Value> index(oldType.getRank(), zero);
|
||||
auto elementValue = tensor::ExtractOp::create(rewriter, loc, scalarToBroadcast, index).getResult();
|
||||
|
||||
auto elementValue = tensor::ExtractOp::create(rewriter, loc, input, index).getResult();
|
||||
return tensor::SplatOp::create(rewriter, loc, type, elementValue);
|
||||
};
|
||||
|
||||
if (isHostFoldableValue(scalarToBroadcast))
|
||||
return buildBroadcast(scalarToBroadcast);
|
||||
|
||||
auto broadcastCompute =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {type}, {}, ValueRange {scalarToBroadcast}, [&](Value input) {
|
||||
spatial::SpatYieldOp::create(rewriter, loc, buildBroadcast(input));
|
||||
});
|
||||
return broadcastCompute.getResult(0);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -136,7 +136,7 @@ tileMatrix(mlir::Value& matrixToTile,
|
||||
mlir::ConversionPatternRewriter& rewriter,
|
||||
mlir::Location& loc);
|
||||
|
||||
mlir::tensor::SplatOp broadcastToVector(mlir::Value scalarToBroadcast,
|
||||
mlir::Value broadcastToVector(mlir::Value scalarToBroadcast,
|
||||
int64_t length,
|
||||
mlir::ConversionPatternRewriter& rewriter,
|
||||
mlir::Location loc);
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallBitVector.h"
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
@@ -18,6 +22,11 @@ static bool hasStaticUnitStrides(tensor::ExtractSliceOp extractSliceOp) {
|
||||
return llvm::all_of(extractSliceOp.getStaticStrides(), [](int64_t stride) { return stride == 1; });
|
||||
}
|
||||
|
||||
static bool hasConstantIndices(tensor::ExtractOp extractOp) {
|
||||
return llvm::all_of(extractOp.getIndices(),
|
||||
[](Value index) { return isa_and_nonnull<arith::ConstantIndexOp>(index.getDefiningOp()); });
|
||||
}
|
||||
|
||||
static bool isStaticTensorResult(Operation* op) {
|
||||
return llvm::all_of(op->getResultTypes(), [](Type type) {
|
||||
auto shapedType = dyn_cast<ShapedType>(type);
|
||||
@@ -25,6 +34,167 @@ static bool isStaticTensorResult(Operation* op) {
|
||||
});
|
||||
}
|
||||
|
||||
static SmallVector<int64_t> computeRowMajorStrides(ArrayRef<int64_t> shape) {
|
||||
SmallVector<int64_t> strides(shape.size(), 1);
|
||||
for (int64_t dim = static_cast<int64_t>(shape.size()) - 2; dim >= 0; --dim)
|
||||
strides[dim] = strides[dim + 1] * shape[dim + 1];
|
||||
return strides;
|
||||
}
|
||||
|
||||
static FailureOr<DenseElementsAttr> transposeDenseElements(DenseElementsAttr denseAttr, ArrayRef<int64_t> perms) {
|
||||
auto tensorType = dyn_cast<RankedTensorType>(denseAttr.getType());
|
||||
if (!tensorType)
|
||||
return failure();
|
||||
|
||||
int64_t rank = tensorType.getRank();
|
||||
if (static_cast<int64_t>(perms.size()) != rank)
|
||||
return failure();
|
||||
|
||||
llvm::SmallBitVector seen(rank);
|
||||
SmallVector<int64_t> transposedShape;
|
||||
transposedShape.reserve(rank);
|
||||
for (int64_t perm : perms) {
|
||||
if (perm < 0 || perm >= rank || seen.test(perm))
|
||||
return failure();
|
||||
seen.set(perm);
|
||||
transposedShape.push_back(tensorType.getShape()[perm]);
|
||||
}
|
||||
|
||||
auto transposedType = RankedTensorType::get(transposedShape, tensorType.getElementType(), tensorType.getEncoding());
|
||||
if (denseAttr.isSplat())
|
||||
return DenseElementsAttr::get(transposedType, denseAttr.getSplatValue<Attribute>());
|
||||
|
||||
SmallVector<Attribute> originalValues(denseAttr.getValues<Attribute>());
|
||||
SmallVector<Attribute> transposedValues(originalValues.size());
|
||||
SmallVector<int64_t> originalStrides = computeRowMajorStrides(tensorType.getShape());
|
||||
SmallVector<int64_t> transposedStrides = computeRowMajorStrides(transposedShape);
|
||||
SmallVector<int64_t> originalIndices(rank);
|
||||
|
||||
for (auto [linearIndex, value] : llvm::enumerate(originalValues)) {
|
||||
int64_t remaining = static_cast<int64_t>(linearIndex);
|
||||
for (int64_t dim = 0; dim < rank; ++dim) {
|
||||
originalIndices[dim] = remaining / originalStrides[dim];
|
||||
remaining %= originalStrides[dim];
|
||||
}
|
||||
|
||||
int64_t transposedLinearIndex = 0;
|
||||
for (int64_t dim = 0; dim < rank; ++dim)
|
||||
transposedLinearIndex += originalIndices[perms[dim]] * transposedStrides[dim];
|
||||
|
||||
transposedValues[transposedLinearIndex] = value;
|
||||
}
|
||||
|
||||
return DenseElementsAttr::get(transposedType, transposedValues);
|
||||
}
|
||||
|
||||
static FailureOr<DenseElementsAttr> reshapeDenseElements(DenseElementsAttr denseAttr, RankedTensorType resultType) {
|
||||
auto sourceType = dyn_cast<RankedTensorType>(denseAttr.getType());
|
||||
if (!sourceType || !resultType || sourceType.getNumElements() != resultType.getNumElements())
|
||||
return failure();
|
||||
|
||||
if (denseAttr.isSplat())
|
||||
return DenseElementsAttr::get(resultType, denseAttr.getSplatValue<Attribute>());
|
||||
|
||||
SmallVector<Attribute> values(denseAttr.getValues<Attribute>());
|
||||
return DenseElementsAttr::get(resultType, values);
|
||||
}
|
||||
|
||||
static FailureOr<DenseElementsAttr> extractSliceDenseElements(DenseElementsAttr denseAttr,
|
||||
tensor::ExtractSliceOp extractSliceOp) {
|
||||
auto sourceType = dyn_cast<RankedTensorType>(denseAttr.getType());
|
||||
auto resultType = dyn_cast<RankedTensorType>(extractSliceOp.getType());
|
||||
if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
ArrayRef<int64_t> offsets = extractSliceOp.getStaticOffsets();
|
||||
ArrayRef<int64_t> sizes = extractSliceOp.getStaticSizes();
|
||||
ArrayRef<int64_t> strides = extractSliceOp.getStaticStrides();
|
||||
if (llvm::any_of(offsets, [](int64_t value) { return ShapedType::isDynamic(value); })
|
||||
|| llvm::any_of(sizes, [](int64_t value) { return ShapedType::isDynamic(value); })
|
||||
|| llvm::any_of(strides, [](int64_t stride) { return ShapedType::isDynamic(stride) || stride != 1; }))
|
||||
return failure();
|
||||
|
||||
if (denseAttr.isSplat())
|
||||
return DenseElementsAttr::get(resultType, denseAttr.getSplatValue<Attribute>());
|
||||
|
||||
SmallVector<Attribute> sourceValues(denseAttr.getValues<Attribute>());
|
||||
SmallVector<int64_t> sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
||||
SmallVector<int64_t> resultStrides = computeRowMajorStrides(resultType.getShape());
|
||||
SmallVector<Attribute> resultValues;
|
||||
resultValues.reserve(resultType.getNumElements());
|
||||
|
||||
for (int64_t linearIndex = 0; linearIndex < resultType.getNumElements(); ++linearIndex) {
|
||||
int64_t remaining = linearIndex;
|
||||
int64_t sourceLinearIndex = 0;
|
||||
for (int64_t dim = 0; dim < resultType.getRank(); ++dim) {
|
||||
const int64_t resultIndex = resultStrides.empty() ? 0 : remaining / resultStrides[dim];
|
||||
remaining = resultStrides.empty() ? 0 : remaining % resultStrides[dim];
|
||||
sourceLinearIndex += (offsets[dim] + resultIndex) * sourceStrides[dim];
|
||||
}
|
||||
resultValues.push_back(sourceValues[sourceLinearIndex]);
|
||||
}
|
||||
|
||||
return DenseElementsAttr::get(resultType, resultValues);
|
||||
}
|
||||
|
||||
static DenseElementsAttr getDirectDenseConstantAttr(Value value) {
|
||||
if (auto constantOp = value.getDefiningOp<arith::ConstantOp>())
|
||||
return dyn_cast<DenseElementsAttr>(constantOp.getValue());
|
||||
if (auto constantOp = value.getDefiningOp<ONNXConstantOp>())
|
||||
return dyn_cast_or_null<DenseElementsAttr>(constantOp.getValueAttr());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static DenseElementsAttr getHostFoldableDenseElementsAttrImpl(Value value, llvm::SmallPtrSetImpl<Operation*>& visited) {
|
||||
auto* definingOp = value.getDefiningOp();
|
||||
if (!definingOp || !visited.insert(definingOp).second)
|
||||
return nullptr;
|
||||
|
||||
// Rebuild dense attributes through view-only host-foldable chains so later
|
||||
// lowering stages can still recognize grouped/sliced constants.
|
||||
if (auto denseAttr = getDirectDenseConstantAttr(value))
|
||||
return denseAttr;
|
||||
|
||||
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(definingOp)) {
|
||||
auto inputAttr = getHostFoldableDenseElementsAttrImpl(transposeOp.getData(), visited);
|
||||
if (!inputAttr)
|
||||
return nullptr;
|
||||
|
||||
SmallVector<int64_t> perm;
|
||||
perm.reserve(transposeOp.getPermAttr().size());
|
||||
for (IntegerAttr attr : transposeOp.getPermAttr().getAsRange<IntegerAttr>())
|
||||
perm.push_back(attr.getInt());
|
||||
auto transposedAttr = transposeDenseElements(inputAttr, perm);
|
||||
return succeeded(transposedAttr) ? *transposedAttr : nullptr;
|
||||
}
|
||||
|
||||
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(definingOp)) {
|
||||
auto inputAttr = getHostFoldableDenseElementsAttrImpl(collapseShapeOp.getSrc(), visited);
|
||||
if (!inputAttr)
|
||||
return nullptr;
|
||||
auto reshapedAttr = reshapeDenseElements(inputAttr, cast<RankedTensorType>(collapseShapeOp.getType()));
|
||||
return succeeded(reshapedAttr) ? *reshapedAttr : nullptr;
|
||||
}
|
||||
|
||||
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(definingOp)) {
|
||||
auto inputAttr = getHostFoldableDenseElementsAttrImpl(expandShapeOp.getSrc(), visited);
|
||||
if (!inputAttr)
|
||||
return nullptr;
|
||||
auto reshapedAttr = reshapeDenseElements(inputAttr, cast<RankedTensorType>(expandShapeOp.getType()));
|
||||
return succeeded(reshapedAttr) ? *reshapedAttr : nullptr;
|
||||
}
|
||||
|
||||
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) {
|
||||
auto inputAttr = getHostFoldableDenseElementsAttrImpl(extractSliceOp.getSource(), visited);
|
||||
if (!inputAttr)
|
||||
return nullptr;
|
||||
auto slicedAttr = extractSliceDenseElements(inputAttr, extractSliceOp);
|
||||
return succeeded(slicedAttr) ? *slicedAttr : nullptr;
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static bool isHostFoldableOpImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*>& visited) {
|
||||
if (!op || !visited.insert(op).second)
|
||||
return false;
|
||||
@@ -32,6 +202,9 @@ static bool isHostFoldableOpImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*
|
||||
if (isa<arith::ConstantOp, ONNXConstantOp, ONNXNoneOp>(op))
|
||||
return true;
|
||||
|
||||
if (auto extractOp = dyn_cast<tensor::ExtractOp>(op))
|
||||
return hasConstantIndices(extractOp) && isHostFoldableValue(extractOp.getTensor());
|
||||
|
||||
if (!isStaticTensorResult(op))
|
||||
return false;
|
||||
|
||||
@@ -47,6 +220,9 @@ static bool isHostFoldableOpImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*
|
||||
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op))
|
||||
return hasStaticUnitStrides(extractSliceOp) && isHostFoldableValue(extractSliceOp.getSource());
|
||||
|
||||
if (auto splatOp = dyn_cast<tensor::SplatOp>(op))
|
||||
return isHostFoldableValue(splatOp.getInput());
|
||||
|
||||
if (auto extractRowsOp = dyn_cast<spatial::SpatExtractRowsOp>(op))
|
||||
return isHostFoldableValue(extractRowsOp.getInput());
|
||||
|
||||
@@ -72,4 +248,9 @@ bool isHostFoldableOp(Operation* op) {
|
||||
return isHostFoldableOpImpl(op, visited);
|
||||
}
|
||||
|
||||
DenseElementsAttr getHostFoldableDenseElementsAttr(Value value) {
|
||||
llvm::SmallPtrSet<Operation*, 8> visited;
|
||||
return getHostFoldableDenseElementsAttrImpl(value, visited);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
@@ -9,4 +10,6 @@ bool isHostFoldableValue(mlir::Value value);
|
||||
|
||||
bool isHostFoldableOp(mlir::Operation* op);
|
||||
|
||||
mlir::DenseElementsAttr getHostFoldableDenseElementsAttr(mlir::Value value);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostLegality.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
@@ -11,7 +12,7 @@ using namespace mlir;
|
||||
namespace onnx_mlir {
|
||||
|
||||
LogicalResult verifyONNXToSpatialHostLegality(func::FuncOp funcOp) {
|
||||
bool hasFailure = false;
|
||||
pim::CappedDiagnosticReporter diagnostics;
|
||||
|
||||
for (Operation& op : funcOp.getFunctionBody().front()) {
|
||||
if (isa<func::ReturnOp, spatial::SpatCompute, spatial::SpatComputeBatch>(&op))
|
||||
@@ -19,11 +20,15 @@ LogicalResult verifyONNXToSpatialHostLegality(func::FuncOp funcOp) {
|
||||
if (isHostFoldableOp(&op))
|
||||
continue;
|
||||
|
||||
op.emitOpError("non-foldable top-level runtime op remains after ONNX-to-Spatial; lower it inside spat.compute");
|
||||
hasFailure = true;
|
||||
diagnostics.report(&op, [](Operation* illegalOp) {
|
||||
illegalOp->emitOpError("non-foldable top-level runtime op remains after ONNX-to-Spatial; lower it inside "
|
||||
"spat.compute");
|
||||
});
|
||||
}
|
||||
|
||||
return success(!hasFailure);
|
||||
diagnostics.emitSuppressedSummary(funcOp, "ONNX-to-Spatial host legality failures");
|
||||
|
||||
return success(!diagnostics.hasFailure());
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -5,17 +5,15 @@
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
||||
#include "Common/Common.hpp"
|
||||
#include "Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostLegality.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp"
|
||||
@@ -87,17 +85,68 @@ static void populateEmptyFunction(func::FuncOp funcOp) {
|
||||
returnOp.setOperand(index, computeResult);
|
||||
}
|
||||
|
||||
static void wrapTopLevelRuntimeTransposes(func::FuncOp funcOp) {
|
||||
IRRewriter rewriter(funcOp.getContext());
|
||||
Block& entryBlock = funcOp.getFunctionBody().front();
|
||||
|
||||
for (Operation& op : llvm::make_early_inc_range(entryBlock)) {
|
||||
auto transposeOp = dyn_cast<ONNXTransposeOp>(&op);
|
||||
if (!transposeOp || isHostFoldableOp(transposeOp))
|
||||
continue;
|
||||
|
||||
// Transpose stays globally legal because constant/view-only cases are
|
||||
// allowed on the host. Any residual runtime transpose must be sunk into
|
||||
// spat.compute before the host legality check.
|
||||
auto resultType = transposeOp.getResult().getType();
|
||||
rewriter.setInsertionPoint(transposeOp);
|
||||
auto computeOp = createSpatCompute<1>(
|
||||
rewriter, transposeOp.getLoc(), TypeRange {resultType}, {}, ValueRange {transposeOp.getData()}, [&](Value input) {
|
||||
Value transposed =
|
||||
ONNXTransposeOp::create(rewriter, transposeOp.getLoc(), resultType, input, transposeOp.getPermAttr());
|
||||
spatial::SpatYieldOp::create(rewriter, transposeOp.getLoc(), transposed);
|
||||
});
|
||||
rewriter.replaceOp(transposeOp, computeOp.getResult(0));
|
||||
}
|
||||
}
|
||||
|
||||
void ONNXToSpatialPass::runOnOperation() {
|
||||
ModuleOp moduleOp = getOperation();
|
||||
MLIRContext* ctx = &getContext();
|
||||
|
||||
ConversionTarget preTarget(*ctx);
|
||||
preTarget.addLegalDialect<spatial::SpatialDialect,
|
||||
ONNXDialect,
|
||||
tensor::TensorDialect,
|
||||
arith::ArithDialect,
|
||||
scf::SCFDialect>();
|
||||
preTarget.addIllegalOp<ONNXConstantOp, ONNXFlattenOp>();
|
||||
|
||||
RewritePatternSet prePatterns(ctx);
|
||||
populatePrePatterns(prePatterns, ctx);
|
||||
if (failed(applyPatternsGreedily(moduleOp, std::move(prePatterns))))
|
||||
llvm::dbgs() << "Failed to apply pre-patterns, continuing...\n";
|
||||
if (failed(applyPartialConversion(moduleOp, preTarget, std::move(prePatterns)))) {
|
||||
moduleOp.emitError("failed to apply ONNX-to-Spatial pre-rewrites");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
auto entryFunc = getPimEntryFunc(moduleOp);
|
||||
if (failed(entryFunc)) {
|
||||
moduleOp.emitError("failed to locate the PIM entry function during ONNX-to-Spatial lowering");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
RewritePatternSet matmulPatterns(ctx);
|
||||
populateMatMulRewritePatterns(matmulPatterns, ctx);
|
||||
walkAndApplyPatterns(moduleOp, std::move(matmulPatterns));
|
||||
|
||||
bool hasUnloweredMatMul = false;
|
||||
moduleOp.walk([&](ONNXMatMulOp matmulOp) {
|
||||
hasUnloweredMatMul = true;
|
||||
matmulOp.emitOpError("remaining ONNX MatMul before the required ONNX-to-Spatial conversion");
|
||||
});
|
||||
if (hasUnloweredMatMul) {
|
||||
moduleOp.emitError("failed to lower all ONNX MatMul ops before ONNX-to-Spatial conversion");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
@@ -130,45 +179,58 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
RewritePatternSet conversionPatterns(ctx);
|
||||
populateConversionPatterns(conversionPatterns, ctx);
|
||||
if (failed(applyPartialConversion(moduleOp, target, std::move(conversionPatterns)))) {
|
||||
moduleOp.emitError("failed to convert required ONNX ops to Spatial ops");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
ConversionTarget earlyPostTarget(*ctx);
|
||||
earlyPostTarget.addLegalDialect<spatial::SpatialDialect,
|
||||
ONNXDialect,
|
||||
tensor::TensorDialect,
|
||||
arith::ArithDialect,
|
||||
scf::SCFDialect>();
|
||||
earlyPostTarget.addDynamicallyLegalOp<spatial::SpatComputeBatch>(
|
||||
[](spatial::SpatComputeBatch batchOp) { return !requiresEarlyPostRewrite(batchOp); });
|
||||
|
||||
RewritePatternSet earlyPostPatterns(ctx);
|
||||
populateEarlyPostPatterns(earlyPostPatterns, ctx);
|
||||
if (failed(applyPatternsGreedily(*entryFunc, std::move(earlyPostPatterns)))) {
|
||||
if (failed(applyPartialConversion(*entryFunc, earlyPostTarget, std::move(earlyPostPatterns)))) {
|
||||
moduleOp.emitError("failed to normalize single-lane spat.compute_batch ops before core assignment checks");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
if (coresCount != -1) {
|
||||
int computeOpsCount = 0;
|
||||
for (Operation& op : entryFunc->getFunctionBody().front().getOperations())
|
||||
if (isa<spatial::SpatCompute>(op))
|
||||
computeOpsCount++;
|
||||
|
||||
if (computeOpsCount > coresCount) {
|
||||
llvm::dbgs() << "Number of compute ops exceeds the core count\n";
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
PassManager cleanupPM(ctx);
|
||||
cleanupPM.addPass(createCanonicalizerPass());
|
||||
if (failed(cleanupPM.run(moduleOp)))
|
||||
llvm::dbgs() << "Failed to run canonicalization cleanup, continuing...\n";
|
||||
moduleOp.emitWarning("failed to run ONNX-to-Spatial canonicalization cleanup; continuing");
|
||||
|
||||
annotateWeightsConstants(*entryFunc);
|
||||
|
||||
ConversionTarget postTarget(*ctx);
|
||||
postTarget.addLegalDialect<spatial::SpatialDialect,
|
||||
ONNXDialect,
|
||||
tensor::TensorDialect,
|
||||
arith::ArithDialect,
|
||||
scf::SCFDialect>();
|
||||
postTarget.addDynamicallyLegalOp<spatial::SpatCompute>(
|
||||
[](spatial::SpatCompute computeOp) { return !requiresPostRewrite(computeOp); });
|
||||
postTarget.addDynamicallyLegalOp<spatial::SpatComputeBatch>(
|
||||
[](spatial::SpatComputeBatch computeOp) { return !requiresPostRewrite(computeOp); });
|
||||
|
||||
RewritePatternSet postPatterns(ctx);
|
||||
populatePostPatterns(postPatterns, ctx);
|
||||
if (failed(applyPatternsGreedily(*entryFunc, std::move(postPatterns)))) {
|
||||
if (failed(applyPartialConversion(*entryFunc, postTarget, std::move(postPatterns)))) {
|
||||
moduleOp.emitError("failed to normalize weight-like Spatial compute operands before Spatial-to-PIM lowering");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
wrapTopLevelRuntimeTransposes(*entryFunc);
|
||||
|
||||
if (failed(verifyONNXToSpatialHostLegality(*entryFunc))) {
|
||||
moduleOp.emitError("ONNX-to-Spatial host legality verification failed");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -27,16 +28,6 @@ struct ConvToGemm : OpConversionPattern<ONNXConvOp> {
|
||||
ConversionPatternRewriter& rewriter) const override;
|
||||
};
|
||||
|
||||
static DenseElementsAttr getDenseConstantAttr(Value value) {
|
||||
if (auto constantOp = value.getDefiningOp<arith::ConstantOp>())
|
||||
return dyn_cast<DenseElementsAttr>(constantOp.getValue());
|
||||
|
||||
if (auto constantOp = value.getDefiningOp<ONNXConstantOp>())
|
||||
return dyn_cast_or_null<DenseElementsAttr>(constantOp.getValueAttr());
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static int64_t getI64FromArrayAttr(ArrayAttr arr, size_t idx) { return cast<IntegerAttr>(arr[idx]).getInt(); }
|
||||
|
||||
static Value expandBiasIfNeeded(Value bias, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
@@ -355,49 +346,22 @@ static Value createCollectedConvOutput(ValueRange gemmRows,
|
||||
return collectComputeOp.getResult(0);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
ONNXConvOpAdaptor convOpAdaptor,
|
||||
ConversionPatternRewriter& rewriter) const {
|
||||
Location loc = convOp.getLoc();
|
||||
Value x = convOpAdaptor.getX();
|
||||
Value w = convOpAdaptor.getW();
|
||||
Value b = convOpAdaptor.getB();
|
||||
|
||||
auto xType = cast<RankedTensorType>(x.getType());
|
||||
auto wType = cast<RankedTensorType>(w.getType());
|
||||
auto outType = cast<RankedTensorType>(convOp.getY().getType());
|
||||
|
||||
if (!xType.hasStaticShape()) {
|
||||
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv input");
|
||||
return failure();
|
||||
}
|
||||
if (!wType.hasStaticShape()) {
|
||||
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv weight");
|
||||
return failure();
|
||||
}
|
||||
if (!outType.hasStaticShape()) {
|
||||
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv result");
|
||||
return failure();
|
||||
}
|
||||
if (xType.getRank() != 4) {
|
||||
pim::emitUnsupportedRankDiagnostic(convOp, "conv input", xType.getRank(), {4});
|
||||
return failure();
|
||||
}
|
||||
if (wType.getRank() != 4) {
|
||||
pim::emitUnsupportedRankDiagnostic(convOp, "conv weight", wType.getRank(), {4});
|
||||
return failure();
|
||||
}
|
||||
if (outType.getRank() != 4) {
|
||||
pim::emitUnsupportedRankDiagnostic(convOp, "conv result", outType.getRank(), {4});
|
||||
return failure();
|
||||
}
|
||||
if (convOp.getGroup() != 1) {
|
||||
convOp.emitOpError("only group=1 convolution is supported for Spatial lowering");
|
||||
return failure();
|
||||
}
|
||||
|
||||
static Value lowerSingleConvGroup(Value x,
|
||||
Value w,
|
||||
Value b,
|
||||
RankedTensorType xType,
|
||||
RankedTensorType wType,
|
||||
RankedTensorType outType,
|
||||
int64_t padHeightBegin,
|
||||
int64_t padHeightEnd,
|
||||
int64_t padWidthBegin,
|
||||
int64_t padWidthEnd,
|
||||
int64_t strideHeight,
|
||||
int64_t strideWidth,
|
||||
int64_t dilationHeight,
|
||||
int64_t dilationWidth,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
const int64_t batchSize = xType.getDimSize(0);
|
||||
const int64_t numChannelsIn = xType.getDimSize(1);
|
||||
const int64_t xHeight = xType.getDimSize(2);
|
||||
@@ -408,71 +372,6 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
const int64_t outHeight = outType.getDimSize(2);
|
||||
const int64_t outWidth = outType.getDimSize(3);
|
||||
|
||||
// Read optional conv attributes (ONNX defaults: stride=1, dilation=1, pad=0)
|
||||
const auto stridesAttr = convOp.getStrides();
|
||||
const auto dilationsAttr = convOp.getDilations();
|
||||
const auto padsAttr = convOp.getPads();
|
||||
|
||||
if (stridesAttr && stridesAttr->size() != 2) {
|
||||
convOp.emitOpError("requires exactly two stride values for Spatial lowering");
|
||||
return failure();
|
||||
}
|
||||
if (dilationsAttr && dilationsAttr->size() != 2) {
|
||||
convOp.emitOpError("requires exactly two dilation values for Spatial lowering");
|
||||
return failure();
|
||||
}
|
||||
if (padsAttr && padsAttr->size() != 4) {
|
||||
convOp.emitOpError("requires exactly four pad values for 2D Spatial lowering");
|
||||
return failure();
|
||||
}
|
||||
|
||||
const int64_t strideHeight = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 0) : 1;
|
||||
const int64_t strideWidth = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 1) : 1;
|
||||
const int64_t dilationHeight = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 0) : 1;
|
||||
const int64_t dilationWidth = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 1) : 1;
|
||||
|
||||
int64_t padHeightBegin = 0;
|
||||
int64_t padHeightEnd = 0;
|
||||
int64_t padWidthBegin = 0;
|
||||
int64_t padWidthEnd = 0;
|
||||
|
||||
if (padsAttr) {
|
||||
padHeightBegin = getI64FromArrayAttr(*padsAttr, 0);
|
||||
padWidthBegin = getI64FromArrayAttr(*padsAttr, 1);
|
||||
padHeightEnd = getI64FromArrayAttr(*padsAttr, 2);
|
||||
padWidthEnd = getI64FromArrayAttr(*padsAttr, 3);
|
||||
}
|
||||
else {
|
||||
// Compute padding from auto_pad attribute
|
||||
const auto autoPad = convOp.getAutoPad();
|
||||
if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
|
||||
const int64_t effectiveKernelH = (wHeight - 1) * dilationHeight + 1;
|
||||
const int64_t effectiveKernelW = (wWidth - 1) * dilationWidth + 1;
|
||||
const int64_t totalPadH =
|
||||
std::max(static_cast<int64_t>(0), (outHeight - 1) * strideHeight + effectiveKernelH - xHeight);
|
||||
const int64_t totalPadW =
|
||||
std::max(static_cast<int64_t>(0), (outWidth - 1) * strideWidth + effectiveKernelW - xWidth);
|
||||
|
||||
if (autoPad == "SAME_UPPER") {
|
||||
padHeightBegin = totalPadH / 2;
|
||||
padHeightEnd = totalPadH - padHeightBegin;
|
||||
padWidthBegin = totalPadW / 2;
|
||||
padWidthEnd = totalPadW - padWidthBegin;
|
||||
}
|
||||
else { // SAME_LOWER
|
||||
padHeightEnd = totalPadH / 2;
|
||||
padHeightBegin = totalPadH - padHeightEnd;
|
||||
padWidthEnd = totalPadW / 2;
|
||||
padWidthBegin = totalPadW - padWidthEnd;
|
||||
}
|
||||
}
|
||||
else if (autoPad != "NOTSET" && autoPad != "VALID") {
|
||||
convOp.emitOpError() << "unsupported auto_pad value `" << autoPad << "` for Spatial lowering";
|
||||
return failure();
|
||||
}
|
||||
// "NOTSET" or "VALID" -> all pads stay 0
|
||||
}
|
||||
|
||||
// im2col layout (flipped with respect to the standard, so filters sit in B = crossbar):
|
||||
// A (im2col): [numPatches, patchSize] -- one row per output spatial position
|
||||
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
|
||||
@@ -492,7 +391,7 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
const int64_t xbarSize = static_cast<int64_t>(crossbarSize.getValue());
|
||||
const int64_t wMaxDim = std::max(patchSize, numChannelsOut);
|
||||
const int64_t maxParallelPixels = std::max<int64_t>(1, xbarSize / wMaxDim);
|
||||
auto wDenseAttr = getDenseConstantAttr(w);
|
||||
auto wDenseAttr = getHostFoldableDenseElementsAttr(w);
|
||||
|
||||
// Prepare weight matrix W for crossbar storage:
|
||||
// W: [numChannelsOut, numChannelsIn, wHeight, wWidth] -> [numChannelsOut, patchSize] -> [patchSize, numChannelsOut]
|
||||
@@ -513,7 +412,7 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
DenseElementsAttr biasDenseAttr;
|
||||
if (hasB) {
|
||||
gemmBias = b;
|
||||
biasDenseAttr = getDenseConstantAttr(b);
|
||||
biasDenseAttr = getHostFoldableDenseElementsAttr(b);
|
||||
biasMatrix = expandBiasIfNeeded(b, rewriter, loc);
|
||||
}
|
||||
const bool canPackWeightsAsConstants = static_cast<bool>(wDenseAttr);
|
||||
@@ -589,9 +488,8 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
rewriter.getBoolAttr(false))
|
||||
.getY();
|
||||
|
||||
rewriter.replaceOp(convOp,
|
||||
createCollectedConvOutput(ValueRange {gemmRows},
|
||||
convOp.getType(),
|
||||
return createCollectedConvOutput(ValueRange {gemmRows},
|
||||
outType,
|
||||
gemmOutType,
|
||||
nhwcType,
|
||||
outType,
|
||||
@@ -599,8 +497,238 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
numChannelsOut,
|
||||
effectiveMaxParallelPixels,
|
||||
rewriter,
|
||||
loc);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
ONNXConvOpAdaptor convOpAdaptor,
|
||||
ConversionPatternRewriter& rewriter) const {
|
||||
Location loc = convOp.getLoc();
|
||||
Value x = convOpAdaptor.getX();
|
||||
Value w = convOpAdaptor.getW();
|
||||
Value b = convOpAdaptor.getB();
|
||||
|
||||
auto xType = cast<RankedTensorType>(x.getType());
|
||||
auto wType = cast<RankedTensorType>(w.getType());
|
||||
auto outType = cast<RankedTensorType>(convOp.getY().getType());
|
||||
|
||||
if (!xType.hasStaticShape()) {
|
||||
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv input");
|
||||
return failure();
|
||||
}
|
||||
if (!wType.hasStaticShape()) {
|
||||
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv weight");
|
||||
return failure();
|
||||
}
|
||||
if (!outType.hasStaticShape()) {
|
||||
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv result");
|
||||
return failure();
|
||||
}
|
||||
if (xType.getRank() != 4) {
|
||||
pim::emitUnsupportedRankDiagnostic(convOp, "conv input", xType.getRank(), {4});
|
||||
return failure();
|
||||
}
|
||||
if (wType.getRank() != 4) {
|
||||
pim::emitUnsupportedRankDiagnostic(convOp, "conv weight", wType.getRank(), {4});
|
||||
return failure();
|
||||
}
|
||||
if (outType.getRank() != 4) {
|
||||
pim::emitUnsupportedRankDiagnostic(convOp, "conv result", outType.getRank(), {4});
|
||||
return failure();
|
||||
}
|
||||
if (convOp.getGroup() < 1) {
|
||||
convOp.emitOpError("requires group >= 1 for Spatial lowering");
|
||||
return failure();
|
||||
}
|
||||
|
||||
const int64_t batchSize = xType.getDimSize(0);
|
||||
const int64_t numChannelsIn = xType.getDimSize(1);
|
||||
const int64_t xHeight = xType.getDimSize(2);
|
||||
const int64_t xWidth = xType.getDimSize(3);
|
||||
const int64_t numChannelsOut = wType.getDimSize(0);
|
||||
const int64_t wHeight = wType.getDimSize(2);
|
||||
const int64_t wWidth = wType.getDimSize(3);
|
||||
const int64_t outHeight = outType.getDimSize(2);
|
||||
const int64_t outWidth = outType.getDimSize(3);
|
||||
const int64_t group = convOp.getGroup();
|
||||
const bool hasB = !isa<ONNXNoneOp>(b.getDefiningOp());
|
||||
|
||||
if (numChannelsIn % group != 0) {
|
||||
convOp.emitOpError() << "requires input channels " << numChannelsIn << " to be divisible by group " << group
|
||||
<< " for Spatial lowering";
|
||||
return failure();
|
||||
}
|
||||
if (numChannelsOut % group != 0) {
|
||||
convOp.emitOpError() << "requires output channels " << numChannelsOut << " to be divisible by group " << group
|
||||
<< " for Spatial lowering";
|
||||
return failure();
|
||||
}
|
||||
|
||||
const int64_t numChannelsInPerGroup = numChannelsIn / group;
|
||||
const int64_t numChannelsOutPerGroup = numChannelsOut / group;
|
||||
if (wType.getDimSize(1) != numChannelsInPerGroup) {
|
||||
convOp.emitOpError() << "requires grouped conv weight input channels " << wType.getDimSize(1)
|
||||
<< " to match input channels per group " << numChannelsInPerGroup << " for Spatial lowering";
|
||||
return failure();
|
||||
}
|
||||
if (wType.getDimSize(0) != numChannelsOut) {
|
||||
convOp.emitOpError() << "requires weight output channels " << wType.getDimSize(0) << " to match result channels "
|
||||
<< numChannelsOut << " for Spatial lowering";
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Read optional conv attributes (ONNX defaults: stride=1, dilation=1, pad=0)
|
||||
const auto stridesAttr = convOp.getStrides();
|
||||
const auto dilationsAttr = convOp.getDilations();
|
||||
const auto padsAttr = convOp.getPads();
|
||||
|
||||
if (stridesAttr && stridesAttr->size() != 2) {
|
||||
convOp.emitOpError("requires exactly two stride values for Spatial lowering");
|
||||
return failure();
|
||||
}
|
||||
if (dilationsAttr && dilationsAttr->size() != 2) {
|
||||
convOp.emitOpError("requires exactly two dilation values for Spatial lowering");
|
||||
return failure();
|
||||
}
|
||||
if (padsAttr && padsAttr->size() != 4) {
|
||||
convOp.emitOpError("requires exactly four pad values for 2D Spatial lowering");
|
||||
return failure();
|
||||
}
|
||||
|
||||
const int64_t strideHeight = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 0) : 1;
|
||||
const int64_t strideWidth = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 1) : 1;
|
||||
const int64_t dilationHeight = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 0) : 1;
|
||||
const int64_t dilationWidth = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 1) : 1;
|
||||
|
||||
int64_t padHeightBegin = 0;
|
||||
int64_t padHeightEnd = 0;
|
||||
int64_t padWidthBegin = 0;
|
||||
int64_t padWidthEnd = 0;
|
||||
|
||||
if (padsAttr) {
|
||||
padHeightBegin = getI64FromArrayAttr(*padsAttr, 0);
|
||||
padWidthBegin = getI64FromArrayAttr(*padsAttr, 1);
|
||||
padHeightEnd = getI64FromArrayAttr(*padsAttr, 2);
|
||||
padWidthEnd = getI64FromArrayAttr(*padsAttr, 3);
|
||||
}
|
||||
else {
|
||||
// Compute padding from auto_pad attribute
|
||||
const auto autoPad = convOp.getAutoPad();
|
||||
if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
|
||||
const int64_t effectiveKernelH = (wHeight - 1) * dilationHeight + 1;
|
||||
const int64_t effectiveKernelW = (wWidth - 1) * dilationWidth + 1;
|
||||
const int64_t totalPadH =
|
||||
std::max(static_cast<int64_t>(0), (outHeight - 1) * strideHeight + effectiveKernelH - xHeight);
|
||||
const int64_t totalPadW =
|
||||
std::max(static_cast<int64_t>(0), (outWidth - 1) * strideWidth + effectiveKernelW - xWidth);
|
||||
|
||||
if (autoPad == "SAME_UPPER") {
|
||||
padHeightBegin = totalPadH / 2;
|
||||
padHeightEnd = totalPadH - padHeightBegin;
|
||||
padWidthBegin = totalPadW / 2;
|
||||
padWidthEnd = totalPadW - padWidthBegin;
|
||||
}
|
||||
else { // SAME_LOWER
|
||||
padHeightEnd = totalPadH / 2;
|
||||
padHeightBegin = totalPadH - padHeightEnd;
|
||||
padWidthEnd = totalPadW / 2;
|
||||
padWidthBegin = totalPadW - padWidthEnd;
|
||||
}
|
||||
}
|
||||
else if (autoPad != "NOTSET" && autoPad != "VALID") {
|
||||
convOp.emitOpError() << "unsupported auto_pad value `" << autoPad << "` for Spatial lowering";
|
||||
return failure();
|
||||
}
|
||||
// "NOTSET" or "VALID" -> all pads stay 0
|
||||
}
|
||||
|
||||
if (group == 1) {
|
||||
rewriter.replaceOp(convOp,
|
||||
lowerSingleConvGroup(x,
|
||||
w,
|
||||
b,
|
||||
xType,
|
||||
wType,
|
||||
outType,
|
||||
padHeightBegin,
|
||||
padHeightEnd,
|
||||
padWidthBegin,
|
||||
padWidthEnd,
|
||||
strideHeight,
|
||||
strideWidth,
|
||||
dilationHeight,
|
||||
dilationWidth,
|
||||
rewriter,
|
||||
loc));
|
||||
return success();
|
||||
}
|
||||
|
||||
SmallVector<Value> xSlices = sliceTensor(x, /*axis=*/1, numChannelsInPerGroup, rewriter, loc);
|
||||
SmallVector<Value> wSlices = sliceTensor(w, /*axis=*/0, numChannelsOutPerGroup, rewriter, loc);
|
||||
SmallVector<Value> bSlices;
|
||||
if (hasB) {
|
||||
auto biasType = cast<RankedTensorType>(b.getType());
|
||||
int64_t biasAxis = -1;
|
||||
if (biasType.getRank() == 1)
|
||||
biasAxis = 0;
|
||||
else if (biasType.getRank() == 2)
|
||||
biasAxis = biasType.getDimSize(0) != 1 ? 0 : 1;
|
||||
else {
|
||||
convOp.emitOpError() << "requires rank-1 or rank-2 bias for grouped convolution Spatial lowering, but got rank "
|
||||
<< biasType.getRank();
|
||||
return failure();
|
||||
}
|
||||
bSlices = sliceTensor(b, biasAxis, numChannelsOutPerGroup, rewriter, loc);
|
||||
}
|
||||
|
||||
if (xSlices.size() != static_cast<size_t>(group) || wSlices.size() != static_cast<size_t>(group)
|
||||
|| (hasB && bSlices.size() != static_cast<size_t>(group))) {
|
||||
convOp.emitOpError("failed to partition grouped convolution operands for Spatial lowering");
|
||||
return failure();
|
||||
}
|
||||
|
||||
SmallVector<Value> groupResults;
|
||||
groupResults.reserve(group);
|
||||
auto groupOutType =
|
||||
RankedTensorType::get({batchSize, numChannelsOutPerGroup, outHeight, outWidth}, outType.getElementType());
|
||||
Value noBias = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
|
||||
for (int64_t groupId = 0; groupId < group; groupId++) {
|
||||
Value groupX = xSlices[groupId];
|
||||
Value groupW = wSlices[groupId];
|
||||
Value groupB = hasB ? bSlices[groupId] : noBias;
|
||||
groupResults.push_back(lowerSingleConvGroup(groupX,
|
||||
groupW,
|
||||
groupB,
|
||||
cast<RankedTensorType>(groupX.getType()),
|
||||
cast<RankedTensorType>(groupW.getType()),
|
||||
groupOutType,
|
||||
padHeightBegin,
|
||||
padHeightEnd,
|
||||
padWidthBegin,
|
||||
padWidthEnd,
|
||||
strideHeight,
|
||||
strideWidth,
|
||||
dilationHeight,
|
||||
dilationWidth,
|
||||
rewriter,
|
||||
loc));
|
||||
}
|
||||
|
||||
Value result;
|
||||
if (llvm::all_of(groupResults, isHostFoldableValue)) {
|
||||
result = createSpatConcat(rewriter, loc, /*axis=*/1, groupResults);
|
||||
}
|
||||
else {
|
||||
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {outType}, {}, groupResults, [&](ValueRange args) {
|
||||
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, /*axis=*/1, args));
|
||||
});
|
||||
result = concatCompute.getResult(0);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(convOp, result);
|
||||
return success();
|
||||
}
|
||||
|
||||
void populateConvPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert<ConvToGemm>(ctx); }
|
||||
|
||||
@@ -502,9 +502,6 @@ LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
}
|
||||
(void) bType;
|
||||
|
||||
if (!isHostFoldableValue(b))
|
||||
return failure();
|
||||
|
||||
Value sharedBias;
|
||||
if (hasC) {
|
||||
auto scaledC = materializeScaledConstantTensor(c, gemmOpAdaptor.getBeta().convertToFloat(), rewriter, loc);
|
||||
|
||||
@@ -2,8 +2,12 @@
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <functional>
|
||||
#include <numeric>
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
@@ -19,6 +23,79 @@ static bool haveStaticPositiveShape(ArrayRef<int64_t> shape) {
|
||||
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
|
||||
}
|
||||
|
||||
static int64_t getStaticShapeElementCount(ArrayRef<int64_t> shape) {
|
||||
return std::accumulate(shape.begin(), shape.end(), int64_t {1}, std::multiplies<int64_t> {});
|
||||
}
|
||||
|
||||
static FailureOr<SmallVector<int64_t>> inferSupportedBatchShape(ArrayRef<int64_t> lhsBatchShape,
|
||||
ArrayRef<int64_t> rhsBatchShape) {
|
||||
if (lhsBatchShape.empty())
|
||||
return SmallVector<int64_t>(rhsBatchShape.begin(), rhsBatchShape.end());
|
||||
if (rhsBatchShape.empty())
|
||||
return SmallVector<int64_t>(lhsBatchShape.begin(), lhsBatchShape.end());
|
||||
if (!llvm::equal(lhsBatchShape, rhsBatchShape))
|
||||
return failure();
|
||||
return SmallVector<int64_t>(lhsBatchShape.begin(), lhsBatchShape.end());
|
||||
}
|
||||
|
||||
static Value collapseBatchDims(Value value,
|
||||
int64_t batchSize,
|
||||
int64_t rows,
|
||||
int64_t cols,
|
||||
PatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
auto type = cast<RankedTensorType>(value.getType());
|
||||
if (type.getRank() == 2 || type.getRank() == 3)
|
||||
return value;
|
||||
|
||||
auto collapsedType =
|
||||
RankedTensorType::get({batchSize, rows, cols}, type.getElementType(), type.getEncoding());
|
||||
SmallVector<ReassociationIndices> reassociation = {
|
||||
ReassociationIndices {},
|
||||
ReassociationIndices {static_cast<int64_t>(type.getRank() - 2)},
|
||||
ReassociationIndices {static_cast<int64_t>(type.getRank() - 1)}
|
||||
};
|
||||
for (int64_t dim = 0; dim < type.getRank() - 2; ++dim)
|
||||
reassociation.front().push_back(dim);
|
||||
|
||||
auto buildCollapsed = [&](Value input) -> Value {
|
||||
return tensor::CollapseShapeOp::create(rewriter, loc, collapsedType, input, reassociation);
|
||||
};
|
||||
|
||||
if (isHostFoldableValue(value))
|
||||
return buildCollapsed(value);
|
||||
|
||||
auto collapseCompute =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {collapsedType}, {}, ValueRange {value}, [&](Value input) {
|
||||
spatial::SpatYieldOp::create(rewriter, loc, buildCollapsed(input));
|
||||
});
|
||||
return collapseCompute.getResult(0);
|
||||
}
|
||||
|
||||
static Value expandBatchDims(Value value,
|
||||
RankedTensorType outputType,
|
||||
size_t batchRank,
|
||||
PatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
if (cast<RankedTensorType>(value.getType()) == outputType)
|
||||
return value;
|
||||
|
||||
SmallVector<ReassociationIndices> reassociation = {
|
||||
ReassociationIndices {},
|
||||
ReassociationIndices {static_cast<int64_t>(batchRank)},
|
||||
ReassociationIndices {static_cast<int64_t>(batchRank + 1)}
|
||||
};
|
||||
for (size_t dim = 0; dim < batchRank; ++dim)
|
||||
reassociation.front().push_back(static_cast<int64_t>(dim));
|
||||
|
||||
auto expandCompute =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {outputType}, {}, ValueRange {value}, [&](Value input) {
|
||||
Value expanded = tensor::ExpandShapeOp::create(rewriter, loc, outputType, input, reassociation);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, expanded);
|
||||
});
|
||||
return expandCompute.getResult(0);
|
||||
}
|
||||
|
||||
static Value extractBatchMatrix(Value value,
|
||||
int64_t batchIndex,
|
||||
int64_t batchSize,
|
||||
@@ -62,13 +139,29 @@ static Value extractBatchMatrix(Value value,
|
||||
static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Location loc) {
|
||||
auto type = cast<RankedTensorType>(value.getType());
|
||||
auto shape = type.getShape();
|
||||
RankedTensorType transposedType;
|
||||
SmallVector<int64_t> perm;
|
||||
if (type.getRank() == 2) {
|
||||
auto transposedType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType());
|
||||
return ONNXTransposeOp::create(rewriter, loc, transposedType, value, rewriter.getI64ArrayAttr({1, 0}));
|
||||
transposedType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType());
|
||||
perm = {1, 0};
|
||||
}
|
||||
else {
|
||||
transposedType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType());
|
||||
perm = {0, 2, 1};
|
||||
}
|
||||
|
||||
auto transposedType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType());
|
||||
return ONNXTransposeOp::create(rewriter, loc, transposedType, value, rewriter.getI64ArrayAttr({0, 2, 1}));
|
||||
auto buildTranspose = [&](Value input) -> Value {
|
||||
return ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm));
|
||||
};
|
||||
|
||||
if (isHostFoldableValue(value))
|
||||
return buildTranspose(value);
|
||||
|
||||
auto transposeCompute =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {transposedType}, {}, ValueRange {value}, [&](Value input) {
|
||||
spatial::SpatYieldOp::create(rewriter, loc, buildTranspose(input));
|
||||
});
|
||||
return transposeCompute.getResult(0);
|
||||
}
|
||||
|
||||
static Value transposeLastTwoDimsInCompute(Value value, PatternRewriter& rewriter, Location loc) {
|
||||
@@ -120,24 +213,25 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||
if (!lhsType || !rhsType || !outType || !lhsType.hasStaticShape() || !rhsType.hasStaticShape()
|
||||
|| !outType.hasStaticShape())
|
||||
return failure();
|
||||
if ((lhsType.getRank() != 2 && lhsType.getRank() != 3) || (rhsType.getRank() != 2 && rhsType.getRank() != 3)
|
||||
|| (outType.getRank() != 2 && outType.getRank() != 3))
|
||||
if (lhsType.getRank() < 2 || rhsType.getRank() < 2 || outType.getRank() < 2)
|
||||
return failure();
|
||||
if (!haveStaticPositiveShape(lhsType.getShape()) || !haveStaticPositiveShape(rhsType.getShape())
|
||||
|| !haveStaticPositiveShape(outType.getShape()))
|
||||
return failure();
|
||||
|
||||
const int64_t lhsBatch = lhsType.getRank() == 3 ? lhsType.getDimSize(0) : 1;
|
||||
const int64_t rhsBatch = rhsType.getRank() == 3 ? rhsType.getDimSize(0) : 1;
|
||||
const int64_t batch = std::max(lhsBatch, rhsBatch);
|
||||
|
||||
if ((lhsBatch != 1 && lhsBatch != batch) || (rhsBatch != 1 && rhsBatch != batch))
|
||||
SmallVector<int64_t> lhsBatchShape(lhsType.getShape().begin(), lhsType.getShape().end() - 2);
|
||||
SmallVector<int64_t> rhsBatchShape(rhsType.getShape().begin(), rhsType.getShape().end() - 2);
|
||||
auto batchShape = inferSupportedBatchShape(lhsBatchShape, rhsBatchShape);
|
||||
if (failed(batchShape))
|
||||
return failure();
|
||||
const int64_t lhsBatch = lhsBatchShape.empty() ? 1 : getStaticShapeElementCount(lhsBatchShape);
|
||||
const int64_t rhsBatch = rhsBatchShape.empty() ? 1 : getStaticShapeElementCount(rhsBatchShape);
|
||||
const int64_t batch = batchShape->empty() ? 1 : getStaticShapeElementCount(*batchShape);
|
||||
|
||||
const int64_t m = lhsType.getRank() == 3 ? lhsType.getDimSize(1) : lhsType.getDimSize(0);
|
||||
const int64_t k = lhsType.getRank() == 3 ? lhsType.getDimSize(2) : lhsType.getDimSize(1);
|
||||
const int64_t rhsK = rhsType.getRank() == 3 ? rhsType.getDimSize(1) : rhsType.getDimSize(0);
|
||||
const int64_t n = rhsType.getRank() == 3 ? rhsType.getDimSize(2) : rhsType.getDimSize(1);
|
||||
const int64_t m = lhsType.getDimSize(lhsType.getRank() - 2);
|
||||
const int64_t k = lhsType.getDimSize(lhsType.getRank() - 1);
|
||||
const int64_t rhsK = rhsType.getDimSize(rhsType.getRank() - 2);
|
||||
const int64_t n = rhsType.getDimSize(rhsType.getRank() - 1);
|
||||
if (k != rhsK)
|
||||
return failure();
|
||||
|
||||
@@ -146,15 +240,17 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||
return failure();
|
||||
}
|
||||
else {
|
||||
if (outType.getDimSize(0) != batch || outType.getDimSize(1) != m || outType.getDimSize(2) != n)
|
||||
SmallVector<int64_t> outBatchShape(outType.getShape().begin(), outType.getShape().end() - 2);
|
||||
if (!llvm::equal(outBatchShape, *batchShape) || outType.getDimSize(outType.getRank() - 2) != m
|
||||
|| outType.getDimSize(outType.getRank() - 1) != n)
|
||||
return failure();
|
||||
}
|
||||
|
||||
Location loc = matmulOp.getLoc();
|
||||
bool useTransposedForm = isHostFoldableValue(matmulOp.getA()) && !isHostFoldableValue(matmulOp.getB());
|
||||
|
||||
Value lhs = matmulOp.getA();
|
||||
Value rhs = matmulOp.getB();
|
||||
Value lhs = collapseBatchDims(matmulOp.getA(), lhsBatch, m, k, rewriter, loc);
|
||||
Value rhs = collapseBatchDims(matmulOp.getB(), rhsBatch, k, n, rewriter, loc);
|
||||
int64_t lhsBatchForGemm = lhsBatch;
|
||||
int64_t rhsBatchForGemm = rhsBatch;
|
||||
int64_t gemmM = m;
|
||||
@@ -239,6 +335,7 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||
}
|
||||
|
||||
Value result = concatValues(batchResults, /*axis=*/0, rewriter, loc);
|
||||
result = expandBatchDims(result, outType, batchShape->size(), rewriter, loc);
|
||||
rewriter.replaceOp(matmulOp, result);
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -4,9 +4,9 @@
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/APFloat.h"
|
||||
#include "llvm/ADT/APInt.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <optional>
|
||||
@@ -47,8 +47,8 @@ static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Loca
|
||||
return tensor::InsertSliceOp::create(rewriter, loc, tile, empty, offsets, sizes, strides);
|
||||
}
|
||||
|
||||
static Value createPoolFillElement(
|
||||
ConversionPatternRewriter& rewriter, Location loc, Type elementType, bool useMinimumValue) {
|
||||
static Value
|
||||
createPoolFillElement(ConversionPatternRewriter& rewriter, Location loc, Type elementType, bool useMinimumValue) {
|
||||
if (!useMinimumValue)
|
||||
return arith::ConstantOp::create(rewriter, loc, elementType, rewriter.getZeroAttr(elementType));
|
||||
|
||||
@@ -65,8 +65,10 @@ static Value createPoolFillElement(
|
||||
llvm_unreachable("unsupported pool element type");
|
||||
}
|
||||
|
||||
static Value createPoolFillTensor(
|
||||
ConversionPatternRewriter& rewriter, Location loc, RankedTensorType tensorType, bool useMinimumValue) {
|
||||
static Value createPoolFillTensor(ConversionPatternRewriter& rewriter,
|
||||
Location loc,
|
||||
RankedTensorType tensorType,
|
||||
bool useMinimumValue) {
|
||||
auto fillElement = createPoolFillElement(rewriter, loc, tensorType.getElementType(), useMinimumValue);
|
||||
return tensor::SplatOp::create(rewriter, loc, tensorType, fillElement);
|
||||
}
|
||||
@@ -90,10 +92,8 @@ static Value createPaddedPoolInput(ConversionPatternRewriter& rewriter,
|
||||
inputType.getDimSize(3) + padLeft + padRight},
|
||||
inputType.getElementType(),
|
||||
inputType.getEncoding());
|
||||
SmallVector<OpFoldResult> lowPads = {rewriter.getIndexAttr(0),
|
||||
rewriter.getIndexAttr(0),
|
||||
rewriter.getIndexAttr(padTop),
|
||||
rewriter.getIndexAttr(padLeft)};
|
||||
SmallVector<OpFoldResult> lowPads = {
|
||||
rewriter.getIndexAttr(0), rewriter.getIndexAttr(0), rewriter.getIndexAttr(padTop), rewriter.getIndexAttr(padLeft)};
|
||||
SmallVector<OpFoldResult> highPads = {rewriter.getIndexAttr(0),
|
||||
rewriter.getIndexAttr(0),
|
||||
rewriter.getIndexAttr(padBottom),
|
||||
@@ -104,8 +104,8 @@ static Value createPaddedPoolInput(ConversionPatternRewriter& rewriter,
|
||||
padBlock->addArgument(rewriter.getIndexType(), loc);
|
||||
padOp.getRegion().push_back(padBlock);
|
||||
rewriter.setInsertionPointToStart(padBlock);
|
||||
Value padValue = createPoolFillElement(
|
||||
rewriter, loc, inputType.getElementType(), std::is_same_v<PoolOp, ONNXMaxPoolSingleOutOp>);
|
||||
Value padValue =
|
||||
createPoolFillElement(rewriter, loc, inputType.getElementType(), std::is_same_v<PoolOp, ONNXMaxPoolSingleOutOp>);
|
||||
tensor::YieldOp::create(rewriter, loc, padValue);
|
||||
rewriter.setInsertionPointAfter(padOp);
|
||||
return padOp.getResult();
|
||||
@@ -279,7 +279,8 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
|
||||
constexpr size_t numInputs = 1;
|
||||
auto computeOp =
|
||||
createSpatCompute<numInputs>(rewriter, loc, outType, {}, ValueRange {x}, [&](Value xArg) -> LogicalResult {
|
||||
Value paddedInput = createPaddedPoolInput(rewriter, loc, poolOp, xArg, xType, padTop, padLeft, padBottom, padRight);
|
||||
Value paddedInput =
|
||||
createPaddedPoolInput(rewriter, loc, poolOp, xArg, xType, padTop, padLeft, padBottom, padRight);
|
||||
Value pooledOutputInit = tensor::EmptyOp::create(rewriter, loc, outType.getShape(), outType.getElementType());
|
||||
|
||||
Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
|
||||
@@ -307,8 +308,8 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
|
||||
for (int64_t channelTile = 0; channelTile < channelTileCount; ++channelTile) {
|
||||
const int64_t tileChannels = std::min<int64_t>(xbarSize, channels - channelTile * xbarSize);
|
||||
auto tileType = RankedTensorType::get({1, tileChannels, 1, 1}, outType.getElementType());
|
||||
Value reducedWindow = createPoolFillTensor(
|
||||
rewriter, loc, tileType, std::is_same_v<PoolOp, ONNXMaxPoolSingleOutOp>);
|
||||
Value reducedWindow =
|
||||
createPoolFillTensor(rewriter, loc, tileType, std::is_same_v<PoolOp, ONNXMaxPoolSingleOutOp>);
|
||||
|
||||
for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) {
|
||||
Value paddedInH = windowBaseH;
|
||||
@@ -324,18 +325,14 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
|
||||
paddedInW = arith::AddIOp::create(rewriter, loc, paddedInW, kernelWOffset);
|
||||
}
|
||||
|
||||
SmallVector<OpFoldResult> offsets = {batchIndex,
|
||||
rewriter.getIndexAttr(channelTile * xbarSize),
|
||||
paddedInH,
|
||||
paddedInW};
|
||||
SmallVector<OpFoldResult> offsets = {
|
||||
batchIndex, rewriter.getIndexAttr(channelTile * xbarSize), paddedInH, paddedInW};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(tileChannels),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1)};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1)};
|
||||
SmallVector<OpFoldResult> strides = {
|
||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
Value windowValue =
|
||||
tensor::ExtractSliceOp::create(rewriter, loc, tileType, paddedInput, offsets, sizes, strides);
|
||||
windowValue = materializeContiguousTile(rewriter, loc, windowValue);
|
||||
@@ -344,36 +341,28 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
|
||||
}
|
||||
|
||||
if constexpr (std::is_same_v<PoolOp, ONNXAveragePoolOp>) {
|
||||
SmallVector<OpFoldResult> scaleOffsets = {rewriter.getIndexAttr(0),
|
||||
rewriter.getIndexAttr(channelTile * xbarSize),
|
||||
outHeightIndex,
|
||||
outWidthIndex};
|
||||
SmallVector<OpFoldResult> scaleOffsets = {
|
||||
rewriter.getIndexAttr(0), rewriter.getIndexAttr(channelTile * xbarSize), outHeightIndex, outWidthIndex};
|
||||
SmallVector<OpFoldResult> scaleSizes = {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(tileChannels),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1)};
|
||||
SmallVector<OpFoldResult> scaleStrides = {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1)};
|
||||
SmallVector<OpFoldResult> scaleStrides = {
|
||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
Value scaleSlice = tensor::ExtractSliceOp::create(
|
||||
rewriter, loc, tileType, averageScaleTensor, scaleOffsets, scaleSizes, scaleStrides);
|
||||
scaleSlice = materializeContiguousTile(rewriter, loc, scaleSlice);
|
||||
reducedWindow = spatial::SpatVMulOp::create(rewriter, loc, tileType, reducedWindow, scaleSlice);
|
||||
}
|
||||
|
||||
SmallVector<OpFoldResult> outputOffsets = {batchIndex,
|
||||
rewriter.getIndexAttr(channelTile * xbarSize),
|
||||
outHeightIndex,
|
||||
outWidthIndex};
|
||||
SmallVector<OpFoldResult> outputOffsets = {
|
||||
batchIndex, rewriter.getIndexAttr(channelTile * xbarSize), outHeightIndex, outWidthIndex};
|
||||
SmallVector<OpFoldResult> outputSizes = {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(tileChannels),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1)};
|
||||
SmallVector<OpFoldResult> outputStrides = {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(1)};
|
||||
SmallVector<OpFoldResult> outputStrides = {
|
||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
updatedOutput = tensor::InsertSliceOp::create(
|
||||
rewriter, loc, reducedWindow, updatedOutput, outputOffsets, outputSizes, outputStrides);
|
||||
}
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -22,53 +23,83 @@ static SmallVector<int64_t> permuteShape(ArrayRef<int64_t> shape, ArrayRef<int64
|
||||
return permutedShape;
|
||||
}
|
||||
|
||||
static Value createSoftmaxCompute(Value input, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
static Value buildLoopSoftmaxSlice(Value input,
|
||||
Value accumulator,
|
||||
RankedTensorType inputType,
|
||||
ArrayRef<Value> outerIndices,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
int64_t rank = inputType.getRank();
|
||||
SmallVector<int64_t> sliceShape(static_cast<size_t>(rank - 1), 1);
|
||||
sliceShape.push_back(inputType.getDimSize(rank - 1));
|
||||
auto sliceType = RankedTensorType::get(sliceShape, inputType.getElementType(), inputType.getEncoding());
|
||||
|
||||
SmallVector<OpFoldResult> offsets;
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
|
||||
offsets.reserve(rank);
|
||||
sizes.reserve(rank);
|
||||
|
||||
for (Value outerIndex : outerIndices) {
|
||||
offsets.push_back(outerIndex);
|
||||
sizes.push_back(rewriter.getIndexAttr(1));
|
||||
}
|
||||
offsets.push_back(rewriter.getIndexAttr(0));
|
||||
sizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(rank - 1)));
|
||||
|
||||
Value inputSlice = tensor::ExtractSliceOp::create(rewriter, loc, sliceType, input, offsets, sizes, strides);
|
||||
Value softmaxSlice = spatial::SpatSoftmaxOp::create(rewriter, loc, sliceType, inputSlice).getResult();
|
||||
return tensor::InsertSliceOp::create(rewriter, loc, softmaxSlice, accumulator, offsets, sizes, strides);
|
||||
}
|
||||
|
||||
static Value buildLoopSoftmaxNest(Value input,
|
||||
Value accumulator,
|
||||
RankedTensorType inputType,
|
||||
int64_t axis,
|
||||
SmallVectorImpl<Value>& outerIndices,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
if (axis == inputType.getRank() - 1)
|
||||
return buildLoopSoftmaxSlice(input, accumulator, inputType, outerIndices, rewriter, loc);
|
||||
|
||||
Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
|
||||
Value c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
|
||||
Value cUpper = arith::ConstantIndexOp::create(rewriter, loc, inputType.getDimSize(axis));
|
||||
|
||||
auto loop = scf::ForOp::create(rewriter, loc, c0, cUpper, c1, ValueRange {accumulator});
|
||||
rewriter.setInsertionPointToStart(loop.getBody());
|
||||
|
||||
Value loopIndex = loop.getInductionVar();
|
||||
Value loopAccumulator = loop.getRegionIterArgs().front();
|
||||
outerIndices.push_back(loopIndex);
|
||||
Value updatedAccumulator =
|
||||
buildLoopSoftmaxNest(input, loopAccumulator, inputType, axis + 1, outerIndices, rewriter, loc);
|
||||
outerIndices.pop_back();
|
||||
|
||||
scf::YieldOp::create(rewriter, loc, updatedAccumulator);
|
||||
rewriter.setInsertionPointAfter(loop);
|
||||
return loop.getResult(0);
|
||||
}
|
||||
|
||||
static Value createLoopSoftmaxCompute(Value input, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
auto inputType = cast<RankedTensorType>(input.getType());
|
||||
constexpr size_t numInputs = 1;
|
||||
auto computeOp =
|
||||
createSpatCompute<numInputs>(rewriter, loc, TypeRange {inputType}, {}, ValueRange {input}, [&](Value x) {
|
||||
auto softmaxOp = spatial::SpatSoftmaxOp::create(rewriter, loc, inputType, x);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, softmaxOp.getResult());
|
||||
if (inputType.getRank() == 1) {
|
||||
Value softmax = spatial::SpatSoftmaxOp::create(rewriter, loc, inputType, x).getResult();
|
||||
spatial::SpatYieldOp::create(rewriter, loc, softmax);
|
||||
return;
|
||||
}
|
||||
|
||||
Value outputInit = tensor::EmptyOp::create(rewriter, loc, inputType.getShape(), inputType.getElementType());
|
||||
SmallVector<Value> outerIndices;
|
||||
Value result = buildLoopSoftmaxNest(x, outputInit, inputType, /*axis=*/0, outerIndices, rewriter, loc);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, result);
|
||||
});
|
||||
return computeOp.getResult(0);
|
||||
}
|
||||
|
||||
static Value concatValues(ValueRange inputs, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
auto firstType = cast<RankedTensorType>(inputs.front().getType());
|
||||
SmallVector<int64_t> outputShape(firstType.getShape().begin(), firstType.getShape().end());
|
||||
int64_t concatDimSize = 0;
|
||||
for (Value input : inputs)
|
||||
concatDimSize += cast<RankedTensorType>(input.getType()).getDimSize(axis);
|
||||
outputShape[axis] = concatDimSize;
|
||||
auto resultType = RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
|
||||
|
||||
if (llvm::all_of(inputs, isHostFoldableValue))
|
||||
return createSpatConcat(rewriter, loc, axis, inputs);
|
||||
|
||||
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {resultType}, {}, inputs, [&](ValueRange args) {
|
||||
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, axis, args));
|
||||
});
|
||||
return concatCompute.getResult(0);
|
||||
}
|
||||
|
||||
static Value
|
||||
buildSoftmax(Value input, int64_t softmaxAxis, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
auto inputType = cast<RankedTensorType>(input.getType());
|
||||
if (axis == inputType.getRank())
|
||||
return createSoftmaxCompute(input, rewriter, loc);
|
||||
|
||||
if (axis == softmaxAxis)
|
||||
return buildSoftmax(input, softmaxAxis, axis + 1, rewriter, loc);
|
||||
|
||||
SmallVector<Value> slices = sliceTensor(input, axis, /*sliceSize=*/1, rewriter, loc);
|
||||
SmallVector<Value> rebuiltSlices;
|
||||
rebuiltSlices.reserve(slices.size());
|
||||
for (Value slice : slices)
|
||||
rebuiltSlices.push_back(buildSoftmax(slice, softmaxAxis, axis + 1, rewriter, loc));
|
||||
|
||||
return concatValues(rebuiltSlices, axis, rewriter, loc);
|
||||
}
|
||||
|
||||
struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
@@ -86,7 +117,7 @@ struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
|
||||
Value input = adaptor.getInput();
|
||||
Value result;
|
||||
if (axis == inputType.getRank() - 1) {
|
||||
result = buildSoftmax(input, axis, /*axis=*/0, rewriter, softmaxOp.getLoc());
|
||||
result = createLoopSoftmaxCompute(input, rewriter, softmaxOp.getLoc());
|
||||
}
|
||||
else {
|
||||
SmallVector<int64_t> permutation;
|
||||
@@ -109,8 +140,7 @@ struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
|
||||
spatial::SpatYieldOp::create(rewriter, softmaxOp.getLoc(), transposed);
|
||||
});
|
||||
Value transposedInput = preTransposeCompute.getResult(0);
|
||||
Value transposedResult = buildSoftmax(
|
||||
transposedInput, /*softmaxAxis=*/inputType.getRank() - 1, /*axis=*/0, rewriter, softmaxOp.getLoc());
|
||||
Value transposedResult = createLoopSoftmaxCompute(transposedInput, rewriter, softmaxOp.getLoc());
|
||||
auto postTransposeCompute =
|
||||
createSpatCompute<1>(rewriter, softmaxOp.getLoc(), TypeRange {inputType}, {}, transposedResult, [&](Value x) {
|
||||
Value transposed = ONNXTransposeOp::create(
|
||||
|
||||
@@ -80,6 +80,22 @@ static bool inferExpandReassociation(ArrayRef<int64_t> sourceShape,
|
||||
return sourceIdx == sourceShape.size() && resultIdx == resultShape.size();
|
||||
}
|
||||
|
||||
static SmallVector<ReassociationIndices> getCollapseTo1DReassociation(size_t rank) {
|
||||
SmallVector<ReassociationIndices> reassociation(1);
|
||||
reassociation.front().reserve(rank);
|
||||
for (size_t dim = 0; dim < rank; ++dim)
|
||||
reassociation.front().push_back(static_cast<int64_t>(dim));
|
||||
return reassociation;
|
||||
}
|
||||
|
||||
static SmallVector<ReassociationIndices> getExpandFrom1DReassociation(size_t rank) {
|
||||
SmallVector<ReassociationIndices> reassociation(1);
|
||||
reassociation.front().reserve(rank);
|
||||
for (size_t dim = 0; dim < rank; ++dim)
|
||||
reassociation.front().push_back(static_cast<int64_t>(dim));
|
||||
return reassociation;
|
||||
}
|
||||
|
||||
struct Reshape : OpConversionPattern<ONNXReshapeOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
@@ -126,6 +142,23 @@ struct Reshape : OpConversionPattern<ONNXReshapeOp> {
|
||||
return tensor::ExpandShapeOp::create(rewriter, reshapeOp.getLoc(), resultType, data, reassociation);
|
||||
});
|
||||
|
||||
if (sourceType.getNumElements() != resultType.getNumElements())
|
||||
return failure();
|
||||
|
||||
return replaceWithReshape([&](Value data) -> Value {
|
||||
Value reshaped = data;
|
||||
if (sourceType.getRank() != 1) {
|
||||
auto flatType = RankedTensorType::get({sourceType.getNumElements()}, sourceType.getElementType());
|
||||
reshaped = tensor::CollapseShapeOp::create(
|
||||
rewriter, reshapeOp.getLoc(), flatType, reshaped, getCollapseTo1DReassociation(sourceType.getRank()));
|
||||
}
|
||||
if (resultType.getRank() == 1)
|
||||
return reshaped;
|
||||
return tensor::ExpandShapeOp::create(
|
||||
rewriter, reshapeOp.getLoc(), resultType, reshaped, getExpandFrom1DReassociation(resultType.getRank()))
|
||||
.getResult();
|
||||
});
|
||||
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
@@ -15,42 +15,88 @@ using namespace mlir;
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
static Value
|
||||
extractSliceAt(Value input, int64_t axis, int64_t offset, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
auto inputType = cast<RankedTensorType>(input.getType());
|
||||
SmallVector<OpFoldResult> offsets(inputType.getRank(), rewriter.getIndexAttr(0));
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
SmallVector<OpFoldResult> strides(inputType.getRank(), rewriter.getIndexAttr(1));
|
||||
sizes.reserve(inputType.getRank());
|
||||
for (int64_t dim : inputType.getShape())
|
||||
sizes.push_back(rewriter.getIndexAttr(dim));
|
||||
offsets[axis] = rewriter.getIndexAttr(offset);
|
||||
sizes[axis] = rewriter.getIndexAttr(1);
|
||||
return tensor::ExtractSliceOp::create(rewriter, loc, input, offsets, sizes, strides);
|
||||
static Value buildNearestAsymmetricIndex(
|
||||
Value outputIndex, int64_t inputDim, int64_t outputDim, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
Value cInputDim = arith::ConstantIndexOp::create(rewriter, loc, inputDim);
|
||||
Value cOutputDim = arith::ConstantIndexOp::create(rewriter, loc, outputDim);
|
||||
Value cInputDimLast = arith::ConstantIndexOp::create(rewriter, loc, inputDim - 1);
|
||||
Value scaledIndex = arith::MulIOp::create(rewriter, loc, outputIndex, cInputDim);
|
||||
Value inputIndex = arith::DivUIOp::create(rewriter, loc, scaledIndex, cOutputDim);
|
||||
return arith::MinUIOp::create(rewriter, loc, inputIndex, cInputDimLast);
|
||||
}
|
||||
|
||||
static int64_t nearestAsymmetricIndex(int64_t outputIndex, int64_t inputDim, int64_t outputDim) {
|
||||
return std::min<int64_t>((outputIndex * inputDim) / outputDim, inputDim - 1);
|
||||
}
|
||||
|
||||
static Value buildNearestResize(Value input,
|
||||
ArrayRef<int64_t> inputShape,
|
||||
ArrayRef<int64_t> outputShape,
|
||||
int64_t axis,
|
||||
static Value buildNearestResizeLoop(Value input,
|
||||
RankedTensorType inputType,
|
||||
RankedTensorType resultType,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
if (axis == static_cast<int64_t>(outputShape.size()))
|
||||
return input;
|
||||
auto elemType = resultType.getElementType();
|
||||
SmallVector<int64_t> unitShape(resultType.getRank(), 1);
|
||||
auto unitTensorType = RankedTensorType::get(unitShape, elemType);
|
||||
|
||||
SmallVector<Value> slices;
|
||||
slices.reserve(outputShape[axis]);
|
||||
for (int64_t outputIndex = 0; outputIndex < outputShape[axis]; ++outputIndex) {
|
||||
int64_t inputIndex = nearestAsymmetricIndex(outputIndex, inputShape[axis], outputShape[axis]);
|
||||
Value slice = extractSliceAt(input, axis, inputIndex, rewriter, loc);
|
||||
slices.push_back(buildNearestResize(slice, inputShape, outputShape, axis + 1, rewriter, loc));
|
||||
}
|
||||
SmallVector<OpFoldResult> unitSizes(resultType.getRank(), rewriter.getIndexAttr(1));
|
||||
SmallVector<OpFoldResult> unitStrides(resultType.getRank(), rewriter.getIndexAttr(1));
|
||||
|
||||
return createSpatConcat(rewriter, loc, axis, slices);
|
||||
Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
|
||||
Value c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
|
||||
Value cOutputN = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(0));
|
||||
Value cOutputC = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(1));
|
||||
Value cOutputH = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(2));
|
||||
Value cOutputW = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(3));
|
||||
|
||||
Value outputInit = tensor::EmptyOp::create(rewriter, loc, resultType.getShape(), elemType);
|
||||
|
||||
auto batchLoop = scf::ForOp::create(rewriter, loc, c0, cOutputN, c1, ValueRange {outputInit});
|
||||
rewriter.setInsertionPointToStart(batchLoop.getBody());
|
||||
|
||||
Value outputN = batchLoop.getInductionVar();
|
||||
Value outputBatchAcc = batchLoop.getRegionIterArgs().front();
|
||||
Value inputN = buildNearestAsymmetricIndex(outputN, inputType.getDimSize(0), resultType.getDimSize(0), rewriter, loc);
|
||||
|
||||
auto channelLoop = scf::ForOp::create(rewriter, loc, c0, cOutputC, c1, ValueRange {outputBatchAcc});
|
||||
rewriter.setInsertionPointToStart(channelLoop.getBody());
|
||||
|
||||
Value outputC = channelLoop.getInductionVar();
|
||||
Value outputChannelAcc = channelLoop.getRegionIterArgs().front();
|
||||
Value inputC =
|
||||
buildNearestAsymmetricIndex(outputC, inputType.getDimSize(1), resultType.getDimSize(1), rewriter, loc);
|
||||
|
||||
auto heightLoop = scf::ForOp::create(rewriter, loc, c0, cOutputH, c1, ValueRange {outputChannelAcc});
|
||||
rewriter.setInsertionPointToStart(heightLoop.getBody());
|
||||
|
||||
Value outputH = heightLoop.getInductionVar();
|
||||
Value outputHeightAcc = heightLoop.getRegionIterArgs().front();
|
||||
Value inputH =
|
||||
buildNearestAsymmetricIndex(outputH, inputType.getDimSize(2), resultType.getDimSize(2), rewriter, loc);
|
||||
|
||||
auto widthLoop = scf::ForOp::create(rewriter, loc, c0, cOutputW, c1, ValueRange {outputHeightAcc});
|
||||
rewriter.setInsertionPointToStart(widthLoop.getBody());
|
||||
|
||||
Value outputW = widthLoop.getInductionVar();
|
||||
Value outputWidthAcc = widthLoop.getRegionIterArgs().front();
|
||||
Value inputW =
|
||||
buildNearestAsymmetricIndex(outputW, inputType.getDimSize(3), resultType.getDimSize(3), rewriter, loc);
|
||||
|
||||
SmallVector<OpFoldResult> inputOffsets = {inputN, inputC, inputH, inputW};
|
||||
Value inputSlice =
|
||||
tensor::ExtractSliceOp::create(rewriter, loc, unitTensorType, input, inputOffsets, unitSizes, unitStrides);
|
||||
|
||||
SmallVector<OpFoldResult> outputOffsets = {outputN, outputC, outputH, outputW};
|
||||
Value updatedOutput =
|
||||
tensor::InsertSliceOp::create(rewriter, loc, inputSlice, outputWidthAcc, outputOffsets, unitSizes, unitStrides);
|
||||
scf::YieldOp::create(rewriter, loc, updatedOutput);
|
||||
|
||||
rewriter.setInsertionPointAfter(widthLoop);
|
||||
scf::YieldOp::create(rewriter, loc, widthLoop.getResult(0));
|
||||
|
||||
rewriter.setInsertionPointAfter(heightLoop);
|
||||
scf::YieldOp::create(rewriter, loc, heightLoop.getResult(0));
|
||||
|
||||
rewriter.setInsertionPointAfter(channelLoop);
|
||||
scf::YieldOp::create(rewriter, loc, channelLoop.getResult(0));
|
||||
|
||||
rewriter.setInsertionPointAfter(batchLoop);
|
||||
return batchLoop.getResult(0);
|
||||
}
|
||||
|
||||
struct Resize : OpConversionPattern<ONNXResizeOp> {
|
||||
@@ -62,20 +108,22 @@ struct Resize : OpConversionPattern<ONNXResizeOp> {
|
||||
auto inputType = dyn_cast<RankedTensorType>(adaptor.getX().getType());
|
||||
auto resultType = dyn_cast<RankedTensorType>(resizeOp.getY().getType());
|
||||
if (!inputType || !resultType || !inputType.hasStaticShape() || !resultType.hasStaticShape())
|
||||
return failure();
|
||||
return rewriter.notifyMatchFailure(resizeOp, "resize lowering requires static ranked tensor types.");
|
||||
if (inputType.getRank() != 4 || resultType.getRank() != 4)
|
||||
return rewriter.notifyMatchFailure(resizeOp, "resize lowering currently supports only rank-4 NCHW tensors.");
|
||||
|
||||
if (resizeOp.getMode() != "nearest" || resizeOp.getCoordinateTransformationMode() != "asymmetric"
|
||||
|| resizeOp.getNearestMode() != "floor")
|
||||
return failure();
|
||||
return rewriter.notifyMatchFailure(
|
||||
resizeOp, "resize lowering currently supports only nearest + asymmetric + floor.");
|
||||
|
||||
if (llvm::any_of(inputType.getShape(), [](int64_t dim) { return dim <= 0; })
|
||||
|| llvm::any_of(resultType.getShape(), [](int64_t dim) { return dim <= 0; }))
|
||||
return failure();
|
||||
return rewriter.notifyMatchFailure(resizeOp, "resize lowering requires positive static dimensions.");
|
||||
|
||||
auto computeOp =
|
||||
createSpatCompute<1>(rewriter, resizeOp.getLoc(), TypeRange {resultType}, {}, adaptor.getX(), [&](Value x) {
|
||||
Value result =
|
||||
buildNearestResize(x, inputType.getShape(), resultType.getShape(), /*axis=*/0, rewriter, resizeOp.getLoc());
|
||||
Value result = buildNearestResizeLoop(x, inputType, resultType, rewriter, resizeOp.getLoc());
|
||||
spatial::SpatYieldOp::create(rewriter, resizeOp.getLoc(), result);
|
||||
});
|
||||
rewriter.replaceOp(resizeOp, computeOp.getResults());
|
||||
|
||||
@@ -31,6 +31,21 @@ static bool isDirectConstantValue(Value value) {
|
||||
return isa_and_nonnull<arith::ConstantOp, ONNXConstantOp>(value.getDefiningOp());
|
||||
}
|
||||
|
||||
template <typename ComputeOpTy>
|
||||
static bool hasPromotableWeightLikeInputs(ComputeOpTy compute) {
|
||||
Block& block = compute.getBody().front();
|
||||
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||
if (inputIdx >= block.getNumArguments())
|
||||
continue;
|
||||
if (!isWeightLikeComputeOperand(input))
|
||||
continue;
|
||||
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(block.getArgument(inputIdx)))
|
||||
continue;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Collapses one-lane batches so later phases do not carry batch-only structure unnecessarily.
|
||||
struct FoldSingleLaneComputeBatchPattern : OpRewritePattern<spatial::SpatComputeBatch> {
|
||||
using OpRewritePattern<spatial::SpatComputeBatch>::OpRewritePattern;
|
||||
@@ -262,4 +277,10 @@ void annotateWeightsConstants(func::FuncOp funcOp) {
|
||||
});
|
||||
}
|
||||
|
||||
bool requiresEarlyPostRewrite(spatial::SpatComputeBatch batchOp) { return batchOp.getLaneCount() == 1; }
|
||||
|
||||
bool requiresPostRewrite(spatial::SpatCompute computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
|
||||
|
||||
bool requiresPostRewrite(spatial::SpatComputeBatch computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -3,8 +3,16 @@
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
bool requiresEarlyPostRewrite(spatial::SpatComputeBatch batchOp);
|
||||
|
||||
bool requiresPostRewrite(spatial::SpatCompute computeOp);
|
||||
|
||||
bool requiresPostRewrite(spatial::SpatComputeBatch computeOp);
|
||||
|
||||
void populateEarlyPostPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populatePostPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
@@ -17,9 +17,7 @@ void populatePrePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* c
|
||||
patterns.add<convAddToConvWithBiasLeft>(ctx);
|
||||
patterns.add<convAddToConvWithBiasRight>(ctx);
|
||||
patterns.add<matMulAddToGemm>(ctx);
|
||||
patterns.add<matMulToGemm>(ctx);
|
||||
patterns.add<removeFlattenSameShape>(ctx);
|
||||
populateMatMulRewritePatterns(patterns, ctx);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -202,6 +202,7 @@ void SpatialToGraphvizPass::runOnOperation() {
|
||||
|
||||
auto entryFunc = getPimEntryFunc(module);
|
||||
if (failed(entryFunc)) {
|
||||
module.emitError("failed to locate the PIM entry function for Spatial graph visualization");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -9,12 +9,14 @@
|
||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/SymbolTable.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
|
||||
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <utility>
|
||||
@@ -68,183 +70,81 @@ private:
|
||||
|
||||
} // namespace
|
||||
|
||||
static int32_t translateSpatialCoreIdToPimCoreId(size_t spatialCoreId) { return static_cast<int32_t>(spatialCoreId); }
|
||||
static memref::GlobalOp getOrCreateZeroGlobal(IRRewriter& rewriter, Location loc, RankedTensorType tensorType) {
|
||||
auto moduleOp = rewriter.getBlock()->getParentOp()->getParentOfType<ModuleOp>();
|
||||
auto memRefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType());
|
||||
auto zeroAttr = DenseElementsAttr::get(tensorType, rewriter.getZeroAttr(tensorType.getElementType()));
|
||||
|
||||
static void lowerChannelSend(spatial::SpatChannelSendOp sendOp, IRRewriter& rewriter) {
|
||||
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, sendOp.getInput());
|
||||
auto targetCoreIdAttr = rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(sendOp.getTargetCoreId()));
|
||||
|
||||
rewriter.setInsertionPoint(sendOp);
|
||||
PimSendOp::create(rewriter, sendOp.getLoc(), sendOp.getInput(), sizeAttr, targetCoreIdAttr);
|
||||
rewriter.eraseOp(sendOp);
|
||||
}
|
||||
|
||||
static void lowerChannelReceive(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter) {
|
||||
if (receiveOp->use_empty()) {
|
||||
rewriter.eraseOp(receiveOp);
|
||||
return;
|
||||
for (auto globalOp : moduleOp.getOps<memref::GlobalOp>()) {
|
||||
if (!globalOp.getConstant() || globalOp.getType() != memRefType || !globalOp.getInitialValue())
|
||||
continue;
|
||||
if (dyn_cast<DenseElementsAttr>(*globalOp.getInitialValue()) == zeroAttr)
|
||||
return globalOp;
|
||||
}
|
||||
|
||||
auto outputType = cast<ShapedType>(receiveOp.getResult().getType());
|
||||
rewriter.setInsertionPoint(receiveOp);
|
||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveOp.getLoc(), outputType);
|
||||
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, receiveOp.getResult());
|
||||
auto sourceCoreIdAttr = rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(receiveOp.getSourceCoreId()));
|
||||
std::string nameStem;
|
||||
llvm::raw_string_ostream nameStream(nameStem);
|
||||
nameStream << "__pim_zero_" << tensorType.getRank() << "d_" << tensorType.getNumElements();
|
||||
nameStream.flush();
|
||||
|
||||
Value received =
|
||||
PimReceiveOp::create(rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr)
|
||||
std::string symbolName = nameStem;
|
||||
unsigned suffix = 0;
|
||||
while (SymbolTable::lookupSymbolIn(moduleOp, symbolName))
|
||||
symbolName = (nameStem + "_" + Twine(suffix++)).str();
|
||||
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(moduleOp.getBody());
|
||||
return memref::GlobalOp::create(rewriter,
|
||||
loc,
|
||||
rewriter.getStringAttr(symbolName),
|
||||
rewriter.getStringAttr("private"),
|
||||
TypeAttr::get(memRefType),
|
||||
zeroAttr,
|
||||
rewriter.getUnitAttr(),
|
||||
IntegerAttr {});
|
||||
}
|
||||
|
||||
static Value createZeroedDeviceHVector(IRRewriter& rewriter, Location loc, RankedTensorType tensorType) {
|
||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, tensorType);
|
||||
auto zeroGlobal = getOrCreateZeroGlobal(rewriter, loc, tensorType);
|
||||
auto zeroValue = memref::GetGlobalOp::create(rewriter, loc, zeroGlobal.getType(), zeroGlobal.getName());
|
||||
auto zeroAttr = rewriter.getI32IntegerAttr(0);
|
||||
auto sizeAttr = rewriter.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(tensorType)));
|
||||
|
||||
if (outputBuffer->getParentOfType<PimCoreBatchOp>())
|
||||
return PimMemCopyHostToDevBatchOp::create(
|
||||
rewriter, loc, tensorType, outputBuffer, zeroValue, zeroAttr, zeroAttr, sizeAttr)
|
||||
.getOutput();
|
||||
rewriter.replaceOp(receiveOp, received);
|
||||
}
|
||||
|
||||
static void lowerChannelSendTensor(spatial::SpatChannelSendTensorOp sendTensorOp, IRRewriter& rewriter) {
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
targetCoreIds.reserve(sendTensorOp.getTargetCoreIds().size());
|
||||
for (int32_t targetCoreId : sendTensorOp.getTargetCoreIds())
|
||||
targetCoreIds.push_back(translateSpatialCoreIdToPimCoreId(targetCoreId));
|
||||
|
||||
rewriter.setInsertionPoint(sendTensorOp);
|
||||
PimSendTensorOp::create(
|
||||
rewriter, sendTensorOp.getLoc(), sendTensorOp.getInput(), rewriter.getDenseI32ArrayAttr(targetCoreIds));
|
||||
rewriter.eraseOp(sendTensorOp);
|
||||
}
|
||||
|
||||
static void lowerChannelReceiveTensor(spatial::SpatChannelReceiveTensorOp receiveTensorOp, IRRewriter& rewriter) {
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
sourceCoreIds.reserve(receiveTensorOp.getSourceCoreIds().size());
|
||||
for (int32_t sourceCoreId : receiveTensorOp.getSourceCoreIds())
|
||||
sourceCoreIds.push_back(translateSpatialCoreIdToPimCoreId(sourceCoreId));
|
||||
|
||||
rewriter.setInsertionPoint(receiveTensorOp);
|
||||
auto outputType = cast<ShapedType>(receiveTensorOp.getOutput().getType());
|
||||
Value outputBuffer = createEmptyTensorFromShaped(rewriter, receiveTensorOp.getLoc(), outputType).getResult();
|
||||
Value received = PimReceiveTensorOp::create(rewriter,
|
||||
receiveTensorOp.getLoc(),
|
||||
receiveTensorOp.getOutput().getType(),
|
||||
outputBuffer,
|
||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds))
|
||||
return PimMemCopyHostToDevOp::create(rewriter, loc, tensorType, outputBuffer, zeroValue, zeroAttr, zeroAttr, sizeAttr)
|
||||
.getOutput();
|
||||
rewriter.replaceOp(receiveTensorOp, received);
|
||||
}
|
||||
|
||||
static void lowerExtractRows(spatial::SpatExtractRowsOp extractRowsOp, IRRewriter& rewriter) {
|
||||
rewriter.setInsertionPoint(extractRowsOp);
|
||||
auto inputType = cast<RankedTensorType>(extractRowsOp.getInput().getType());
|
||||
SmallVector<Value> replacements;
|
||||
replacements.reserve(extractRowsOp.getNumResults());
|
||||
for (auto [rowIndex, output] : llvm::enumerate(extractRowsOp.getOutputs())) {
|
||||
auto outputType = cast<RankedTensorType>(output.getType());
|
||||
SmallVector<OpFoldResult> offsets = {
|
||||
rewriter.getIndexAttr(static_cast<int64_t>(rowIndex) * outputType.getDimSize(0)), rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(outputType.getDimSize(0)),
|
||||
rewriter.getIndexAttr(inputType.getDimSize(1))};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
replacements.push_back(
|
||||
tensor::ExtractSliceOp::create(
|
||||
rewriter, extractRowsOp.getLoc(), outputType, extractRowsOp.getInput(), offsets, sizes, strides)
|
||||
.getResult());
|
||||
}
|
||||
rewriter.replaceOp(extractRowsOp, replacements);
|
||||
}
|
||||
static Value padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector) {
|
||||
auto vectorType = cast<RankedTensorType>(vector.getType());
|
||||
ArrayRef<int64_t> shape = vectorType.getShape();
|
||||
assert(isHVectorShape(shape) && "expected a horizontal vector");
|
||||
assert(shape[1] <= static_cast<int64_t>(crossbarSize) && "vector width must fit in one crossbar");
|
||||
|
||||
static void compactSpatialTensorGroups(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||
SmallVector<spatial::SpatConcatOp> concatOps;
|
||||
funcOp.walk([&](spatial::SpatConcatOp concatOp) { concatOps.push_back(concatOp); });
|
||||
for (auto concatOp : concatOps) {
|
||||
if (concatOp.getAxis() != 0 || concatOp.getInputs().empty())
|
||||
continue;
|
||||
if (shape[1] == static_cast<int64_t>(crossbarSize))
|
||||
return vector;
|
||||
|
||||
SmallVector<Value> packedInputs;
|
||||
bool changed = false;
|
||||
rewriter.setInsertionPoint(concatOp);
|
||||
|
||||
for (unsigned index = 0; index < concatOp.getInputs().size();) {
|
||||
Value input = concatOp.getInputs()[index];
|
||||
|
||||
if (input.getDefiningOp<tensor::ExtractSliceOp>()) {
|
||||
unsigned endIndex = index + 1;
|
||||
while (endIndex < concatOp.getInputs().size()
|
||||
&& concatOp.getInputs()[endIndex].getDefiningOp<tensor::ExtractSliceOp>())
|
||||
++endIndex;
|
||||
|
||||
Value packedInput = createPackedExtractSliceTensor(
|
||||
concatOp.getInputs().slice(index, endIndex - index), rewriter, concatOp.getLoc());
|
||||
if (packedInput) {
|
||||
packedInputs.push_back(packedInput);
|
||||
changed = true;
|
||||
index = endIndex;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
auto result = dyn_cast<OpResult>(input);
|
||||
if (!result) {
|
||||
packedInputs.push_back(input);
|
||||
++index;
|
||||
continue;
|
||||
}
|
||||
|
||||
Operation* owner = result.getOwner();
|
||||
unsigned startIndex = result.getResultNumber();
|
||||
unsigned endIndex = index + 1;
|
||||
while (endIndex < concatOp.getInputs().size()) {
|
||||
auto nextResult = dyn_cast<OpResult>(concatOp.getInputs()[endIndex]);
|
||||
if (!nextResult || nextResult.getOwner() != owner
|
||||
|| nextResult.getResultNumber() != startIndex + (endIndex - index))
|
||||
break;
|
||||
++endIndex;
|
||||
}
|
||||
|
||||
unsigned count = endIndex - index;
|
||||
Value packedInput;
|
||||
if (auto extractRowsOp = dyn_cast<spatial::SpatExtractRowsOp>(owner))
|
||||
packedInput = createPackedExtractRowsSlice(extractRowsOp, startIndex, count, rewriter, concatOp.getLoc());
|
||||
|
||||
if (packedInput) {
|
||||
packedInputs.push_back(packedInput);
|
||||
changed = true;
|
||||
}
|
||||
else {
|
||||
for (unsigned oldIndex = index; oldIndex < endIndex; ++oldIndex)
|
||||
packedInputs.push_back(concatOp.getInputs()[oldIndex]);
|
||||
}
|
||||
|
||||
index = endIndex;
|
||||
}
|
||||
|
||||
if (!changed)
|
||||
continue;
|
||||
|
||||
auto newConcat = pim::PimConcatOp::create(
|
||||
rewriter,
|
||||
concatOp.getLoc(),
|
||||
concatOp.getOutput().getType(),
|
||||
concatOp.getAxisAttr(),
|
||||
ValueRange(packedInputs),
|
||||
createEmptyTensorFromShaped(rewriter, concatOp.getLoc(), cast<ShapedType>(concatOp.getOutput().getType()))
|
||||
.getResult());
|
||||
rewriter.replaceOp(concatOp, newConcat.getOutput());
|
||||
}
|
||||
auto eraseUnusedOps = [&](auto tag) {
|
||||
using OpTy = decltype(tag);
|
||||
SmallVector<OpTy> ops;
|
||||
funcOp.walk([&](OpTy op) { ops.push_back(op); });
|
||||
for (auto op : llvm::reverse(ops))
|
||||
if (op->use_empty())
|
||||
rewriter.eraseOp(op);
|
||||
};
|
||||
eraseUnusedOps(tensor::ConcatOp {});
|
||||
eraseUnusedOps(tensor::ExtractSliceOp {});
|
||||
eraseUnusedOps(spatial::SpatExtractRowsOp {});
|
||||
auto paddedType = RankedTensorType::get(
|
||||
{shape[0], static_cast<int64_t>(crossbarSize)}, vectorType.getElementType(), vectorType.getEncoding());
|
||||
Value zeroed = createZeroedDeviceHVector(rewriter, loc, paddedType);
|
||||
auto zeroAttr = rewriter.getI32IntegerAttr(0);
|
||||
auto sizeAttr = rewriter.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(vectorType)));
|
||||
return PimMemCopyOp::create(rewriter, loc, paddedType, zeroed, vector, zeroAttr, zeroAttr, sizeAttr).getOutput();
|
||||
}
|
||||
|
||||
void SpatialToPimPass::runOnOperation() {
|
||||
coreId = 1;
|
||||
coreId = 0;
|
||||
ModuleOp moduleOp = getOperation();
|
||||
MLIRContext* ctx = moduleOp.getContext();
|
||||
|
||||
auto entryFunc = getPimEntryFunc(moduleOp);
|
||||
if (failed(entryFunc)) {
|
||||
moduleOp.emitError("failed to locate the PIM entry function during Spatial-to-PIM lowering");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
@@ -270,26 +170,22 @@ void SpatialToPimPass::runOnOperation() {
|
||||
spatial::SpatChannelSendTensorBatchOp,
|
||||
spatial::SpatExtractRowsOp>();
|
||||
|
||||
{
|
||||
RewritePatternSet patterns(ctx);
|
||||
populateWithGenerated(patterns);
|
||||
|
||||
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
|
||||
RewritePatternSet initialPatterns(ctx);
|
||||
populateWithGenerated(initialPatterns);
|
||||
if (failed(applyPartialConversion(moduleOp, target, std::move(initialPatterns)))) {
|
||||
moduleOp.emitError("failed to lower required Spatial ops to the initial PIM form");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
RewritePatternSet patterns(ctx);
|
||||
populateGlobalTensorMaterializationPatterns(patterns);
|
||||
|
||||
walkAndApplyPatterns(moduleOp, std::move(patterns));
|
||||
}
|
||||
RewritePatternSet globalTensorPatterns(ctx);
|
||||
populateGlobalTensorMaterializationPatterns(globalTensorPatterns);
|
||||
walkAndApplyPatterns(moduleOp, std::move(globalTensorPatterns));
|
||||
auto returnOp = cast<func::ReturnOp>(funcOp.front().getTerminator());
|
||||
|
||||
addReturnOutputBuffers(returnOp, rewriter, outputTensors);
|
||||
if (failed(allocateAndInitializeCoreLocalVariables(funcOp, rewriter))) {
|
||||
funcOp.emitOpError("failed to allocate or initialize core-local tensors during Spatial-to-PIM lowering");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
@@ -298,6 +194,7 @@ void SpatialToPimPass::runOnOperation() {
|
||||
for (auto computeOp : funcOp.getOps<spatial::SpatCompute>()) {
|
||||
markOpToRemove(computeOp);
|
||||
if (failed(lowerComputeOp(computeOp, coreLoweringState, rewriter))) {
|
||||
computeOp.emitOpError("failed to lower spat.compute to pim.core");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
@@ -306,12 +203,16 @@ void SpatialToPimPass::runOnOperation() {
|
||||
for (auto computeBatchOp : funcOp.getOps<spatial::SpatComputeBatch>()) {
|
||||
markOpToRemove(computeBatchOp);
|
||||
if (failed(lowerComputeBatchOp(computeBatchOp, coreLoweringState, rewriter))) {
|
||||
computeBatchOp.emitOpError("failed to lower spat.compute_batch to pim.core_batch");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
compactSpatialTensorGroups(funcOp, rewriter);
|
||||
RewritePatternSet initialTensorPackingPatterns(ctx);
|
||||
populateTensorPackingPatterns(initialTensorPackingPatterns);
|
||||
walkAndApplyPatterns(funcOp, std::move(initialTensorPackingPatterns));
|
||||
eraseUnusedTensorPackingOps(funcOp, rewriter);
|
||||
|
||||
SmallVector<spatial::SpatChannelReceiveOp> receiveOps;
|
||||
for (auto op : funcOp.getOps<spatial::SpatChannelReceiveOp>())
|
||||
@@ -323,38 +224,8 @@ void SpatialToPimPass::runOnOperation() {
|
||||
markOpToRemove(receiveOp);
|
||||
continue;
|
||||
}
|
||||
if (receiveOp->use_empty()) {
|
||||
rewriter.eraseOp(receiveOp);
|
||||
continue;
|
||||
}
|
||||
lowerChannelReceive(receiveOp, rewriter);
|
||||
}
|
||||
|
||||
SmallVector<spatial::SpatChannelReceiveTensorOp> receiveTensorOps;
|
||||
for (auto op : funcOp.getOps<spatial::SpatChannelReceiveTensorOp>())
|
||||
receiveTensorOps.push_back(op);
|
||||
for (auto receiveTensorOp : receiveTensorOps)
|
||||
lowerChannelReceiveTensor(receiveTensorOp, rewriter);
|
||||
|
||||
SmallVector<spatial::SpatChannelSendOp> sendOps;
|
||||
for (auto op : funcOp.getOps<spatial::SpatChannelSendOp>())
|
||||
sendOps.push_back(op);
|
||||
for (auto sendOp : sendOps)
|
||||
lowerChannelSend(sendOp, rewriter);
|
||||
|
||||
SmallVector<spatial::SpatChannelSendTensorOp> sendTensorOps;
|
||||
for (auto op : funcOp.getOps<spatial::SpatChannelSendTensorOp>())
|
||||
sendTensorOps.push_back(op);
|
||||
for (auto sendTensorOp : sendTensorOps)
|
||||
lowerChannelSendTensor(sendTensorOp, rewriter);
|
||||
|
||||
SmallVector<spatial::SpatExtractRowsOp> extractRowsOps;
|
||||
for (auto op : funcOp.getOps<spatial::SpatExtractRowsOp>())
|
||||
extractRowsOps.push_back(op);
|
||||
for (auto extractRowsOp : extractRowsOps)
|
||||
lowerExtractRows(extractRowsOp, rewriter);
|
||||
|
||||
{
|
||||
RewritePatternSet coreBodyPatterns(ctx);
|
||||
populateWithGenerated(coreBodyPatterns);
|
||||
FrozenRewritePatternSet frozenCoreBodyPatterns(std::move(coreBodyPatterns));
|
||||
@@ -363,6 +234,7 @@ void SpatialToPimPass::runOnOperation() {
|
||||
funcOp.walk([&](pim::PimCoreOp coreOp) { coreOps.push_back(coreOp); });
|
||||
for (auto coreOp : coreOps) {
|
||||
if (failed(applyFullConversion(coreOp.getOperation(), target, frozenCoreBodyPatterns))) {
|
||||
coreOp.emitOpError("failed to convert nested Spatial ops inside pim.core");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
@@ -372,11 +244,11 @@ void SpatialToPimPass::runOnOperation() {
|
||||
funcOp.walk([&](pim::PimCoreBatchOp coreBatchOp) { coreBatchOps.push_back(coreBatchOp); });
|
||||
for (auto coreBatchOp : coreBatchOps) {
|
||||
if (failed(applyFullConversion(coreBatchOp.getOperation(), target, frozenCoreBodyPatterns))) {
|
||||
coreBatchOp.emitOpError("failed to convert nested Spatial ops inside pim.core_batch");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
enlargeVMMOutTensorsToCrossbarSize(funcOp, rewriter);
|
||||
ReturnPathState returnPathState {outputTensors, operationsToRemove};
|
||||
@@ -384,13 +256,16 @@ void SpatialToPimPass::runOnOperation() {
|
||||
|
||||
SmallVector<Operation*> pendingRemovals(operationsToRemove.begin(), operationsToRemove.end());
|
||||
if (failed(erasePendingOps(pendingRemovals, rewriter))) {
|
||||
funcOp.emitOpError("failed to erase obsolete Spatial ops after lowering to PIM");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
compactSpatialTensorGroups(funcOp, rewriter);
|
||||
RewritePatternSet finalTensorPackingPatterns(ctx);
|
||||
populateTensorPackingPatterns(finalTensorPackingPatterns);
|
||||
walkAndApplyPatterns(funcOp, std::move(finalTensorPackingPatterns));
|
||||
eraseUnusedTensorPackingOps(funcOp, rewriter);
|
||||
|
||||
{
|
||||
ConversionTarget communicationTarget(*ctx);
|
||||
communicationTarget.addLegalDialect<PimDialect,
|
||||
tensor::TensorDialect,
|
||||
@@ -411,12 +286,13 @@ void SpatialToPimPass::runOnOperation() {
|
||||
RewritePatternSet communicationPatterns(ctx);
|
||||
populateChannelLoweringPatterns(communicationPatterns);
|
||||
if (failed(applyFullConversion(funcOp, communicationTarget, std::move(communicationPatterns)))) {
|
||||
funcOp.emitOpError("failed to lower Spatial communication ops to PIM communication ops");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (failed(verifySpatialToPimBoundary(moduleOp))) {
|
||||
moduleOp.emitError("Spatial-to-PIM boundary verification failed");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
@@ -426,54 +302,35 @@ void SpatialToPimPass::runOnOperation() {
|
||||
}
|
||||
|
||||
void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||
auto enlargeTiedDpsChain = [&](Value value, RankedTensorType newType, auto& self) -> void {
|
||||
auto* definingOp = value.getDefiningOp();
|
||||
if (!definingOp)
|
||||
return;
|
||||
auto dpsDefiningOp = dyn_cast<DestinationStyleOpInterface>(definingOp);
|
||||
if (!dpsDefiningOp)
|
||||
return;
|
||||
auto* tiedOperand = dpsDefiningOp.getTiedOpOperand(cast<OpResult>(value));
|
||||
if (!tiedOperand)
|
||||
return;
|
||||
Value tiedValue = tiedOperand->get();
|
||||
assert(tiedValue.hasOneUse() && "Tied DPS operand expected to have a single use");
|
||||
tiedValue.setType(newType);
|
||||
self(tiedValue, newType, self);
|
||||
};
|
||||
|
||||
funcOp.walk([&](PimVMMOp vmmOp) {
|
||||
auto outTensorOperand = vmmOp.getOutputBuffer();
|
||||
auto resultTensor = vmmOp.getOutput();
|
||||
auto outShape = getTensorShape(outTensorOperand);
|
||||
assert(isHVectorShape(outShape));
|
||||
if (outShape[1] != static_cast<int64_t>(crossbarSize)) {
|
||||
auto newShape = SmallVector<int64_t> {outShape[0], static_cast<int64_t>(crossbarSize)};
|
||||
auto newType = RankedTensorType::get(newShape, outTensorOperand.getType().getElementType());
|
||||
if (outTensorOperand == vmmOp.getInput()) {
|
||||
rewriter.setInsertionPoint(vmmOp);
|
||||
auto newOutputBuffer =
|
||||
tensor::EmptyOp::create(rewriter, vmmOp.getLoc(), newShape, outTensorOperand.getType().getElementType());
|
||||
vmmOp.getOutputBufferMutable().assign(newOutputBuffer);
|
||||
}
|
||||
else {
|
||||
enlargeTiedDpsChain(outTensorOperand, newType, enlargeTiedDpsChain);
|
||||
outTensorOperand.setType(newType);
|
||||
}
|
||||
resultTensor.setType(newType);
|
||||
auto outputType = cast<RankedTensorType>(vmmOp.getOutput().getType());
|
||||
ArrayRef<int64_t> outputShape = outputType.getShape();
|
||||
assert(isHVectorShape(outputShape) && "expected a horizontal vector output");
|
||||
assert(outputShape[1] <= static_cast<int64_t>(crossbarSize) && "output width must fit in one crossbar");
|
||||
|
||||
IntegerAttr zeroAttr = rewriter.getIndexAttr(0);
|
||||
IntegerAttr oneAttr = rewriter.getIndexAttr(1);
|
||||
IntegerAttr oldShapeZeroAttr = rewriter.getIndexAttr(outShape[0]);
|
||||
IntegerAttr oldShapeOneAttr = rewriter.getIndexAttr(outShape[1]);
|
||||
SmallVector<OpFoldResult> offsets = {zeroAttr, zeroAttr};
|
||||
SmallVector<OpFoldResult> sizes = {oldShapeZeroAttr, oldShapeOneAttr};
|
||||
SmallVector<OpFoldResult> strides = {oneAttr, oneAttr};
|
||||
rewriter.setInsertionPoint(vmmOp);
|
||||
Value paddedInput = padHVectorInputToCrossbarSize(rewriter, vmmOp.getLoc(), vmmOp.getInput());
|
||||
auto paddedOutputType = RankedTensorType::get(
|
||||
{outputShape[0], static_cast<int64_t>(crossbarSize)}, outputType.getElementType(), outputType.getEncoding());
|
||||
Value paddedOutputBuffer = outputShape[1] == static_cast<int64_t>(crossbarSize)
|
||||
? vmmOp.getOutputBuffer()
|
||||
: createEmptyTensorFromShaped(rewriter, vmmOp.getLoc(), paddedOutputType).getResult();
|
||||
vmmOp.getInputMutable().assign(paddedInput);
|
||||
vmmOp.getOutputBufferMutable().assign(paddedOutputBuffer);
|
||||
|
||||
vmmOp.getOutput().setType(paddedOutputType);
|
||||
|
||||
if (outputShape[1] == static_cast<int64_t>(crossbarSize))
|
||||
return;
|
||||
|
||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(outputShape[0]), rewriter.getIndexAttr(outputShape[1])};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
rewriter.setInsertionPointAfter(vmmOp);
|
||||
auto sliceOp = tensor::ExtractSliceOp::create(rewriter, vmmOp.getLoc(), resultTensor, offsets, sizes, strides);
|
||||
auto sliceOp =
|
||||
tensor::ExtractSliceOp::create(rewriter, vmmOp.getLoc(), outputType, vmmOp.getOutput(), offsets, sizes, strides);
|
||||
SmallPtrSet<Operation*, 2> exceptions = {vmmOp, sliceOp};
|
||||
resultTensor.replaceAllUsesExcept(sliceOp.getResult(), exceptions);
|
||||
}
|
||||
vmmOp.getOutput().replaceAllUsesExcept(sliceOp.getResult(), exceptions);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -1,8 +1,93 @@
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
struct PackSpatialConcatInputsPattern final : OpRewritePattern<spatial::SpatConcatOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(spatial::SpatConcatOp concatOp, PatternRewriter& rewriter) const override {
|
||||
if (concatOp.getAxis() != 0 || concatOp.getInputs().empty())
|
||||
return failure();
|
||||
|
||||
SmallVector<Value> packedInputs;
|
||||
bool changed = false;
|
||||
|
||||
for (unsigned index = 0; index < concatOp.getInputs().size();) {
|
||||
Value input = concatOp.getInputs()[index];
|
||||
|
||||
if (input.getDefiningOp<tensor::ExtractSliceOp>()) {
|
||||
unsigned endIndex = index + 1;
|
||||
while (endIndex < concatOp.getInputs().size()
|
||||
&& concatOp.getInputs()[endIndex].getDefiningOp<tensor::ExtractSliceOp>())
|
||||
++endIndex;
|
||||
|
||||
Value packedInput = createPackedExtractSliceTensor(
|
||||
concatOp.getInputs().slice(index, endIndex - index), rewriter, concatOp.getLoc());
|
||||
if (packedInput) {
|
||||
packedInputs.push_back(packedInput);
|
||||
changed = true;
|
||||
index = endIndex;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
auto result = dyn_cast<OpResult>(input);
|
||||
if (!result) {
|
||||
packedInputs.push_back(input);
|
||||
++index;
|
||||
continue;
|
||||
}
|
||||
|
||||
Operation* owner = result.getOwner();
|
||||
unsigned startIndex = result.getResultNumber();
|
||||
unsigned endIndex = index + 1;
|
||||
while (endIndex < concatOp.getInputs().size()) {
|
||||
auto nextResult = dyn_cast<OpResult>(concatOp.getInputs()[endIndex]);
|
||||
if (!nextResult || nextResult.getOwner() != owner
|
||||
|| nextResult.getResultNumber() != startIndex + (endIndex - index))
|
||||
break;
|
||||
++endIndex;
|
||||
}
|
||||
|
||||
unsigned count = endIndex - index;
|
||||
Value packedInput;
|
||||
if (auto extractRowsOp = dyn_cast<spatial::SpatExtractRowsOp>(owner))
|
||||
packedInput = createPackedExtractRowsSlice(extractRowsOp, startIndex, count, rewriter, concatOp.getLoc());
|
||||
|
||||
if (packedInput) {
|
||||
packedInputs.push_back(packedInput);
|
||||
changed = true;
|
||||
}
|
||||
else {
|
||||
for (unsigned oldIndex = index; oldIndex < endIndex; ++oldIndex)
|
||||
packedInputs.push_back(concatOp.getInputs()[oldIndex]);
|
||||
}
|
||||
|
||||
index = endIndex;
|
||||
}
|
||||
|
||||
if (!changed)
|
||||
return failure();
|
||||
|
||||
auto outputType = cast<ShapedType>(concatOp.getOutput().getType());
|
||||
auto newConcat = pim::PimConcatOp::create(
|
||||
rewriter,
|
||||
concatOp.getLoc(),
|
||||
concatOp.getOutput().getType(),
|
||||
concatOp.getAxisAttr(),
|
||||
ValueRange(packedInputs),
|
||||
tensor::EmptyOp::create(rewriter, concatOp.getLoc(), outputType.getShape(), outputType.getElementType())
|
||||
.getResult());
|
||||
rewriter.replaceOp(concatOp, newConcat.getOutput());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
RankedTensorType getPackedTensorType(RankedTensorType elementType, int64_t count) {
|
||||
SmallVector<int64_t> packedShape(elementType.getShape().begin(), elementType.getShape().end());
|
||||
@@ -146,4 +231,23 @@ Value createPackedExtractSliceTensor(ValueRange values, OpBuilder& builder, Loca
|
||||
return tensor::ExtractSliceOp::create(builder, loc, packedType, firstSliceOp.getSource(), offsets, sizes, strides)
|
||||
.getResult();
|
||||
}
|
||||
|
||||
void populateTensorPackingPatterns(RewritePatternSet& patterns) {
|
||||
patterns.add<PackSpatialConcatInputsPattern>(patterns.getContext());
|
||||
}
|
||||
|
||||
void eraseUnusedTensorPackingOps(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||
auto eraseUnusedOps = [&](auto tag) {
|
||||
using OpTy = decltype(tag);
|
||||
SmallVector<OpTy> ops;
|
||||
funcOp.walk([&](OpTy op) { ops.push_back(op); });
|
||||
for (auto op : llvm::reverse(ops))
|
||||
if (op->use_empty())
|
||||
rewriter.eraseOp(op);
|
||||
};
|
||||
eraseUnusedOps(tensor::ConcatOp {});
|
||||
eraseUnusedOps(tensor::ExtractSliceOp {});
|
||||
eraseUnusedOps(spatial::SpatExtractRowsOp {});
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
@@ -19,5 +20,7 @@ mlir::Value createPackedExtractRowsSlice(spatial::SpatExtractRowsOp extractRowsO
|
||||
mlir::OpBuilder& builder,
|
||||
mlir::Location loc);
|
||||
mlir::Value createPackedExtractSliceTensor(mlir::ValueRange values, mlir::OpBuilder& builder, mlir::Location loc);
|
||||
void populateTensorPackingPatterns(mlir::RewritePatternSet& patterns);
|
||||
void eraseUnusedTensorPackingOps(mlir::func::FuncOp funcOp, mlir::IRRewriter& rewriter);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -2,6 +2,7 @@ add_onnx_mlir_dialect(Pim pim)
|
||||
add_onnx_mlir_dialect_doc(pim Pim.td)
|
||||
|
||||
add_subdirectory(Transforms/Bufferization)
|
||||
add_subdirectory(Transforms/StaticMemoryCoalescing)
|
||||
|
||||
add_pim_library(PimOps
|
||||
PimOps.hpp
|
||||
|
||||
@@ -389,6 +389,7 @@ def PimVMMOp : PimOp<"vmm", [DestinationStyleOpInterface]> {
|
||||
}
|
||||
}];
|
||||
|
||||
let hasVerifier = 1;
|
||||
let assemblyFormat = [{
|
||||
`(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output)
|
||||
}];
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
#include "llvm/Support/LogicalResult.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
@@ -77,6 +78,22 @@ verifyTensorBatchCommunication(Operation* op, Type type, ArrayRef<int32_t> coreI
|
||||
return success();
|
||||
}
|
||||
|
||||
static FailureOr<ArrayRef<int64_t>> getWeightShapeForVMM(Operation* op, size_t weightIndex) {
|
||||
if (auto coreOp = op->getParentOfType<PimCoreOp>()) {
|
||||
if (weightIndex >= coreOp.getWeights().size())
|
||||
return failure();
|
||||
return cast<ShapedType>(coreOp.getWeights()[weightIndex].getType()).getShape();
|
||||
}
|
||||
|
||||
if (auto coreBatchOp = op->getParentOfType<PimCoreBatchOp>()) {
|
||||
if (weightIndex >= coreBatchOp.getWeights().size())
|
||||
return failure();
|
||||
return cast<ShapedType>(coreBatchOp.getWeights()[weightIndex].getType()).getShape();
|
||||
}
|
||||
|
||||
return failure();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult PimSendTensorOp::verify() {
|
||||
@@ -104,6 +121,47 @@ LogicalResult PimReceiveTensorBatchOp::verify() {
|
||||
getOperation(), getOutput().getType(), getSourceCoreIds(), "receive_tensor_batch");
|
||||
}
|
||||
|
||||
LogicalResult PimVMMOp::verify() {
|
||||
if (failed(verifyCompatibleShapedTypes(
|
||||
getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match")))
|
||||
return failure();
|
||||
|
||||
auto matrixShapeOpt = getWeightShapeForVMM(getOperation(), getWeightIndex());
|
||||
if (failed(matrixShapeOpt))
|
||||
return emitError("must be nested inside pim.core or pim.core_batch with a valid weightIndex");
|
||||
ArrayRef<int64_t> matrixShape = *matrixShapeOpt;
|
||||
|
||||
auto vectorType = dyn_cast<ShapedType>(getInput().getType());
|
||||
auto outputType = dyn_cast<ShapedType>(getOutput().getType());
|
||||
if (!vectorType || !outputType)
|
||||
return emitError("input and output must be shaped types");
|
||||
|
||||
ArrayRef<int64_t> vectorShape = vectorType.getShape();
|
||||
ArrayRef<int64_t> outputShape = outputType.getShape();
|
||||
|
||||
if (matrixShape.size() != 2 || vectorShape.size() != 2 || outputShape.size() != 2)
|
||||
return emitError("matrix, vector and output must have rank 2");
|
||||
|
||||
int64_t N = matrixShape[0];
|
||||
int64_t M = matrixShape[1];
|
||||
if (N <= 0 || M <= 0)
|
||||
return emitError("matrix shape must be (N, M) with N > 0 and M > 0");
|
||||
if (N > static_cast<int64_t>(crossbarSize) || M > static_cast<int64_t>(crossbarSize))
|
||||
return emitError("matrix dimensions must fit in one crossbar");
|
||||
|
||||
int64_t vector1 = vectorShape[0];
|
||||
int64_t vectorWidth = vectorShape[1];
|
||||
if (vector1 != 1 || vectorWidth != static_cast<int64_t>(crossbarSize))
|
||||
return emitError("vector shape must be (1, crossbar-size)");
|
||||
|
||||
int64_t output1 = outputShape[0];
|
||||
int64_t outputWidth = outputShape[1];
|
||||
if (output1 != 1 || outputWidth != static_cast<int64_t>(crossbarSize))
|
||||
return emitError("output shape must be (1, crossbar-size)");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult PimConcatOp::verify() {
|
||||
if (getInputs().empty())
|
||||
return emitError("requires at least one input");
|
||||
|
||||
@@ -105,6 +105,37 @@ struct MemCopyDevToHostOpInterface
|
||||
}
|
||||
};
|
||||
|
||||
struct MemCopyOpInterface : DstBufferizableOpInterfaceExternalModel<MemCopyOpInterface, PimMemCopyOp> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
BufferizationState& state) const {
|
||||
auto memCopyOp = cast<PimMemCopyOp>(op);
|
||||
|
||||
auto targetOpt = getBufferOrValue(rewriter, memCopyOp.getTarget(), options, state);
|
||||
if (failed(targetOpt))
|
||||
return failure();
|
||||
|
||||
auto sourceOpt = getBufferOrValue(rewriter, memCopyOp.getSource(), options, state);
|
||||
if (failed(sourceOpt))
|
||||
return failure();
|
||||
|
||||
replaceOpWithNewBufferizedOp<PimMemCopyOp>(rewriter,
|
||||
memCopyOp,
|
||||
targetOpt->getType(),
|
||||
*targetOpt,
|
||||
*sourceOpt,
|
||||
memCopyOp.getTargetOffsetAttr(),
|
||||
memCopyOp.getSourceOffsetAttr(),
|
||||
memCopyOp.getSizeAttr());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ReceiveOpInterface : DstBufferizableOpInterfaceExternalModel<ReceiveOpInterface, PimReceiveOp> {
|
||||
bool bufferizesToMemoryRead(Operation* op, OpOperand& opOperand, const AnalysisState& state) const {
|
||||
return !cast<DestinationStyleOpInterface>(op).isDpsInit(&opOperand);
|
||||
@@ -626,6 +657,7 @@ void registerOpBufferizationInterfaces(DialectRegistry& registry) {
|
||||
PimMemCopyHostToDevOp::attachInterface<MemCopyHostToDevOpInterface>(*ctx);
|
||||
PimMemCopyHostToDevBatchOp::attachInterface<MemCopyHostToDevBatchOpInterface>(*ctx);
|
||||
PimMemCopyDevToHostOp::attachInterface<MemCopyDevToHostOpInterface>(*ctx);
|
||||
PimMemCopyOp::attachInterface<MemCopyOpInterface>(*ctx);
|
||||
PimTransposeOp::attachInterface<TransposeOpInterface>(*ctx);
|
||||
PimVMMOp::attachInterface<VMMOpInterface>(*ctx);
|
||||
|
||||
|
||||
@@ -79,6 +79,7 @@ void PimBufferizationPass::runOnOperation() {
|
||||
return WalkResult::skip();
|
||||
});
|
||||
if (hasFailed) {
|
||||
moduleOp.emitError("failed to lower memref.copy-like ops inside PIM core bodies during bufferization");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
add_pim_library(OMPimStaticMemoryCoalescing
|
||||
StaticMemoryCoalescing.cpp
|
||||
StaticMemoryCoalescing.hpp
|
||||
StaticMemoryCoalescingPass.cpp
|
||||
|
||||
EXCLUDE_FROM_OM_LIBS
|
||||
|
||||
INCLUDE_DIRS PUBLIC
|
||||
${PIM_PUBLIC_INCLUDE_DIRS}
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
OMPimCommon
|
||||
PimOps
|
||||
)
|
||||
@@ -0,0 +1,178 @@
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <limits>
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescing.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace pim {
|
||||
|
||||
namespace {
|
||||
|
||||
static bool isSupportedAliasOp(Operation* op) {
|
||||
return isa<memref::SubViewOp, memref::CastOp, memref::CollapseShapeOp, memref::ExpandShapeOp>(op);
|
||||
}
|
||||
|
||||
static bool isCandidateAllocType(MemRefType type) {
|
||||
return type && type.hasStaticShape() && type.getLayout().isIdentity() && type.getElementTypeBitWidth() > 0;
|
||||
}
|
||||
|
||||
static uint64_t getTypeSizeBytes(MemRefType type) {
|
||||
return static_cast<uint64_t>(type.getNumElements() * type.getElementTypeBitWidth() / 8);
|
||||
}
|
||||
|
||||
static FailureOr<uint64_t>
|
||||
getLastUseInstruction(memref::AllocOp allocOp, Block& body, const DenseMap<Operation*, uint64_t>& opOrder) {
|
||||
uint64_t endInstruction = opOrder.lookup(allocOp);
|
||||
SmallPtrSet<Operation*, 16> visited;
|
||||
SmallVector<Value> pendingValues;
|
||||
pendingValues.push_back(allocOp.getResult());
|
||||
|
||||
while (!pendingValues.empty()) {
|
||||
Value value = pendingValues.pop_back_val();
|
||||
for (Operation* user : value.getUsers()) {
|
||||
if (user->getBlock() != &body)
|
||||
return failure();
|
||||
if (!visited.insert(user).second)
|
||||
continue;
|
||||
|
||||
if (isSupportedAliasOp(user))
|
||||
for (Value result : user->getResults())
|
||||
pendingValues.push_back(result);
|
||||
|
||||
if (auto forOp = dyn_cast<scf::ForOp>(user)) {
|
||||
for (auto [index, initArg] : llvm::enumerate(forOp.getInitArgs())) {
|
||||
if (initArg == value)
|
||||
pendingValues.push_back(forOp.getResult(index));
|
||||
}
|
||||
}
|
||||
|
||||
if (auto dpsOp = dyn_cast<DestinationStyleOpInterface>(user)) {
|
||||
for (OpResult result : user->getResults()) {
|
||||
OpOperand* tiedOperand = dpsOp.getTiedOpOperand(result);
|
||||
if (!tiedOperand || tiedOperand->get() != value)
|
||||
continue;
|
||||
pendingValues.push_back(result);
|
||||
}
|
||||
}
|
||||
|
||||
auto order = opOrder.find(user);
|
||||
if (order == opOrder.end())
|
||||
return failure();
|
||||
endInstruction = std::max(endInstruction, order->second);
|
||||
}
|
||||
}
|
||||
|
||||
return endInstruction;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
StaticMemoryCoalescingAnalysis analyzeStaticMemoryCoalescingCandidates(Operation* coreLikeOp) {
|
||||
StaticMemoryCoalescingAnalysis analysis;
|
||||
if (!coreLikeOp || coreLikeOp->getNumRegions() != 1 || coreLikeOp->getRegion(0).empty())
|
||||
return analysis;
|
||||
|
||||
Block& body = coreLikeOp->getRegion(0).front();
|
||||
DenseMap<Operation*, uint64_t> opOrder;
|
||||
uint64_t nextInstruction = 0;
|
||||
for (Operation& op : body)
|
||||
opOrder.try_emplace(&op, nextInstruction++);
|
||||
|
||||
for (Operation& op : body) {
|
||||
auto allocOp = dyn_cast<memref::AllocOp>(&op);
|
||||
if (!allocOp)
|
||||
continue;
|
||||
|
||||
auto allocType = dyn_cast<MemRefType>(allocOp.getType());
|
||||
if (!isCandidateAllocType(allocType)) {
|
||||
++analysis.skippedAllocations;
|
||||
continue;
|
||||
}
|
||||
|
||||
auto endInstruction = getLastUseInstruction(allocOp, body, opOrder);
|
||||
if (failed(endInstruction)) {
|
||||
++analysis.skippedAllocations;
|
||||
continue;
|
||||
}
|
||||
|
||||
analysis.candidates.push_back(
|
||||
StaticAllocationCandidate {allocOp, opOrder.lookup(allocOp), *endInstruction, getTypeSizeBytes(allocType)});
|
||||
}
|
||||
|
||||
return analysis;
|
||||
}
|
||||
|
||||
StaticMemoryCoalescingStats coalesceStaticMemory(Operation* coreLikeOp, RewriterBase& rewriter) {
|
||||
StaticMemoryCoalescingStats stats;
|
||||
auto analysis = analyzeStaticMemoryCoalescingCandidates(coreLikeOp);
|
||||
stats.skippedAllocations = analysis.skippedAllocations;
|
||||
|
||||
llvm::sort(analysis.candidates, [](const StaticAllocationCandidate& lhs, const StaticAllocationCandidate& rhs) {
|
||||
if (lhs.startInstruction != rhs.startInstruction)
|
||||
return lhs.startInstruction < rhs.startInstruction;
|
||||
return lhs.endInstruction < rhs.endInstruction;
|
||||
});
|
||||
|
||||
struct ActiveStorage {
|
||||
memref::AllocOp root;
|
||||
uint64_t endInstruction = 0;
|
||||
};
|
||||
|
||||
SmallVector<ActiveStorage> active;
|
||||
SmallVector<memref::AllocOp> freeList;
|
||||
|
||||
for (StaticAllocationCandidate& candidate : analysis.candidates) {
|
||||
for (auto it = active.begin(); it != active.end();) {
|
||||
if (it->endInstruction < candidate.startInstruction) {
|
||||
freeList.push_back(it->root);
|
||||
it = active.erase(it);
|
||||
continue;
|
||||
}
|
||||
++it;
|
||||
}
|
||||
|
||||
auto bestFit = freeList.end();
|
||||
uint64_t bestFitBytes = std::numeric_limits<uint64_t>::max();
|
||||
auto candidateType = cast<MemRefType>(candidate.alloc.getType());
|
||||
for (auto it = freeList.begin(); it != freeList.end(); ++it) {
|
||||
auto freeType = cast<MemRefType>((*it).getType());
|
||||
if (freeType != candidateType)
|
||||
continue;
|
||||
|
||||
uint64_t freeBytes = getTypeSizeBytes(freeType);
|
||||
if (freeBytes < candidate.sizeBytes || freeBytes >= bestFitBytes)
|
||||
continue;
|
||||
|
||||
bestFit = it;
|
||||
bestFitBytes = freeBytes;
|
||||
}
|
||||
|
||||
if (bestFit == freeList.end()) {
|
||||
active.push_back(ActiveStorage {candidate.alloc, candidate.endInstruction});
|
||||
continue;
|
||||
}
|
||||
|
||||
memref::AllocOp root = *bestFit;
|
||||
freeList.erase(bestFit);
|
||||
candidate.alloc.getResult().replaceAllUsesWith(root.getResult());
|
||||
rewriter.eraseOp(candidate.alloc);
|
||||
active.push_back(ActiveStorage {root, candidate.endInstruction});
|
||||
++stats.removedAllocs;
|
||||
stats.savedBytes += candidate.sizeBytes;
|
||||
}
|
||||
|
||||
return stats;
|
||||
}
|
||||
|
||||
} // namespace pim
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,34 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace pim {
|
||||
|
||||
struct StaticAllocationCandidate {
|
||||
mlir::memref::AllocOp alloc;
|
||||
uint64_t startInstruction = 0;
|
||||
uint64_t endInstruction = 0;
|
||||
uint64_t sizeBytes = 0;
|
||||
};
|
||||
|
||||
struct StaticMemoryCoalescingAnalysis {
|
||||
llvm::SmallVector<StaticAllocationCandidate> candidates;
|
||||
uint64_t skippedAllocations = 0;
|
||||
};
|
||||
|
||||
struct StaticMemoryCoalescingStats {
|
||||
uint64_t removedAllocs = 0;
|
||||
uint64_t savedBytes = 0;
|
||||
uint64_t skippedAllocations = 0;
|
||||
};
|
||||
|
||||
StaticMemoryCoalescingAnalysis analyzeStaticMemoryCoalescingCandidates(mlir::Operation* coreLikeOp);
|
||||
|
||||
StaticMemoryCoalescingStats coalesceStaticMemory(mlir::Operation* coreLikeOp, mlir::RewriterBase& rewriter);
|
||||
|
||||
} // namespace pim
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,204 @@
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/raw_os_ostream.h"
|
||||
|
||||
#include <fstream>
|
||||
|
||||
#include "Common/IR/CompactAsmUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/DebugDump.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescing.hpp"
|
||||
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace onnx_mlir::compact_asm;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
struct CoalescingReportRow {
|
||||
uint64_t numCandidates = 0;
|
||||
uint64_t numSkipped = 0;
|
||||
uint64_t numRemoved = 0;
|
||||
uint64_t savedBytes = 0;
|
||||
|
||||
bool operator==(const CoalescingReportRow& other) const {
|
||||
return numCandidates == other.numCandidates && numSkipped == other.numSkipped && numRemoved == other.numRemoved
|
||||
&& savedBytes == other.savedBytes;
|
||||
}
|
||||
};
|
||||
|
||||
struct CoalescingReportEntry {
|
||||
enum class Kind {
|
||||
Core,
|
||||
Batch
|
||||
};
|
||||
|
||||
Kind kind = Kind::Core;
|
||||
uint64_t id = 0;
|
||||
llvm::SmallVector<int32_t, 8> coreIds;
|
||||
CoalescingReportRow row;
|
||||
};
|
||||
|
||||
static std::string formatMemory(uint64_t bytes) { return formatReportMemory(bytes); }
|
||||
|
||||
static SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) {
|
||||
auto coreIdsAttr = coreBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName);
|
||||
assert(coreIdsAttr && "pim.core_batch requires coreIds array attribute");
|
||||
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
|
||||
}
|
||||
|
||||
static void printReportRow(raw_ostream& os, const CoalescingReportRow& row) {
|
||||
llvm::SmallVector<ReportField, 4> fields = {
|
||||
{"Number of candidates", std::to_string(row.numCandidates)},
|
||||
{"Skipped allocations", std::to_string(row.numSkipped) },
|
||||
{"Removed allocations", std::to_string(row.numRemoved) },
|
||||
{"Saved memory", formatMemory(row.savedBytes) }
|
||||
};
|
||||
printReportFlatFields(os, fields);
|
||||
}
|
||||
|
||||
static CoalescingReportRow getTotalRow(const CoalescingReportEntry& entry) {
|
||||
uint64_t factor = std::max<uint64_t>(1, entry.coreIds.size());
|
||||
return {entry.row.numCandidates * factor,
|
||||
entry.row.numSkipped * factor,
|
||||
entry.row.numRemoved * factor,
|
||||
entry.row.savedBytes * factor};
|
||||
}
|
||||
|
||||
static void emitReport(ArrayRef<CoalescingReportEntry> entries) {
|
||||
std::fstream file = openReportFile("static_memory_coalescing_report");
|
||||
if (!file.is_open())
|
||||
return;
|
||||
|
||||
llvm::raw_os_ostream os(file);
|
||||
CoalescingReportRow totalRow;
|
||||
for (const CoalescingReportEntry& entry : entries) {
|
||||
CoalescingReportRow entryTotal = getTotalRow(entry);
|
||||
totalRow.numCandidates += entryTotal.numCandidates;
|
||||
totalRow.numSkipped += entryTotal.numSkipped;
|
||||
totalRow.numRemoved += entryTotal.numRemoved;
|
||||
totalRow.savedBytes += entryTotal.savedBytes;
|
||||
}
|
||||
|
||||
llvm::SmallVector<ReportField, 4> totalFields = {
|
||||
{"Number of candidates", std::to_string(totalRow.numCandidates)},
|
||||
{"Skipped allocations", std::to_string(totalRow.numSkipped) },
|
||||
{"Removed allocations", std::to_string(totalRow.numRemoved) },
|
||||
{"Saved memory", formatMemory(totalRow.savedBytes) }
|
||||
};
|
||||
printReportTotalsBlock(os, totalFields);
|
||||
if (!entries.empty())
|
||||
os << "\n";
|
||||
|
||||
llvm::SmallVector<CoalescingReportEntry, 32> sortedEntries(entries.begin(), entries.end());
|
||||
sortReportEntriesByFirstCore(sortedEntries);
|
||||
|
||||
for (size_t index = 0; index < sortedEntries.size();) {
|
||||
size_t runEnd = index + 1;
|
||||
while (runEnd < sortedEntries.size() && sortedEntries[runEnd].kind == sortedEntries[index].kind
|
||||
&& sortedEntries[runEnd].row == sortedEntries[index].row) {
|
||||
++runEnd;
|
||||
}
|
||||
|
||||
if (sortedEntries[index].kind == CoalescingReportEntry::Kind::Batch) {
|
||||
os << "Batch ";
|
||||
for (size_t batchIndex = index; batchIndex < runEnd; ++batchIndex) {
|
||||
if (batchIndex != index)
|
||||
os << ",\n ";
|
||||
os << sortedEntries[batchIndex].id << " (cores ";
|
||||
printCompressedIntegerEntries(os, ArrayRef<int32_t>(sortedEntries[batchIndex].coreIds));
|
||||
os << ")";
|
||||
}
|
||||
}
|
||||
else {
|
||||
llvm::SmallVector<int32_t, 8> coreIds;
|
||||
for (size_t coreIndex = index; coreIndex < runEnd; ++coreIndex)
|
||||
coreIds.push_back(sortedEntries[coreIndex].coreIds.front());
|
||||
os << "Core ";
|
||||
printCompressedIntegerEntries(os, ArrayRef<int32_t>(coreIds));
|
||||
}
|
||||
|
||||
os << ":\n";
|
||||
if (sortedEntries[index].kind == CoalescingReportEntry::Kind::Batch) {
|
||||
llvm::SmallVector<ReportField, 4> perCoreFields = {
|
||||
{"Number of candidates", std::to_string(sortedEntries[index].row.numCandidates)},
|
||||
{"Skipped allocations", std::to_string(sortedEntries[index].row.numSkipped) },
|
||||
{"Removed allocations", std::to_string(sortedEntries[index].row.numRemoved) },
|
||||
{"Saved memory", formatMemory(sortedEntries[index].row.savedBytes) }
|
||||
};
|
||||
CoalescingReportRow totalRow = getTotalRow(sortedEntries[index]);
|
||||
llvm::SmallVector<ReportField, 4> totalFields = {
|
||||
{"Number of candidates", std::to_string(totalRow.numCandidates)},
|
||||
{"Skipped allocations", std::to_string(totalRow.numSkipped) },
|
||||
{"Removed allocations", std::to_string(totalRow.numRemoved) },
|
||||
{"Saved memory", formatMemory(totalRow.savedBytes) }
|
||||
};
|
||||
printReportPerCoreAndTotalFields(os, perCoreFields, totalFields);
|
||||
}
|
||||
else {
|
||||
printReportRow(os, sortedEntries[index].row);
|
||||
}
|
||||
printReportEntrySeparator(os, runEnd < sortedEntries.size());
|
||||
index = runEnd;
|
||||
}
|
||||
|
||||
os.flush();
|
||||
file.close();
|
||||
}
|
||||
|
||||
struct StaticMemoryCoalescingPass : PassWrapper<StaticMemoryCoalescingPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(StaticMemoryCoalescingPass)
|
||||
|
||||
StringRef getArgument() const override { return "pim-static-memory-coalescing"; }
|
||||
StringRef getDescription() const override { return "Analyze static local PIM memory reuse opportunities"; }
|
||||
|
||||
StaticMemoryCoalescingPass() = default;
|
||||
StaticMemoryCoalescingPass(const StaticMemoryCoalescingPass& pass) {}
|
||||
|
||||
void runOnOperation() override {
|
||||
IRRewriter rewriter(&getContext());
|
||||
SmallVector<CoalescingReportEntry, 32> reportEntries;
|
||||
uint64_t nextBatchId = 0;
|
||||
|
||||
getOperation().walk([&](Operation* op) {
|
||||
if (!isa<pim::PimCoreOp, pim::PimCoreBatchOp>(op))
|
||||
return;
|
||||
|
||||
auto analysis = pim::analyzeStaticMemoryCoalescingCandidates(op);
|
||||
auto stats = pim::coalesceStaticMemory(op, rewriter);
|
||||
CoalescingReportRow row {
|
||||
analysis.candidates.size(), stats.skippedAllocations, stats.removedAllocs, stats.savedBytes};
|
||||
|
||||
if (auto coreOp = dyn_cast<pim::PimCoreOp>(op)) {
|
||||
reportEntries.push_back({CoalescingReportEntry::Kind::Core,
|
||||
static_cast<uint64_t>(coreOp.getCoreId()),
|
||||
{static_cast<int32_t>(coreOp.getCoreId())},
|
||||
row});
|
||||
return;
|
||||
}
|
||||
|
||||
auto coreIds = getBatchCoreIds(cast<pim::PimCoreBatchOp>(op));
|
||||
CoalescingReportEntry entry;
|
||||
entry.kind = CoalescingReportEntry::Kind::Batch;
|
||||
entry.id = nextBatchId++;
|
||||
llvm::append_range(entry.coreIds, coreIds);
|
||||
entry.row = row;
|
||||
reportEntries.push_back(std::move(entry));
|
||||
});
|
||||
|
||||
emitReport(reportEntries);
|
||||
dumpModule(getOperation(), "pim2_coalesced");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> createPimStaticMemoryCoalescingPass() { return std::make_unique<StaticMemoryCoalescingPass>(); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -8,7 +8,14 @@ add_pim_library(SpatialOps
|
||||
SpatialOpsVerify.cpp
|
||||
SpatialOpsCanonicalization.cpp
|
||||
Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp
|
||||
Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp
|
||||
Transforms/MergeComputeNodes/PostMergeCompaction.cpp
|
||||
Transforms/MergeComputeNodes/RegularOpCompaction.cpp
|
||||
Transforms/MergeComputeNodes/Scheduling/ComputeGraph.cpp
|
||||
Transforms/MergeComputeNodes/Scheduling/ComputeInstanceUtils.cpp
|
||||
Transforms/MergeComputeNodes/Scheduling/DcpScheduler.cpp
|
||||
Transforms/MergeComputeNodes/Scheduling/MergeSchedulingAnalysis.cpp
|
||||
Transforms/MergeComputeNodes/Scheduling/PeftScheduler.cpp
|
||||
Transforms/MergeComputeNodes/DCPGraph/Graph.cpp
|
||||
Transforms/MergeComputeNodes/DCPGraph/GraphDebug.cpp
|
||||
Transforms/MergeComputeNodes/DCPGraph/GraphSupport.cpp
|
||||
|
||||
@@ -52,9 +52,10 @@ static void printTensorSendOp(OpAsmPrinter& printer, TensorSendOpTy op) {
|
||||
printer << " ";
|
||||
printer.printOperand(op.getInput());
|
||||
printChannelMetadata(printer, op.getChannelIds(), op.getSourceCoreIds(), op.getTargetCoreIds());
|
||||
printer.printOptionalAttrDict(
|
||||
op->getAttrs(),
|
||||
{op.getChannelIdsAttrName().getValue(), op.getSourceCoreIdsAttrName().getValue(), op.getTargetCoreIdsAttrName().getValue()});
|
||||
printer.printOptionalAttrDict(op->getAttrs(),
|
||||
{op.getChannelIdsAttrName().getValue(),
|
||||
op.getSourceCoreIdsAttrName().getValue(),
|
||||
op.getTargetCoreIdsAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printer.printType(op.getInput().getType());
|
||||
}
|
||||
@@ -62,9 +63,10 @@ static void printTensorSendOp(OpAsmPrinter& printer, TensorSendOpTy op) {
|
||||
template <typename TensorReceiveOpTy>
|
||||
static void printTensorReceiveOp(OpAsmPrinter& printer, TensorReceiveOpTy op) {
|
||||
printChannelMetadata(printer, op.getChannelIds(), op.getSourceCoreIds(), op.getTargetCoreIds());
|
||||
printer.printOptionalAttrDict(
|
||||
op->getAttrs(),
|
||||
{op.getChannelIdsAttrName().getValue(), op.getSourceCoreIdsAttrName().getValue(), op.getTargetCoreIdsAttrName().getValue()});
|
||||
printer.printOptionalAttrDict(op->getAttrs(),
|
||||
{op.getChannelIdsAttrName().getValue(),
|
||||
op.getSourceCoreIdsAttrName().getValue(),
|
||||
op.getTargetCoreIdsAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printer.printType(op.getOutput().getType());
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "llvm/Support/LogicalResult.h"
|
||||
@@ -338,6 +339,19 @@ LogicalResult SpatConcatOp::verify() {
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult verifyComputeResultsUses(Operation* op) {
|
||||
if (!isa<SpatCompute, SpatComputeBatch>(op))
|
||||
return op->emitError("verifyComputeResultUses: Op is not a SpatCompute/SpatComputeBatch operation");
|
||||
if (!llvm::all_of(op->getResults(), [](Value result) {
|
||||
return llvm::all_of(result.getUsers(), [](Operation* op) {
|
||||
return !(op->getParentOfType<SpatCompute>() || op->getParentOfType<SpatComputeBatch>());
|
||||
});
|
||||
})) {
|
||||
return op->emitError("ComputeResult used directly inside another Compute" );
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult SpatCompute::verify() {
|
||||
auto& block = getBody().front();
|
||||
if (block.mightHaveTerminator()) {
|
||||
@@ -375,7 +389,8 @@ LogicalResult SpatCompute::verify() {
|
||||
for (auto arg : block.getArguments())
|
||||
if (arg.use_empty())
|
||||
return emitError("ComputeOp block argument is not used");
|
||||
|
||||
if (failed(verifyComputeResultsUses(this->getOperation())))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -465,8 +480,8 @@ LogicalResult SpatComputeBatch::verify() {
|
||||
return emitError("compute_batch coreIds attribute must be a dense i32 array");
|
||||
if (coreIdsAttr.size() != static_cast<int64_t>(laneCountSz))
|
||||
return emitError("compute_batch coreIds array length must match laneCount");
|
||||
if (llvm::any_of(coreIdsAttr.asArrayRef(), [](int32_t coreId) { return coreId <= 0; }))
|
||||
return emitError("compute_batch coreIds values must be positive");
|
||||
if (llvm::any_of(coreIdsAttr.asArrayRef(), [](int32_t coreId) { return coreId < 0; }))
|
||||
return emitError("compute_batch coreIds values must be non-negative");
|
||||
llvm::SmallDenseSet<int32_t, 8> seenCoreIds;
|
||||
for (int32_t coreId : coreIdsAttr.asArrayRef())
|
||||
if (!seenCoreIds.insert(coreId).second)
|
||||
@@ -485,6 +500,8 @@ LogicalResult SpatComputeBatch::verify() {
|
||||
return emitError("body block argument type must match input type");
|
||||
}
|
||||
|
||||
if (failed(verifyComputeResultsUses(this->getOperation())))
|
||||
return failure();
|
||||
return verifyBatchBody(getOperation(), block, getResultTypes(), weightsPerLane);
|
||||
}
|
||||
|
||||
|
||||
@@ -1,802 +1,19 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/IR/ValueRange.h"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdlib>
|
||||
#include <iterator>
|
||||
#include <numeric>
|
||||
#include <optional>
|
||||
#include <queue>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "DCPAnalysis.hpp"
|
||||
#include "Graph.hpp"
|
||||
#include "../Scheduling/ComputeGraph.hpp"
|
||||
#include "../Scheduling/DcpScheduler.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Support/TypeUtilities.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
using SpatCompute = onnx_mlir::spatial::SpatCompute;
|
||||
using SpatComputeBatch = onnx_mlir::spatial::SpatComputeBatch;
|
||||
|
||||
bool isDcpCoarsenDebugEnabled() { return std::getenv("DCP_COARSEN_DEBUG") != nullptr; }
|
||||
|
||||
struct VirtualNode {
|
||||
SmallVector<size_t, 4> originalComputeIndices;
|
||||
Weight weight = 0;
|
||||
CrossbarUsage crossbarUsage = 0;
|
||||
};
|
||||
|
||||
struct VirtualGraph {
|
||||
std::vector<VirtualNode> nodes;
|
||||
std::vector<IndexedEdge> edges;
|
||||
};
|
||||
|
||||
struct TimingInfo {
|
||||
std::vector<Time> aest;
|
||||
std::vector<Time> alst;
|
||||
std::vector<size_t> topologicalOrder;
|
||||
bool valid = false;
|
||||
};
|
||||
|
||||
struct WindowScheduleResult {
|
||||
std::vector<std::vector<size_t>> mergeGroups;
|
||||
CPU cpuCount = 0;
|
||||
size_t mergedNodeCount = 0;
|
||||
size_t maxMergeGroupSize = 0;
|
||||
};
|
||||
|
||||
size_t getSchedulingCpuBudget() {
|
||||
if (coresCount.getValue() > 0)
|
||||
return static_cast<size_t>(coresCount.getValue());
|
||||
return std::numeric_limits<size_t>::max();
|
||||
}
|
||||
|
||||
size_t getBatchChunkTargetCount(int32_t laneCount) {
|
||||
assert(laneCount > 0 && "laneCount must be positive");
|
||||
return std::min(static_cast<size_t>(laneCount), std::max<size_t>(1, getSchedulingCpuBudget()));
|
||||
}
|
||||
|
||||
ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex) {
|
||||
size_t totalLanes = static_cast<size_t>(batch.getLaneCount());
|
||||
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount());
|
||||
size_t baseChunkSize = totalLanes / chunkCount;
|
||||
size_t largeChunkCount = totalLanes % chunkCount;
|
||||
|
||||
size_t laneStart = chunkIndex * baseChunkSize + std::min(chunkIndex, largeChunkCount);
|
||||
size_t laneCount = baseChunkSize + (chunkIndex < largeChunkCount ? 1 : 0);
|
||||
return {batch.getOperation(), static_cast<uint32_t>(laneStart), static_cast<uint32_t>(laneCount)};
|
||||
}
|
||||
|
||||
ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane) {
|
||||
size_t totalLanes = static_cast<size_t>(batch.getLaneCount());
|
||||
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount());
|
||||
size_t baseChunkSize = totalLanes / chunkCount;
|
||||
size_t largeChunkCount = totalLanes % chunkCount;
|
||||
size_t largeChunkSpan = largeChunkCount * (baseChunkSize + 1);
|
||||
|
||||
size_t chunkIndex = 0;
|
||||
if (static_cast<size_t>(lane) < largeChunkSpan)
|
||||
chunkIndex = static_cast<size_t>(lane) / (baseChunkSize + 1);
|
||||
else
|
||||
chunkIndex = largeChunkCount + (static_cast<size_t>(lane) - largeChunkSpan) / baseChunkSize;
|
||||
return getBatchChunkForIndex(batch, chunkIndex);
|
||||
}
|
||||
|
||||
std::vector<IndexedEdge> aggregateEdges(ArrayRef<IndexedEdge> edges) {
|
||||
llvm::DenseMap<std::pair<size_t, size_t>, Weight> edgeWeights;
|
||||
for (auto [start, end, weight] : edges) {
|
||||
size_t startIndex = static_cast<size_t>(start);
|
||||
size_t endIndex = static_cast<size_t>(end);
|
||||
if (startIndex == endIndex)
|
||||
continue;
|
||||
auto key = std::make_pair(startIndex, endIndex);
|
||||
Weight edgeWeight = static_cast<Weight>(weight);
|
||||
auto inserted = edgeWeights.try_emplace(key, edgeWeight);
|
||||
if (!inserted.second)
|
||||
inserted.first->second = std::max(inserted.first->second, edgeWeight);
|
||||
}
|
||||
|
||||
std::vector<IndexedEdge> aggregatedEdges;
|
||||
aggregatedEdges.reserve(edgeWeights.size());
|
||||
for (auto [key, weight] : edgeWeights)
|
||||
aggregatedEdges.push_back(
|
||||
{static_cast<int64_t>(key.first), static_cast<int64_t>(key.second), static_cast<int64_t>(weight)});
|
||||
llvm::sort(aggregatedEdges, [](const IndexedEdge& lhs, const IndexedEdge& rhs) {
|
||||
if (std::get<0>(lhs) != std::get<0>(rhs))
|
||||
return std::get<0>(lhs) < std::get<0>(rhs);
|
||||
return std::get<1>(lhs) < std::get<1>(rhs);
|
||||
});
|
||||
return aggregatedEdges;
|
||||
}
|
||||
|
||||
Weight getComputeBodyWeight(Region& body) {
|
||||
constexpr Weight kOperationWeight = 100;
|
||||
Weight numOperations = 0;
|
||||
for (auto& block : body)
|
||||
for ([[maybe_unused]] auto& op : block)
|
||||
numOperations = checkedAdd(numOperations, static_cast<Weight>(1));
|
||||
return checkedMultiply(numOperations, kOperationWeight);
|
||||
}
|
||||
|
||||
CrossbarUsage getComputeBodyCrossbarUsage(Region& body) {
|
||||
CrossbarUsage crossbarUsage = 0;
|
||||
for (auto& block : body)
|
||||
for (auto& op : block)
|
||||
if (isa<SpatVMMOp>(op))
|
||||
crossbarUsage = checkedAdd(crossbarUsage, static_cast<CrossbarUsage>(1));
|
||||
return crossbarUsage;
|
||||
}
|
||||
|
||||
Weight getComputeInstanceWeight(const ComputeInstance& instance) {
|
||||
if (auto spatCompute = dyn_cast<SpatCompute>(instance.op))
|
||||
return getSpatComputeWeight(spatCompute);
|
||||
auto batch = cast<SpatComputeBatch>(instance.op);
|
||||
return checkedMultiply(getComputeBodyWeight(batch.getBody()), static_cast<Weight>(instance.laneCount));
|
||||
}
|
||||
|
||||
CrossbarUsage getComputeInstanceCrossbarUsage(const ComputeInstance& instance) {
|
||||
if (auto spatCompute = dyn_cast<SpatCompute>(instance.op))
|
||||
return getSpatComputeCrossbarUsage(spatCompute);
|
||||
auto batch = cast<SpatComputeBatch>(instance.op);
|
||||
return checkedMultiply(getComputeBodyCrossbarUsage(batch.getBody()), static_cast<CrossbarUsage>(instance.laneCount));
|
||||
}
|
||||
|
||||
SmallVector<Value, 4> getComputeInstanceInputs(const ComputeInstance& instance) {
|
||||
if (auto spatCompute = dyn_cast<SpatCompute>(instance.op))
|
||||
return SmallVector<Value, 4>(spatCompute.getInputs().begin(), spatCompute.getInputs().end());
|
||||
auto batch = cast<SpatComputeBatch>(instance.op);
|
||||
SmallVector<Value, 4> inputs;
|
||||
inputs.reserve(instance.laneCount);
|
||||
for (uint32_t lane = instance.laneStart; lane < instance.laneStart + instance.laneCount; ++lane)
|
||||
inputs.push_back(batch.getInputs()[lane]);
|
||||
return inputs;
|
||||
}
|
||||
|
||||
std::optional<ComputeInstance> getOriginalComputeInstance(Value value) {
|
||||
Operation* op = value.getDefiningOp();
|
||||
if (!op)
|
||||
return std::nullopt;
|
||||
|
||||
while (auto extract = dyn_cast<tensor::ExtractSliceOp>(op)) {
|
||||
value = extract.getSource();
|
||||
op = value.getDefiningOp();
|
||||
if (!op)
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
if (auto spatCompute = dyn_cast<SpatCompute>(op))
|
||||
return ComputeInstance {spatCompute.getOperation(), 0, 1};
|
||||
if (auto batch = dyn_cast<SpatComputeBatch>(op))
|
||||
return getBatchChunkForLane(batch, static_cast<uint32_t>(cast<OpResult>(value).getResultNumber()));
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
SmallVector<ComputeInstance> collectComputeInstances(Operation* entryOp) {
|
||||
SmallVector<ComputeInstance> instances;
|
||||
auto isUsedAsWeightOnly = [](Operation* producerOp) {
|
||||
if (producerOp->getNumResults() == 0)
|
||||
return false;
|
||||
for (Value result : producerOp->getResults()) {
|
||||
if (result.use_empty())
|
||||
return false;
|
||||
for (Operation* user : result.getUsers()) {
|
||||
if (auto compute = dyn_cast<SpatCompute>(user)) {
|
||||
if (!llvm::is_contained(compute.getWeights(), result))
|
||||
return false;
|
||||
continue;
|
||||
}
|
||||
if (auto batch = dyn_cast<SpatComputeBatch>(user)) {
|
||||
if (!llvm::is_contained(batch.getWeights(), result))
|
||||
return false;
|
||||
continue;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
};
|
||||
for (Region& region : entryOp->getRegions()) {
|
||||
for (Block& block : region) {
|
||||
for (Operation& op : block) {
|
||||
if (auto spatCompute = dyn_cast<SpatCompute>(&op)) {
|
||||
if (isUsedAsWeightOnly(spatCompute.getOperation()))
|
||||
continue;
|
||||
instances.push_back({spatCompute.getOperation(), 0, 1});
|
||||
continue;
|
||||
}
|
||||
if (auto batch = dyn_cast<SpatComputeBatch>(&op)) {
|
||||
if (isUsedAsWeightOnly(batch.getOperation()))
|
||||
continue;
|
||||
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount());
|
||||
for (size_t chunkIndex = 0; chunkIndex < chunkCount; ++chunkIndex)
|
||||
instances.push_back(getBatchChunkForIndex(batch, chunkIndex));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return instances;
|
||||
}
|
||||
|
||||
VirtualGraph buildInitialVirtualGraph(ArrayRef<ComputeInstance> computeInstances, ArrayRef<IndexedEdge> edges) {
|
||||
VirtualGraph graph;
|
||||
graph.nodes.reserve(computeInstances.size());
|
||||
for (auto [index, computeInstance] : llvm::enumerate(computeInstances)) {
|
||||
VirtualNode node;
|
||||
node.originalComputeIndices.push_back(index);
|
||||
node.weight = getComputeInstanceWeight(computeInstance);
|
||||
node.crossbarUsage = getComputeInstanceCrossbarUsage(computeInstance);
|
||||
graph.nodes.push_back(std::move(node));
|
||||
}
|
||||
graph.edges = aggregateEdges(edges);
|
||||
return graph;
|
||||
}
|
||||
|
||||
TimingInfo computeTiming(const VirtualGraph& graph) {
|
||||
TimingInfo timing;
|
||||
size_t nodeCount = graph.nodes.size();
|
||||
timing.aest.assign(nodeCount, 0);
|
||||
timing.alst.assign(nodeCount, 0);
|
||||
timing.topologicalOrder.reserve(nodeCount);
|
||||
|
||||
std::vector<std::vector<std::pair<size_t, Weight>>> parents(nodeCount);
|
||||
std::vector<std::vector<std::pair<size_t, Weight>>> children(nodeCount);
|
||||
std::vector<size_t> incomingEdgeCount(nodeCount, 0);
|
||||
|
||||
for (auto [start, end, weight] : graph.edges) {
|
||||
size_t startIndex = static_cast<size_t>(start);
|
||||
size_t endIndex = static_cast<size_t>(end);
|
||||
Weight edgeWeight = static_cast<Weight>(weight);
|
||||
assert(startIndex < nodeCount && endIndex < nodeCount && "virtual edge endpoint out of range");
|
||||
children[startIndex].push_back({endIndex, edgeWeight});
|
||||
parents[endIndex].push_back({startIndex, edgeWeight});
|
||||
incomingEdgeCount[endIndex]++;
|
||||
}
|
||||
|
||||
auto getVirtualNodeOrderKey = [&](size_t nodeIndex) {
|
||||
const VirtualNode& node = graph.nodes[nodeIndex];
|
||||
if (!node.originalComputeIndices.empty())
|
||||
return node.originalComputeIndices.front();
|
||||
return nodeIndex;
|
||||
};
|
||||
auto readyNodeGreater = [&](size_t lhs, size_t rhs) {
|
||||
size_t lhsKey = getVirtualNodeOrderKey(lhs);
|
||||
size_t rhsKey = getVirtualNodeOrderKey(rhs);
|
||||
if (lhsKey != rhsKey)
|
||||
return lhsKey > rhsKey;
|
||||
return lhs > rhs;
|
||||
};
|
||||
std::priority_queue<size_t, std::vector<size_t>, decltype(readyNodeGreater)> readyNodes(readyNodeGreater);
|
||||
for (size_t i = 0; i < nodeCount; ++i)
|
||||
if (incomingEdgeCount[i] == 0)
|
||||
readyNodes.push(i);
|
||||
|
||||
while (!readyNodes.empty()) {
|
||||
size_t current = readyNodes.top();
|
||||
readyNodes.pop();
|
||||
timing.topologicalOrder.push_back(current);
|
||||
for (auto [child, weight] : children[current]) {
|
||||
(void) weight;
|
||||
assert(incomingEdgeCount[child] > 0 && "incoming edge count underflow");
|
||||
incomingEdgeCount[child]--;
|
||||
if (incomingEdgeCount[child] == 0)
|
||||
readyNodes.push(child);
|
||||
}
|
||||
}
|
||||
|
||||
if (timing.topologicalOrder.size() != nodeCount)
|
||||
return timing;
|
||||
|
||||
Time dcpl = 0;
|
||||
for (size_t nodeIndex : timing.topologicalOrder) {
|
||||
Time maxParentAest = 0;
|
||||
for (auto [parent, transferCost] : parents[nodeIndex]) {
|
||||
maxParentAest =
|
||||
std::max(maxParentAest, addOrMax(addOrMax(timing.aest[parent], graph.nodes[parent].weight), transferCost));
|
||||
}
|
||||
timing.aest[nodeIndex] = maxParentAest;
|
||||
dcpl = std::max(dcpl, addOrMax(maxParentAest, graph.nodes[nodeIndex].weight));
|
||||
}
|
||||
|
||||
for (size_t nodeIndex : llvm::reverse(timing.topologicalOrder)) {
|
||||
Time minAlst = std::numeric_limits<Time>::max();
|
||||
if (children[nodeIndex].empty())
|
||||
minAlst = subtractOrZero(dcpl, graph.nodes[nodeIndex].weight);
|
||||
for (auto [child, transferCost] : children[nodeIndex]) {
|
||||
minAlst =
|
||||
std::min(minAlst, subtractOrZero(timing.alst[child], addOrMax(graph.nodes[nodeIndex].weight, transferCost)));
|
||||
}
|
||||
timing.alst[nodeIndex] = minAlst;
|
||||
}
|
||||
|
||||
timing.valid = true;
|
||||
return timing;
|
||||
}
|
||||
|
||||
std::vector<std::vector<size_t>> buildUndirectedAdjacency(const VirtualGraph& graph) {
|
||||
std::vector<std::vector<size_t>> adjacency(graph.nodes.size());
|
||||
for (auto [start, end, weight] : graph.edges) {
|
||||
(void) weight;
|
||||
size_t startIndex = static_cast<size_t>(start);
|
||||
size_t endIndex = static_cast<size_t>(end);
|
||||
assert(startIndex < graph.nodes.size() && endIndex < graph.nodes.size() && "virtual edge endpoint out of range");
|
||||
adjacency[startIndex].push_back(endIndex);
|
||||
adjacency[endIndex].push_back(startIndex);
|
||||
}
|
||||
for (auto& neighbours : adjacency) {
|
||||
llvm::sort(neighbours);
|
||||
neighbours.erase(std::unique(neighbours.begin(), neighbours.end()), neighbours.end());
|
||||
}
|
||||
return adjacency;
|
||||
}
|
||||
|
||||
std::vector<size_t> selectCriticalWindow(const VirtualGraph& graph, const TimingInfo& timing, size_t windowSize) {
|
||||
std::vector<size_t> ranked(timing.aest.size());
|
||||
std::iota(ranked.begin(), ranked.end(), 0);
|
||||
auto isHigherPriority = [&](size_t lhs, size_t rhs) {
|
||||
Time lhsSlack = slackOrZero(timing.aest[lhs], timing.alst[lhs]);
|
||||
Time rhsSlack = slackOrZero(timing.aest[rhs], timing.alst[rhs]);
|
||||
if (lhsSlack != rhsSlack)
|
||||
return lhsSlack < rhsSlack;
|
||||
if (timing.aest[lhs] != timing.aest[rhs])
|
||||
return timing.aest[lhs] < timing.aest[rhs];
|
||||
return lhs < rhs;
|
||||
};
|
||||
|
||||
windowSize = std::min(windowSize, ranked.size());
|
||||
if (windowSize == 0)
|
||||
return {};
|
||||
if (windowSize == ranked.size()) {
|
||||
llvm::sort(ranked, isHigherPriority);
|
||||
return ranked;
|
||||
}
|
||||
|
||||
size_t criticalPoolSize = std::min(ranked.size(), std::max(windowSize, windowSize * 2));
|
||||
if (criticalPoolSize < ranked.size())
|
||||
std::nth_element(
|
||||
ranked.begin(), ranked.begin() + static_cast<std::ptrdiff_t>(criticalPoolSize), ranked.end(), isHigherPriority);
|
||||
|
||||
std::vector<char> inCriticalPool(ranked.size(), false);
|
||||
for (size_t i = 0; i < criticalPoolSize; ++i)
|
||||
inCriticalPool[ranked[i]] = true;
|
||||
|
||||
size_t seed = *std::min_element(ranked.begin(), ranked.end(), isHigherPriority);
|
||||
std::vector<std::vector<size_t>> adjacency = buildUndirectedAdjacency(graph);
|
||||
std::vector<size_t> selected;
|
||||
std::vector<char> inWindow(ranked.size(), false);
|
||||
selected.reserve(windowSize);
|
||||
|
||||
struct FrontierEntry {
|
||||
size_t node;
|
||||
};
|
||||
auto frontierCompare = [&](FrontierEntry lhs, FrontierEntry rhs) { return isHigherPriority(rhs.node, lhs.node); };
|
||||
std::priority_queue<FrontierEntry, std::vector<FrontierEntry>, decltype(frontierCompare)> frontier(frontierCompare);
|
||||
|
||||
auto addToWindow = [&](size_t node, const std::vector<char>& eligible) {
|
||||
if (inWindow[node])
|
||||
return;
|
||||
inWindow[node] = true;
|
||||
selected.push_back(node);
|
||||
for (size_t neighbour : adjacency[node])
|
||||
if (!inWindow[neighbour] && eligible[neighbour])
|
||||
frontier.push({neighbour});
|
||||
};
|
||||
|
||||
addToWindow(seed, inCriticalPool);
|
||||
while (!frontier.empty() && selected.size() < windowSize) {
|
||||
size_t node = frontier.top().node;
|
||||
frontier.pop();
|
||||
if (!inWindow[node])
|
||||
addToWindow(node, inCriticalPool);
|
||||
}
|
||||
|
||||
if (selected.size() < windowSize) {
|
||||
std::vector<char> anyNode(ranked.size(), true);
|
||||
for (size_t node : selected)
|
||||
for (size_t neighbour : adjacency[node])
|
||||
if (!inWindow[neighbour])
|
||||
frontier.push({neighbour});
|
||||
while (!frontier.empty() && selected.size() < windowSize) {
|
||||
size_t node = frontier.top().node;
|
||||
frontier.pop();
|
||||
if (!inWindow[node])
|
||||
addToWindow(node, anyNode);
|
||||
}
|
||||
}
|
||||
|
||||
if (selected.size() < windowSize) {
|
||||
llvm::sort(ranked, isHigherPriority);
|
||||
for (size_t node : ranked) {
|
||||
if (selected.size() == windowSize)
|
||||
break;
|
||||
if (!inWindow[node]) {
|
||||
inWindow[node] = true;
|
||||
selected.push_back(node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
llvm::sort(selected, isHigherPriority);
|
||||
return selected;
|
||||
}
|
||||
|
||||
std::vector<IndexedEdge> buildWindowEdges(const VirtualGraph& graph, const std::vector<int64_t>& nodeToWindowIndex) {
|
||||
std::vector<IndexedEdge> windowEdges;
|
||||
windowEdges.reserve(graph.edges.size());
|
||||
for (auto [start, end, weight] : graph.edges) {
|
||||
int64_t mappedStart = nodeToWindowIndex[static_cast<size_t>(start)];
|
||||
int64_t mappedEnd = nodeToWindowIndex[static_cast<size_t>(end)];
|
||||
if (mappedStart == -1 || mappedEnd == -1)
|
||||
continue;
|
||||
windowEdges.push_back({mappedStart, mappedEnd, weight});
|
||||
}
|
||||
return aggregateEdges(windowEdges);
|
||||
}
|
||||
|
||||
WindowScheduleResult scheduleWindow(const VirtualGraph& graph, ArrayRef<size_t> selectedNodes, MLIRContext* context) {
|
||||
std::vector<Weight> windowWeights;
|
||||
std::vector<CrossbarUsage> windowCrossbarUsage;
|
||||
std::vector<int64_t> windowNodeOrderKeys;
|
||||
std::vector<int64_t> nodeToWindowIndex(graph.nodes.size(), -1);
|
||||
windowWeights.reserve(selectedNodes.size());
|
||||
windowCrossbarUsage.reserve(selectedNodes.size());
|
||||
windowNodeOrderKeys.reserve(selectedNodes.size());
|
||||
|
||||
for (auto [windowIndex, nodeIndex] : llvm::enumerate(selectedNodes)) {
|
||||
nodeToWindowIndex[nodeIndex] = static_cast<int64_t>(windowIndex);
|
||||
windowWeights.push_back(graph.nodes[nodeIndex].weight);
|
||||
windowCrossbarUsage.push_back(graph.nodes[nodeIndex].crossbarUsage);
|
||||
windowNodeOrderKeys.push_back(static_cast<int64_t>(nodeIndex));
|
||||
}
|
||||
|
||||
GraphDCP windowGraph(
|
||||
windowWeights, buildWindowEdges(graph, nodeToWindowIndex), windowNodeOrderKeys, windowCrossbarUsage);
|
||||
if (coresCount.getValue() > 0)
|
||||
windowGraph.setMaxCpuCount(static_cast<int>(coresCount.getValue()));
|
||||
windowGraph.setContext(context);
|
||||
windowGraph.runDcp();
|
||||
|
||||
WindowScheduleResult result;
|
||||
result.cpuCount = windowGraph.cpuCount();
|
||||
for (CPU cpu = 0; cpu < windowGraph.cpuCount(); ++cpu) {
|
||||
auto scheduledTasks = windowGraph.getScheduledTasks(cpu);
|
||||
if (scheduledTasks.size() < 2)
|
||||
continue;
|
||||
|
||||
result.mergedNodeCount += scheduledTasks.size();
|
||||
result.maxMergeGroupSize = std::max(result.maxMergeGroupSize, scheduledTasks.size());
|
||||
std::vector<size_t> mergeGroup;
|
||||
mergeGroup.reserve(scheduledTasks.size());
|
||||
for (const auto& task : scheduledTasks)
|
||||
mergeGroup.push_back(selectedNodes[task.nodeIndex]);
|
||||
result.mergeGroups.push_back(std::move(mergeGroup));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
bool coarsenGraph(const VirtualGraph& graph,
|
||||
ArrayRef<std::vector<size_t>> mergeGroups,
|
||||
VirtualGraph& coarsenedGraph,
|
||||
std::vector<size_t>& oldToNewNode) {
|
||||
TimingInfo timing = computeTiming(graph);
|
||||
std::vector<size_t> topologicalRank(graph.nodes.size());
|
||||
std::iota(topologicalRank.begin(), topologicalRank.end(), 0);
|
||||
if (timing.valid)
|
||||
for (auto [rank, nodeIndex] : llvm::enumerate(timing.topologicalOrder))
|
||||
topologicalRank[nodeIndex] = rank;
|
||||
|
||||
std::vector<std::vector<size_t>> orderedMergeGroups;
|
||||
orderedMergeGroups.reserve(mergeGroups.size());
|
||||
for (const auto& mergeGroup : mergeGroups) {
|
||||
orderedMergeGroups.emplace_back(mergeGroup.begin(), mergeGroup.end());
|
||||
std::stable_sort(orderedMergeGroups.back().begin(), orderedMergeGroups.back().end(), [&](size_t lhs, size_t rhs) {
|
||||
if (topologicalRank[lhs] != topologicalRank[rhs])
|
||||
return topologicalRank[lhs] < topologicalRank[rhs];
|
||||
return lhs < rhs;
|
||||
});
|
||||
}
|
||||
|
||||
std::vector<int64_t> nodeToMergeGroup(graph.nodes.size(), -1);
|
||||
for (auto [groupIndex, mergeGroup] : llvm::enumerate(orderedMergeGroups)) {
|
||||
if (mergeGroup.size() < 2)
|
||||
continue;
|
||||
for (size_t nodeIndex : mergeGroup) {
|
||||
assert(nodeIndex < graph.nodes.size() && "merge group node out of range");
|
||||
nodeToMergeGroup[nodeIndex] = static_cast<int64_t>(groupIndex);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::optional<size_t>> mergeGroupToNewNode(orderedMergeGroups.size());
|
||||
std::vector<size_t> newNodeRank;
|
||||
oldToNewNode.assign(graph.nodes.size(), 0);
|
||||
bool mergedAny = false;
|
||||
coarsenedGraph.nodes.clear();
|
||||
coarsenedGraph.edges.clear();
|
||||
coarsenedGraph.nodes.reserve(graph.nodes.size());
|
||||
newNodeRank.reserve(graph.nodes.size());
|
||||
|
||||
for (size_t nodeIndex = 0; nodeIndex < graph.nodes.size(); ++nodeIndex) {
|
||||
int64_t mergeGroupIndex = nodeToMergeGroup[nodeIndex];
|
||||
if (mergeGroupIndex == -1) {
|
||||
oldToNewNode[nodeIndex] = coarsenedGraph.nodes.size();
|
||||
coarsenedGraph.nodes.push_back(graph.nodes[nodeIndex]);
|
||||
newNodeRank.push_back(topologicalRank[nodeIndex]);
|
||||
continue;
|
||||
}
|
||||
|
||||
auto& newNodeIndex = mergeGroupToNewNode[static_cast<size_t>(mergeGroupIndex)];
|
||||
if (newNodeIndex.has_value()) {
|
||||
oldToNewNode[nodeIndex] = *newNodeIndex;
|
||||
continue;
|
||||
}
|
||||
|
||||
VirtualNode mergedNode;
|
||||
for (size_t memberIndex : orderedMergeGroups[static_cast<size_t>(mergeGroupIndex)]) {
|
||||
const VirtualNode& memberNode = graph.nodes[memberIndex];
|
||||
mergedNode.originalComputeIndices.append(memberNode.originalComputeIndices.begin(),
|
||||
memberNode.originalComputeIndices.end());
|
||||
mergedNode.weight = addOrMax(mergedNode.weight, memberNode.weight);
|
||||
mergedNode.crossbarUsage = addOrMax(mergedNode.crossbarUsage, memberNode.crossbarUsage);
|
||||
}
|
||||
std::sort(mergedNode.originalComputeIndices.begin(), mergedNode.originalComputeIndices.end());
|
||||
|
||||
mergedAny = true;
|
||||
newNodeIndex = coarsenedGraph.nodes.size();
|
||||
for (size_t memberIndex : orderedMergeGroups[static_cast<size_t>(mergeGroupIndex)])
|
||||
oldToNewNode[memberIndex] = *newNodeIndex;
|
||||
newNodeRank.push_back(topologicalRank[orderedMergeGroups[static_cast<size_t>(mergeGroupIndex)].front()]);
|
||||
coarsenedGraph.nodes.push_back(std::move(mergedNode));
|
||||
}
|
||||
|
||||
if (!mergedAny)
|
||||
return false;
|
||||
|
||||
std::vector<IndexedEdge> remappedEdges;
|
||||
remappedEdges.reserve(graph.edges.size());
|
||||
for (auto [start, end, weight] : graph.edges) {
|
||||
size_t newStart = oldToNewNode[static_cast<size_t>(start)];
|
||||
size_t newEnd = oldToNewNode[static_cast<size_t>(end)];
|
||||
if (newStart == newEnd)
|
||||
continue;
|
||||
if (newNodeRank[newStart] >= newNodeRank[newEnd])
|
||||
continue;
|
||||
remappedEdges.push_back({static_cast<int64_t>(newStart), static_cast<int64_t>(newEnd), weight});
|
||||
}
|
||||
coarsenedGraph.edges = aggregateEdges(remappedEdges);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
CPU getVirtualGraphMaxCpuCount() { return static_cast<CPU>(getSchedulingCpuBudget()); }
|
||||
|
||||
size_t getDcpCoarseningWindowSize(size_t nodeCount) {
|
||||
size_t windowSize = std::min(dcpCriticalWindowSize.getValue(), nodeCount);
|
||||
CPU maxCpuCount = std::max<CPU>(1, getVirtualGraphMaxCpuCount());
|
||||
if (nodeCount > static_cast<size_t>(maxCpuCount))
|
||||
windowSize = std::max(windowSize, std::min(nodeCount, static_cast<size_t>(maxCpuCount) + 1));
|
||||
return windowSize;
|
||||
}
|
||||
|
||||
DCPAnalysisResult buildResultFromVirtualGraph(const VirtualGraph& graph, ArrayRef<ComputeInstance> computeInstances) {
|
||||
DCPAnalysisResult result;
|
||||
|
||||
TimingInfo timing = computeTiming(graph);
|
||||
std::vector<size_t> virtualNodeOrder;
|
||||
if (timing.valid) {
|
||||
virtualNodeOrder = std::move(timing.topologicalOrder);
|
||||
}
|
||||
else {
|
||||
virtualNodeOrder.resize(graph.nodes.size());
|
||||
std::iota(virtualNodeOrder.begin(), virtualNodeOrder.end(), 0);
|
||||
}
|
||||
|
||||
std::vector<size_t> originalComputeToCpu(computeInstances.size(), 0);
|
||||
for (auto [cpu, virtualNodeIndex] : llvm::enumerate(virtualNodeOrder)) {
|
||||
const VirtualNode& virtualNode = graph.nodes[virtualNodeIndex];
|
||||
for (size_t originalIndex : virtualNode.originalComputeIndices)
|
||||
originalComputeToCpu[originalIndex] = cpu;
|
||||
}
|
||||
|
||||
result.dominanceOrderCompute.reserve(computeInstances.size());
|
||||
llvm::DenseMap<size_t, size_t> nextCpuSlot;
|
||||
for (auto [originalIndex, computeInstance] : llvm::enumerate(computeInstances)) {
|
||||
size_t cpu = originalComputeToCpu[originalIndex];
|
||||
result.dominanceOrderCompute.push_back(computeInstance);
|
||||
result.computeToCpuMap[computeInstance] = cpu;
|
||||
result.computeToCpuSlotMap[computeInstance] = nextCpuSlot[cpu]++;
|
||||
result.computeToAestMap[computeInstance] = originalIndex;
|
||||
result.cpuToLastComputeMap[cpu] = computeInstance;
|
||||
}
|
||||
for (const auto& [cpu, lastCompute] : result.cpuToLastComputeMap)
|
||||
result.isLastComputeOfCpu.insert(lastCompute);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
DCPAnalysisResult buildResultFromScheduledGraph(GraphDCP& graphDCP, ArrayRef<ComputeInstance> computeInstances) {
|
||||
DCPAnalysisResult result;
|
||||
result.dominanceOrderCompute.assign(computeInstances.begin(), computeInstances.end());
|
||||
|
||||
for (CPU cpu = 0; cpu < graphDCP.cpuCount(); ++cpu) {
|
||||
auto scheduledTasks = graphDCP.getScheduledTasks(cpu);
|
||||
if (scheduledTasks.empty())
|
||||
continue;
|
||||
|
||||
for (auto [slot, task] : llvm::enumerate(scheduledTasks)) {
|
||||
ComputeInstance instance = computeInstances[task.nodeIndex];
|
||||
result.computeToCpuMap[instance] = cpu;
|
||||
result.computeToCpuSlotMap[instance] = slot;
|
||||
result.computeToAestMap[instance] = static_cast<uint64_t>(task.aest);
|
||||
}
|
||||
result.cpuToLastComputeMap[cpu] = computeInstances[scheduledTasks.back().nodeIndex];
|
||||
result.isLastComputeOfCpu.insert(computeInstances[scheduledTasks.back().nodeIndex]);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
DCPAnalysisResult
|
||||
runLegacyDcp(ArrayRef<ComputeInstance> computeInstances, ArrayRef<IndexedEdge> edges, MLIRContext* context) {
|
||||
SmallVector<Weight> nodeWeights;
|
||||
SmallVector<CrossbarUsage> nodeCrossbarUsage;
|
||||
SmallVector<int64_t> nodeOrderKeys;
|
||||
nodeWeights.reserve(computeInstances.size());
|
||||
nodeCrossbarUsage.reserve(computeInstances.size());
|
||||
nodeOrderKeys.reserve(computeInstances.size());
|
||||
for (auto [index, instance] : llvm::enumerate(computeInstances)) {
|
||||
nodeWeights.push_back(getComputeInstanceWeight(instance));
|
||||
nodeCrossbarUsage.push_back(getComputeInstanceCrossbarUsage(instance));
|
||||
nodeOrderKeys.push_back(static_cast<int64_t>(index));
|
||||
}
|
||||
|
||||
GraphDCP graphDCP(nodeWeights, edges, nodeOrderKeys, nodeCrossbarUsage);
|
||||
if (coresCount.getValue() > 0)
|
||||
graphDCP.setMaxCpuCount(static_cast<int>(coresCount.getValue()));
|
||||
graphDCP.setContext(context);
|
||||
graphDCP.runDcp();
|
||||
return buildResultFromScheduledGraph(graphDCP, computeInstances);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
SpatCompute getOriginalSpatCompute(Operation* op) {
|
||||
if (!op)
|
||||
return {};
|
||||
while (auto extract = dyn_cast<tensor::ExtractSliceOp>(op)) {
|
||||
op = extract.getSource().getDefiningOp();
|
||||
if (!op)
|
||||
return {};
|
||||
}
|
||||
if (auto res = dyn_cast<SpatCompute>(op))
|
||||
return res;
|
||||
return {};
|
||||
}
|
||||
|
||||
DCPAnalysisResult DCPAnalysis::run() {
|
||||
SmallVector<ComputeInstance> computeInstances = collectComputeInstances(entryOp);
|
||||
SmallVector<IndexedEdge, 10> edges;
|
||||
|
||||
llvm::DenseMap<ComputeInstance, size_t> instanceToIndex;
|
||||
instanceToIndex.reserve(computeInstances.size());
|
||||
for (auto [index, instance] : llvm::enumerate(computeInstances))
|
||||
instanceToIndex[instance] = index;
|
||||
|
||||
for (auto [indexEndEdge, computeInstance] : llvm::enumerate(computeInstances)) {
|
||||
for (Value input : getComputeInstanceInputs(computeInstance)) {
|
||||
if (auto producerInstance = getOriginalComputeInstance(input)) {
|
||||
auto producerIt = instanceToIndex.find(*producerInstance);
|
||||
assert(producerIt != instanceToIndex.end());
|
||||
auto indexStartEdge = producerIt->second;
|
||||
edges.push_back({static_cast<int64_t>(indexStartEdge),
|
||||
static_cast<int64_t>(indexEndEdge),
|
||||
static_cast<int64_t>(getSizeInBytes(cast<ShapedType>(input.getType())))});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (coresCount.getValue() > 0) {
|
||||
size_t schedulingCpuBudget = getSchedulingCpuBudget();
|
||||
bool needsExactScheduledBatches = llvm::any_of(computeInstances, [&](const ComputeInstance& instance) {
|
||||
auto batch = dyn_cast<SpatComputeBatch>(instance.op);
|
||||
return batch && static_cast<size_t>(batch.getLaneCount()) > schedulingCpuBudget;
|
||||
});
|
||||
if (needsExactScheduledBatches)
|
||||
return runLegacyDcp(computeInstances, edges, entryOp->getContext());
|
||||
}
|
||||
|
||||
if (dcpCriticalWindowSize.getValue() == 0)
|
||||
return runLegacyDcp(computeInstances, edges, entryOp->getContext());
|
||||
|
||||
VirtualGraph virtualGraph = buildInitialVirtualGraph(computeInstances, edges);
|
||||
size_t iteration = 0;
|
||||
bool debugCoarsening = isDcpCoarsenDebugEnabled();
|
||||
auto tryCoarsenSelectedNodes = [&](ArrayRef<size_t> selectedNodes) {
|
||||
size_t oldNodeCount = virtualGraph.nodes.size();
|
||||
WindowScheduleResult windowSchedule = scheduleWindow(virtualGraph, selectedNodes, entryOp->getContext());
|
||||
if (windowSchedule.mergeGroups.empty()) {
|
||||
if (debugCoarsening && oldNodeCount >= 200)
|
||||
llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} windowCpus={3} "
|
||||
"groups=0 mergedNodes=0 maxGroup=0 new={1} changed=0\n",
|
||||
iteration,
|
||||
oldNodeCount,
|
||||
selectedNodes.size(),
|
||||
windowSchedule.cpuCount);
|
||||
return false;
|
||||
}
|
||||
|
||||
VirtualGraph coarsenedGraph;
|
||||
std::vector<size_t> oldToNewNode;
|
||||
if (!coarsenGraph(virtualGraph, windowSchedule.mergeGroups, coarsenedGraph, oldToNewNode))
|
||||
return false;
|
||||
if (debugCoarsening && (oldNodeCount >= 200 || coarsenedGraph.nodes.size() >= 200))
|
||||
llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} windowCpus={3} "
|
||||
"groups={4} mergedNodes={5} maxGroup={6} new={7} changed={8}\n",
|
||||
iteration,
|
||||
oldNodeCount,
|
||||
selectedNodes.size(),
|
||||
windowSchedule.cpuCount,
|
||||
windowSchedule.mergeGroups.size(),
|
||||
windowSchedule.mergedNodeCount,
|
||||
windowSchedule.maxMergeGroupSize,
|
||||
coarsenedGraph.nodes.size(),
|
||||
oldNodeCount - coarsenedGraph.nodes.size());
|
||||
virtualGraph = std::move(coarsenedGraph);
|
||||
return true;
|
||||
};
|
||||
|
||||
while (virtualGraph.nodes.size() > 1) {
|
||||
if (virtualGraph.nodes.size() <= getSchedulingCpuBudget()) {
|
||||
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
|
||||
llvm::errs() << llvm::formatv(
|
||||
"[DCP-COARSEN] iter={0} old={1} stop=cpu-budget\n", iteration, virtualGraph.nodes.size());
|
||||
break;
|
||||
}
|
||||
|
||||
iteration++;
|
||||
TimingInfo timing = computeTiming(virtualGraph);
|
||||
if (!timing.valid) {
|
||||
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
|
||||
llvm::errs() << llvm::formatv(
|
||||
"[DCP-COARSEN] iter={0} old={1} invalid-timing\n", iteration, virtualGraph.nodes.size());
|
||||
break;
|
||||
}
|
||||
|
||||
SmallVector<size_t> selectedNodes;
|
||||
auto criticalWindow =
|
||||
selectCriticalWindow(virtualGraph, timing, getDcpCoarseningWindowSize(virtualGraph.nodes.size()));
|
||||
selectedNodes.append(criticalWindow.begin(), criticalWindow.end());
|
||||
|
||||
if (selectedNodes.size() < 2) {
|
||||
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
|
||||
llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} stop=small-window\n",
|
||||
iteration,
|
||||
virtualGraph.nodes.size(),
|
||||
selectedNodes.size());
|
||||
break;
|
||||
}
|
||||
|
||||
if (tryCoarsenSelectedNodes(selectedNodes))
|
||||
continue;
|
||||
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
|
||||
llvm::errs() << llvm::formatv(
|
||||
"[DCP-COARSEN] iter={0} old={1} stop=no-merge\n", iteration, virtualGraph.nodes.size());
|
||||
break;
|
||||
}
|
||||
|
||||
return buildResultFromVirtualGraph(virtualGraph, computeInstances);
|
||||
ComputeGraph graph = buildComputeGraph(entryOp);
|
||||
DcpScheduleOptions options;
|
||||
if (coresCount.getValue() > 0)
|
||||
options.processorCount = static_cast<size_t>(coresCount.getValue());
|
||||
options.criticalWindowSize = dcpCriticalWindowSize.getValue();
|
||||
options.allowFallbackForAutoCoreCount = true;
|
||||
return runDcpScheduler(graph, options, entryOp->getContext());
|
||||
}
|
||||
|
||||
} // namespace spatial
|
||||
|
||||
@@ -2,64 +2,27 @@
|
||||
|
||||
#include "mlir/IR/Operation.h"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
// A scheduling identity that covers both spat.compute and scheduled shards of
|
||||
// spat.compute_batch.
|
||||
struct ComputeInstance {
|
||||
mlir::Operation* op = nullptr;
|
||||
uint32_t laneStart = 0;
|
||||
uint32_t laneCount = 1;
|
||||
|
||||
bool operator==(const ComputeInstance& other) const {
|
||||
return op == other.op && laneStart == other.laneStart && laneCount == other.laneCount;
|
||||
}
|
||||
};
|
||||
|
||||
struct DCPAnalysisResult {
|
||||
std::vector<ComputeInstance> dominanceOrderCompute;
|
||||
llvm::DenseMap<ComputeInstance, size_t> computeToCpuMap;
|
||||
llvm::DenseMap<ComputeInstance, size_t> computeToCpuSlotMap;
|
||||
llvm::DenseMap<ComputeInstance, uint64_t> computeToAestMap;
|
||||
llvm::DenseSet<ComputeInstance> isLastComputeOfCpu;
|
||||
llvm::DenseMap<size_t, ComputeInstance> cpuToLastComputeMap;
|
||||
};
|
||||
#include "../Scheduling/MergeSchedule.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
using DCPAnalysisResult = MergeScheduleResult;
|
||||
|
||||
struct DCPAnalysis {
|
||||
private:
|
||||
DCPAnalysisResult result;
|
||||
mlir::Operation* entryOp;
|
||||
mlir::Operation *entryOp;
|
||||
DCPAnalysisResult run();
|
||||
|
||||
public:
|
||||
DCPAnalysis(mlir::Operation* op)
|
||||
DCPAnalysis(mlir::Operation *op)
|
||||
: entryOp(op) {
|
||||
result = run();
|
||||
}
|
||||
DCPAnalysisResult& getResult() { return result; }
|
||||
DCPAnalysisResult &getResult() { return result; }
|
||||
};
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
|
||||
namespace llvm {
|
||||
template <>
|
||||
struct DenseMapInfo<ComputeInstance> {
|
||||
static ComputeInstance getEmptyKey() {
|
||||
return {DenseMapInfo<mlir::Operation*>::getEmptyKey(), UINT32_MAX, UINT32_MAX};
|
||||
}
|
||||
static ComputeInstance getTombstoneKey() {
|
||||
return {DenseMapInfo<mlir::Operation*>::getTombstoneKey(), UINT32_MAX, UINT32_MAX};
|
||||
}
|
||||
static unsigned getHashValue(const ComputeInstance& v) { return llvm::hash_combine(v.op, v.laneStart, v.laneCount); }
|
||||
static bool isEqual(const ComputeInstance& a, const ComputeInstance& b) { return a == b; }
|
||||
};
|
||||
} // namespace llvm
|
||||
using DCPAnalysisResult = onnx_mlir::spatial::DCPAnalysisResult;
|
||||
|
||||
@@ -0,0 +1,636 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <optional>
|
||||
#include <utility>
|
||||
|
||||
#include "MaterializeMergeSchedule.hpp"
|
||||
#include "Scheduling/ComputeInstanceUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
namespace {
|
||||
|
||||
using SpatCompute = spatial::SpatCompute;
|
||||
using ProducerValueRef = spatial::ProducerValueRef;
|
||||
using spatial::getComputeInstanceInputs;
|
||||
using spatial::getComputeInstanceOutputTypes;
|
||||
using spatial::getComputeInstanceOutputValues;
|
||||
using spatial::getComputeInstanceTemplateBlock;
|
||||
using spatial::getComputeInstanceWeights;
|
||||
using spatial::getProducerValueRef;
|
||||
|
||||
class MergeScheduleMaterializerImpl {
|
||||
public:
|
||||
explicit MergeScheduleMaterializerImpl(func::FuncOp funcOp)
|
||||
: func(funcOp), loc(funcOp.getLoc()), returnOp(cast<func::ReturnOp>(funcOp.getBody().front().getTerminator())) {}
|
||||
|
||||
LogicalResult run(const MergeScheduleResult& scheduleResult, int64_t& nextChannelIdRef) {
|
||||
schedule = &scheduleResult;
|
||||
nextChannelId = &nextChannelIdRef;
|
||||
|
||||
collectScheduledTasks();
|
||||
buildTaskIndex();
|
||||
collectExternalInputsAndWeights();
|
||||
planRemoteChannels();
|
||||
planReceiveReordering();
|
||||
createCpuComputeOps();
|
||||
if (failed(cloneTaskBodies()))
|
||||
return failure();
|
||||
replaceExternalUses();
|
||||
if (failed(eraseOldScheduledOps()))
|
||||
return failure();
|
||||
moveExternalUsersBeforeReturn();
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
struct ScheduledTask {
|
||||
ComputeInstance computeInstance;
|
||||
size_t cpu = 0;
|
||||
size_t orderWithinCpu = 0;
|
||||
};
|
||||
|
||||
struct ChannelInfo {
|
||||
int64_t channelId = -1;
|
||||
int32_t sourceCoreId = -1;
|
||||
int32_t targetCoreId = -1;
|
||||
};
|
||||
|
||||
struct CpuProgram {
|
||||
SpatCompute op;
|
||||
DenseMap<Value, Value> externalInputMap;
|
||||
DenseMap<Value, size_t> weightToIndex;
|
||||
};
|
||||
|
||||
struct RemoteSendInfo {
|
||||
ChannelInfo channelInfo;
|
||||
ComputeInstance consumer;
|
||||
size_t inputIndex = 0;
|
||||
size_t consumerOrder = 0;
|
||||
size_t sourceOrder = 0;
|
||||
};
|
||||
|
||||
struct RemoteReceiveEntry {
|
||||
ChannelInfo channelInfo;
|
||||
ComputeInstance consumer;
|
||||
size_t inputIndex = 0;
|
||||
size_t sourceOrder = 0;
|
||||
};
|
||||
|
||||
static uint64_t getRemoteSendPairKey(const ChannelInfo& channelInfo) {
|
||||
return (static_cast<uint64_t>(static_cast<uint32_t>(channelInfo.sourceCoreId)) << 32)
|
||||
| static_cast<uint32_t>(channelInfo.targetCoreId);
|
||||
}
|
||||
|
||||
void collectExternalUsers(Operation* op) {
|
||||
if (!externalUsersToMove.insert(op).second)
|
||||
return;
|
||||
for (Value result : op->getResults()) {
|
||||
for (Operation* user : result.getUsers()) {
|
||||
if (oldComputeOps.contains(user) || isa<func::ReturnOp>(user))
|
||||
continue;
|
||||
collectExternalUsers(user);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void collectScheduledTasks() {
|
||||
for (ComputeInstance scheduledInstance : schedule->dominanceOrderCompute) {
|
||||
oldComputeOps.insert(scheduledInstance.op);
|
||||
scheduledTasks.push_back({scheduledInstance,
|
||||
schedule->computeToCpuMap.lookup(scheduledInstance),
|
||||
schedule->computeToCpuSlotMap.lookup(scheduledInstance)});
|
||||
}
|
||||
}
|
||||
|
||||
void buildTaskIndex() {
|
||||
auto markCpuSeen = [&](size_t cpu) {
|
||||
if (seenCpus.insert(cpu).second)
|
||||
orderedCpus.push_back(cpu);
|
||||
};
|
||||
|
||||
for (const ScheduledTask& task : scheduledTasks) {
|
||||
taskByComputeInstance[task.computeInstance] = task;
|
||||
tasksByCpu[task.cpu].push_back(task);
|
||||
markCpuSeen(task.cpu);
|
||||
}
|
||||
|
||||
llvm::sort(orderedCpus);
|
||||
for (size_t cpu : orderedCpus)
|
||||
llvm::stable_sort(tasksByCpu[cpu], [&](const ScheduledTask& lhs, const ScheduledTask& rhs) {
|
||||
return lhs.orderWithinCpu < rhs.orderWithinCpu;
|
||||
});
|
||||
}
|
||||
|
||||
void collectExternalInputsAndWeights() {
|
||||
for (size_t cpu : orderedCpus) {
|
||||
for (const ScheduledTask& task : tasksByCpu[cpu]) {
|
||||
auto& thisCpuWeights = cpuWeights[cpu];
|
||||
auto& thisSeenWeights = seenWeightsByCpu[cpu];
|
||||
auto taskWeights = getComputeInstanceWeights(task.computeInstance);
|
||||
for (Value weight : taskWeights)
|
||||
if (thisSeenWeights.insert(weight).second)
|
||||
thisCpuWeights.push_back(weight);
|
||||
|
||||
auto taskInputs = getComputeInstanceInputs(task.computeInstance);
|
||||
auto& remoteInputs = remoteInputsByTask[task.computeInstance];
|
||||
remoteInputs.resize(taskInputs.size());
|
||||
for (auto [inputIndex, input] : llvm::enumerate(taskInputs)) {
|
||||
auto producerRef = getProducerValueRef(input);
|
||||
if (producerRef) {
|
||||
auto producerIt = taskByComputeInstance.find(producerRef->instance);
|
||||
if (producerIt != taskByComputeInstance.end()) {
|
||||
if (producerIt->second.cpu != cpu) {
|
||||
ChannelInfo info {
|
||||
(*nextChannelId)++,
|
||||
static_cast<int32_t>(producerIt->second.cpu),
|
||||
static_cast<int32_t>(cpu),
|
||||
};
|
||||
remoteInputs[inputIndex] = info;
|
||||
auto& perResultChannels = remoteSendsByTask[producerRef->instance];
|
||||
if (perResultChannels.empty())
|
||||
perResultChannels.resize(getComputeInstanceOutputTypes(producerIt->second.computeInstance).size());
|
||||
perResultChannels[producerRef->resultIndex].push_back(
|
||||
{info, task.computeInstance, inputIndex, task.orderWithinCpu, 0});
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if (seenExternalInputsByCpu[cpu].insert(input).second)
|
||||
cpuExternalInputs[cpu].push_back(input);
|
||||
}
|
||||
|
||||
auto taskOutputs = getComputeInstanceOutputValues(task.computeInstance);
|
||||
for (auto [resultIndex, output] : llvm::enumerate(taskOutputs)) {
|
||||
bool hasExternalUser = false;
|
||||
for (auto& use : output.getUses()) {
|
||||
Operation* useOwner = use.getOwner();
|
||||
if (oldComputeOps.contains(useOwner))
|
||||
continue;
|
||||
hasExternalUser = true;
|
||||
if (!isa<func::ReturnOp>(useOwner))
|
||||
collectExternalUsers(useOwner);
|
||||
}
|
||||
if (hasExternalUser)
|
||||
cpuExternalOutputs[cpu].push_back({task.computeInstance, resultIndex});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void planRemoteChannels() {
|
||||
for (size_t cpu : orderedCpus) {
|
||||
DenseMap<uint64_t, size_t> nextSourceOrderByPair;
|
||||
DenseMap<uint64_t, size_t> lastConsumerOrderByPair;
|
||||
for (const ScheduledTask& task : tasksByCpu[cpu]) {
|
||||
auto sendsIt = remoteSendsByTask.find(task.computeInstance);
|
||||
if (sendsIt == remoteSendsByTask.end())
|
||||
continue;
|
||||
for (auto& sendInfos : sendsIt->second) {
|
||||
for (RemoteSendInfo& sendInfo : sendInfos) {
|
||||
uint64_t pairKey = getRemoteSendPairKey(sendInfo.channelInfo);
|
||||
sendInfo.sourceOrder = nextSourceOrderByPair[pairKey]++;
|
||||
auto [it, inserted] = lastConsumerOrderByPair.try_emplace(pairKey, sendInfo.consumerOrder);
|
||||
if (!inserted) {
|
||||
if (sendInfo.consumerOrder < it->second)
|
||||
pairsNeedingReceiveReorder.insert(pairKey);
|
||||
it->second = sendInfo.consumerOrder;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void planReceiveReordering() {
|
||||
DenseMap<uint64_t, SmallVector<RemoteSendInfo*>> reorderedSendsByPair;
|
||||
for (auto& taskSends : remoteSendsByTask) {
|
||||
for (auto& sendInfos : taskSends.second) {
|
||||
for (RemoteSendInfo& sendInfo : sendInfos) {
|
||||
uint64_t pairKey = getRemoteSendPairKey(sendInfo.channelInfo);
|
||||
if (pairsNeedingReceiveReorder.contains(pairKey))
|
||||
reorderedSendsByPair[pairKey].push_back(&sendInfo);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& pairSends : reorderedSendsByPair) {
|
||||
llvm::stable_sort(pairSends.second, [](const RemoteSendInfo* lhs, const RemoteSendInfo* rhs) {
|
||||
if (lhs->sourceOrder != rhs->sourceOrder)
|
||||
return lhs->sourceOrder < rhs->sourceOrder;
|
||||
return lhs->channelInfo.channelId < rhs->channelInfo.channelId;
|
||||
});
|
||||
for (RemoteSendInfo* sendInfo : pairSends.second) {
|
||||
int64_t channelId = (*nextChannelId)++;
|
||||
sendInfo->channelInfo.channelId = channelId;
|
||||
auto remoteInputsIt = remoteInputsByTask.find(sendInfo->consumer);
|
||||
assert(remoteInputsIt != remoteInputsByTask.end() && "missing remote input for reordered send");
|
||||
assert(sendInfo->inputIndex < remoteInputsIt->second.size() && "remote input index out of range");
|
||||
assert(remoteInputsIt->second[sendInfo->inputIndex] && "missing reordered remote input channel");
|
||||
remoteInputsIt->second[sendInfo->inputIndex]->channelId = channelId;
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto& taskSends : remoteSendsByTask) {
|
||||
for (const auto& sendInfos : taskSends.second) {
|
||||
for (const RemoteSendInfo& sendInfo : sendInfos) {
|
||||
auto remoteInputsIt = remoteInputsByTask.find(sendInfo.consumer);
|
||||
assert(remoteInputsIt != remoteInputsByTask.end() && "missing remote input for send");
|
||||
assert(sendInfo.inputIndex < remoteInputsIt->second.size() && "remote input index out of range");
|
||||
assert(remoteInputsIt->second[sendInfo.inputIndex] && "missing remote input channel");
|
||||
remoteInputsIt->second[sendInfo.inputIndex] = sendInfo.channelInfo;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& taskSends : remoteSendsByTask) {
|
||||
for (const auto& sendInfos : taskSends.second) {
|
||||
for (const RemoteSendInfo& sendInfo : sendInfos) {
|
||||
uint64_t pairKey = getRemoteSendPairKey(sendInfo.channelInfo);
|
||||
if (!pairsNeedingReceiveReorder.contains(pairKey))
|
||||
continue;
|
||||
size_t targetCpu = static_cast<size_t>(sendInfo.channelInfo.targetCoreId);
|
||||
receiveQueuesByCpu[targetCpu][pairKey].push_back(
|
||||
{sendInfo.channelInfo, sendInfo.consumer, sendInfo.inputIndex, sendInfo.sourceOrder});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& cpuQueues : receiveQueuesByCpu) {
|
||||
for (auto& pairQueue : cpuQueues.second) {
|
||||
llvm::stable_sort(pairQueue.second, [](const RemoteReceiveEntry& lhs, const RemoteReceiveEntry& rhs) {
|
||||
if (lhs.sourceOrder != rhs.sourceOrder)
|
||||
return lhs.sourceOrder < rhs.sourceOrder;
|
||||
return lhs.channelInfo.channelId < rhs.channelInfo.channelId;
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void createCpuComputeOps() {
|
||||
IRRewriter rewriter(func.getContext());
|
||||
for (size_t cpu : orderedCpus) {
|
||||
SmallVector<Value> operands;
|
||||
operands.reserve(cpuWeights[cpu].size() + cpuExternalInputs[cpu].size());
|
||||
llvm::append_range(operands, cpuWeights[cpu]);
|
||||
llvm::append_range(operands, cpuExternalInputs[cpu]);
|
||||
|
||||
SmallVector<Type> resultTypes;
|
||||
resultTypes.reserve(cpuExternalOutputs[cpu].size());
|
||||
for (ProducerValueRef outputRef : cpuExternalOutputs[cpu]) {
|
||||
ScheduledTask task = taskByComputeInstance.at(outputRef.instance);
|
||||
resultTypes.push_back(getComputeInstanceOutputTypes(task.computeInstance)[outputRef.resultIndex]);
|
||||
}
|
||||
|
||||
rewriter.setInsertionPoint(returnOp);
|
||||
auto newCompute = SpatCompute::create(rewriter, loc, TypeRange(resultTypes), ValueRange(operands));
|
||||
newCompute.getProperties().setOperandSegmentSizes(
|
||||
{static_cast<int>(cpuWeights[cpu].size()), static_cast<int>(cpuExternalInputs[cpu].size())});
|
||||
newCompute->setAttr(onnx_mlir::kCoreIdAttrName, rewriter.getI32IntegerAttr(static_cast<int32_t>(cpu)));
|
||||
|
||||
SmallVector<Type> blockArgTypes;
|
||||
SmallVector<Location> blockArgLocs;
|
||||
blockArgTypes.reserve(cpuExternalInputs[cpu].size());
|
||||
blockArgLocs.reserve(cpuExternalInputs[cpu].size());
|
||||
for (Value input : cpuExternalInputs[cpu]) {
|
||||
blockArgTypes.push_back(input.getType());
|
||||
blockArgLocs.push_back(loc);
|
||||
}
|
||||
Block* newBlock =
|
||||
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
||||
|
||||
CpuProgram program;
|
||||
program.op = newCompute;
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(cpuWeights[cpu]))
|
||||
program.weightToIndex[weight] = weightIndex;
|
||||
for (auto [inputIndex, input] : llvm::enumerate(cpuExternalInputs[cpu]))
|
||||
program.externalInputMap[input] = newBlock->getArgument(inputIndex);
|
||||
for (auto [resultIndex, outputRef] : llvm::enumerate(cpuExternalOutputs[cpu])) {
|
||||
ScheduledTask task = taskByComputeInstance.at(outputRef.instance);
|
||||
oldToNewExternalValueMap[getComputeInstanceOutputValues(task.computeInstance)[outputRef.resultIndex]] =
|
||||
newCompute.getResult(resultIndex);
|
||||
}
|
||||
cpuPrograms[cpu] = std::move(program);
|
||||
}
|
||||
}
|
||||
|
||||
FailureOr<Value> receiveThroughInput(IRRewriter& rewriter,
|
||||
size_t cpu,
|
||||
DenseMap<uint64_t, size_t>& receiveQueueIndices,
|
||||
DenseMap<ComputeInstance, SmallVector<Value>>& preReceivedInputsByTask,
|
||||
const ChannelInfo& requestedChannelInfo,
|
||||
ComputeInstance requestedConsumer,
|
||||
size_t requestedInputIndex) {
|
||||
uint64_t pairKey = getRemoteSendPairKey(requestedChannelInfo);
|
||||
auto cpuQueuesIt = receiveQueuesByCpu.find(cpu);
|
||||
if (cpuQueuesIt == receiveQueuesByCpu.end())
|
||||
return failure();
|
||||
auto queueIt = cpuQueuesIt->second.find(pairKey);
|
||||
if (queueIt == cpuQueuesIt->second.end())
|
||||
return failure();
|
||||
|
||||
auto& queue = queueIt->second;
|
||||
size_t& queueIndex = receiveQueueIndices[pairKey];
|
||||
while (queueIndex < queue.size()) {
|
||||
const RemoteReceiveEntry& entry = queue[queueIndex++];
|
||||
auto consumerTaskIt = taskByComputeInstance.find(entry.consumer);
|
||||
if (consumerTaskIt == taskByComputeInstance.end())
|
||||
return failure();
|
||||
SmallVector<Value> consumerInputs = getComputeInstanceInputs(consumerTaskIt->second.computeInstance);
|
||||
if (consumerInputs.size() <= entry.inputIndex)
|
||||
return failure();
|
||||
Type inputType = consumerInputs[entry.inputIndex].getType();
|
||||
auto receive = spatial::SpatChannelReceiveOp::create(rewriter,
|
||||
loc,
|
||||
inputType,
|
||||
rewriter.getI64IntegerAttr(entry.channelInfo.channelId),
|
||||
rewriter.getI32IntegerAttr(entry.channelInfo.sourceCoreId),
|
||||
rewriter.getI32IntegerAttr(entry.channelInfo.targetCoreId));
|
||||
|
||||
auto& receivedInputs = preReceivedInputsByTask[entry.consumer];
|
||||
if (receivedInputs.size() <= entry.inputIndex)
|
||||
receivedInputs.resize(entry.inputIndex + 1);
|
||||
receivedInputs[entry.inputIndex] = receive.getResult();
|
||||
|
||||
if (entry.consumer == requestedConsumer && entry.inputIndex == requestedInputIndex)
|
||||
return receive.getResult();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
LogicalResult cloneTaskBodies() {
|
||||
for (size_t cpu : orderedCpus) {
|
||||
CpuProgram& program = cpuPrograms[cpu];
|
||||
IRRewriter rewriter(func.getContext());
|
||||
rewriter.setInsertionPointToEnd(&program.op.getBody().front());
|
||||
DenseMap<uint64_t, size_t> receiveQueueIndices;
|
||||
DenseMap<ComputeInstance, SmallVector<Value>> preReceivedInputsByTask;
|
||||
|
||||
auto lookupPreReceivedInput = [&](ComputeInstance consumer, size_t inputIndex) -> std::optional<Value> {
|
||||
auto inputsIt = preReceivedInputsByTask.find(consumer);
|
||||
if (inputsIt == preReceivedInputsByTask.end() || inputsIt->second.size() <= inputIndex)
|
||||
return std::nullopt;
|
||||
Value value = inputsIt->second[inputIndex];
|
||||
if (!value)
|
||||
return std::nullopt;
|
||||
return value;
|
||||
};
|
||||
|
||||
for (const ScheduledTask& task : tasksByCpu[cpu]) {
|
||||
SmallVector<Value> taskInputs = getComputeInstanceInputs(task.computeInstance);
|
||||
auto taskWeights = getComputeInstanceWeights(task.computeInstance);
|
||||
Block& templateBlock = getComputeInstanceTemplateBlock(task.computeInstance);
|
||||
|
||||
SmallVector<Value> resolvedInputs;
|
||||
resolvedInputs.reserve(taskInputs.size());
|
||||
auto remoteInputsIt = remoteInputsByTask.find(task.computeInstance);
|
||||
for (auto [inputIndex, input] : llvm::enumerate(taskInputs)) {
|
||||
auto producerRef = getProducerValueRef(input);
|
||||
if (producerRef) {
|
||||
auto producerIt = taskByComputeInstance.find(producerRef->instance);
|
||||
if (producerIt != taskByComputeInstance.end()) {
|
||||
if (producerIt->second.cpu == cpu) {
|
||||
auto producedIt = producedValuesByTask.find(producerRef->instance);
|
||||
if (producedIt == producedValuesByTask.end() || producedIt->second.size() <= producerRef->resultIndex) {
|
||||
task.computeInstance.op->emitOpError("missing local producer value during per-cpu merge materialization")
|
||||
<< " consumerCpu=" << cpu << " producerCpu=" << producerIt->second.cpu
|
||||
<< " producerLaneStart=" << producerRef->instance.laneStart
|
||||
<< " producerLaneCount=" << producerRef->instance.laneCount;
|
||||
return failure();
|
||||
}
|
||||
resolvedInputs.push_back(producedIt->second[producerRef->resultIndex]);
|
||||
continue;
|
||||
}
|
||||
const ChannelInfo& channelInfo = *remoteInputsIt->second[inputIndex];
|
||||
uint64_t pairKey = getRemoteSendPairKey(channelInfo);
|
||||
if (pairsNeedingReceiveReorder.contains(pairKey)) {
|
||||
if (std::optional<Value> preReceived = lookupPreReceivedInput(task.computeInstance, inputIndex)) {
|
||||
resolvedInputs.push_back(*preReceived);
|
||||
continue;
|
||||
}
|
||||
FailureOr<Value> received = receiveThroughInput(rewriter,
|
||||
cpu,
|
||||
receiveQueueIndices,
|
||||
preReceivedInputsByTask,
|
||||
channelInfo,
|
||||
task.computeInstance,
|
||||
inputIndex);
|
||||
if (failed(received)) {
|
||||
task.computeInstance.op->emitOpError("failed to materialize reordered remote receive")
|
||||
<< " consumerCpu=" << cpu << " sourceCoreId=" << channelInfo.sourceCoreId
|
||||
<< " targetCoreId=" << channelInfo.targetCoreId << " channelId=" << channelInfo.channelId;
|
||||
return failure();
|
||||
}
|
||||
resolvedInputs.push_back(*received);
|
||||
continue;
|
||||
}
|
||||
auto receive =
|
||||
spatial::SpatChannelReceiveOp::create(rewriter,
|
||||
loc,
|
||||
input.getType(),
|
||||
rewriter.getI64IntegerAttr(channelInfo.channelId),
|
||||
rewriter.getI32IntegerAttr(channelInfo.sourceCoreId),
|
||||
rewriter.getI32IntegerAttr(channelInfo.targetCoreId));
|
||||
resolvedInputs.push_back(receive.getResult());
|
||||
continue;
|
||||
}
|
||||
}
|
||||
resolvedInputs.push_back(program.externalInputMap.at(input));
|
||||
}
|
||||
|
||||
SmallVector<Value> taskYieldValues;
|
||||
rewriter.setInsertionPointToEnd(&program.op.getBody().front());
|
||||
if (isa<SpatCompute>(task.computeInstance.op)) {
|
||||
IRMapping mapper;
|
||||
for (auto [argIndex, oldArg] : llvm::enumerate(templateBlock.getArguments()))
|
||||
mapper.map(oldArg, resolvedInputs[argIndex]);
|
||||
|
||||
for (Operation& op : templateBlock) {
|
||||
if (auto yield = dyn_cast<spatial::SpatYieldOp>(&op)) {
|
||||
for (Value yieldOperand : yield.getOperands())
|
||||
taskYieldValues.push_back(mapper.lookup(yieldOperand));
|
||||
continue;
|
||||
}
|
||||
|
||||
Operation* clonedOp = rewriter.clone(op, mapper);
|
||||
if (auto oldWeightedMvmOp = dyn_cast<spatial::SpatMVMOp>(&op)) {
|
||||
auto newWeightedMvmOp = cast<spatial::SpatMVMOp>(clonedOp);
|
||||
Value weight = taskWeights[oldWeightedMvmOp.getWeightIndex()];
|
||||
newWeightedMvmOp.setWeightIndex(program.weightToIndex.at(weight));
|
||||
}
|
||||
if (auto oldWeightedVmmOp = dyn_cast<spatial::SpatVMMOp>(&op)) {
|
||||
auto newWeightedVmmOp = cast<spatial::SpatVMMOp>(clonedOp);
|
||||
Value weight = taskWeights[oldWeightedVmmOp.getWeightIndex()];
|
||||
newWeightedVmmOp.setWeightIndex(program.weightToIndex.at(weight));
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
for (size_t laneOffset = 0; laneOffset < task.computeInstance.laneCount; ++laneOffset) {
|
||||
IRMapping mapper;
|
||||
if (templateBlock.getNumArguments() == 1)
|
||||
mapper.map(templateBlock.getArgument(0), resolvedInputs[laneOffset]);
|
||||
|
||||
for (Operation& op : templateBlock) {
|
||||
if (auto yield = dyn_cast<spatial::SpatYieldOp>(&op)) {
|
||||
for (Value yieldOperand : yield.getOperands())
|
||||
taskYieldValues.push_back(mapper.lookup(yieldOperand));
|
||||
continue;
|
||||
}
|
||||
|
||||
Operation* clonedOp = rewriter.clone(op, mapper);
|
||||
if (auto oldWeightedMvmOp = dyn_cast<spatial::SpatMVMOp>(&op)) {
|
||||
if (oldWeightedMvmOp.getWeightIndex() != 0) {
|
||||
task.computeInstance.op->emitOpError(
|
||||
"batched per-cpu merge materialization expects lane-local weight index 0");
|
||||
return failure();
|
||||
}
|
||||
auto newWeightedMvmOp = cast<spatial::SpatMVMOp>(clonedOp);
|
||||
newWeightedMvmOp.setWeightIndex(program.weightToIndex.at(taskWeights[laneOffset]));
|
||||
}
|
||||
if (auto oldWeightedVmmOp = dyn_cast<spatial::SpatVMMOp>(&op)) {
|
||||
if (oldWeightedVmmOp.getWeightIndex() != 0) {
|
||||
task.computeInstance.op->emitOpError(
|
||||
"batched per-cpu merge materialization expects lane-local weight index 0");
|
||||
return failure();
|
||||
}
|
||||
auto newWeightedVmmOp = cast<spatial::SpatVMMOp>(clonedOp);
|
||||
newWeightedVmmOp.setWeightIndex(program.weightToIndex.at(taskWeights[laneOffset]));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
producedValuesByTask[task.computeInstance] = taskYieldValues;
|
||||
if (auto sendsIt = remoteSendsByTask.find(task.computeInstance); sendsIt != remoteSendsByTask.end()) {
|
||||
for (auto [resultIndex, sendInfos] : llvm::enumerate(sendsIt->second)) {
|
||||
if (sendInfos.empty())
|
||||
continue;
|
||||
Value producedValue = taskYieldValues[resultIndex];
|
||||
for (const RemoteSendInfo& sendInfo : sendInfos) {
|
||||
spatial::SpatChannelSendOp::create(rewriter,
|
||||
loc,
|
||||
rewriter.getI64IntegerAttr(sendInfo.channelInfo.channelId),
|
||||
rewriter.getI32IntegerAttr(sendInfo.channelInfo.sourceCoreId),
|
||||
rewriter.getI32IntegerAttr(sendInfo.channelInfo.targetCoreId),
|
||||
producedValue);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<Value> yieldValues;
|
||||
yieldValues.reserve(cpuExternalOutputs[cpu].size());
|
||||
for (ProducerValueRef outputRef : cpuExternalOutputs[cpu]) {
|
||||
auto producedIt = producedValuesByTask.find(outputRef.instance);
|
||||
if (producedIt == producedValuesByTask.end() || producedIt->second.size() <= outputRef.resultIndex) {
|
||||
ScheduledTask task = taskByComputeInstance.at(outputRef.instance);
|
||||
task.computeInstance.op->emitOpError("missing yielded external value during per-cpu merge materialization")
|
||||
<< " cpu=" << cpu << " laneStart=" << outputRef.instance.laneStart;
|
||||
return failure();
|
||||
}
|
||||
yieldValues.push_back(producedIt->second[outputRef.resultIndex]);
|
||||
}
|
||||
spatial::SpatYieldOp::create(rewriter, loc, ValueRange(yieldValues));
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
void replaceExternalUses() {
|
||||
for (auto [oldValue, newValue] : oldToNewExternalValueMap) {
|
||||
for (auto& use : llvm::make_early_inc_range(oldValue.getUses()))
|
||||
if (!oldComputeOps.contains(use.getOwner()))
|
||||
use.assign(newValue);
|
||||
}
|
||||
}
|
||||
|
||||
LogicalResult eraseOldScheduledOps() {
|
||||
SmallVector<Operation*> orderedOpsToErase;
|
||||
for (Operation& op : func.getBody().front())
|
||||
if (oldComputeOps.contains(&op))
|
||||
orderedOpsToErase.push_back(&op);
|
||||
|
||||
for (Operation* op : llvm::reverse(orderedOpsToErase)) {
|
||||
SmallVector<Operation*> remainingUsers;
|
||||
for (Value result : op->getResults())
|
||||
for (Operation* user : result.getUsers())
|
||||
remainingUsers.push_back(user);
|
||||
if (!remainingUsers.empty()) {
|
||||
InFlightDiagnostic diagnostic = op->emitOpError("still has uses during per-cpu merge cleanup")
|
||||
<< "; erase-set=" << (oldComputeOps.contains(op) ? "yes" : "no");
|
||||
for (Operation* user : remainingUsers) {
|
||||
diagnostic.attachNote(user->getLoc())
|
||||
<< "remaining user " << user->getName() << "; erase-set=" << (oldComputeOps.contains(user) ? "yes" : "no");
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
op->erase();
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
void moveExternalUsersBeforeReturn() {
|
||||
SmallVector<Operation*> orderedUsersToMove;
|
||||
for (Operation& op : func.getBody().front()) {
|
||||
if (&op == returnOp.getOperation())
|
||||
break;
|
||||
if (externalUsersToMove.contains(&op))
|
||||
orderedUsersToMove.push_back(&op);
|
||||
}
|
||||
for (Operation* op : orderedUsersToMove)
|
||||
op->moveBefore(returnOp);
|
||||
}
|
||||
|
||||
func::FuncOp func;
|
||||
const MergeScheduleResult* schedule = nullptr;
|
||||
int64_t* nextChannelId = nullptr;
|
||||
Location loc;
|
||||
func::ReturnOp returnOp;
|
||||
|
||||
SmallVector<ScheduledTask> scheduledTasks;
|
||||
DenseSet<Operation*> oldComputeOps;
|
||||
DenseMap<ComputeInstance, ScheduledTask> taskByComputeInstance;
|
||||
DenseMap<size_t, SmallVector<ScheduledTask>> tasksByCpu;
|
||||
SmallVector<size_t> orderedCpus;
|
||||
DenseSet<size_t> seenCpus;
|
||||
DenseSet<Operation*> externalUsersToMove;
|
||||
DenseMap<ComputeInstance, SmallVector<SmallVector<RemoteSendInfo>>> remoteSendsByTask;
|
||||
DenseMap<ComputeInstance, SmallVector<std::optional<ChannelInfo>>> remoteInputsByTask;
|
||||
DenseMap<size_t, SmallVector<Value>> cpuExternalInputs;
|
||||
DenseMap<size_t, SmallVector<Value>> cpuWeights;
|
||||
DenseMap<size_t, SmallVector<ProducerValueRef>> cpuExternalOutputs;
|
||||
DenseMap<size_t, DenseSet<Value>> seenExternalInputsByCpu;
|
||||
DenseMap<size_t, DenseSet<Value>> seenWeightsByCpu;
|
||||
DenseSet<uint64_t> pairsNeedingReceiveReorder;
|
||||
DenseMap<size_t, DenseMap<uint64_t, SmallVector<RemoteReceiveEntry>>> receiveQueuesByCpu;
|
||||
DenseMap<size_t, CpuProgram> cpuPrograms;
|
||||
DenseMap<Value, Value> oldToNewExternalValueMap;
|
||||
DenseMap<ComputeInstance, SmallVector<Value>> producedValuesByTask;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult
|
||||
MergeScheduleMaterializer::run(func::FuncOp func, const MergeScheduleResult& schedule, int64_t& nextChannelId) {
|
||||
return MergeScheduleMaterializerImpl(func).run(schedule, nextChannelId);
|
||||
}
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,18 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
|
||||
#include "Scheduling/MergeSchedule.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
class MergeScheduleMaterializer {
|
||||
public:
|
||||
mlir::LogicalResult
|
||||
run(mlir::func::FuncOp func, const MergeScheduleResult &schedule, int64_t &nextChannelId);
|
||||
};
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,459 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "llvm/ADT/Hashing.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include <chrono>
|
||||
#include <cstdlib>
|
||||
#include <limits>
|
||||
#include <optional>
|
||||
|
||||
#include "PostMergeCompaction.hpp"
|
||||
#include "RegularOpCompaction.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
using SpatCompute = spatial::SpatCompute;
|
||||
using SpatComputeBatch = spatial::SpatComputeBatch;
|
||||
|
||||
bool isMergeProfilingEnabled() { return std::getenv("RAPTOR_PROFILE_MERGE") != nullptr; }
|
||||
|
||||
class ScopedMergePhaseTimer {
|
||||
public:
|
||||
explicit ScopedMergePhaseTimer(StringRef phaseName)
|
||||
: enabled(isMergeProfilingEnabled()), phase(phaseName.str()) {
|
||||
if (enabled)
|
||||
start = std::chrono::steady_clock::now();
|
||||
}
|
||||
|
||||
~ScopedMergePhaseTimer() {
|
||||
if (!enabled)
|
||||
return;
|
||||
auto elapsed = std::chrono::steady_clock::now() - start;
|
||||
double millis = std::chrono::duration<double, std::milli>(elapsed).count();
|
||||
llvm::errs() << "[merge-profile] " << phase << ": " << llvm::formatv("{0:F3}", millis) << " ms\n";
|
||||
}
|
||||
|
||||
private:
|
||||
bool enabled = false;
|
||||
std::string phase;
|
||||
std::chrono::steady_clock::time_point start;
|
||||
};
|
||||
|
||||
std::optional<int32_t> getComputeCoreId(SpatCompute compute) {
|
||||
if (auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
|
||||
return static_cast<int32_t>(coreIdAttr.getInt());
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
static constexpr StringLiteral kRebatchPhaseAttrName = "_pim_rebatch_phase";
|
||||
|
||||
std::optional<uint64_t> getComputeRebatchPhase(SpatCompute compute) {
|
||||
if (auto phaseAttr = compute->getAttrOfType<IntegerAttr>(kRebatchPhaseAttrName))
|
||||
return static_cast<uint64_t>(phaseAttr.getInt());
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
struct RebatchKey {
|
||||
unsigned inputCount = 0;
|
||||
unsigned resultCount = 0;
|
||||
unsigned weightCount = 0;
|
||||
uint64_t phase = 0;
|
||||
bool hasPhase = false;
|
||||
uint64_t structureHash = 0;
|
||||
|
||||
bool operator==(const RebatchKey& other) const {
|
||||
return inputCount == other.inputCount && resultCount == other.resultCount && weightCount == other.weightCount
|
||||
&& phase == other.phase && hasPhase == other.hasPhase && structureHash == other.structureHash;
|
||||
}
|
||||
};
|
||||
|
||||
struct RebatchKeyInfo {
|
||||
static inline RebatchKey getEmptyKey() { return {std::numeric_limits<unsigned>::max(), 0, 0, 0, false, 0}; }
|
||||
|
||||
static inline RebatchKey getTombstoneKey() { return {std::numeric_limits<unsigned>::max() - 1, 0, 0, 0, false, 0}; }
|
||||
|
||||
static unsigned getHashValue(const RebatchKey& key) {
|
||||
return static_cast<unsigned>(
|
||||
llvm::hash_combine(key.inputCount, key.resultCount, key.weightCount, key.phase, key.hasPhase, key.structureHash));
|
||||
}
|
||||
|
||||
static bool isEqual(const RebatchKey& lhs, const RebatchKey& rhs) { return lhs == rhs; }
|
||||
};
|
||||
|
||||
uint64_t getTypeHash(Type type) { return reinterpret_cast<uintptr_t>(type.getAsOpaquePointer()); }
|
||||
|
||||
uint64_t getValueHash(Value value) { return reinterpret_cast<uintptr_t>(value.getAsOpaquePointer()); }
|
||||
|
||||
uint64_t getAttributeHash(Attribute attr) { return reinterpret_cast<uintptr_t>(attr.getAsOpaquePointer()); }
|
||||
|
||||
RebatchKey computeRebatchKey(SpatCompute compute) {
|
||||
llvm::hash_code structureHash =
|
||||
llvm::hash_combine(compute.getInputs().size(), compute.getResultTypes().size(), compute.getWeights().size());
|
||||
|
||||
for (Value weight : compute.getWeights())
|
||||
structureHash = llvm::hash_combine(structureHash, getValueHash(weight));
|
||||
if (std::optional<uint64_t> phase = getComputeRebatchPhase(compute))
|
||||
structureHash = llvm::hash_combine(structureHash, *phase);
|
||||
|
||||
Block& body = compute.getBody().front();
|
||||
structureHash = llvm::hash_combine(structureHash, body.getNumArguments());
|
||||
for (BlockArgument arg : body.getArguments())
|
||||
structureHash = llvm::hash_combine(structureHash, getTypeHash(arg.getType()));
|
||||
|
||||
for (Operation& op : body) {
|
||||
structureHash = llvm::hash_combine(
|
||||
structureHash, op.getName().getStringRef(), op.getNumOperands(), op.getNumResults(), op.getNumRegions());
|
||||
for (Type type : op.getResultTypes())
|
||||
structureHash = llvm::hash_combine(structureHash, getTypeHash(type));
|
||||
for (NamedAttribute attr : op.getAttrs())
|
||||
structureHash = llvm::hash_combine(structureHash, attr.getName().strref(), getAttributeHash(attr.getValue()));
|
||||
}
|
||||
|
||||
std::optional<uint64_t> phase = getComputeRebatchPhase(compute);
|
||||
return {static_cast<unsigned>(compute.getInputs().size()),
|
||||
static_cast<unsigned>(compute.getResultTypes().size()),
|
||||
static_cast<unsigned>(compute.getWeights().size()),
|
||||
phase.value_or(0),
|
||||
phase.has_value(),
|
||||
static_cast<uint64_t>(structureHash)};
|
||||
}
|
||||
|
||||
bool areEquivalentForRebatch(SpatCompute lhs, SpatCompute rhs) {
|
||||
if (!lhs || !rhs)
|
||||
return false;
|
||||
if (lhs.getInputs().size() != rhs.getInputs().size())
|
||||
return false;
|
||||
if (lhs.getResultTypes() != rhs.getResultTypes())
|
||||
return false;
|
||||
if (lhs.getWeights().size() != rhs.getWeights().size())
|
||||
return false;
|
||||
if (getComputeRebatchPhase(lhs) != getComputeRebatchPhase(rhs))
|
||||
return false;
|
||||
if (!llvm::equal(lhs.getWeights(), rhs.getWeights()))
|
||||
return false;
|
||||
|
||||
auto& lhsBlock = lhs.getBody().front();
|
||||
auto& rhsBlock = rhs.getBody().front();
|
||||
if (lhsBlock.getNumArguments() != rhsBlock.getNumArguments())
|
||||
return false;
|
||||
|
||||
DenseMap<Value, Value> mappedValues;
|
||||
for (auto [lhsArg, rhsArg] : llvm::zip(lhsBlock.getArguments(), rhsBlock.getArguments())) {
|
||||
if (lhsArg.getType() != rhsArg.getType())
|
||||
return false;
|
||||
mappedValues[lhsArg] = rhsArg;
|
||||
}
|
||||
auto lhsIt = lhsBlock.begin();
|
||||
auto rhsIt = rhsBlock.begin();
|
||||
for (; lhsIt != lhsBlock.end() && rhsIt != rhsBlock.end(); ++lhsIt, ++rhsIt) {
|
||||
Operation& lhsOp = *lhsIt;
|
||||
Operation& rhsOp = *rhsIt;
|
||||
|
||||
if (lhsOp.getName() != rhsOp.getName())
|
||||
return false;
|
||||
if (lhsOp.getNumOperands() != rhsOp.getNumOperands())
|
||||
return false;
|
||||
if (lhsOp.getNumResults() != rhsOp.getNumResults())
|
||||
return false;
|
||||
if (lhsOp.getNumRegions() != 0 || rhsOp.getNumRegions() != 0)
|
||||
return false;
|
||||
|
||||
for (auto [lhsOperand, rhsOperand] : llvm::zip(lhsOp.getOperands(), rhsOp.getOperands())) {
|
||||
auto mapped = mappedValues.find(lhsOperand);
|
||||
if (mapped != mappedValues.end()) {
|
||||
if (mapped->second != rhsOperand)
|
||||
return false;
|
||||
continue;
|
||||
}
|
||||
if (lhsOperand != rhsOperand)
|
||||
return false;
|
||||
}
|
||||
|
||||
if (auto lhsReceive = dyn_cast<spatial::SpatChannelReceiveOp>(lhsOp)) {
|
||||
auto rhsReceive = cast<spatial::SpatChannelReceiveOp>(rhsOp);
|
||||
if (lhsReceive.getOutput().getType() != rhsReceive.getOutput().getType())
|
||||
return false;
|
||||
}
|
||||
else if (auto lhsSend = dyn_cast<spatial::SpatChannelSendOp>(lhsOp)) {
|
||||
auto rhsSend = cast<spatial::SpatChannelSendOp>(rhsOp);
|
||||
if (lhsSend.getInput().getType() != rhsSend.getInput().getType())
|
||||
return false;
|
||||
}
|
||||
else if (lhsOp.getAttrs() != rhsOp.getAttrs()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (lhsOp.getResultTypes() != rhsOp.getResultTypes())
|
||||
return false;
|
||||
for (auto [lhsResult, rhsResult] : llvm::zip(lhsOp.getResults(), rhsOp.getResults()))
|
||||
mappedValues[lhsResult] = rhsResult;
|
||||
}
|
||||
|
||||
return lhsIt == lhsBlock.end() && rhsIt == rhsBlock.end();
|
||||
}
|
||||
|
||||
void rebatchEquivalentComputes(func::FuncOp funcOp) {
|
||||
IRRewriter rewriter(funcOp.getContext());
|
||||
SmallVector<SpatCompute> computes(funcOp.getOps<SpatCompute>());
|
||||
DenseSet<Operation*> consumed;
|
||||
DenseMap<Operation*, size_t> computeOrder;
|
||||
DenseMap<RebatchKey, SmallVector<SpatCompute>, RebatchKeyInfo> candidatesByKey;
|
||||
|
||||
for (auto [index, compute] : llvm::enumerate(computes)) {
|
||||
computeOrder[compute.getOperation()] = index;
|
||||
if (compute.getInputs().size() <= 1 && compute.getResults().empty())
|
||||
candidatesByKey[computeRebatchKey(compute)].push_back(compute);
|
||||
}
|
||||
|
||||
for (size_t index = 0; index < computes.size(); ++index) {
|
||||
auto anchor = computes[index];
|
||||
if (consumed.contains(anchor))
|
||||
continue;
|
||||
if (anchor.getInputs().size() > 1)
|
||||
continue;
|
||||
if (!anchor.getResults().empty())
|
||||
continue;
|
||||
|
||||
SmallVector<SpatCompute> group {anchor};
|
||||
llvm::SmallDenseSet<int32_t, 8> usedCoreIds;
|
||||
if (auto coreId = getComputeCoreId(anchor))
|
||||
usedCoreIds.insert(*coreId);
|
||||
|
||||
auto bucketIt = candidatesByKey.find(computeRebatchKey(anchor));
|
||||
if (bucketIt == candidatesByKey.end())
|
||||
continue;
|
||||
|
||||
for (auto candidate : bucketIt->second) {
|
||||
if (computeOrder.lookup(candidate.getOperation()) <= index)
|
||||
continue;
|
||||
if (consumed.contains(candidate))
|
||||
continue;
|
||||
if (!areEquivalentForRebatch(anchor, candidate))
|
||||
continue;
|
||||
|
||||
if (auto coreId = getComputeCoreId(candidate))
|
||||
if (!usedCoreIds.insert(*coreId).second)
|
||||
continue;
|
||||
|
||||
group.push_back(candidate);
|
||||
}
|
||||
|
||||
if (group.size() <= 1)
|
||||
continue;
|
||||
|
||||
auto insertionAnchor = group.front();
|
||||
if (llvm::all_of(group, [](SpatCompute compute) { return getComputeCoreId(compute).has_value(); })) {
|
||||
llvm::stable_sort(
|
||||
group, [](SpatCompute lhs, SpatCompute rhs) { return *getComputeCoreId(lhs) < *getComputeCoreId(rhs); });
|
||||
}
|
||||
|
||||
SmallVector<Value> weights;
|
||||
weights.reserve(group.size() * anchor.getWeights().size());
|
||||
SmallVector<Value> inputs;
|
||||
inputs.reserve(group.size() * anchor.getInputs().size());
|
||||
SmallVector<int32_t> coreIds;
|
||||
coreIds.reserve(group.size());
|
||||
bool haveAllCoreIds = true;
|
||||
for (auto compute : group) {
|
||||
llvm::append_range(weights, compute.getWeights());
|
||||
llvm::append_range(inputs, compute.getInputs());
|
||||
auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName);
|
||||
if (!coreIdAttr)
|
||||
haveAllCoreIds = false;
|
||||
else if (haveAllCoreIds)
|
||||
coreIds.push_back(static_cast<int32_t>(coreIdAttr.getInt()));
|
||||
}
|
||||
|
||||
rewriter.setInsertionPoint(insertionAnchor);
|
||||
auto rebatched = SpatComputeBatch::create(rewriter,
|
||||
insertionAnchor.getLoc(),
|
||||
TypeRange {},
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(group.size())),
|
||||
ValueRange(weights),
|
||||
ValueRange(inputs));
|
||||
rebatched.getProperties().setOperandSegmentSizes(
|
||||
{static_cast<int>(weights.size()), static_cast<int>(inputs.size())});
|
||||
if (haveAllCoreIds)
|
||||
rebatched->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds));
|
||||
|
||||
SmallVector<Type> blockArgTypes;
|
||||
SmallVector<Location> blockArgLocs;
|
||||
for (BlockArgument arg : anchor.getBody().front().getArguments()) {
|
||||
blockArgTypes.push_back(arg.getType());
|
||||
blockArgLocs.push_back(arg.getLoc());
|
||||
}
|
||||
auto* newBlock =
|
||||
rewriter.createBlock(&rebatched.getBody(), rebatched.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
||||
rewriter.setInsertionPointToEnd(newBlock);
|
||||
|
||||
IRMapping mapper;
|
||||
auto& anchorBlock = anchor.getBody().front();
|
||||
for (auto [oldArg, newArg] : llvm::zip(anchorBlock.getArguments(), newBlock->getArguments()))
|
||||
mapper.map(oldArg, newArg);
|
||||
auto opIts = llvm::map_to_vector(group, [](SpatCompute compute) { return compute.getBody().front().begin(); });
|
||||
for (Operation& anchorOp : anchorBlock) {
|
||||
if (auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(&anchorOp)) {
|
||||
struct BatchReceiveEntry {
|
||||
uint64_t channelId = 0;
|
||||
uint32_t sourceCoreId = 0;
|
||||
uint32_t targetCoreId = 0;
|
||||
};
|
||||
SmallVector<BatchReceiveEntry> entries;
|
||||
entries.reserve(group.size());
|
||||
for (auto [groupIndex, compute] : llvm::enumerate(group)) {
|
||||
auto groupReceive = cast<spatial::SpatChannelReceiveOp>(&*opIts[groupIndex]);
|
||||
entries.push_back(
|
||||
{groupReceive.getChannelId(), groupReceive.getSourceCoreId(), groupReceive.getTargetCoreId()});
|
||||
++opIts[groupIndex];
|
||||
}
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
channelIds.reserve(group.size());
|
||||
sourceCoreIds.reserve(group.size());
|
||||
targetCoreIds.reserve(group.size());
|
||||
for (const BatchReceiveEntry& entry : entries) {
|
||||
channelIds.push_back(static_cast<int64_t>(entry.channelId));
|
||||
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
|
||||
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
|
||||
}
|
||||
auto batchReceive = spatial::SpatChannelReceiveBatchOp::create(rewriter,
|
||||
receiveOp.getLoc(),
|
||||
receiveOp.getOutput().getType(),
|
||||
rewriter.getDenseI64ArrayAttr(channelIds),
|
||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
||||
rewriter.getDenseI32ArrayAttr(targetCoreIds));
|
||||
mapper.map(receiveOp.getOutput(), batchReceive.getOutput());
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto sendOp = dyn_cast<spatial::SpatChannelSendOp>(&anchorOp)) {
|
||||
struct BatchSendEntry {
|
||||
uint64_t channelId = 0;
|
||||
uint32_t sourceCoreId = 0;
|
||||
uint32_t targetCoreId = 0;
|
||||
};
|
||||
SmallVector<BatchSendEntry> entries;
|
||||
entries.reserve(group.size());
|
||||
for (auto [groupIndex, compute] : llvm::enumerate(group)) {
|
||||
auto groupSend = cast<spatial::SpatChannelSendOp>(&*opIts[groupIndex]);
|
||||
entries.push_back({groupSend.getChannelId(), groupSend.getSourceCoreId(), groupSend.getTargetCoreId()});
|
||||
++opIts[groupIndex];
|
||||
}
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
channelIds.reserve(group.size());
|
||||
sourceCoreIds.reserve(group.size());
|
||||
targetCoreIds.reserve(group.size());
|
||||
for (const BatchSendEntry& entry : entries) {
|
||||
channelIds.push_back(static_cast<int64_t>(entry.channelId));
|
||||
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
|
||||
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
|
||||
}
|
||||
spatial::SpatChannelSendBatchOp::create(rewriter,
|
||||
sendOp.getLoc(),
|
||||
rewriter.getDenseI64ArrayAttr(channelIds),
|
||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
||||
rewriter.getDenseI32ArrayAttr(targetCoreIds),
|
||||
mapper.lookup(sendOp.getInput()));
|
||||
continue;
|
||||
}
|
||||
|
||||
if (isa<spatial::SpatYieldOp>(anchorOp)) {
|
||||
for (auto& opIt : opIts)
|
||||
++opIt;
|
||||
spatial::SpatYieldOp::create(rewriter, anchorOp.getLoc(), ValueRange {});
|
||||
continue;
|
||||
}
|
||||
|
||||
Operation* cloned = rewriter.clone(anchorOp, mapper);
|
||||
for (auto [originalResult, clonedResult] : llvm::zip(anchorOp.getResults(), cloned->getResults()))
|
||||
mapper.map(originalResult, clonedResult);
|
||||
for (auto& opIt : opIts)
|
||||
++opIt;
|
||||
}
|
||||
|
||||
for (auto compute : group) {
|
||||
compute->removeAttr(kRebatchPhaseAttrName);
|
||||
consumed.insert(compute);
|
||||
rewriter.eraseOp(compute);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto compute : funcOp.getOps<SpatCompute>())
|
||||
compute->removeAttr(kRebatchPhaseAttrName);
|
||||
}
|
||||
|
||||
void cleanupDeadPackingOps(func::FuncOp funcOp) {
|
||||
auto eraseUnusedOps = [&](auto tag) {
|
||||
using OpTy = decltype(tag);
|
||||
SmallVector<OpTy> ops;
|
||||
funcOp.walk([&](OpTy op) { ops.push_back(op); });
|
||||
for (auto op : llvm::reverse(ops))
|
||||
if (op->use_empty())
|
||||
op.erase();
|
||||
};
|
||||
eraseUnusedOps(tensor::ExtractSliceOp {});
|
||||
eraseUnusedOps(spatial::SpatConcatOp {});
|
||||
eraseUnusedOps(spatial::SpatExtractRowsOp {});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult runPostMergeCompactionPipeline(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||
{
|
||||
ScopedMergePhaseTimer timer("order-bilateral-channel-ops");
|
||||
orderBilateralChannelOps(funcOp);
|
||||
}
|
||||
{
|
||||
ScopedMergePhaseTimer timer("rebatch-equivalent-computes");
|
||||
rebatchEquivalentComputes(funcOp);
|
||||
}
|
||||
{
|
||||
ScopedMergePhaseTimer timer("compact-scalar-channel-runs-1");
|
||||
compactScalarChannelRuns(funcOp, nextChannelId);
|
||||
}
|
||||
{
|
||||
ScopedMergePhaseTimer timer("compact-batch-channel-runs-1");
|
||||
compactBatchChannelRuns(funcOp);
|
||||
}
|
||||
{
|
||||
ScopedMergePhaseTimer timer("compact-regular-op-runs");
|
||||
compactRegularOpRuns(funcOp);
|
||||
}
|
||||
{
|
||||
ScopedMergePhaseTimer timer("compact-row-wise-wvmm-runs");
|
||||
compactRowWiseWvmmRuns(funcOp);
|
||||
}
|
||||
{
|
||||
ScopedMergePhaseTimer timer("compact-scalar-channel-runs-2");
|
||||
compactScalarChannelRuns(funcOp, nextChannelId);
|
||||
}
|
||||
{
|
||||
ScopedMergePhaseTimer timer("compact-batch-channel-runs-2");
|
||||
compactBatchChannelRuns(funcOp);
|
||||
}
|
||||
{
|
||||
ScopedMergePhaseTimer timer("cleanup-dead-packing-ops");
|
||||
cleanupDeadPackingOps(funcOp);
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,12 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
mlir::LogicalResult runPostMergeCompactionPipeline(mlir::func::FuncOp funcOp, int64_t &nextChannelId);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -7,12 +7,13 @@
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <tuple>
|
||||
|
||||
#include "RegularOpCompaction.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
@@ -42,6 +43,47 @@ struct RegularChunk {
|
||||
Value output;
|
||||
};
|
||||
|
||||
struct RegularCompactionResult {
|
||||
bool changed = false;
|
||||
Operation* resumeAfter = nullptr;
|
||||
};
|
||||
|
||||
template <typename OpTy>
|
||||
struct ConsecutiveRun {
|
||||
SmallVector<OpTy> ops;
|
||||
Block::iterator end;
|
||||
};
|
||||
|
||||
template <typename OpTy, typename Predicate>
|
||||
static ConsecutiveRun<OpTy>
|
||||
collectConsecutiveRun(Block::iterator start, Block::iterator blockEnd, Predicate predicate) {
|
||||
ConsecutiveRun<OpTy> run;
|
||||
run.end = start;
|
||||
while (run.end != blockEnd) {
|
||||
auto current = dyn_cast<OpTy>(&*run.end);
|
||||
if (!current || !predicate(current))
|
||||
break;
|
||||
run.ops.push_back(current);
|
||||
++run.end;
|
||||
}
|
||||
return run;
|
||||
}
|
||||
|
||||
static uint64_t getEndpointKey(uint32_t sourceCoreId, uint32_t targetCoreId) {
|
||||
return (static_cast<uint64_t>(sourceCoreId) << 32) | static_cast<uint64_t>(targetCoreId);
|
||||
}
|
||||
|
||||
static void appendChannelAttrs(SmallVectorImpl<int64_t>& channelIds,
|
||||
SmallVectorImpl<int32_t>& sourceCoreIds,
|
||||
SmallVectorImpl<int32_t>& targetCoreIds,
|
||||
uint64_t channelId,
|
||||
uint32_t sourceCoreId,
|
||||
uint32_t targetCoreId) {
|
||||
channelIds.push_back(static_cast<int64_t>(channelId));
|
||||
sourceCoreIds.push_back(static_cast<int32_t>(sourceCoreId));
|
||||
targetCoreIds.push_back(static_cast<int32_t>(targetCoreId));
|
||||
}
|
||||
|
||||
static spatial::SpatConcatOp getContiguousConcatUse(ValueRange values, unsigned& startOperandIndex) {
|
||||
if (values.empty() || !values.front().hasOneUse())
|
||||
return {};
|
||||
@@ -168,6 +210,17 @@ static bool areEquivalentRegularChunks(const RegularChunk& lhs, const RegularChu
|
||||
[](auto pair) { return areEquivalentRegularSteps(std::get<0>(pair), std::get<1>(pair)); });
|
||||
}
|
||||
|
||||
static bool isForwardedChannelPayload(Value value, Block& block) {
|
||||
Operation* op = value.getDefiningOp();
|
||||
if (!op || op->getBlock() != &block)
|
||||
return true;
|
||||
|
||||
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op))
|
||||
return isForwardedChannelPayload(extractSliceOp.getSource(), block);
|
||||
|
||||
return isa<spatial::SpatChannelReceiveOp, spatial::SpatChannelReceiveTensorOp>(op);
|
||||
}
|
||||
|
||||
static FailureOr<RegularChunk> analyzeRegularChunk(spatial::SpatVMMOp startOp) {
|
||||
RegularChunk chunk;
|
||||
chunk.startOp = startOp.getOperation();
|
||||
@@ -202,9 +255,10 @@ static FailureOr<RegularChunk> analyzeRegularChunk(spatial::SpatVMMOp startOp) {
|
||||
return chunk;
|
||||
}
|
||||
|
||||
static void compactRegularChunkRun(IRRewriter& rewriter, ArrayRef<RegularChunk> run) {
|
||||
static RegularCompactionResult compactRegularChunkRun(IRRewriter& rewriter, ArrayRef<RegularChunk> run) {
|
||||
assert(!run.empty() && "expected a non-empty regular chunk run");
|
||||
const RegularChunk& anchorChunk = run.front();
|
||||
RegularCompactionResult result;
|
||||
|
||||
SmallVector<Value> inputs;
|
||||
inputs.reserve(run.size());
|
||||
@@ -214,7 +268,7 @@ static void compactRegularChunkRun(IRRewriter& rewriter, ArrayRef<RegularChunk>
|
||||
rewriter.setInsertionPoint(anchorChunk.startOp);
|
||||
Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, anchorChunk.startOp->getLoc());
|
||||
if (!packedInput)
|
||||
return;
|
||||
return result;
|
||||
|
||||
auto inputType = cast<RankedTensorType>(anchorChunk.input.getType());
|
||||
auto outputType = cast<RankedTensorType>(anchorChunk.output.getType());
|
||||
@@ -317,10 +371,79 @@ static void compactRegularChunkRun(IRRewriter& rewriter, ArrayRef<RegularChunk>
|
||||
llvm::append_range(opsToErase, chunk.ops);
|
||||
for (Operation* op : llvm::reverse(opsToErase))
|
||||
rewriter.eraseOp(op);
|
||||
|
||||
result.changed = true;
|
||||
result.resumeAfter = loop.getOperation()->getNextNode();
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void orderBilateralChannelOps(func::FuncOp funcOp) {
|
||||
for (auto compute : funcOp.getOps<spatial::SpatCompute>()) {
|
||||
auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName);
|
||||
if (!coreIdAttr)
|
||||
continue;
|
||||
|
||||
int32_t coreId = static_cast<int32_t>(coreIdAttr.getInt());
|
||||
Block& block = compute.getBody().front();
|
||||
SmallVector<std::pair<spatial::SpatChannelReceiveOp, Operation*>> moves;
|
||||
DenseMap<uint64_t, Operation*> firstForwardedSendByEndpoint;
|
||||
|
||||
for (Operation& op : block) {
|
||||
if (auto sendOp = dyn_cast<spatial::SpatChannelSendOp>(&op)) {
|
||||
if (sendOp.getSourceCoreId() == static_cast<uint32_t>(coreId)
|
||||
&& isForwardedChannelPayload(sendOp.getInput(), block)) {
|
||||
uint64_t key = getEndpointKey(sendOp.getSourceCoreId(), sendOp.getTargetCoreId());
|
||||
firstForwardedSendByEndpoint.try_emplace(key, sendOp.getOperation());
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(&op);
|
||||
if (!receiveOp || receiveOp.getTargetCoreId() != static_cast<uint32_t>(coreId)
|
||||
|| receiveOp.getSourceCoreId() >= static_cast<uint32_t>(coreId)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
uint64_t key = getEndpointKey(static_cast<uint32_t>(coreId), receiveOp.getSourceCoreId());
|
||||
auto firstMatchingSend = firstForwardedSendByEndpoint.find(key);
|
||||
if (firstMatchingSend != firstForwardedSendByEndpoint.end())
|
||||
moves.push_back({receiveOp, firstMatchingSend->second});
|
||||
}
|
||||
|
||||
for (auto [receiveOp, insertionPoint] : moves)
|
||||
receiveOp->moveBefore(insertionPoint);
|
||||
|
||||
for (auto it = block.begin(); it != block.end();) {
|
||||
auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(&*it);
|
||||
if (!receiveOp || receiveOp.getSourceCoreId() >= static_cast<uint32_t>(coreId)) {
|
||||
++it;
|
||||
continue;
|
||||
}
|
||||
|
||||
Type outputType = receiveOp.getOutput().getType();
|
||||
auto run = collectConsecutiveRun<spatial::SpatChannelReceiveOp>(
|
||||
it, block.end(), [&](spatial::SpatChannelReceiveOp current) {
|
||||
return current.getOutput().getType() == outputType
|
||||
&& current.getSourceCoreId() < static_cast<uint32_t>(coreId);
|
||||
});
|
||||
|
||||
if (run.ops.size() > 1) {
|
||||
SmallVector<spatial::SpatChannelReceiveOp> sorted(run.ops);
|
||||
llvm::stable_sort(sorted, [](spatial::SpatChannelReceiveOp lhs, spatial::SpatChannelReceiveOp rhs) {
|
||||
return lhs.getSourceCoreId() > rhs.getSourceCoreId();
|
||||
});
|
||||
Block::iterator insertIt = run.end;
|
||||
for (auto op : sorted)
|
||||
op->moveBefore(&block, insertIt);
|
||||
}
|
||||
|
||||
it = run.end;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||
IRRewriter rewriter(funcOp.getContext());
|
||||
|
||||
@@ -329,18 +452,23 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||
for (auto it = block.begin(); it != block.end();) {
|
||||
auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(&*it);
|
||||
if (receiveOp) {
|
||||
SmallVector<spatial::SpatChannelReceiveOp> run;
|
||||
Type outputType = receiveOp.getOutput().getType();
|
||||
auto runIt = it;
|
||||
while (runIt != block.end()) {
|
||||
auto current = dyn_cast<spatial::SpatChannelReceiveOp>(&*runIt);
|
||||
if (!current || current.getOutput().getType() != outputType)
|
||||
auto run = collectConsecutiveRun<spatial::SpatChannelReceiveOp>(
|
||||
it, block.end(), [&](spatial::SpatChannelReceiveOp current) {
|
||||
return current.getOutput().getType() == outputType;
|
||||
});
|
||||
|
||||
bool hasRepeatedEndpoint = false;
|
||||
DenseSet<uint64_t> seenEndpoints;
|
||||
for (auto op : run.ops) {
|
||||
uint64_t endpointKey = getEndpointKey(op.getSourceCoreId(), op.getTargetCoreId());
|
||||
if (!seenEndpoints.insert(endpointKey).second) {
|
||||
hasRepeatedEndpoint = true;
|
||||
break;
|
||||
run.push_back(current);
|
||||
++runIt;
|
||||
}
|
||||
}
|
||||
|
||||
if (run.size() > 1) {
|
||||
if (run.ops.size() > 1 && !hasRepeatedEndpoint) {
|
||||
struct ReceiveEntry {
|
||||
spatial::SpatChannelReceiveOp op;
|
||||
size_t originalIndex = 0;
|
||||
@@ -349,13 +477,9 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||
uint64_t channelId = 0;
|
||||
};
|
||||
SmallVector<ReceiveEntry> sortedEntries;
|
||||
sortedEntries.reserve(run.size());
|
||||
for (auto [originalIndex, op] : llvm::enumerate(run))
|
||||
sortedEntries.reserve(run.ops.size());
|
||||
for (auto [originalIndex, op] : llvm::enumerate(run.ops))
|
||||
sortedEntries.push_back({op, originalIndex, op.getSourceCoreId(), op.getTargetCoreId(), op.getChannelId()});
|
||||
llvm::stable_sort(sortedEntries, [](const ReceiveEntry& lhs, const ReceiveEntry& rhs) {
|
||||
return std::tuple(lhs.sourceCoreId, lhs.targetCoreId, lhs.channelId)
|
||||
< std::tuple(rhs.sourceCoreId, rhs.targetCoreId, rhs.channelId);
|
||||
});
|
||||
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
@@ -364,13 +488,11 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||
sourceCoreIds.reserve(sortedEntries.size());
|
||||
targetCoreIds.reserve(sortedEntries.size());
|
||||
for (ReceiveEntry& entry : sortedEntries) {
|
||||
(void) entry;
|
||||
channelIds.push_back(nextChannelId++);
|
||||
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
|
||||
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
|
||||
appendChannelAttrs(
|
||||
channelIds, sourceCoreIds, targetCoreIds, entry.channelId, entry.sourceCoreId, entry.targetCoreId);
|
||||
}
|
||||
|
||||
auto rowType = cast<RankedTensorType>(run.front().getOutput().getType());
|
||||
auto rowType = cast<RankedTensorType>(run.ops.front().getOutput().getType());
|
||||
auto fallbackPackedType = getPackedTensorType(rowType, static_cast<int64_t>(sortedEntries.size()));
|
||||
SmallVector<Value> sortedOutputs;
|
||||
sortedOutputs.reserve(sortedEntries.size());
|
||||
@@ -383,10 +505,10 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||
concatOp ? getPackedConcatSliceType(concatOp, concatStartIndex, static_cast<unsigned>(sortedOutputs.size()))
|
||||
: RankedTensorType {};
|
||||
auto packedType = concatPackedType ? concatPackedType : fallbackPackedType;
|
||||
rewriter.setInsertionPoint(run.front());
|
||||
rewriter.setInsertionPoint(run.ops.front());
|
||||
auto compactReceive =
|
||||
spatial::SpatChannelReceiveTensorOp::create(rewriter,
|
||||
run.front().getLoc(),
|
||||
run.ops.front().getLoc(),
|
||||
packedType,
|
||||
rewriter.getDenseI64ArrayAttr(channelIds),
|
||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
||||
@@ -403,7 +525,7 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||
entry.op.getOutput().replaceAllUsesWith(extractPackedChunk(
|
||||
compactReceive.getOutput(), rowType, static_cast<unsigned>(sortedIndex), rewriter, entry.op.getLoc()));
|
||||
}
|
||||
for (auto op : run)
|
||||
for (auto op : run.ops)
|
||||
rewriter.eraseOp(op);
|
||||
|
||||
it = compactReceive->getIterator();
|
||||
@@ -414,18 +536,13 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||
|
||||
auto sendOp = dyn_cast<spatial::SpatChannelSendOp>(&*it);
|
||||
if (sendOp) {
|
||||
SmallVector<spatial::SpatChannelSendOp> run;
|
||||
Type inputType = sendOp.getInput().getType();
|
||||
auto runIt = it;
|
||||
while (runIt != block.end()) {
|
||||
auto current = dyn_cast<spatial::SpatChannelSendOp>(&*runIt);
|
||||
if (!current || current.getInput().getType() != inputType)
|
||||
break;
|
||||
run.push_back(current);
|
||||
++runIt;
|
||||
}
|
||||
auto run =
|
||||
collectConsecutiveRun<spatial::SpatChannelSendOp>(it, block.end(), [&](spatial::SpatChannelSendOp current) {
|
||||
return current.getInput().getType() == inputType;
|
||||
});
|
||||
|
||||
if (run.size() > 1) {
|
||||
if (run.ops.size() > 1) {
|
||||
struct SendEntry {
|
||||
spatial::SpatChannelSendOp op;
|
||||
uint32_t sourceCoreId = 0;
|
||||
@@ -433,13 +550,9 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||
uint64_t channelId = 0;
|
||||
};
|
||||
SmallVector<SendEntry> sortedEntries;
|
||||
sortedEntries.reserve(run.size());
|
||||
for (auto op : run)
|
||||
sortedEntries.reserve(run.ops.size());
|
||||
for (auto op : run.ops)
|
||||
sortedEntries.push_back({op, op.getSourceCoreId(), op.getTargetCoreId(), op.getChannelId()});
|
||||
llvm::stable_sort(sortedEntries, [](const SendEntry& lhs, const SendEntry& rhs) {
|
||||
return std::tuple(lhs.sourceCoreId, lhs.targetCoreId, lhs.channelId)
|
||||
< std::tuple(rhs.sourceCoreId, rhs.targetCoreId, rhs.channelId);
|
||||
});
|
||||
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
@@ -450,26 +563,24 @@ void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) {
|
||||
targetCoreIds.reserve(sortedEntries.size());
|
||||
inputs.reserve(sortedEntries.size());
|
||||
for (SendEntry& entry : sortedEntries) {
|
||||
(void) entry;
|
||||
channelIds.push_back(nextChannelId++);
|
||||
sourceCoreIds.push_back(static_cast<int32_t>(entry.sourceCoreId));
|
||||
targetCoreIds.push_back(static_cast<int32_t>(entry.targetCoreId));
|
||||
appendChannelAttrs(
|
||||
channelIds, sourceCoreIds, targetCoreIds, entry.channelId, entry.sourceCoreId, entry.targetCoreId);
|
||||
inputs.push_back(entry.op.getInput());
|
||||
}
|
||||
|
||||
rewriter.setInsertionPoint(run.front());
|
||||
Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, run.front().getLoc());
|
||||
rewriter.setInsertionPoint(run.ops.front());
|
||||
Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, run.ops.front().getLoc());
|
||||
if (packedInput) {
|
||||
spatial::SpatChannelSendTensorOp::create(rewriter,
|
||||
run.front().getLoc(),
|
||||
run.ops.front().getLoc(),
|
||||
rewriter.getDenseI64ArrayAttr(channelIds),
|
||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
||||
rewriter.getDenseI32ArrayAttr(targetCoreIds),
|
||||
packedInput);
|
||||
for (auto op : run)
|
||||
for (auto op : run.ops)
|
||||
rewriter.eraseOp(op);
|
||||
|
||||
it = runIt;
|
||||
it = run.end;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
@@ -488,32 +599,27 @@ void compactBatchChannelRuns(func::FuncOp funcOp) {
|
||||
for (auto it = block.begin(); it != block.end();) {
|
||||
auto receiveOp = dyn_cast<spatial::SpatChannelReceiveBatchOp>(&*it);
|
||||
if (receiveOp) {
|
||||
SmallVector<spatial::SpatChannelReceiveBatchOp> run;
|
||||
Type outputType = receiveOp.getOutput().getType();
|
||||
auto runIt = it;
|
||||
while (runIt != block.end()) {
|
||||
auto current = dyn_cast<spatial::SpatChannelReceiveBatchOp>(&*runIt);
|
||||
if (!current || current.getOutput().getType() != outputType)
|
||||
break;
|
||||
run.push_back(current);
|
||||
++runIt;
|
||||
}
|
||||
auto run = collectConsecutiveRun<spatial::SpatChannelReceiveBatchOp>(
|
||||
it, block.end(), [&](spatial::SpatChannelReceiveBatchOp current) {
|
||||
return current.getOutput().getType() == outputType;
|
||||
});
|
||||
|
||||
if (run.size() > 1) {
|
||||
if (run.ops.size() > 1) {
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
for (auto op : run) {
|
||||
for (auto op : run.ops) {
|
||||
llvm::append_range(channelIds, op.getChannelIds());
|
||||
llvm::append_range(sourceCoreIds, op.getSourceCoreIds());
|
||||
llvm::append_range(targetCoreIds, op.getTargetCoreIds());
|
||||
}
|
||||
|
||||
auto rowType = cast<RankedTensorType>(run.front().getOutput().getType());
|
||||
auto fallbackPackedType = getPackedTensorType(rowType, static_cast<int64_t>(run.size()));
|
||||
auto rowType = cast<RankedTensorType>(run.ops.front().getOutput().getType());
|
||||
auto fallbackPackedType = getPackedTensorType(rowType, static_cast<int64_t>(run.ops.size()));
|
||||
SmallVector<Value> outputs;
|
||||
outputs.reserve(run.size());
|
||||
for (auto op : run)
|
||||
outputs.reserve(run.ops.size());
|
||||
for (auto op : run.ops)
|
||||
outputs.push_back(op.getOutput());
|
||||
|
||||
unsigned concatStartIndex = 0;
|
||||
@@ -522,10 +628,10 @@ void compactBatchChannelRuns(func::FuncOp funcOp) {
|
||||
concatOp ? getPackedConcatSliceType(concatOp, concatStartIndex, static_cast<unsigned>(outputs.size()))
|
||||
: RankedTensorType {};
|
||||
auto packedType = concatPackedType ? concatPackedType : fallbackPackedType;
|
||||
rewriter.setInsertionPoint(run.front());
|
||||
rewriter.setInsertionPoint(run.ops.front());
|
||||
auto compactReceive =
|
||||
spatial::SpatChannelReceiveTensorBatchOp::create(rewriter,
|
||||
run.front().getLoc(),
|
||||
run.ops.front().getLoc(),
|
||||
packedType,
|
||||
rewriter.getDenseI64ArrayAttr(channelIds),
|
||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
||||
@@ -535,11 +641,11 @@ void compactBatchChannelRuns(func::FuncOp funcOp) {
|
||||
concatOp, concatStartIndex, static_cast<unsigned>(outputs.size()), compactReceive.getOutput(), rewriter);
|
||||
}
|
||||
else {
|
||||
for (auto [index, op] : llvm::enumerate(run))
|
||||
for (auto [index, op] : llvm::enumerate(run.ops))
|
||||
op.getOutput().replaceAllUsesWith(extractPackedChunk(
|
||||
compactReceive.getOutput(), rowType, static_cast<unsigned>(index), rewriter, op.getLoc()));
|
||||
}
|
||||
for (auto op : run)
|
||||
for (auto op : run.ops)
|
||||
rewriter.eraseOp(op);
|
||||
|
||||
it = compactReceive->getIterator();
|
||||
@@ -550,43 +656,38 @@ void compactBatchChannelRuns(func::FuncOp funcOp) {
|
||||
|
||||
auto sendOp = dyn_cast<spatial::SpatChannelSendBatchOp>(&*it);
|
||||
if (sendOp) {
|
||||
SmallVector<spatial::SpatChannelSendBatchOp> run;
|
||||
Type inputType = sendOp.getInput().getType();
|
||||
auto runIt = it;
|
||||
while (runIt != block.end()) {
|
||||
auto current = dyn_cast<spatial::SpatChannelSendBatchOp>(&*runIt);
|
||||
if (!current || current.getInput().getType() != inputType)
|
||||
break;
|
||||
run.push_back(current);
|
||||
++runIt;
|
||||
}
|
||||
auto run = collectConsecutiveRun<spatial::SpatChannelSendBatchOp>(
|
||||
it, block.end(), [&](spatial::SpatChannelSendBatchOp current) {
|
||||
return current.getInput().getType() == inputType;
|
||||
});
|
||||
|
||||
if (run.size() > 1) {
|
||||
if (run.ops.size() > 1) {
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
SmallVector<Value> inputs;
|
||||
inputs.reserve(run.size());
|
||||
for (auto op : run) {
|
||||
inputs.reserve(run.ops.size());
|
||||
for (auto op : run.ops) {
|
||||
llvm::append_range(channelIds, op.getChannelIds());
|
||||
llvm::append_range(sourceCoreIds, op.getSourceCoreIds());
|
||||
llvm::append_range(targetCoreIds, op.getTargetCoreIds());
|
||||
inputs.push_back(op.getInput());
|
||||
}
|
||||
|
||||
rewriter.setInsertionPoint(run.front());
|
||||
Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, run.front().getLoc());
|
||||
rewriter.setInsertionPoint(run.ops.front());
|
||||
Value packedInput = createPackedTensorForValues(ValueRange(inputs), rewriter, run.ops.front().getLoc());
|
||||
if (packedInput) {
|
||||
spatial::SpatChannelSendTensorBatchOp::create(rewriter,
|
||||
run.front().getLoc(),
|
||||
run.ops.front().getLoc(),
|
||||
rewriter.getDenseI64ArrayAttr(channelIds),
|
||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds),
|
||||
rewriter.getDenseI32ArrayAttr(targetCoreIds),
|
||||
packedInput);
|
||||
for (auto op : run)
|
||||
for (auto op : run.ops)
|
||||
rewriter.eraseOp(op);
|
||||
|
||||
it = runIt;
|
||||
it = run.end;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
@@ -614,8 +715,9 @@ void compactRegularOpRuns(func::FuncOp funcOp) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto anchorEndIt = std::next(it, static_cast<std::ptrdiff_t>(anchorChunk->ops.size()));
|
||||
SmallVector<RegularChunk> run {*anchorChunk};
|
||||
auto runIt = std::next(it, static_cast<std::ptrdiff_t>(anchorChunk->ops.size()));
|
||||
auto runIt = anchorEndIt;
|
||||
while (runIt != block.end()) {
|
||||
auto candidateStart = dyn_cast<spatial::SpatVMMOp>(&*runIt);
|
||||
if (!candidateStart)
|
||||
@@ -630,12 +732,26 @@ void compactRegularOpRuns(func::FuncOp funcOp) {
|
||||
}
|
||||
|
||||
if (run.size() <= 1) {
|
||||
++it;
|
||||
it = anchorEndIt;
|
||||
continue;
|
||||
}
|
||||
|
||||
compactRegularChunkRun(rewriter, run);
|
||||
it = runIt;
|
||||
size_t originalOpCount = 0;
|
||||
for (const RegularChunk& chunk : run)
|
||||
originalOpCount += chunk.ops.size();
|
||||
|
||||
RegularCompactionResult result = compactRegularChunkRun(rewriter, run);
|
||||
if (result.changed) {
|
||||
assert(originalOpCount > anchorChunk->ops.size() && "successful regular compaction must consume the run");
|
||||
if (!result.resumeAfter) {
|
||||
it = block.end();
|
||||
continue;
|
||||
}
|
||||
it = result.resumeAfter->getIterator();
|
||||
continue;
|
||||
}
|
||||
|
||||
it = anchorEndIt;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -666,37 +782,32 @@ void compactRowWiseWvmmRuns(func::FuncOp funcOp) {
|
||||
continue;
|
||||
}
|
||||
|
||||
SmallVector<spatial::SpatVMMOp> run;
|
||||
auto runIt = it;
|
||||
int64_t expectedRow = static_cast<int64_t>(rowResult.getResultNumber());
|
||||
while (runIt != block.end()) {
|
||||
auto current = dyn_cast<spatial::SpatVMMOp>(&*runIt);
|
||||
if (!current || current.getWeightIndex() != wvmmOp.getWeightIndex()
|
||||
auto run = collectConsecutiveRun<spatial::SpatVMMOp>(it, block.end(), [&](spatial::SpatVMMOp current) {
|
||||
if (current.getWeightIndex() != wvmmOp.getWeightIndex()
|
||||
|| current.getInput().getDefiningOp<spatial::SpatExtractRowsOp>() != extractRowsOp
|
||||
|| current.getInput().getType() != wvmmOp.getInput().getType()
|
||||
|| current.getOutput().getType() != wvmmOp.getOutput().getType()) {
|
||||
break;
|
||||
}
|
||||
|| current.getOutput().getType() != wvmmOp.getOutput().getType())
|
||||
return false;
|
||||
|
||||
auto currentRow = dyn_cast<OpResult>(current.getInput());
|
||||
if (!currentRow || currentRow.getResultNumber() != static_cast<unsigned>(expectedRow))
|
||||
break;
|
||||
return false;
|
||||
|
||||
run.push_back(current);
|
||||
++expectedRow;
|
||||
++runIt;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
|
||||
if (run.size() <= 1) {
|
||||
if (run.ops.size() <= 1) {
|
||||
++it;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!run.front().getOutput().hasOneUse()) {
|
||||
if (!run.ops.front().getOutput().hasOneUse()) {
|
||||
++it;
|
||||
continue;
|
||||
}
|
||||
auto concatUse = run.front().getOutput().getUses().begin();
|
||||
auto concatUse = run.ops.front().getOutput().getUses().begin();
|
||||
auto concatOp = dyn_cast<spatial::SpatConcatOp>(concatUse->getOwner());
|
||||
if (!concatOp) {
|
||||
++it;
|
||||
@@ -705,7 +816,7 @@ void compactRowWiseWvmmRuns(func::FuncOp funcOp) {
|
||||
|
||||
unsigned concatStartIndex = concatUse->getOperandNumber();
|
||||
bool validConcatRun = true;
|
||||
for (auto [index, op] : llvm::enumerate(run)) {
|
||||
for (auto [index, op] : llvm::enumerate(run.ops)) {
|
||||
if (!op.getOutput().hasOneUse()) {
|
||||
validConcatRun = false;
|
||||
break;
|
||||
@@ -736,17 +847,17 @@ void compactRowWiseWvmmRuns(func::FuncOp funcOp) {
|
||||
}
|
||||
|
||||
int64_t firstRow = static_cast<int64_t>(rowResult.getResultNumber());
|
||||
int64_t runLength = static_cast<int64_t>(run.size());
|
||||
int64_t runLength = static_cast<int64_t>(run.ops.size());
|
||||
auto packedType = RankedTensorType::get({runLength, outputCols}, outputType.getElementType());
|
||||
|
||||
rewriter.setInsertionPoint(run.front());
|
||||
auto zero = arith::ConstantIndexOp::create(rewriter, run.front().getLoc(), 0);
|
||||
auto upper = arith::ConstantIndexOp::create(rewriter, run.front().getLoc(), runLength);
|
||||
auto step = arith::ConstantIndexOp::create(rewriter, run.front().getLoc(), 1);
|
||||
rewriter.setInsertionPoint(run.ops.front());
|
||||
auto zero = arith::ConstantIndexOp::create(rewriter, run.ops.front().getLoc(), 0);
|
||||
auto upper = arith::ConstantIndexOp::create(rewriter, run.ops.front().getLoc(), runLength);
|
||||
auto step = arith::ConstantIndexOp::create(rewriter, run.ops.front().getLoc(), 1);
|
||||
auto packedInit =
|
||||
tensor::EmptyOp::create(rewriter, run.front().getLoc(), packedType.getShape(), packedType.getElementType());
|
||||
tensor::EmptyOp::create(rewriter, run.ops.front().getLoc(), packedType.getShape(), packedType.getElementType());
|
||||
auto loop =
|
||||
scf::ForOp::create(rewriter, run.front().getLoc(), zero, upper, step, ValueRange {packedInit.getResult()});
|
||||
scf::ForOp::create(rewriter, run.ops.front().getLoc(), zero, upper, step, ValueRange {packedInit.getResult()});
|
||||
|
||||
{
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
@@ -757,41 +868,41 @@ void compactRowWiseWvmmRuns(func::FuncOp funcOp) {
|
||||
|
||||
Value sourceRow = iv;
|
||||
if (firstRow != 0) {
|
||||
auto firstRowValue = arith::ConstantIndexOp::create(rewriter, run.front().getLoc(), firstRow);
|
||||
sourceRow = arith::AddIOp::create(rewriter, run.front().getLoc(), iv, firstRowValue);
|
||||
auto firstRowValue = arith::ConstantIndexOp::create(rewriter, run.ops.front().getLoc(), firstRow);
|
||||
sourceRow = arith::AddIOp::create(rewriter, run.ops.front().getLoc(), iv, firstRowValue);
|
||||
}
|
||||
|
||||
SmallVector<OpFoldResult> extractOffsets = {sourceRow, rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> extractSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(inputCols)};
|
||||
SmallVector<OpFoldResult> extractStrides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
auto extractedRow = tensor::ExtractSliceOp::create(rewriter,
|
||||
run.front().getLoc(),
|
||||
run.ops.front().getLoc(),
|
||||
inputType,
|
||||
extractRowsOp.getInput(),
|
||||
extractOffsets,
|
||||
extractSizes,
|
||||
extractStrides);
|
||||
auto loopWvmm = spatial::SpatVMMOp::create(
|
||||
rewriter, run.front().getLoc(), outputType, wvmmOp.getWeightIndex(), extractedRow.getResult());
|
||||
rewriter, run.ops.front().getLoc(), outputType, wvmmOp.getWeightIndex(), extractedRow.getResult());
|
||||
|
||||
SmallVector<OpFoldResult> insertOffsets = {iv, rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> insertSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(outputCols)};
|
||||
SmallVector<OpFoldResult> insertStrides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
auto inserted = tensor::InsertSliceOp::create(
|
||||
rewriter, run.front().getLoc(), loopWvmm.getResult(), acc, insertOffsets, insertSizes, insertStrides);
|
||||
scf::YieldOp::create(rewriter, run.front().getLoc(), inserted.getResult());
|
||||
rewriter, run.ops.front().getLoc(), loopWvmm.getResult(), acc, insertOffsets, insertSizes, insertStrides);
|
||||
scf::YieldOp::create(rewriter, run.ops.front().getLoc(), inserted.getResult());
|
||||
}
|
||||
|
||||
SmallVector<Value> newConcatInputs;
|
||||
newConcatInputs.reserve(concatOp.getInputs().size() - run.size() + 1);
|
||||
newConcatInputs.reserve(concatOp.getInputs().size() - run.ops.size() + 1);
|
||||
for (auto [operandIndex, operand] : llvm::enumerate(concatOp.getInputs())) {
|
||||
if (operandIndex == concatStartIndex)
|
||||
newConcatInputs.push_back(loop.getResult(0));
|
||||
if (operandIndex < concatStartIndex || operandIndex >= concatStartIndex + run.size())
|
||||
if (operandIndex < concatStartIndex || operandIndex >= concatStartIndex + run.ops.size())
|
||||
newConcatInputs.push_back(operand);
|
||||
}
|
||||
rewriter.modifyOpInPlace(concatOp, [&] { concatOp->setOperands(newConcatInputs); });
|
||||
for (auto op : run)
|
||||
for (auto op : run.ops)
|
||||
rewriter.eraseOp(op);
|
||||
|
||||
it = loop->getIterator();
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
void orderBilateralChannelOps(mlir::func::FuncOp funcOp);
|
||||
void compactScalarChannelRuns(mlir::func::FuncOp funcOp, int64_t& nextChannelId);
|
||||
void compactBatchChannelRuns(mlir::func::FuncOp funcOp);
|
||||
void compactRegularOpRuns(mlir::func::FuncOp funcOp);
|
||||
|
||||
@@ -0,0 +1,189 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
#include <optional>
|
||||
#include <queue>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "ComputeGraph.hpp"
|
||||
#include "src/Support/TypeUtilities.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
|
||||
Weight getComputeBodyWeight(Region &body) {
|
||||
constexpr Weight kOperationWeight = 100;
|
||||
Weight numOperations = 0;
|
||||
for (auto &block : body)
|
||||
for ([[maybe_unused]] auto &op : block)
|
||||
numOperations = checkedAdd(numOperations, static_cast<Weight>(1));
|
||||
return checkedMultiply(numOperations, kOperationWeight);
|
||||
}
|
||||
|
||||
CrossbarUsage getComputeBodyCrossbarUsage(Region &body) {
|
||||
CrossbarUsage crossbarUsage = 0;
|
||||
for (auto &block : body)
|
||||
for (auto &op : block)
|
||||
if (isa<SpatVMMOp>(op))
|
||||
crossbarUsage = checkedAdd(crossbarUsage, static_cast<CrossbarUsage>(1));
|
||||
return crossbarUsage;
|
||||
}
|
||||
|
||||
bool isUsedAsWeightOnly(Operation *producerOp) {
|
||||
if (producerOp->getNumResults() == 0)
|
||||
return false;
|
||||
for (Value result : producerOp->getResults()) {
|
||||
if (result.use_empty())
|
||||
return false;
|
||||
for (Operation *user : result.getUsers()) {
|
||||
if (auto compute = dyn_cast<SpatCompute>(user)) {
|
||||
if (!llvm::is_contained(compute.getWeights(), result))
|
||||
return false;
|
||||
continue;
|
||||
}
|
||||
if (auto batch = dyn_cast<SpatComputeBatch>(user)) {
|
||||
if (!llvm::is_contained(batch.getWeights(), result))
|
||||
return false;
|
||||
continue;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<ComputeGraphEdge> aggregateEdges(llvm::ArrayRef<ComputeGraphEdge> edges) {
|
||||
llvm::DenseMap<std::pair<size_t, size_t>, Weight> edgeWeights;
|
||||
for (const ComputeGraphEdge &edge : edges) {
|
||||
if (edge.source == edge.target)
|
||||
continue;
|
||||
auto inserted = edgeWeights.try_emplace({edge.source, edge.target}, edge.transferCost);
|
||||
if (!inserted.second)
|
||||
inserted.first->second = std::max(inserted.first->second, edge.transferCost);
|
||||
}
|
||||
|
||||
std::vector<ComputeGraphEdge> aggregatedEdges;
|
||||
aggregatedEdges.reserve(edgeWeights.size());
|
||||
for (const auto &[key, weight] : edgeWeights)
|
||||
aggregatedEdges.push_back({key.first, key.second, weight});
|
||||
llvm::sort(aggregatedEdges, [](const ComputeGraphEdge &lhs, const ComputeGraphEdge &rhs) {
|
||||
if (lhs.source != rhs.source)
|
||||
return lhs.source < rhs.source;
|
||||
return lhs.target < rhs.target;
|
||||
});
|
||||
return aggregatedEdges;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Weight getComputeInstanceWeight(const ComputeInstance &instance) {
|
||||
if (auto spatCompute = dyn_cast<SpatCompute>(instance.op))
|
||||
return getSpatComputeWeight(spatCompute);
|
||||
auto batch = cast<SpatComputeBatch>(instance.op);
|
||||
return checkedMultiply(getComputeBodyWeight(batch.getBody()), static_cast<Weight>(instance.laneCount));
|
||||
}
|
||||
|
||||
CrossbarUsage getComputeInstanceCrossbarUsage(const ComputeInstance &instance) {
|
||||
if (auto spatCompute = dyn_cast<SpatCompute>(instance.op))
|
||||
return getSpatComputeCrossbarUsage(spatCompute);
|
||||
auto batch = cast<SpatComputeBatch>(instance.op);
|
||||
return checkedMultiply(getComputeBodyCrossbarUsage(batch.getBody()),
|
||||
static_cast<CrossbarUsage>(instance.laneCount));
|
||||
}
|
||||
|
||||
ComputeGraph buildComputeGraph(Operation *entryOp) {
|
||||
ComputeGraph graph;
|
||||
|
||||
for (Region ®ion : entryOp->getRegions()) {
|
||||
for (Block &block : region) {
|
||||
for (Operation &op : block) {
|
||||
if (auto spatCompute = dyn_cast<SpatCompute>(&op)) {
|
||||
if (isUsedAsWeightOnly(spatCompute.getOperation()))
|
||||
continue;
|
||||
ComputeInstance instance {spatCompute.getOperation(), 0, 1};
|
||||
size_t index = graph.nodes.size();
|
||||
graph.nodes.push_back({instance, getComputeInstanceWeight(instance), getComputeInstanceCrossbarUsage(instance), index});
|
||||
graph.instanceToIndex[instance] = index;
|
||||
continue;
|
||||
}
|
||||
if (auto batch = dyn_cast<SpatComputeBatch>(&op)) {
|
||||
if (isUsedAsWeightOnly(batch.getOperation()))
|
||||
continue;
|
||||
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount());
|
||||
for (size_t chunkIndex = 0; chunkIndex < chunkCount; ++chunkIndex) {
|
||||
ComputeInstance instance = getBatchChunkForIndex(batch, chunkIndex);
|
||||
size_t index = graph.nodes.size();
|
||||
graph.nodes.push_back(
|
||||
{instance, getComputeInstanceWeight(instance), getComputeInstanceCrossbarUsage(instance), index});
|
||||
graph.instanceToIndex[instance] = index;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
llvm::SmallVector<ComputeGraphEdge, 16> rawEdges;
|
||||
for (const auto &[targetIndex, node] : llvm::enumerate(graph.nodes)) {
|
||||
for (Value input : getComputeInstanceInputs(node.instance)) {
|
||||
auto producerInstance = getComputeProducerInstance(input);
|
||||
if (!producerInstance)
|
||||
continue;
|
||||
auto producerIt = graph.instanceToIndex.find(*producerInstance);
|
||||
if (producerIt == graph.instanceToIndex.end())
|
||||
continue;
|
||||
rawEdges.push_back(
|
||||
{producerIt->second, targetIndex, static_cast<Weight>(getSizeInBytes(cast<ShapedType>(input.getType())))});
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<ComputeGraphEdge> aggregatedEdges = aggregateEdges(rawEdges);
|
||||
graph.edges.append(aggregatedEdges.begin(), aggregatedEdges.end());
|
||||
graph.successors.assign(graph.nodes.size(), {});
|
||||
graph.predecessors.assign(graph.nodes.size(), {});
|
||||
for (const ComputeGraphEdge &edge : graph.edges) {
|
||||
graph.successors[edge.source].push_back({edge.target, edge.transferCost});
|
||||
graph.predecessors[edge.target].push_back({edge.source, edge.transferCost});
|
||||
}
|
||||
|
||||
return graph;
|
||||
}
|
||||
|
||||
bool verifyAcyclic(const ComputeGraph &graph) {
|
||||
std::vector<size_t> remainingParents(graph.nodes.size(), 0);
|
||||
std::queue<size_t> readyNodes;
|
||||
for (size_t node = 0; node < graph.nodes.size(); ++node) {
|
||||
remainingParents[node] = graph.predecessors[node].size();
|
||||
if (remainingParents[node] == 0)
|
||||
readyNodes.push(node);
|
||||
}
|
||||
|
||||
size_t visited = 0;
|
||||
while (!readyNodes.empty()) {
|
||||
size_t node = readyNodes.front();
|
||||
readyNodes.pop();
|
||||
++visited;
|
||||
for (const auto &[child, weight] : graph.successors[node]) {
|
||||
(void) weight;
|
||||
assert(remainingParents[child] > 0 && "remaining parent count underflow");
|
||||
if (--remainingParents[child] == 0)
|
||||
readyNodes.push(child);
|
||||
}
|
||||
}
|
||||
|
||||
return visited == graph.nodes.size();
|
||||
}
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,49 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <cstddef>
|
||||
#include <optional>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "../DCPGraph/Utils.hpp"
|
||||
#include "ComputeInstance.hpp"
|
||||
#include "ComputeInstanceUtils.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
struct ComputeGraphNode {
|
||||
ComputeInstance instance;
|
||||
Weight weight = 0;
|
||||
CrossbarUsage crossbarUsage = 0;
|
||||
size_t originalOrder = 0;
|
||||
};
|
||||
|
||||
struct ComputeGraphEdge {
|
||||
size_t source = 0;
|
||||
size_t target = 0;
|
||||
Weight transferCost = 0;
|
||||
};
|
||||
|
||||
struct ComputeGraph {
|
||||
llvm::SmallVector<ComputeGraphNode> nodes;
|
||||
llvm::SmallVector<ComputeGraphEdge> edges;
|
||||
std::vector<std::vector<std::pair<size_t, Weight>>> successors;
|
||||
std::vector<std::vector<std::pair<size_t, Weight>>> predecessors;
|
||||
llvm::DenseMap<ComputeInstance, size_t> instanceToIndex;
|
||||
};
|
||||
|
||||
ComputeGraph buildComputeGraph(mlir::Operation *entryOp);
|
||||
bool verifyAcyclic(const ComputeGraph &graph);
|
||||
|
||||
Weight getComputeInstanceWeight(const ComputeInstance &instance);
|
||||
CrossbarUsage getComputeInstanceCrossbarUsage(const ComputeInstance &instance);
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,45 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/Operation.h"
|
||||
|
||||
#include "llvm/ADT/DenseMapInfo.h"
|
||||
#include "llvm/ADT/Hashing.h"
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
struct ComputeInstance {
|
||||
mlir::Operation *op = nullptr;
|
||||
uint32_t laneStart = 0;
|
||||
uint32_t laneCount = 1;
|
||||
|
||||
bool operator==(const ComputeInstance &other) const {
|
||||
return op == other.op && laneStart == other.laneStart && laneCount == other.laneCount;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
|
||||
using ComputeInstance = onnx_mlir::spatial::ComputeInstance;
|
||||
|
||||
namespace llvm {
|
||||
template <>
|
||||
struct DenseMapInfo<onnx_mlir::spatial::ComputeInstance> {
|
||||
static onnx_mlir::spatial::ComputeInstance getEmptyKey() {
|
||||
return {DenseMapInfo<mlir::Operation *>::getEmptyKey(), UINT32_MAX, UINT32_MAX};
|
||||
}
|
||||
static onnx_mlir::spatial::ComputeInstance getTombstoneKey() {
|
||||
return {DenseMapInfo<mlir::Operation *>::getTombstoneKey(), UINT32_MAX, UINT32_MAX};
|
||||
}
|
||||
static unsigned getHashValue(const onnx_mlir::spatial::ComputeInstance &value) {
|
||||
return llvm::hash_combine(value.op, value.laneStart, value.laneCount);
|
||||
}
|
||||
static bool isEqual(const onnx_mlir::spatial::ComputeInstance &lhs,
|
||||
const onnx_mlir::spatial::ComputeInstance &rhs) {
|
||||
return lhs == rhs;
|
||||
}
|
||||
};
|
||||
} // namespace llvm
|
||||
+151
@@ -0,0 +1,151 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
||||
#include <limits>
|
||||
|
||||
#include "ComputeInstanceUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
size_t getSchedulingCpuBudget() {
|
||||
if (coresCount.getValue() > 0)
|
||||
return static_cast<size_t>(coresCount.getValue());
|
||||
return std::numeric_limits<size_t>::max();
|
||||
}
|
||||
|
||||
size_t getBatchChunkTargetCount(int32_t laneCount) {
|
||||
assert(laneCount > 0 && "laneCount must be positive");
|
||||
return std::min(static_cast<size_t>(laneCount), std::max<size_t>(1, getSchedulingCpuBudget()));
|
||||
}
|
||||
|
||||
ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex) {
|
||||
size_t totalLanes = static_cast<size_t>(batch.getLaneCount());
|
||||
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount());
|
||||
size_t baseChunkSize = totalLanes / chunkCount;
|
||||
size_t largeChunkCount = totalLanes % chunkCount;
|
||||
|
||||
size_t laneStart = chunkIndex * baseChunkSize + std::min(chunkIndex, largeChunkCount);
|
||||
size_t laneCount = baseChunkSize + (chunkIndex < largeChunkCount ? 1 : 0);
|
||||
return {batch.getOperation(), static_cast<uint32_t>(laneStart), static_cast<uint32_t>(laneCount)};
|
||||
}
|
||||
|
||||
ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane) {
|
||||
size_t totalLanes = static_cast<size_t>(batch.getLaneCount());
|
||||
size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount());
|
||||
size_t baseChunkSize = totalLanes / chunkCount;
|
||||
size_t largeChunkCount = totalLanes % chunkCount;
|
||||
size_t largeChunkSpan = largeChunkCount * (baseChunkSize + 1);
|
||||
|
||||
size_t chunkIndex = 0;
|
||||
if (static_cast<size_t>(lane) < largeChunkSpan)
|
||||
chunkIndex = static_cast<size_t>(lane) / (baseChunkSize + 1);
|
||||
else
|
||||
chunkIndex = largeChunkCount + (static_cast<size_t>(lane) - largeChunkSpan) / baseChunkSize;
|
||||
return getBatchChunkForIndex(batch, chunkIndex);
|
||||
}
|
||||
|
||||
SpatCompute getOriginalSpatCompute(Operation *op) {
|
||||
if (!op)
|
||||
return {};
|
||||
|
||||
while (auto extract = dyn_cast<tensor::ExtractSliceOp>(op)) {
|
||||
op = extract.getSource().getDefiningOp();
|
||||
if (!op)
|
||||
return {};
|
||||
}
|
||||
|
||||
return dyn_cast<SpatCompute>(op);
|
||||
}
|
||||
|
||||
std::optional<ProducerValueRef> getProducerValueRef(Value value) {
|
||||
Operation *op = value.getDefiningOp();
|
||||
if (!op)
|
||||
return std::nullopt;
|
||||
|
||||
//TODO Extract Slice is not the only global non compute operation. There are other legal op
|
||||
while (auto extract = dyn_cast<tensor::ExtractSliceOp>(op)) {
|
||||
value = extract.getSource();
|
||||
op = value.getDefiningOp();
|
||||
if (!op)
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
if (auto compute = dyn_cast<SpatCompute>(op)) {
|
||||
return ProducerValueRef {
|
||||
ComputeInstance {compute.getOperation(), 0, 1},
|
||||
static_cast<size_t>(cast<OpResult>(value).getResultNumber())
|
||||
};
|
||||
}
|
||||
|
||||
if (auto batch = dyn_cast<SpatComputeBatch>(op)) {
|
||||
uint32_t lane = static_cast<uint32_t>(cast<OpResult>(value).getResultNumber());
|
||||
ComputeInstance instance = getBatchChunkForLane(batch, lane);
|
||||
size_t resultIndex = static_cast<size_t>(lane - instance.laneStart);
|
||||
return ProducerValueRef {instance, resultIndex};
|
||||
}
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
std::optional<ComputeInstance> getComputeProducerInstance(Value value) {
|
||||
if (std::optional<ProducerValueRef> producer = getProducerValueRef(value))
|
||||
return producer->instance;
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
llvm::SmallVector<Value, 4> getComputeInstanceInputs(const ComputeInstance &instance) {
|
||||
if (auto compute = dyn_cast<SpatCompute>(instance.op))
|
||||
return llvm::SmallVector<Value, 4>(compute.getInputs().begin(), compute.getInputs().end());
|
||||
|
||||
auto batch = cast<SpatComputeBatch>(instance.op);
|
||||
llvm::SmallVector<Value, 4> inputs;
|
||||
inputs.reserve(instance.laneCount);
|
||||
for (uint32_t lane = instance.laneStart; lane < instance.laneStart + instance.laneCount; ++lane)
|
||||
if (!batch.getInputs().empty())
|
||||
inputs.push_back(batch.getInputs()[lane]);
|
||||
return inputs;
|
||||
}
|
||||
|
||||
llvm::SmallVector<Value, 4> getComputeInstanceWeights(const ComputeInstance &instance) {
|
||||
if (auto compute = dyn_cast<SpatCompute>(instance.op))
|
||||
return llvm::SmallVector<Value, 4>(compute.getWeights().begin(), compute.getWeights().end());
|
||||
|
||||
auto batch = cast<SpatComputeBatch>(instance.op);
|
||||
llvm::SmallVector<Value, 4> weights;
|
||||
weights.reserve(instance.laneCount);
|
||||
for (uint32_t lane = instance.laneStart; lane < instance.laneStart + instance.laneCount; ++lane)
|
||||
weights.push_back(batch.getWeights()[lane]);
|
||||
return weights;
|
||||
}
|
||||
|
||||
llvm::SmallVector<Value, 4> getComputeInstanceOutputValues(const ComputeInstance &instance) {
|
||||
if (auto compute = dyn_cast<SpatCompute>(instance.op))
|
||||
return llvm::SmallVector<Value, 4>(compute.getResults().begin(), compute.getResults().end());
|
||||
|
||||
auto batch = cast<SpatComputeBatch>(instance.op);
|
||||
llvm::SmallVector<Value, 4> outputs;
|
||||
outputs.reserve(instance.laneCount);
|
||||
for (uint32_t lane = instance.laneStart; lane < instance.laneStart + instance.laneCount; ++lane)
|
||||
if (!batch.getOutputs().empty())
|
||||
outputs.push_back(batch.getOutputs()[lane]);
|
||||
return outputs;
|
||||
}
|
||||
|
||||
llvm::SmallVector<Type, 4> getComputeInstanceOutputTypes(const ComputeInstance &instance) {
|
||||
llvm::SmallVector<Type, 4> outputTypes;
|
||||
for (Value output : getComputeInstanceOutputValues(instance))
|
||||
outputTypes.push_back(output.getType());
|
||||
return outputTypes;
|
||||
}
|
||||
|
||||
Block &getComputeInstanceTemplateBlock(const ComputeInstance &instance) {
|
||||
if (auto compute = dyn_cast<SpatCompute>(instance.op))
|
||||
return compute.getBody().front();
|
||||
return cast<SpatComputeBatch>(instance.op).getBody().front();
|
||||
}
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
+40
@@ -0,0 +1,40 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/Block.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <cstddef>
|
||||
#include <optional>
|
||||
|
||||
#include "ComputeInstance.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
struct ProducerValueRef {
|
||||
ComputeInstance instance;
|
||||
size_t resultIndex = 0;
|
||||
};
|
||||
|
||||
size_t getSchedulingCpuBudget();
|
||||
size_t getBatchChunkTargetCount(int32_t laneCount);
|
||||
ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex);
|
||||
ComputeInstance getBatchChunkForLane(SpatComputeBatch batch, uint32_t lane);
|
||||
|
||||
SpatCompute getOriginalSpatCompute(mlir::Operation *op);
|
||||
std::optional<ProducerValueRef> getProducerValueRef(mlir::Value value);
|
||||
std::optional<ComputeInstance> getComputeProducerInstance(mlir::Value value);
|
||||
|
||||
llvm::SmallVector<mlir::Value, 4> getComputeInstanceInputs(const ComputeInstance &instance);
|
||||
llvm::SmallVector<mlir::Value, 4> getComputeInstanceWeights(const ComputeInstance &instance);
|
||||
llvm::SmallVector<mlir::Value, 4> getComputeInstanceOutputValues(const ComputeInstance &instance);
|
||||
llvm::SmallVector<mlir::Type, 4> getComputeInstanceOutputTypes(const ComputeInstance &instance);
|
||||
mlir::Block &getComputeInstanceTemplateBlock(const ComputeInstance &instance);
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,720 @@
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdlib>
|
||||
#include <limits>
|
||||
#include <numeric>
|
||||
#include <optional>
|
||||
#include <queue>
|
||||
#include <vector>
|
||||
|
||||
#include "DcpScheduler.hpp"
|
||||
#include "../DCPGraph/Graph.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
|
||||
bool isDcpCoarsenDebugEnabled() { return std::getenv("DCP_COARSEN_DEBUG") != nullptr; }
|
||||
|
||||
struct VirtualNode {
|
||||
llvm::SmallVector<size_t, 4> originalNodeIndices;
|
||||
Weight weight = 0;
|
||||
CrossbarUsage crossbarUsage = 0;
|
||||
};
|
||||
|
||||
struct VirtualGraph {
|
||||
std::vector<VirtualNode> nodes;
|
||||
std::vector<IndexedEdge> edges;
|
||||
};
|
||||
|
||||
struct TimingInfo {
|
||||
std::vector<Time> aest;
|
||||
std::vector<Time> alst;
|
||||
std::vector<size_t> topologicalOrder;
|
||||
bool valid = false;
|
||||
};
|
||||
|
||||
struct WindowScheduleResult {
|
||||
std::vector<std::vector<size_t>> mergeGroups;
|
||||
CPU cpuCount = 0;
|
||||
size_t mergedNodeCount = 0;
|
||||
size_t maxMergeGroupSize = 0;
|
||||
};
|
||||
|
||||
size_t getSchedulingCpuBudget(const DcpScheduleOptions &options) {
|
||||
if (options.processorCount > 0)
|
||||
return options.processorCount;
|
||||
return std::numeric_limits<size_t>::max();
|
||||
}
|
||||
|
||||
std::vector<IndexedEdge> aggregateEdges(llvm::ArrayRef<IndexedEdge> edges) {
|
||||
llvm::DenseMap<std::pair<size_t, size_t>, Weight> edgeWeights;
|
||||
for (auto [start, end, weight] : edges) {
|
||||
size_t startIndex = static_cast<size_t>(start);
|
||||
size_t endIndex = static_cast<size_t>(end);
|
||||
if (startIndex == endIndex)
|
||||
continue;
|
||||
auto key = std::make_pair(startIndex, endIndex);
|
||||
Weight edgeWeight = static_cast<Weight>(weight);
|
||||
auto inserted = edgeWeights.try_emplace(key, edgeWeight);
|
||||
if (!inserted.second)
|
||||
inserted.first->second = std::max(inserted.first->second, edgeWeight);
|
||||
}
|
||||
|
||||
std::vector<IndexedEdge> aggregatedEdges;
|
||||
aggregatedEdges.reserve(edgeWeights.size());
|
||||
for (auto [key, weight] : edgeWeights)
|
||||
aggregatedEdges.push_back(
|
||||
{static_cast<int64_t>(key.first), static_cast<int64_t>(key.second), static_cast<int64_t>(weight)});
|
||||
llvm::sort(aggregatedEdges, [](const IndexedEdge &lhs, const IndexedEdge &rhs) {
|
||||
if (std::get<0>(lhs) != std::get<0>(rhs))
|
||||
return std::get<0>(lhs) < std::get<0>(rhs);
|
||||
return std::get<1>(lhs) < std::get<1>(rhs);
|
||||
});
|
||||
return aggregatedEdges;
|
||||
}
|
||||
|
||||
VirtualGraph buildInitialVirtualGraph(const ComputeGraph &graph) {
|
||||
VirtualGraph virtualGraph;
|
||||
virtualGraph.nodes.reserve(graph.nodes.size());
|
||||
for (auto [index, node] : llvm::enumerate(graph.nodes)) {
|
||||
VirtualNode virtualNode;
|
||||
virtualNode.originalNodeIndices.push_back(index);
|
||||
virtualNode.weight = node.weight;
|
||||
virtualNode.crossbarUsage = node.crossbarUsage;
|
||||
virtualGraph.nodes.push_back(std::move(virtualNode));
|
||||
}
|
||||
|
||||
std::vector<IndexedEdge> edges;
|
||||
edges.reserve(graph.edges.size());
|
||||
for (const ComputeGraphEdge &edge : graph.edges)
|
||||
edges.push_back(
|
||||
{static_cast<int64_t>(edge.source), static_cast<int64_t>(edge.target), static_cast<int64_t>(edge.transferCost)});
|
||||
virtualGraph.edges = aggregateEdges(edges);
|
||||
return virtualGraph;
|
||||
}
|
||||
|
||||
TimingInfo computeTiming(const VirtualGraph &graph) {
|
||||
TimingInfo timing;
|
||||
size_t nodeCount = graph.nodes.size();
|
||||
timing.aest.assign(nodeCount, 0);
|
||||
timing.alst.assign(nodeCount, 0);
|
||||
timing.topologicalOrder.reserve(nodeCount);
|
||||
|
||||
std::vector<std::vector<std::pair<size_t, Weight>>> parents(nodeCount);
|
||||
std::vector<std::vector<std::pair<size_t, Weight>>> children(nodeCount);
|
||||
std::vector<size_t> incomingEdgeCount(nodeCount, 0);
|
||||
|
||||
for (auto [start, end, weight] : graph.edges) {
|
||||
size_t startIndex = static_cast<size_t>(start);
|
||||
size_t endIndex = static_cast<size_t>(end);
|
||||
Weight edgeWeight = static_cast<Weight>(weight);
|
||||
assert(startIndex < nodeCount && endIndex < nodeCount && "virtual edge endpoint out of range");
|
||||
children[startIndex].push_back({endIndex, edgeWeight});
|
||||
parents[endIndex].push_back({startIndex, edgeWeight});
|
||||
incomingEdgeCount[endIndex]++;
|
||||
}
|
||||
|
||||
auto getVirtualNodeOrderKey = [&](size_t nodeIndex) {
|
||||
const VirtualNode &node = graph.nodes[nodeIndex];
|
||||
if (!node.originalNodeIndices.empty())
|
||||
return node.originalNodeIndices.front();
|
||||
return nodeIndex;
|
||||
};
|
||||
auto readyNodeGreater = [&](size_t lhs, size_t rhs) {
|
||||
size_t lhsKey = getVirtualNodeOrderKey(lhs);
|
||||
size_t rhsKey = getVirtualNodeOrderKey(rhs);
|
||||
if (lhsKey != rhsKey)
|
||||
return lhsKey > rhsKey;
|
||||
return lhs > rhs;
|
||||
};
|
||||
std::priority_queue<size_t, std::vector<size_t>, decltype(readyNodeGreater)> readyNodes(readyNodeGreater);
|
||||
for (size_t i = 0; i < nodeCount; ++i)
|
||||
if (incomingEdgeCount[i] == 0)
|
||||
readyNodes.push(i);
|
||||
|
||||
while (!readyNodes.empty()) {
|
||||
size_t current = readyNodes.top();
|
||||
readyNodes.pop();
|
||||
timing.topologicalOrder.push_back(current);
|
||||
for (auto [child, weight] : children[current]) {
|
||||
(void) weight;
|
||||
assert(incomingEdgeCount[child] > 0 && "incoming edge count underflow");
|
||||
incomingEdgeCount[child]--;
|
||||
if (incomingEdgeCount[child] == 0)
|
||||
readyNodes.push(child);
|
||||
}
|
||||
}
|
||||
|
||||
if (timing.topologicalOrder.size() != nodeCount)
|
||||
return timing;
|
||||
|
||||
Time dcpl = 0;
|
||||
for (size_t nodeIndex : timing.topologicalOrder) {
|
||||
Time maxParentAest = 0;
|
||||
for (auto [parent, transferCost] : parents[nodeIndex]) {
|
||||
maxParentAest =
|
||||
std::max(maxParentAest, addOrMax(addOrMax(timing.aest[parent], graph.nodes[parent].weight), transferCost));
|
||||
}
|
||||
timing.aest[nodeIndex] = maxParentAest;
|
||||
dcpl = std::max(dcpl, addOrMax(maxParentAest, graph.nodes[nodeIndex].weight));
|
||||
}
|
||||
|
||||
for (size_t nodeIndex : llvm::reverse(timing.topologicalOrder)) {
|
||||
Time minAlst = std::numeric_limits<Time>::max();
|
||||
if (children[nodeIndex].empty())
|
||||
minAlst = subtractOrZero(dcpl, graph.nodes[nodeIndex].weight);
|
||||
for (auto [child, transferCost] : children[nodeIndex]) {
|
||||
minAlst =
|
||||
std::min(minAlst, subtractOrZero(timing.alst[child], addOrMax(graph.nodes[nodeIndex].weight, transferCost)));
|
||||
}
|
||||
timing.alst[nodeIndex] = minAlst;
|
||||
}
|
||||
|
||||
timing.valid = true;
|
||||
return timing;
|
||||
}
|
||||
|
||||
std::vector<std::vector<size_t>> buildUndirectedAdjacency(const VirtualGraph &graph) {
|
||||
std::vector<std::vector<size_t>> adjacency(graph.nodes.size());
|
||||
for (auto [start, end, weight] : graph.edges) {
|
||||
(void) weight;
|
||||
size_t startIndex = static_cast<size_t>(start);
|
||||
size_t endIndex = static_cast<size_t>(end);
|
||||
assert(startIndex < graph.nodes.size() && endIndex < graph.nodes.size() && "virtual edge endpoint out of range");
|
||||
adjacency[startIndex].push_back(endIndex);
|
||||
adjacency[endIndex].push_back(startIndex);
|
||||
}
|
||||
for (auto &neighbours : adjacency) {
|
||||
llvm::sort(neighbours);
|
||||
neighbours.erase(std::unique(neighbours.begin(), neighbours.end()), neighbours.end());
|
||||
}
|
||||
return adjacency;
|
||||
}
|
||||
|
||||
std::vector<size_t> selectCriticalWindow(const VirtualGraph &graph, const TimingInfo &timing, size_t windowSize) {
|
||||
std::vector<size_t> ranked(timing.aest.size());
|
||||
std::iota(ranked.begin(), ranked.end(), 0);
|
||||
auto isHigherPriority = [&](size_t lhs, size_t rhs) {
|
||||
Time lhsSlack = slackOrZero(timing.aest[lhs], timing.alst[lhs]);
|
||||
Time rhsSlack = slackOrZero(timing.aest[rhs], timing.alst[rhs]);
|
||||
if (lhsSlack != rhsSlack)
|
||||
return lhsSlack < rhsSlack;
|
||||
if (timing.aest[lhs] != timing.aest[rhs])
|
||||
return timing.aest[lhs] < timing.aest[rhs];
|
||||
return lhs < rhs;
|
||||
};
|
||||
|
||||
windowSize = std::min(windowSize, ranked.size());
|
||||
if (windowSize == 0)
|
||||
return {};
|
||||
if (windowSize == ranked.size()) {
|
||||
llvm::sort(ranked, isHigherPriority);
|
||||
return ranked;
|
||||
}
|
||||
|
||||
size_t criticalPoolSize = std::min(ranked.size(), std::max(windowSize, windowSize * 2));
|
||||
if (criticalPoolSize < ranked.size())
|
||||
std::nth_element(
|
||||
ranked.begin(), ranked.begin() + static_cast<std::ptrdiff_t>(criticalPoolSize), ranked.end(), isHigherPriority);
|
||||
|
||||
std::vector<char> inCriticalPool(ranked.size(), false);
|
||||
for (size_t i = 0; i < criticalPoolSize; ++i)
|
||||
inCriticalPool[ranked[i]] = true;
|
||||
|
||||
size_t seed = *std::min_element(ranked.begin(), ranked.end(), isHigherPriority);
|
||||
std::vector<std::vector<size_t>> adjacency = buildUndirectedAdjacency(graph);
|
||||
std::vector<size_t> selected;
|
||||
std::vector<char> inWindow(ranked.size(), false);
|
||||
selected.reserve(windowSize);
|
||||
|
||||
struct FrontierEntry {
|
||||
size_t node;
|
||||
};
|
||||
auto frontierCompare = [&](FrontierEntry lhs, FrontierEntry rhs) { return isHigherPriority(rhs.node, lhs.node); };
|
||||
std::priority_queue<FrontierEntry, std::vector<FrontierEntry>, decltype(frontierCompare)> frontier(frontierCompare);
|
||||
|
||||
auto addToWindow = [&](size_t node, const std::vector<char> &eligible) {
|
||||
if (inWindow[node])
|
||||
return;
|
||||
inWindow[node] = true;
|
||||
selected.push_back(node);
|
||||
for (size_t neighbour : adjacency[node])
|
||||
if (!inWindow[neighbour] && eligible[neighbour])
|
||||
frontier.push({neighbour});
|
||||
};
|
||||
|
||||
addToWindow(seed, inCriticalPool);
|
||||
while (!frontier.empty() && selected.size() < windowSize) {
|
||||
size_t node = frontier.top().node;
|
||||
frontier.pop();
|
||||
if (!inWindow[node])
|
||||
addToWindow(node, inCriticalPool);
|
||||
}
|
||||
|
||||
if (selected.size() < windowSize) {
|
||||
std::vector<char> anyNode(ranked.size(), true);
|
||||
for (size_t node : selected)
|
||||
for (size_t neighbour : adjacency[node])
|
||||
if (!inWindow[neighbour])
|
||||
frontier.push({neighbour});
|
||||
while (!frontier.empty() && selected.size() < windowSize) {
|
||||
size_t node = frontier.top().node;
|
||||
frontier.pop();
|
||||
if (!inWindow[node])
|
||||
addToWindow(node, anyNode);
|
||||
}
|
||||
}
|
||||
|
||||
if (selected.size() < windowSize) {
|
||||
llvm::sort(ranked, isHigherPriority);
|
||||
for (size_t node : ranked) {
|
||||
if (selected.size() == windowSize)
|
||||
break;
|
||||
if (!inWindow[node]) {
|
||||
inWindow[node] = true;
|
||||
selected.push_back(node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
llvm::sort(selected, isHigherPriority);
|
||||
return selected;
|
||||
}
|
||||
|
||||
std::vector<IndexedEdge> buildWindowEdges(const VirtualGraph &graph, const std::vector<int64_t> &nodeToWindowIndex) {
|
||||
std::vector<IndexedEdge> windowEdges;
|
||||
windowEdges.reserve(graph.edges.size());
|
||||
for (auto [start, end, weight] : graph.edges) {
|
||||
int64_t mappedStart = nodeToWindowIndex[static_cast<size_t>(start)];
|
||||
int64_t mappedEnd = nodeToWindowIndex[static_cast<size_t>(end)];
|
||||
if (mappedStart == -1 || mappedEnd == -1)
|
||||
continue;
|
||||
windowEdges.push_back({mappedStart, mappedEnd, weight});
|
||||
}
|
||||
return aggregateEdges(windowEdges);
|
||||
}
|
||||
|
||||
WindowScheduleResult scheduleWindow(const VirtualGraph &graph,
|
||||
llvm::ArrayRef<size_t> selectedNodes,
|
||||
const DcpScheduleOptions &options,
|
||||
mlir::MLIRContext *context) {
|
||||
std::vector<Weight> windowWeights;
|
||||
std::vector<CrossbarUsage> windowCrossbarUsage;
|
||||
std::vector<int64_t> windowNodeOrderKeys;
|
||||
std::vector<int64_t> nodeToWindowIndex(graph.nodes.size(), -1);
|
||||
windowWeights.reserve(selectedNodes.size());
|
||||
windowCrossbarUsage.reserve(selectedNodes.size());
|
||||
windowNodeOrderKeys.reserve(selectedNodes.size());
|
||||
|
||||
for (auto [windowIndex, nodeIndex] : llvm::enumerate(selectedNodes)) {
|
||||
nodeToWindowIndex[nodeIndex] = static_cast<int64_t>(windowIndex);
|
||||
windowWeights.push_back(graph.nodes[nodeIndex].weight);
|
||||
windowCrossbarUsage.push_back(graph.nodes[nodeIndex].crossbarUsage);
|
||||
windowNodeOrderKeys.push_back(static_cast<int64_t>(nodeIndex));
|
||||
}
|
||||
|
||||
GraphDCP windowGraph(
|
||||
windowWeights, buildWindowEdges(graph, nodeToWindowIndex), windowNodeOrderKeys, windowCrossbarUsage);
|
||||
if (options.processorCount > 0)
|
||||
windowGraph.setMaxCpuCount(static_cast<int>(options.processorCount));
|
||||
windowGraph.setContext(context);
|
||||
windowGraph.runDcp();
|
||||
|
||||
WindowScheduleResult result;
|
||||
result.cpuCount = windowGraph.cpuCount();
|
||||
for (CPU cpu = 0; cpu < windowGraph.cpuCount(); ++cpu) {
|
||||
auto scheduledTasks = windowGraph.getScheduledTasks(cpu);
|
||||
if (scheduledTasks.size() < 2)
|
||||
continue;
|
||||
|
||||
result.mergedNodeCount += scheduledTasks.size();
|
||||
result.maxMergeGroupSize = std::max(result.maxMergeGroupSize, scheduledTasks.size());
|
||||
std::vector<size_t> mergeGroup;
|
||||
mergeGroup.reserve(scheduledTasks.size());
|
||||
for (const auto &task : scheduledTasks)
|
||||
mergeGroup.push_back(selectedNodes[task.nodeIndex]);
|
||||
result.mergeGroups.push_back(std::move(mergeGroup));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
bool coarsenGraph(const VirtualGraph &graph,
|
||||
llvm::ArrayRef<std::vector<size_t>> mergeGroups,
|
||||
VirtualGraph &coarsenedGraph,
|
||||
std::vector<size_t> &oldToNewNode) {
|
||||
TimingInfo timing = computeTiming(graph);
|
||||
std::vector<size_t> topologicalRank(graph.nodes.size());
|
||||
std::iota(topologicalRank.begin(), topologicalRank.end(), 0);
|
||||
if (timing.valid)
|
||||
for (auto [rank, nodeIndex] : llvm::enumerate(timing.topologicalOrder))
|
||||
topologicalRank[nodeIndex] = rank;
|
||||
|
||||
std::vector<std::vector<size_t>> orderedMergeGroups;
|
||||
orderedMergeGroups.reserve(mergeGroups.size());
|
||||
for (const auto &mergeGroup : mergeGroups) {
|
||||
orderedMergeGroups.emplace_back(mergeGroup.begin(), mergeGroup.end());
|
||||
std::stable_sort(orderedMergeGroups.back().begin(), orderedMergeGroups.back().end(), [&](size_t lhs, size_t rhs) {
|
||||
if (topologicalRank[lhs] != topologicalRank[rhs])
|
||||
return topologicalRank[lhs] < topologicalRank[rhs];
|
||||
return lhs < rhs;
|
||||
});
|
||||
}
|
||||
|
||||
std::vector<int64_t> nodeToMergeGroup(graph.nodes.size(), -1);
|
||||
for (auto [groupIndex, mergeGroup] : llvm::enumerate(orderedMergeGroups)) {
|
||||
if (mergeGroup.size() < 2)
|
||||
continue;
|
||||
for (size_t nodeIndex : mergeGroup) {
|
||||
assert(nodeIndex < graph.nodes.size() && "merge group node out of range");
|
||||
nodeToMergeGroup[nodeIndex] = static_cast<int64_t>(groupIndex);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::optional<size_t>> mergeGroupToNewNode(orderedMergeGroups.size());
|
||||
std::vector<size_t> newNodeRank;
|
||||
oldToNewNode.assign(graph.nodes.size(), 0);
|
||||
bool mergedAny = false;
|
||||
coarsenedGraph.nodes.clear();
|
||||
coarsenedGraph.edges.clear();
|
||||
coarsenedGraph.nodes.reserve(graph.nodes.size());
|
||||
newNodeRank.reserve(graph.nodes.size());
|
||||
|
||||
for (size_t nodeIndex = 0; nodeIndex < graph.nodes.size(); ++nodeIndex) {
|
||||
int64_t mergeGroupIndex = nodeToMergeGroup[nodeIndex];
|
||||
if (mergeGroupIndex == -1) {
|
||||
oldToNewNode[nodeIndex] = coarsenedGraph.nodes.size();
|
||||
coarsenedGraph.nodes.push_back(graph.nodes[nodeIndex]);
|
||||
newNodeRank.push_back(topologicalRank[nodeIndex]);
|
||||
continue;
|
||||
}
|
||||
|
||||
auto &newNodeIndex = mergeGroupToNewNode[static_cast<size_t>(mergeGroupIndex)];
|
||||
if (newNodeIndex.has_value()) {
|
||||
oldToNewNode[nodeIndex] = *newNodeIndex;
|
||||
continue;
|
||||
}
|
||||
|
||||
VirtualNode mergedNode;
|
||||
for (size_t memberIndex : orderedMergeGroups[static_cast<size_t>(mergeGroupIndex)]) {
|
||||
const VirtualNode &memberNode = graph.nodes[memberIndex];
|
||||
mergedNode.originalNodeIndices.append(memberNode.originalNodeIndices.begin(), memberNode.originalNodeIndices.end());
|
||||
mergedNode.weight = addOrMax(mergedNode.weight, memberNode.weight);
|
||||
mergedNode.crossbarUsage = addOrMax(mergedNode.crossbarUsage, memberNode.crossbarUsage);
|
||||
}
|
||||
std::sort(mergedNode.originalNodeIndices.begin(), mergedNode.originalNodeIndices.end());
|
||||
|
||||
mergedAny = true;
|
||||
newNodeIndex = coarsenedGraph.nodes.size();
|
||||
for (size_t memberIndex : orderedMergeGroups[static_cast<size_t>(mergeGroupIndex)])
|
||||
oldToNewNode[memberIndex] = *newNodeIndex;
|
||||
newNodeRank.push_back(topologicalRank[orderedMergeGroups[static_cast<size_t>(mergeGroupIndex)].front()]);
|
||||
coarsenedGraph.nodes.push_back(std::move(mergedNode));
|
||||
}
|
||||
|
||||
if (!mergedAny)
|
||||
return false;
|
||||
|
||||
std::vector<IndexedEdge> remappedEdges;
|
||||
remappedEdges.reserve(graph.edges.size());
|
||||
for (auto [start, end, weight] : graph.edges) {
|
||||
size_t newStart = oldToNewNode[static_cast<size_t>(start)];
|
||||
size_t newEnd = oldToNewNode[static_cast<size_t>(end)];
|
||||
if (newStart == newEnd)
|
||||
continue;
|
||||
if (newNodeRank[newStart] >= newNodeRank[newEnd])
|
||||
continue;
|
||||
remappedEdges.push_back({static_cast<int64_t>(newStart), static_cast<int64_t>(newEnd), weight});
|
||||
}
|
||||
coarsenedGraph.edges = aggregateEdges(remappedEdges);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
size_t getDcpCoarseningWindowSize(size_t nodeCount, const DcpScheduleOptions &options) {
|
||||
size_t windowSize = std::min(options.criticalWindowSize, nodeCount);
|
||||
CPU maxCpuCount = std::max<CPU>(1, static_cast<CPU>(getSchedulingCpuBudget(options)));
|
||||
if (nodeCount > static_cast<size_t>(maxCpuCount))
|
||||
windowSize = std::max(windowSize, std::min(nodeCount, static_cast<size_t>(maxCpuCount) + 1));
|
||||
return windowSize;
|
||||
}
|
||||
|
||||
void assignFeasibleAest(const ComputeGraph &graph, MergeScheduleResult &result) {
|
||||
llvm::DenseMap<ComputeInstance, size_t> nodeIndexByInstance;
|
||||
nodeIndexByInstance.reserve(graph.nodes.size());
|
||||
for (auto [nodeIndex, node] : llvm::enumerate(graph.nodes))
|
||||
nodeIndexByInstance[node.instance] = nodeIndex;
|
||||
|
||||
struct ScheduledEdge {
|
||||
size_t target = 0;
|
||||
Time delay = 0;
|
||||
};
|
||||
|
||||
std::vector<std::vector<ScheduledEdge>> scheduledChildren(graph.nodes.size());
|
||||
std::vector<size_t> incomingEdgeCount(graph.nodes.size(), 0);
|
||||
for (const ComputeGraphEdge &edge : graph.edges) {
|
||||
const ComputeInstance sourceInstance = graph.nodes[edge.source].instance;
|
||||
const ComputeInstance targetInstance = graph.nodes[edge.target].instance;
|
||||
const size_t sourceCpu = result.computeToCpuMap.lookup(sourceInstance);
|
||||
const size_t targetCpu = result.computeToCpuMap.lookup(targetInstance);
|
||||
|
||||
Time delay = graph.nodes[edge.source].weight;
|
||||
if (sourceCpu != targetCpu)
|
||||
delay = addOrMax(delay, edge.transferCost);
|
||||
|
||||
scheduledChildren[edge.source].push_back({edge.target, delay});
|
||||
incomingEdgeCount[edge.target]++;
|
||||
}
|
||||
|
||||
llvm::DenseMap<size_t, std::vector<std::pair<size_t, size_t>>> tasksByCpu;
|
||||
for (const ComputeGraphNode &node : graph.nodes) {
|
||||
size_t cpu = result.computeToCpuMap.lookup(node.instance);
|
||||
size_t slot = result.computeToCpuSlotMap.lookup(node.instance);
|
||||
tasksByCpu[cpu].push_back({slot, nodeIndexByInstance.lookup(node.instance)});
|
||||
}
|
||||
|
||||
for (auto &entry : tasksByCpu) {
|
||||
auto &scheduledTasks = entry.second;
|
||||
llvm::sort(scheduledTasks, [](const auto &lhs, const auto &rhs) {
|
||||
if (lhs.first != rhs.first)
|
||||
return lhs.first < rhs.first;
|
||||
return lhs.second < rhs.second;
|
||||
});
|
||||
|
||||
for (size_t i = 1; i < scheduledTasks.size(); ++i) {
|
||||
size_t sourceIndex = scheduledTasks[i - 1].second;
|
||||
size_t targetIndex = scheduledTasks[i].second;
|
||||
scheduledChildren[sourceIndex].push_back({targetIndex, graph.nodes[sourceIndex].weight});
|
||||
incomingEdgeCount[targetIndex]++;
|
||||
}
|
||||
}
|
||||
|
||||
auto readyNodeGreater = [&](size_t lhs, size_t rhs) {
|
||||
if (graph.nodes[lhs].originalOrder != graph.nodes[rhs].originalOrder)
|
||||
return graph.nodes[lhs].originalOrder > graph.nodes[rhs].originalOrder;
|
||||
return lhs > rhs;
|
||||
};
|
||||
std::priority_queue<size_t, std::vector<size_t>, decltype(readyNodeGreater)> readyNodes(readyNodeGreater);
|
||||
for (size_t nodeIndex = 0; nodeIndex < graph.nodes.size(); ++nodeIndex)
|
||||
if (incomingEdgeCount[nodeIndex] == 0)
|
||||
readyNodes.push(nodeIndex);
|
||||
|
||||
std::vector<Time> startTimes(graph.nodes.size(), 0);
|
||||
size_t processedNodeCount = 0;
|
||||
while (!readyNodes.empty()) {
|
||||
size_t sourceIndex = readyNodes.top();
|
||||
readyNodes.pop();
|
||||
processedNodeCount++;
|
||||
|
||||
for (const ScheduledEdge &edge : scheduledChildren[sourceIndex]) {
|
||||
startTimes[edge.target] = std::max(startTimes[edge.target], addOrMax(startTimes[sourceIndex], edge.delay));
|
||||
assert(incomingEdgeCount[edge.target] > 0 && "scheduled incoming edge count underflow");
|
||||
incomingEdgeCount[edge.target]--;
|
||||
if (incomingEdgeCount[edge.target] == 0)
|
||||
readyNodes.push(edge.target);
|
||||
}
|
||||
}
|
||||
|
||||
if (processedNodeCount != graph.nodes.size())
|
||||
llvm::report_fatal_error("merge scheduling: coarsened DCP schedule is cyclic");
|
||||
|
||||
for (auto [nodeIndex, node] : llvm::enumerate(graph.nodes))
|
||||
result.computeToAestMap[node.instance] = startTimes[nodeIndex];
|
||||
}
|
||||
|
||||
MergeScheduleResult buildResultFromVirtualGraph(const VirtualGraph &graph, const ComputeGraph &originalGraph) {
|
||||
MergeScheduleResult result;
|
||||
|
||||
TimingInfo timing = computeTiming(graph);
|
||||
std::vector<size_t> virtualNodeOrder;
|
||||
if (timing.valid)
|
||||
virtualNodeOrder = std::move(timing.topologicalOrder);
|
||||
else {
|
||||
virtualNodeOrder.resize(graph.nodes.size());
|
||||
std::iota(virtualNodeOrder.begin(), virtualNodeOrder.end(), 0);
|
||||
}
|
||||
|
||||
std::vector<size_t> originalNodeToCpu(originalGraph.nodes.size(), 0);
|
||||
for (auto [cpu, virtualNodeIndex] : llvm::enumerate(virtualNodeOrder)) {
|
||||
const VirtualNode &virtualNode = graph.nodes[virtualNodeIndex];
|
||||
for (size_t originalIndex : virtualNode.originalNodeIndices)
|
||||
originalNodeToCpu[originalIndex] = cpu;
|
||||
}
|
||||
|
||||
result.dominanceOrderCompute.reserve(originalGraph.nodes.size());
|
||||
llvm::DenseMap<size_t, size_t> nextCpuSlot;
|
||||
for (auto [originalIndex, node] : llvm::enumerate(originalGraph.nodes)) {
|
||||
size_t cpu = originalNodeToCpu[originalIndex];
|
||||
result.dominanceOrderCompute.push_back(node.instance);
|
||||
result.computeToCpuMap[node.instance] = cpu;
|
||||
result.computeToCpuSlotMap[node.instance] = nextCpuSlot[cpu]++;
|
||||
result.cpuToLastComputeMap[cpu] = node.instance;
|
||||
}
|
||||
for (const auto &[cpu, lastCompute] : result.cpuToLastComputeMap)
|
||||
result.isLastComputeOfCpu.insert(lastCompute);
|
||||
assignFeasibleAest(originalGraph, result);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
MergeScheduleResult buildResultFromScheduledGraph(GraphDCP &graphDCP, const ComputeGraph &graph) {
|
||||
MergeScheduleResult result;
|
||||
result.dominanceOrderCompute.reserve(graph.nodes.size());
|
||||
for (const ComputeGraphNode &node : graph.nodes)
|
||||
result.dominanceOrderCompute.push_back(node.instance);
|
||||
|
||||
for (CPU cpu = 0; cpu < graphDCP.cpuCount(); ++cpu) {
|
||||
auto scheduledTasks = graphDCP.getScheduledTasks(cpu);
|
||||
if (scheduledTasks.empty())
|
||||
continue;
|
||||
|
||||
for (auto [slot, task] : llvm::enumerate(scheduledTasks)) {
|
||||
const ComputeInstance instance = graph.nodes[task.nodeIndex].instance;
|
||||
result.computeToCpuMap[instance] = cpu;
|
||||
result.computeToCpuSlotMap[instance] = slot;
|
||||
result.computeToAestMap[instance] = static_cast<uint64_t>(task.aest);
|
||||
}
|
||||
|
||||
const ComputeInstance lastInstance = graph.nodes[scheduledTasks.back().nodeIndex].instance;
|
||||
result.cpuToLastComputeMap[cpu] = lastInstance;
|
||||
result.isLastComputeOfCpu.insert(lastInstance);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
MergeScheduleResult runLegacyDcp(const ComputeGraph &graph, const DcpScheduleOptions &options, mlir::MLIRContext *context) {
|
||||
llvm::SmallVector<Weight> nodeWeights;
|
||||
llvm::SmallVector<CrossbarUsage> nodeCrossbarUsage;
|
||||
llvm::SmallVector<int64_t> nodeOrderKeys;
|
||||
llvm::SmallVector<IndexedEdge> edges;
|
||||
nodeWeights.reserve(graph.nodes.size());
|
||||
nodeCrossbarUsage.reserve(graph.nodes.size());
|
||||
nodeOrderKeys.reserve(graph.nodes.size());
|
||||
edges.reserve(graph.edges.size());
|
||||
|
||||
for (const ComputeGraphNode &node : graph.nodes) {
|
||||
nodeWeights.push_back(node.weight);
|
||||
nodeCrossbarUsage.push_back(node.crossbarUsage);
|
||||
nodeOrderKeys.push_back(static_cast<int64_t>(node.originalOrder));
|
||||
}
|
||||
for (const ComputeGraphEdge &edge : graph.edges) {
|
||||
edges.push_back(
|
||||
{static_cast<int64_t>(edge.source), static_cast<int64_t>(edge.target), static_cast<int64_t>(edge.transferCost)});
|
||||
}
|
||||
|
||||
GraphDCP graphDCP(nodeWeights, edges, nodeOrderKeys, nodeCrossbarUsage);
|
||||
if (options.processorCount > 0)
|
||||
graphDCP.setMaxCpuCount(static_cast<int>(options.processorCount));
|
||||
graphDCP.setContext(context);
|
||||
graphDCP.runDcp();
|
||||
return buildResultFromScheduledGraph(graphDCP, graph);
|
||||
}
|
||||
|
||||
bool needsExactScheduledBatches(const ComputeGraph &graph, const DcpScheduleOptions &options) {
|
||||
if (options.processorCount == 0 || !options.allowFallbackForAutoCoreCount)
|
||||
return false;
|
||||
size_t schedulingCpuBudget = getSchedulingCpuBudget(options);
|
||||
return llvm::any_of(graph.nodes, [&](const ComputeGraphNode &node) {
|
||||
auto batch = dyn_cast<SpatComputeBatch>(node.instance.op);
|
||||
return batch && static_cast<size_t>(batch.getLaneCount()) > schedulingCpuBudget;
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
MergeScheduleResult
|
||||
runDcpScheduler(const ComputeGraph &graph, const DcpScheduleOptions &options, mlir::MLIRContext *context) {
|
||||
if (needsExactScheduledBatches(graph, options))
|
||||
return runLegacyDcp(graph, options, context);
|
||||
|
||||
if (options.criticalWindowSize == 0)
|
||||
return runLegacyDcp(graph, options, context);
|
||||
|
||||
VirtualGraph virtualGraph = buildInitialVirtualGraph(graph);
|
||||
size_t iteration = 0;
|
||||
bool debugCoarsening = isDcpCoarsenDebugEnabled();
|
||||
auto tryCoarsenSelectedNodes = [&](llvm::ArrayRef<size_t> selectedNodes) {
|
||||
size_t oldNodeCount = virtualGraph.nodes.size();
|
||||
WindowScheduleResult windowSchedule = scheduleWindow(virtualGraph, selectedNodes, options, context);
|
||||
if (windowSchedule.mergeGroups.empty()) {
|
||||
if (debugCoarsening && oldNodeCount >= 200)
|
||||
llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} windowCpus={3} "
|
||||
"groups=0 mergedNodes=0 maxGroup=0 new={1} changed=0\n",
|
||||
iteration,
|
||||
oldNodeCount,
|
||||
selectedNodes.size(),
|
||||
windowSchedule.cpuCount);
|
||||
return false;
|
||||
}
|
||||
|
||||
VirtualGraph coarsenedGraph;
|
||||
std::vector<size_t> oldToNewNode;
|
||||
if (!coarsenGraph(virtualGraph, windowSchedule.mergeGroups, coarsenedGraph, oldToNewNode))
|
||||
return false;
|
||||
if (debugCoarsening && (oldNodeCount >= 200 || coarsenedGraph.nodes.size() >= 200))
|
||||
llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} windowCpus={3} "
|
||||
"groups={4} mergedNodes={5} maxGroup={6} new={7} changed={8}\n",
|
||||
iteration,
|
||||
oldNodeCount,
|
||||
selectedNodes.size(),
|
||||
windowSchedule.cpuCount,
|
||||
windowSchedule.mergeGroups.size(),
|
||||
windowSchedule.mergedNodeCount,
|
||||
windowSchedule.maxMergeGroupSize,
|
||||
coarsenedGraph.nodes.size(),
|
||||
oldNodeCount - coarsenedGraph.nodes.size());
|
||||
virtualGraph = std::move(coarsenedGraph);
|
||||
return true;
|
||||
};
|
||||
|
||||
while (virtualGraph.nodes.size() > 1) {
|
||||
if (virtualGraph.nodes.size() <= getSchedulingCpuBudget(options)) {
|
||||
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
|
||||
llvm::errs() << llvm::formatv(
|
||||
"[DCP-COARSEN] iter={0} old={1} stop=cpu-budget\n", iteration, virtualGraph.nodes.size());
|
||||
break;
|
||||
}
|
||||
|
||||
iteration++;
|
||||
TimingInfo timing = computeTiming(virtualGraph);
|
||||
if (!timing.valid) {
|
||||
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
|
||||
llvm::errs() << llvm::formatv(
|
||||
"[DCP-COARSEN] iter={0} old={1} invalid-timing\n", iteration, virtualGraph.nodes.size());
|
||||
break;
|
||||
}
|
||||
|
||||
llvm::SmallVector<size_t> selectedNodes;
|
||||
auto criticalWindow =
|
||||
selectCriticalWindow(virtualGraph, timing, getDcpCoarseningWindowSize(virtualGraph.nodes.size(), options));
|
||||
selectedNodes.append(criticalWindow.begin(), criticalWindow.end());
|
||||
|
||||
if (selectedNodes.size() < 2) {
|
||||
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
|
||||
llvm::errs() << llvm::formatv("[DCP-COARSEN] iter={0} old={1} selected={2} stop=small-window\n",
|
||||
iteration,
|
||||
virtualGraph.nodes.size(),
|
||||
selectedNodes.size());
|
||||
break;
|
||||
}
|
||||
|
||||
if (tryCoarsenSelectedNodes(selectedNodes))
|
||||
continue;
|
||||
if (debugCoarsening && virtualGraph.nodes.size() >= 200)
|
||||
llvm::errs() << llvm::formatv(
|
||||
"[DCP-COARSEN] iter={0} old={1} stop=no-merge\n", iteration, virtualGraph.nodes.size());
|
||||
break;
|
||||
}
|
||||
|
||||
return buildResultFromVirtualGraph(virtualGraph, graph);
|
||||
}
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,21 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
|
||||
#include "ComputeGraph.hpp"
|
||||
#include "MergeSchedule.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
struct DcpScheduleOptions {
|
||||
size_t processorCount = 0;
|
||||
size_t criticalWindowSize = 0;
|
||||
bool allowFallbackForAutoCoreCount = true;
|
||||
};
|
||||
|
||||
MergeScheduleResult
|
||||
runDcpScheduler(const ComputeGraph &graph, const DcpScheduleOptions &options, mlir::MLIRContext *context);
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,25 @@
|
||||
#pragma once
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
|
||||
#include "ComputeInstance.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
struct MergeScheduleResult {
|
||||
std::vector<ComputeInstance> dominanceOrderCompute;
|
||||
llvm::DenseMap<ComputeInstance, size_t> computeToCpuMap;
|
||||
llvm::DenseMap<ComputeInstance, size_t> computeToCpuSlotMap;
|
||||
llvm::DenseMap<ComputeInstance, uint64_t> computeToAestMap;
|
||||
llvm::DenseSet<ComputeInstance> isLastComputeOfCpu;
|
||||
llvm::DenseMap<size_t, ComputeInstance> cpuToLastComputeMap;
|
||||
};
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
+139
@@ -0,0 +1,139 @@
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
|
||||
#include "ComputeGraph.hpp"
|
||||
#include "../DCPGraph/DCPAnalysis.hpp"
|
||||
#include "DcpScheduler.hpp"
|
||||
#include "MergeSchedulingAnalysis.hpp"
|
||||
#include "PeftScheduler.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
namespace {
|
||||
|
||||
MergeSchedulerKind getSchedulerKind() {
|
||||
switch (pimMergeScheduler.getValue()) {
|
||||
case MergeSchedulerPeft:
|
||||
return MergeSchedulerKind::Peft;
|
||||
case MergeSchedulerDcp:
|
||||
return MergeSchedulerKind::Dcp;
|
||||
}
|
||||
llvm_unreachable("unknown merge scheduler kind");
|
||||
}
|
||||
|
||||
void verifySchedule(const ComputeGraph &graph, const MergeScheduleResult &result, CrossbarUsage crossbarCapacity) {
|
||||
llvm::DenseMap<size_t, std::vector<std::pair<size_t, size_t>>> tasksByCpu;
|
||||
tasksByCpu.reserve(result.cpuToLastComputeMap.size());
|
||||
|
||||
for (size_t nodeIndex = 0; nodeIndex < graph.nodes.size(); ++nodeIndex) {
|
||||
const ComputeInstance instance = graph.nodes[nodeIndex].instance;
|
||||
if (!result.computeToCpuMap.count(instance))
|
||||
llvm::report_fatal_error("merge scheduling: missing CPU assignment");
|
||||
if (!result.computeToCpuSlotMap.count(instance))
|
||||
llvm::report_fatal_error("merge scheduling: missing CPU slot assignment");
|
||||
if (!result.computeToAestMap.count(instance))
|
||||
llvm::report_fatal_error("merge scheduling: missing start time");
|
||||
|
||||
tasksByCpu[result.computeToCpuMap.lookup(instance)].push_back(
|
||||
{result.computeToCpuSlotMap.lookup(instance), nodeIndex});
|
||||
}
|
||||
|
||||
for (auto &entry : tasksByCpu) {
|
||||
auto &scheduledTasks = entry.second;
|
||||
llvm::sort(scheduledTasks, [](const auto &lhs, const auto &rhs) {
|
||||
if (lhs.first != rhs.first)
|
||||
return lhs.first < rhs.first;
|
||||
return lhs.second < rhs.second;
|
||||
});
|
||||
|
||||
CrossbarUsage usedCrossbars = 0;
|
||||
for (size_t slot = 0; slot < scheduledTasks.size(); ++slot) {
|
||||
if (scheduledTasks[slot].first != slot)
|
||||
llvm::report_fatal_error("merge scheduling: CPU slots are not contiguous");
|
||||
usedCrossbars = addOrMax(usedCrossbars, graph.nodes[scheduledTasks[slot].second].crossbarUsage);
|
||||
if (usedCrossbars > crossbarCapacity)
|
||||
llvm::report_fatal_error("merge scheduling: CPU crossbar capacity exceeded");
|
||||
}
|
||||
|
||||
const ComputeInstance expectedLast = graph.nodes[scheduledTasks.back().second].instance;
|
||||
auto lastIt = result.cpuToLastComputeMap.find(entry.first);
|
||||
if (lastIt == result.cpuToLastComputeMap.end() || !(lastIt->second == expectedLast))
|
||||
llvm::report_fatal_error("merge scheduling: cpuToLastComputeMap does not match slot order");
|
||||
if (!result.isLastComputeOfCpu.count(expectedLast))
|
||||
llvm::report_fatal_error("merge scheduling: missing last-compute marker");
|
||||
}
|
||||
|
||||
for (const ComputeGraphEdge &edge : graph.edges) {
|
||||
const ComputeInstance source = graph.nodes[edge.source].instance;
|
||||
const ComputeInstance target = graph.nodes[edge.target].instance;
|
||||
const size_t sourceCpu = result.computeToCpuMap.lookup(source);
|
||||
const size_t targetCpu = result.computeToCpuMap.lookup(target);
|
||||
const size_t sourceSlot = result.computeToCpuSlotMap.lookup(source);
|
||||
const size_t targetSlot = result.computeToCpuSlotMap.lookup(target);
|
||||
const Time sourceStart = static_cast<Time>(result.computeToAestMap.lookup(source));
|
||||
const Time targetStart = static_cast<Time>(result.computeToAestMap.lookup(target));
|
||||
if (sourceCpu == targetCpu && sourceSlot >= targetSlot)
|
||||
llvm::report_fatal_error("merge scheduling: same-CPU dependency order is invalid");
|
||||
|
||||
Time earliestTargetStart = addOrMax(sourceStart, graph.nodes[edge.source].weight);
|
||||
if (sourceCpu != targetCpu)
|
||||
earliestTargetStart = addOrMax(earliestTargetStart, edge.transferCost);
|
||||
if (targetStart < earliestTargetStart) {
|
||||
std::string message = llvm::formatv("merge scheduling: dependency legality failed between tasks {0} and {1}",
|
||||
graph.nodes[edge.source].originalOrder,
|
||||
graph.nodes[edge.target].originalOrder)
|
||||
.str();
|
||||
llvm::report_fatal_error(llvm::StringRef(message));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
MergeSchedulingAnalysis::MergeSchedulingAnalysis(mlir::Operation *op)
|
||||
: entryOp(op) {
|
||||
result = run();
|
||||
}
|
||||
|
||||
MergeScheduleResult MergeSchedulingAnalysis::run() {
|
||||
verifyExplicitPimCoreCount();
|
||||
ComputeGraph graph = buildComputeGraph(entryOp);
|
||||
if (!verifyAcyclic(graph))
|
||||
llvm::report_fatal_error("merge scheduling: compute graph is cyclic");
|
||||
|
||||
MergeSchedulingOptions options;
|
||||
options.kind = getSchedulerKind();
|
||||
if (coresCount.getValue() > 0)
|
||||
options.processorCount = static_cast<size_t>(coresCount.getValue());
|
||||
|
||||
MergeScheduleResult schedule;
|
||||
if (options.kind == MergeSchedulerKind::Peft) {
|
||||
schedule = runPeftScheduler(
|
||||
graph,
|
||||
PeftScheduleOptions {options.processorCount, static_cast<CrossbarUsage>(crossbarCountInCore.getValue()),
|
||||
entryOp->getContext()});
|
||||
}
|
||||
else {
|
||||
schedule = runDcpScheduler(
|
||||
graph,
|
||||
DcpScheduleOptions {
|
||||
options.processorCount,
|
||||
dcpCriticalWindowSize.getValue(),
|
||||
options.allowDcpFallbackForAutoCoreCount
|
||||
},
|
||||
entryOp->getContext());
|
||||
}
|
||||
|
||||
verifySchedule(graph, schedule, static_cast<CrossbarUsage>(crossbarCountInCore.getValue()));
|
||||
return schedule;
|
||||
}
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
+36
@@ -0,0 +1,36 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/Operation.h"
|
||||
|
||||
#include <cstddef>
|
||||
|
||||
#include "MergeSchedule.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
enum class MergeSchedulerKind {
|
||||
Dcp,
|
||||
Peft,
|
||||
};
|
||||
|
||||
struct MergeSchedulingOptions {
|
||||
MergeSchedulerKind kind = MergeSchedulerKind::Peft;
|
||||
size_t processorCount = 0;
|
||||
bool allowDcpFallbackForAutoCoreCount = true;
|
||||
};
|
||||
|
||||
class MergeSchedulingAnalysis {
|
||||
public:
|
||||
explicit MergeSchedulingAnalysis(mlir::Operation *op);
|
||||
MergeScheduleResult &getResult() { return result; }
|
||||
|
||||
private:
|
||||
mlir::Operation *entryOp = nullptr;
|
||||
MergeScheduleResult result;
|
||||
|
||||
MergeScheduleResult run();
|
||||
};
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,303 @@
|
||||
#include "mlir/IR/Threading.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
|
||||
#include <limits>
|
||||
#include <queue>
|
||||
#include <vector>
|
||||
|
||||
#include "PeftScheduler.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
namespace {
|
||||
|
||||
struct ScheduledTask {
|
||||
size_t processor = std::numeric_limits<size_t>::max();
|
||||
Time startTime = 0;
|
||||
Time endTime = 0;
|
||||
size_t slot = 0;
|
||||
};
|
||||
|
||||
std::vector<std::vector<size_t>> buildReverseLevels(const ComputeGraph& graph) {
|
||||
std::vector<size_t> remainingSuccessors(graph.nodes.size(), 0);
|
||||
std::queue<size_t> readySinks;
|
||||
std::vector<std::vector<size_t>> reverseLevels;
|
||||
|
||||
for (size_t node = 0; node < graph.nodes.size(); ++node) {
|
||||
remainingSuccessors[node] = graph.successors[node].size();
|
||||
if (remainingSuccessors[node] == 0)
|
||||
readySinks.push(node);
|
||||
}
|
||||
|
||||
size_t levelizedCount = 0;
|
||||
while (!readySinks.empty()) {
|
||||
size_t levelSize = readySinks.size();
|
||||
std::vector<size_t> levelNodes;
|
||||
levelNodes.reserve(levelSize);
|
||||
for (size_t i = 0; i < levelSize; ++i) {
|
||||
size_t node = readySinks.front();
|
||||
readySinks.pop();
|
||||
levelNodes.push_back(node);
|
||||
++levelizedCount;
|
||||
for (const auto& [pred, weight] : graph.predecessors[node]) {
|
||||
assert(remainingSuccessors[pred] > 0 && "remaining successor count underflow");
|
||||
if (--remainingSuccessors[pred] == 0)
|
||||
readySinks.push(pred);
|
||||
}
|
||||
}
|
||||
reverseLevels.push_back(std::move(levelNodes));
|
||||
}
|
||||
|
||||
if (levelizedCount != graph.nodes.size())
|
||||
llvm::report_fatal_error("PEFT scheduler: compute graph is cyclic or malformed");
|
||||
|
||||
return reverseLevels;
|
||||
}
|
||||
|
||||
void verifyOctTableSize(size_t nodeCount, size_t processorCount) {
|
||||
constexpr size_t kMaxOctTableBytes = 1ull << 30;
|
||||
if (nodeCount == 0 || processorCount == 0)
|
||||
return;
|
||||
if (processorCount > std::numeric_limits<size_t>::max() / sizeof(Time))
|
||||
llvm::report_fatal_error("PEFT scheduler: OCT table size overflow");
|
||||
size_t rowBytes = processorCount * sizeof(Time);
|
||||
if (nodeCount > std::numeric_limits<size_t>::max() / rowBytes)
|
||||
llvm::report_fatal_error("PEFT scheduler: OCT table size overflow");
|
||||
size_t totalBytes = nodeCount * rowBytes;
|
||||
if (totalBytes > kMaxOctTableBytes) {
|
||||
std::string message = llvm::formatv("PEFT scheduler: OCT table would require {0} MiB, exceeding the 1024 MiB guard",
|
||||
totalBytes / (1024 * 1024))
|
||||
.str();
|
||||
llvm::report_fatal_error(llvm::StringRef(message));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftScheduleOptions& options) {
|
||||
const size_t nodeCount = graph.nodes.size();
|
||||
const size_t processorCount = options.processorCount;
|
||||
if (processorCount == 0)
|
||||
llvm::report_fatal_error("PEFT scheduler: processor count must be positive");
|
||||
|
||||
verifyOctTableSize(nodeCount, processorCount);
|
||||
std::vector<std::vector<size_t>> reverseLevels = buildReverseLevels(graph);
|
||||
|
||||
// MOCK: Replace this with your actual heterogeneous cost lookup.
|
||||
// If graph.nodes[task] is modified to hold a vector of weights per processor, access it here.
|
||||
auto getComputeCost = [&](size_t task, size_t processor) -> Time { return graph.nodes[task].weight; };
|
||||
|
||||
std::vector<Time> oct(nodeCount * processorCount, 0);
|
||||
std::vector<Time> minOctPlusComp(nodeCount, 0);
|
||||
|
||||
// 1. O(P(E+V)) Heterogeneous OCT Calculation
|
||||
for (const std::vector<size_t>& levelNodes : reverseLevels) {
|
||||
auto computeNodeOct = [&](size_t levelIndex) {
|
||||
size_t task = levelNodes[levelIndex];
|
||||
std::vector<Time> maxVals(processorCount, 0);
|
||||
|
||||
for (const auto& [succ, comm] : graph.successors[task]) {
|
||||
Time valDifferentCpu = addOrMax(minOctPlusComp[succ], comm);
|
||||
for (size_t processor = 0; processor < processorCount; ++processor) {
|
||||
Time valSameCpu = addOrMax(oct[succ * processorCount + processor], getComputeCost(succ, processor));
|
||||
Time bestSucc = std::min(valSameCpu, valDifferentCpu);
|
||||
maxVals[processor] = std::max(maxVals[processor], bestSucc);
|
||||
}
|
||||
}
|
||||
|
||||
Time minForPreds = std::numeric_limits<Time>::max();
|
||||
for (size_t processor = 0; processor < processorCount; ++processor) {
|
||||
oct[task * processorCount + processor] = maxVals[processor];
|
||||
minForPreds = std::min(minForPreds, addOrMax(maxVals[processor], getComputeCost(task, processor)));
|
||||
}
|
||||
minOctPlusComp[task] = minForPreds == std::numeric_limits<Time>::max() ? 0 : minForPreds;
|
||||
};
|
||||
|
||||
if (options.context != nullptr)
|
||||
mlir::parallelFor(options.context, 0, levelNodes.size(), computeNodeOct);
|
||||
else
|
||||
for (size_t i = 0; i < levelNodes.size(); ++i)
|
||||
computeNodeOct(i);
|
||||
}
|
||||
|
||||
struct RankEntry {
|
||||
long double rank = 0.0L;
|
||||
size_t node = 0;
|
||||
size_t originalOrder = 0;
|
||||
};
|
||||
std::vector<RankEntry> ranks(nodeCount);
|
||||
auto computeRank = [&](size_t node) {
|
||||
long double rank = 0.0L;
|
||||
for (size_t processor = 0; processor < processorCount; ++processor)
|
||||
rank += static_cast<long double>(oct[node * processorCount + processor]);
|
||||
ranks[node] = {rank, node, graph.nodes[node].originalOrder};
|
||||
};
|
||||
|
||||
if (options.context != nullptr)
|
||||
mlir::parallelFor(options.context, 0, nodeCount, computeRank);
|
||||
else
|
||||
for (size_t node = 0; node < nodeCount; ++node)
|
||||
computeRank(node);
|
||||
|
||||
auto readyCompare = [&](size_t lhs, size_t rhs) {
|
||||
const RankEntry& lhsRank = ranks[lhs];
|
||||
const RankEntry& rhsRank = ranks[rhs];
|
||||
if (lhsRank.rank != rhsRank.rank)
|
||||
return lhsRank.rank < rhsRank.rank;
|
||||
if (lhsRank.originalOrder != rhsRank.originalOrder)
|
||||
return lhsRank.originalOrder > rhsRank.originalOrder;
|
||||
return lhs > rhs;
|
||||
};
|
||||
|
||||
std::vector<int> remainingParents(nodeCount, 0);
|
||||
std::priority_queue<size_t, std::vector<size_t>, decltype(readyCompare)> readyQueue(readyCompare);
|
||||
for (size_t node = 0; node < nodeCount; ++node) {
|
||||
remainingParents[node] = graph.predecessors[node].size();
|
||||
if (remainingParents[node] == 0)
|
||||
readyQueue.push(node);
|
||||
}
|
||||
|
||||
std::vector<char> scheduled(nodeCount, false);
|
||||
std::vector<CrossbarUsage> processorCrossbars(processorCount, 0);
|
||||
std::vector<ScheduledTask> schedules(nodeCount);
|
||||
std::vector<std::vector<size_t>> tasksByProcessor(processorCount);
|
||||
|
||||
size_t scheduledCount = 0;
|
||||
while (!readyQueue.empty()) {
|
||||
size_t task = readyQueue.top();
|
||||
readyQueue.pop();
|
||||
if (scheduled[task])
|
||||
continue;
|
||||
|
||||
size_t bestProcessor = std::numeric_limits<size_t>::max();
|
||||
Time bestEst = 0;
|
||||
Time bestEft = 0;
|
||||
Time bestOeft = std::numeric_limits<Time>::max();
|
||||
bool crossbarRejected = false;
|
||||
|
||||
for (size_t processor = 0; processor < processorCount; ++processor) {
|
||||
if (graph.nodes[task].crossbarUsage != 0
|
||||
&& addOrMax(processorCrossbars[processor], graph.nodes[task].crossbarUsage) > options.crossbarCapacity) {
|
||||
crossbarRejected = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
Time dataReady = 0;
|
||||
for (const auto& [pred, comm] : graph.predecessors[task]) {
|
||||
const ScheduledTask& predSchedule = schedules[pred];
|
||||
Time commPenalty = predSchedule.processor == processor ? 0 : comm;
|
||||
dataReady = std::max(dataReady, addOrMax(predSchedule.endTime, commPenalty));
|
||||
}
|
||||
|
||||
// 2. PEFT Gap-Filling EST Calculation (Maintains optimal scheduling math)
|
||||
Time compWeight = getComputeCost(task, processor);
|
||||
Time est = dataReady;
|
||||
Time currentEnd = 0;
|
||||
bool foundGap = false;
|
||||
|
||||
for (size_t schedTaskIndex : tasksByProcessor[processor]) {
|
||||
const ScheduledTask& schedTask = schedules[schedTaskIndex];
|
||||
Time gapStart = std::max(currentEnd, dataReady);
|
||||
|
||||
if (addOrMax(gapStart, compWeight) <= schedTask.startTime) {
|
||||
est = gapStart;
|
||||
foundGap = true;
|
||||
break;
|
||||
}
|
||||
currentEnd = schedTask.endTime;
|
||||
}
|
||||
|
||||
if (!foundGap)
|
||||
est = std::max(currentEnd, dataReady);
|
||||
|
||||
Time eft = addOrMax(est, compWeight);
|
||||
Time oeft = addOrMax(eft, oct[task * processorCount + processor]);
|
||||
|
||||
if (oeft < bestOeft || (oeft == bestOeft && eft < bestEft)
|
||||
|| (oeft == bestOeft && eft == bestEft && est < bestEst)
|
||||
|| (oeft == bestOeft && eft == bestEft && est == bestEst && processor < bestProcessor)) {
|
||||
bestProcessor = processor;
|
||||
bestEst = est;
|
||||
bestEft = eft;
|
||||
bestOeft = oeft;
|
||||
}
|
||||
}
|
||||
|
||||
if (bestProcessor == std::numeric_limits<size_t>::max()) {
|
||||
if (crossbarRejected) {
|
||||
std::string message =
|
||||
llvm::formatv("PEFT scheduler: no valid processor for task {0}; crossbar capacity {1} is exhausted",
|
||||
graph.nodes[task].originalOrder,
|
||||
options.crossbarCapacity)
|
||||
.str();
|
||||
llvm::report_fatal_error(llvm::StringRef(message));
|
||||
}
|
||||
std::string message = llvm::formatv("PEFT scheduler: no valid processor for task {0} with {1} processors",
|
||||
graph.nodes[task].originalOrder,
|
||||
processorCount)
|
||||
.str();
|
||||
llvm::report_fatal_error(llvm::StringRef(message));
|
||||
}
|
||||
|
||||
schedules[task] = {bestProcessor, bestEst, bestEft, 0};
|
||||
scheduled[task] = true;
|
||||
++scheduledCount;
|
||||
processorCrossbars[bestProcessor] = addOrMax(processorCrossbars[bestProcessor], graph.nodes[task].crossbarUsage);
|
||||
|
||||
// 3. CRITICAL FIX: Topological Append
|
||||
// Because the readyQueue pops in strict topological order, simply pushing to the
|
||||
// back guarantees the Monoliths will be physically generated cycle-free.
|
||||
// The hardware will still benefit from the processor assignment chosen by PEFT.
|
||||
tasksByProcessor[bestProcessor].push_back(task);
|
||||
|
||||
for (const auto& [child, weight] : graph.successors[task]) {
|
||||
(void) weight;
|
||||
assert(remainingParents[child] > 0 && "remaining parent count underflow");
|
||||
if (--remainingParents[child] == 0)
|
||||
readyQueue.push(child);
|
||||
}
|
||||
}
|
||||
|
||||
if (scheduledCount != nodeCount)
|
||||
llvm::report_fatal_error("PEFT scheduler: failed to schedule every compute node");
|
||||
|
||||
// 4. Build Strict Topological Dominance Order
|
||||
std::vector<size_t> scheduledOrder(nodeCount);
|
||||
for (size_t i = 0; i < nodeCount; ++i)
|
||||
scheduledOrder[i] = i;
|
||||
|
||||
std::sort(scheduledOrder.begin(), scheduledOrder.end(), [&](size_t a, size_t b) {
|
||||
return graph.nodes[a].originalOrder < graph.nodes[b].originalOrder;
|
||||
});
|
||||
|
||||
// 5. Populate Final Result
|
||||
MergeScheduleResult result;
|
||||
result.dominanceOrderCompute.reserve(nodeCount);
|
||||
|
||||
for (size_t task : scheduledOrder)
|
||||
result.dominanceOrderCompute.push_back(graph.nodes[task].instance);
|
||||
|
||||
for (size_t processor = 0; processor < processorCount; ++processor) {
|
||||
size_t currentSlot = 0;
|
||||
for (size_t task : tasksByProcessor[processor]) {
|
||||
const ComputeInstance instance = graph.nodes[task].instance;
|
||||
result.computeToCpuMap[instance] = processor;
|
||||
result.computeToCpuSlotMap[instance] = currentSlot++;
|
||||
result.computeToAestMap[instance] = schedules[task].startTime;
|
||||
}
|
||||
if (!tasksByProcessor[processor].empty()) {
|
||||
const ComputeInstance lastInstance = graph.nodes[tasksByProcessor[processor].back()].instance;
|
||||
result.cpuToLastComputeMap[processor] = lastInstance;
|
||||
result.isLastComputeOfCpu.insert(lastInstance);
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
|
||||
#include "ComputeGraph.hpp"
|
||||
#include "MergeSchedule.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
|
||||
struct PeftScheduleOptions {
|
||||
size_t processorCount = 0;
|
||||
CrossbarUsage crossbarCapacity = 0;
|
||||
mlir::MLIRContext *context = nullptr;
|
||||
};
|
||||
|
||||
MergeScheduleResult runPeftScheduler(const ComputeGraph &graph, const PeftScheduleOptions &options);
|
||||
|
||||
} // namespace spatial
|
||||
} // namespace onnx_mlir
|
||||
@@ -7,7 +7,7 @@ add_pim_library(OMPimPasses
|
||||
PimCodegen/HostConstantFolding/Patterns/Subview.cpp
|
||||
PimCodegen/MaterializeHostConstantsPass.cpp
|
||||
PimCodegen/VerificationPass.cpp
|
||||
PimCodegen/EmitPimJsonPass.cpp
|
||||
PimCodegen/EmitPimCodePass.cpp
|
||||
|
||||
EXCLUDE_FROM_OM_LIBS
|
||||
|
||||
|
||||
@@ -15,6 +15,8 @@ std::unique_ptr<mlir::Pass> createSpatialToPimPass();
|
||||
|
||||
std::unique_ptr<mlir::Pass> createPimBufferizationPass();
|
||||
|
||||
std::unique_ptr<mlir::Pass> createPimStaticMemoryCoalescingPass();
|
||||
|
||||
std::unique_ptr<mlir::Pass> createMergeComputeNodesPass();
|
||||
|
||||
std::unique_ptr<mlir::Pass> createPimHostConstantFoldingPass();
|
||||
@@ -23,7 +25,7 @@ std::unique_ptr<mlir::Pass> createPimMaterializeHostConstantsPass();
|
||||
|
||||
std::unique_ptr<mlir::Pass> createPimVerificationPass();
|
||||
|
||||
std::unique_ptr<mlir::Pass> createEmitPimJsonPass();
|
||||
std::unique_ptr<mlir::Pass> createEmitPimCodePass();
|
||||
|
||||
std::unique_ptr<mlir::Pass> createMessagePass(std::string message);
|
||||
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
#include "Common/PimCommon.hpp"
|
||||
#include "Compiler/PimCodeGen.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
namespace {
|
||||
|
||||
struct EmitPimCodePass : PassWrapper<EmitPimCodePass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(EmitPimCodePass);
|
||||
StringRef getArgument() const override { return "emit-pim-code-pass"; }
|
||||
StringRef getDescription() const override { return "Emit PIM simulator code artifacts"; }
|
||||
|
||||
EmitPimCodePass() {}
|
||||
EmitPimCodePass(const EmitPimCodePass& pass) {}
|
||||
|
||||
void runOnOperation() override {
|
||||
ModuleOp moduleOp = getOperation();
|
||||
|
||||
std::string pimDir = getOutputDir() + "/pim";
|
||||
createDirectory(pimDir);
|
||||
|
||||
int compiler_error_code = compileToPimCode(moduleOp, pimDir);
|
||||
if (compiler_error_code != CompilerSuccess) {
|
||||
moduleOp.emitError() << "failed to emit PIM simulator code artifacts; compiler error code "
|
||||
<< compiler_error_code;
|
||||
signalPassFailure();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> createEmitPimCodePass() { return std::make_unique<EmitPimCodePass>(); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -1,36 +0,0 @@
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
#include "Common/PimCommon.hpp"
|
||||
#include "Compiler/PimCodeGen.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
namespace {
|
||||
|
||||
struct EmitPimJsonPass : PassWrapper<EmitPimJsonPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(EmitPimJsonPass);
|
||||
StringRef getArgument() const override { return "emit-pim-json-pass"; }
|
||||
StringRef getDescription() const override { return "Emit json code for the pim simulators"; }
|
||||
|
||||
EmitPimJsonPass() {}
|
||||
EmitPimJsonPass(const EmitPimJsonPass& pass) {}
|
||||
|
||||
void runOnOperation() override {
|
||||
ModuleOp moduleOp = getOperation();
|
||||
|
||||
std::string pimDir = getOutputDir() + "/pim";
|
||||
createDirectory(pimDir);
|
||||
|
||||
int compiler_error_code = compileToPimJson(moduleOp, pimDir);
|
||||
if (compiler_error_code != CompilerSuccess)
|
||||
signalPassFailure();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> createEmitPimJsonPass() { return std::make_unique<EmitPimJsonPass>(); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -41,30 +41,6 @@ struct DenseSubviewKeyInfo {
|
||||
|
||||
} // namespace
|
||||
|
||||
Value stripMemRefCasts(Value value) {
|
||||
while (auto castOp = value.getDefiningOp<memref::CastOp>())
|
||||
value = castOp.getSource();
|
||||
return value;
|
||||
}
|
||||
|
||||
Value stripMemRefViewOps(Value value) {
|
||||
while (true) {
|
||||
if (auto castOp = value.getDefiningOp<memref::CastOp>()) {
|
||||
value = castOp.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto collapseOp = value.getDefiningOp<memref::CollapseShapeOp>()) {
|
||||
value = collapseOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
if (auto expandOp = value.getDefiningOp<memref::ExpandShapeOp>()) {
|
||||
value = expandOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
return value;
|
||||
}
|
||||
}
|
||||
|
||||
memref::GlobalOp createFoldedGlobal(ModuleOp moduleOp,
|
||||
Location loc,
|
||||
MemRefType globalType,
|
||||
@@ -177,48 +153,4 @@ FailureOr<DenseElementsAttr> foldDenseSourceToType(ModuleOp moduleOp, Value sour
|
||||
return *denseAttr;
|
||||
}
|
||||
|
||||
FailureOr<StaticSubviewInfo> getStaticSubviewInfo(Value value) {
|
||||
value = stripMemRefViewOps(value);
|
||||
auto subviewOp = value.getDefiningOp<memref::SubViewOp>();
|
||||
if (!subviewOp)
|
||||
return failure();
|
||||
|
||||
auto source = stripMemRefCasts(subviewOp.getSource());
|
||||
auto sourceType = dyn_cast<MemRefType>(source.getType());
|
||||
auto subviewType = dyn_cast<MemRefType>(subviewOp.getType());
|
||||
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
StaticSubviewInfo info;
|
||||
info.source = source;
|
||||
info.sourceShape.assign(sourceType.getShape().begin(), sourceType.getShape().end());
|
||||
SmallVector<OpFoldResult> mixedOffsets = subviewOp.getMixedOffsets();
|
||||
info.offsets.assign(mixedOffsets.begin(), mixedOffsets.end());
|
||||
for (OpFoldResult size : subviewOp.getMixedSizes()) {
|
||||
auto staticSize = getConstantIntValue(size);
|
||||
if (!staticSize)
|
||||
return failure();
|
||||
info.sizes.push_back(*staticSize);
|
||||
}
|
||||
for (OpFoldResult stride : subviewOp.getMixedStrides()) {
|
||||
auto staticStride = getConstantIntValue(stride);
|
||||
if (!staticStride)
|
||||
return failure();
|
||||
info.strides.push_back(*staticStride);
|
||||
}
|
||||
return info;
|
||||
}
|
||||
|
||||
FailureOr<SmallVector<int64_t>> getStaticSubviewOffsets(const StaticSubviewInfo& info) {
|
||||
SmallVector<int64_t> staticOffsets;
|
||||
staticOffsets.reserve(info.offsets.size());
|
||||
for (OpFoldResult offset : info.offsets) {
|
||||
auto staticOffset = getConstantIntValue(offset);
|
||||
if (!staticOffset)
|
||||
return failure();
|
||||
staticOffsets.push_back(*staticOffset);
|
||||
}
|
||||
return staticOffsets;
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -6,23 +6,12 @@
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
struct StaticSubviewInfo {
|
||||
mlir::Value source;
|
||||
llvm::SmallVector<int64_t> sourceShape;
|
||||
llvm::SmallVector<mlir::OpFoldResult> offsets;
|
||||
llvm::SmallVector<int64_t> sizes;
|
||||
llvm::SmallVector<int64_t> strides;
|
||||
};
|
||||
|
||||
mlir::Value stripMemRefCasts(mlir::Value value);
|
||||
|
||||
mlir::Value stripMemRefViewOps(mlir::Value value);
|
||||
|
||||
mlir::memref::GlobalOp createFoldedGlobal(mlir::ModuleOp moduleOp,
|
||||
mlir::Location loc,
|
||||
mlir::MemRefType globalType,
|
||||
@@ -39,9 +28,4 @@ llvm::FailureOr<mlir::DenseElementsAttr> getDenseGlobalValue(mlir::ModuleOp modu
|
||||
llvm::FailureOr<mlir::DenseElementsAttr>
|
||||
foldDenseSourceToType(mlir::ModuleOp moduleOp, mlir::Value source, mlir::MemRefType resultType);
|
||||
|
||||
llvm::FailureOr<StaticSubviewInfo> getStaticSubviewInfo(mlir::Value value);
|
||||
|
||||
/// Returns the offsets in `info` as int64_t, failing if any offset is dynamic.
|
||||
llvm::FailureOr<llvm::SmallVector<int64_t>> getStaticSubviewOffsets(const StaticSubviewInfo& info);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -32,14 +32,16 @@ struct HostConstantFoldingPass : PassWrapper<HostConstantFoldingPass, OperationP
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
ModuleOp moduleOp = getOperation();
|
||||
GreedyRewriteConfig config;
|
||||
config.enableFolding();
|
||||
if (failed(applyPatternsGreedily(getOperation(), *patterns, config))) {
|
||||
if (failed(applyPatternsGreedily(moduleOp, *patterns, config))) {
|
||||
moduleOp.emitError("PIM host constant folding failed in the greedy rewrite driver");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
dumpModule(getOperation(), "pim2_folded");
|
||||
dumpModule(moduleOp, "pim3_folded");
|
||||
}
|
||||
|
||||
std::shared_ptr<const FrozenRewritePatternSet> patterns;
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user