Compare commits
46 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 7f3c7464b4 | |||
| c77ffa9c56 | |||
| 495186503c | |||
| 2c1da813b5 | |||
| 8337a11ce9 | |||
| d136136d22 | |||
| 074eb183c7 | |||
| 43ed3914b8 | |||
| 6aaf1c0870 | |||
| fe35b3ed43 | |||
| 90a9339686 | |||
| a50e77ff38 | |||
| f56c4159b5 | |||
| 5637c861b4 | |||
| 94157a8404 | |||
| 68a3521978 | |||
| a103ba328b | |||
| e263e05f56 | |||
| 34c29fdec4 | |||
| aa088e2ba5 | |||
| 2836e759ab | |||
| 8071ebab0b | |||
| f1602c0550 | |||
| de0a2f4561 | |||
| 1c4a5bde76 | |||
| 78242e2887 | |||
| fe244d5aa1 | |||
| d09e76c8f9 | |||
| c5e608fa5b | |||
| 43f3ccdd21 | |||
| 8d95c604a6 | |||
| 55eda487dc | |||
| 061139aefb | |||
| ea61540e08 | |||
| 324178cba8 | |||
| e71ba07cd5 | |||
| 64a3805619 | |||
| 9f9e7c0892 | |||
| 03eab42971 | |||
| c15aba5d96 | |||
| 4821e8a55e | |||
| 88bb223bb1 | |||
| 623ee62a04 | |||
| ad56888b0b | |||
| f993840641 | |||
| 0c7db55a24 |
+1
-1
@@ -3,4 +3,4 @@
|
||||
url = https://github.com/onnx/onnx-mlir.git
|
||||
[submodule "backend-simulators/pim/pimsim-nn"]
|
||||
path = backend-simulators/pim/pimsim-nn
|
||||
url = https://github.com/wangxy-2000/pimsim-nn.git
|
||||
url = https://github.com/HEAPLab/pimsim-nn.git
|
||||
|
||||
+92
-24
@@ -3,31 +3,99 @@ cmake_minimum_required(VERSION 3.20.0)
|
||||
|
||||
project(raptor)
|
||||
|
||||
# Add symlink to PIM as accelerator in onnx-mlir
|
||||
function(raptor_ensure_symlink link_path target_path)
|
||||
get_filename_component(link_parent "${link_path}" DIRECTORY)
|
||||
# Materialize a CMake shim directory
|
||||
function(raptor_write_external_cmake_shim shim_dir external_source_dir description)
|
||||
get_filename_component(real_external_source_dir "${external_source_dir}" REALPATH)
|
||||
file(RELATIVE_PATH relative_external_source_dir "${shim_dir}" "${real_external_source_dir}")
|
||||
|
||||
if(NOT EXISTS "${link_parent}")
|
||||
message(FATAL_ERROR "Directory not found: ${link_parent}")
|
||||
endif()
|
||||
|
||||
if(NOT EXISTS "${link_path}")
|
||||
message(STATUS "Creating symlink ${link_path} -> ${target_path}")
|
||||
file(CREATE_LINK
|
||||
"${target_path}"
|
||||
"${link_path}"
|
||||
SYMBOLIC
|
||||
if (NOT EXISTS "${real_external_source_dir}/CMakeLists.txt")
|
||||
message(FATAL_ERROR
|
||||
"External CMake source directory not found or missing CMakeLists.txt:\n"
|
||||
" ${real_external_source_dir}"
|
||||
)
|
||||
endif()
|
||||
endif ()
|
||||
|
||||
if (IS_SYMLINK "${shim_dir}")
|
||||
message(STATUS "Removing old full-directory symlink: ${shim_dir}")
|
||||
file(REMOVE "${shim_dir}")
|
||||
endif ()
|
||||
|
||||
if (EXISTS "${shim_dir}" AND NOT IS_DIRECTORY "${shim_dir}")
|
||||
message(FATAL_ERROR "Expected directory or absent path, got file: ${shim_dir}")
|
||||
endif ()
|
||||
|
||||
file(MAKE_DIRECTORY "${shim_dir}")
|
||||
|
||||
set(shim_file "${shim_dir}/CMakeLists.txt")
|
||||
set(shim_contents
|
||||
"get_filename_component(raptor_external_source_dir
|
||||
\"\${CMAKE_CURRENT_LIST_DIR}/${relative_external_source_dir}\"
|
||||
REALPATH
|
||||
)
|
||||
add_subdirectory(
|
||||
\"\${raptor_external_source_dir}\"
|
||||
\"\${CMAKE_CURRENT_BINARY_DIR}/raptor-external\"
|
||||
)
|
||||
if (DEFINED PIM_ENABLED)
|
||||
set(PIM_ENABLED \"\${PIM_ENABLED}\" PARENT_SCOPE)
|
||||
endif ()
|
||||
"
|
||||
)
|
||||
|
||||
if (EXISTS "${shim_file}")
|
||||
file(READ "${shim_file}" old_contents)
|
||||
else ()
|
||||
set(old_contents "")
|
||||
endif ()
|
||||
|
||||
if (NOT old_contents STREQUAL shim_contents)
|
||||
file(WRITE "${shim_file}" "${shim_contents}")
|
||||
message(STATUS "Wrote CMake shim for ${description}: ${shim_file}")
|
||||
else ()
|
||||
message(STATUS "CMake shim already up to date for ${description}")
|
||||
endif ()
|
||||
|
||||
# Mirror the external tree's first-level entries into the shim directory
|
||||
# so legacy includes like src/Accelerators/PIM/Compiler/... keep working.
|
||||
file(GLOB children RELATIVE "${real_external_source_dir}" "${real_external_source_dir}/*")
|
||||
|
||||
foreach (child IN LISTS children)
|
||||
if (child STREQUAL "CMakeLists.txt")
|
||||
continue()
|
||||
endif ()
|
||||
|
||||
set(real_child "${real_external_source_dir}/${child}")
|
||||
set(shim_child "${shim_dir}/${child}")
|
||||
|
||||
if (IS_SYMLINK "${shim_child}")
|
||||
file(READ_SYMLINK "${shim_child}" existing_link_target)
|
||||
if (existing_link_target STREQUAL real_child)
|
||||
continue()
|
||||
endif ()
|
||||
file(REMOVE_RECURSE "${shim_child}")
|
||||
elseif (EXISTS "${shim_child}")
|
||||
# Do not delete real files/directories. This protects the generated shim.
|
||||
continue()
|
||||
endif ()
|
||||
|
||||
file(CREATE_LINK
|
||||
"${real_child}"
|
||||
"${shim_child}"
|
||||
SYMBOLIC
|
||||
)
|
||||
endforeach ()
|
||||
endfunction()
|
||||
|
||||
raptor_ensure_symlink(
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/onnx-mlir/src/Accelerators/PIM"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/src/PIM"
|
||||
raptor_write_external_cmake_shim(
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/onnx-mlir/src/Accelerators/PIM"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/src/PIM"
|
||||
"PIM accelerator"
|
||||
)
|
||||
raptor_ensure_symlink(
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/onnx-mlir/test/accelerators/PIM"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/test/PIM"
|
||||
|
||||
raptor_write_external_cmake_shim(
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/onnx-mlir/test/accelerators/PIM"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/test/PIM"
|
||||
"PIM accelerator tests"
|
||||
)
|
||||
|
||||
# Patch onnx-mlir sources for PIM accelerator support.
|
||||
@@ -38,21 +106,21 @@ function(raptor_apply_patch file_path anchor replacement description)
|
||||
|
||||
# Already applied – replacement text is present
|
||||
string(FIND "${contents}" "${replacement}" already_applied_pos)
|
||||
if(NOT already_applied_pos EQUAL -1)
|
||||
if (NOT already_applied_pos EQUAL -1)
|
||||
message(STATUS "Patch already applied: ${description}")
|
||||
return()
|
||||
endif()
|
||||
endif ()
|
||||
|
||||
# Anchor must exist for the patch to be applicable
|
||||
string(FIND "${contents}" "${anchor}" anchor_pos)
|
||||
if(anchor_pos EQUAL -1)
|
||||
if (anchor_pos EQUAL -1)
|
||||
message(FATAL_ERROR
|
||||
"Patch anchor not found – onnx-mlir may have changed.\n"
|
||||
" Patch : ${description}\n"
|
||||
" File : ${file_path}\n"
|
||||
" Anchor: ${anchor}"
|
||||
)
|
||||
endif()
|
||||
endif ()
|
||||
|
||||
string(REPLACE "${anchor}" "${replacement}" patched "${contents}")
|
||||
file(WRITE "${file_path}" "${patched}")
|
||||
|
||||
@@ -114,7 +114,9 @@ Pass these on the `onnx-mlir` command line when compiling for PIM:
|
||||
run only the codegen tail.
|
||||
- `--crossbar-size=<N>` / `--crossbar-count=<N>` — crossbar dimensions and
|
||||
per-core count.
|
||||
- `--core-count=<N>` — number of cores (`-1` picks the minimum).
|
||||
- `--core-count=<N>` — number of cores. Required for PIM compilation.
|
||||
- `--pim-merge-scheduler={peft,dcp}` — scheduler used by the Spatial
|
||||
merge-compute-nodes pass (default: `peft`).
|
||||
- `--dcp-critical-window-size=<N>` — DCP coarsening window (0 = legacy).
|
||||
- `--use-experimental-conv-impl` — alternative convolution lowering.
|
||||
- `--ignore-concat-error` — soft-fail corner case in `ConcatOp`.
|
||||
@@ -129,7 +131,8 @@ Per-operation validation (from `validation/`):
|
||||
```
|
||||
validate.py \
|
||||
--raptor-path ../cmake-build-release/Release/bin/onnx-mlir \
|
||||
--onnx-include-dir ../onnx-mlir/include
|
||||
--onnx-include-dir ../onnx-mlir/include \
|
||||
--core-count 1000
|
||||
```
|
||||
|
||||
End-to-end network validation (example: first 4 layers of YOLOv11n):
|
||||
@@ -142,6 +145,46 @@ validate.py \
|
||||
--crossbar-size 2048 --crossbar-count 256 --core-count 1000
|
||||
```
|
||||
|
||||
Each validation run writes debugging artifacts into the benchmark's workspace
|
||||
directory (for example `validation/operations/gemm/small/`):
|
||||
- `inputs/` — generated input CSVs used for the run.
|
||||
- `outputs/` — reference outputs dumped by the native ONNX runner.
|
||||
- `raptor/` — compiler artifacts:
|
||||
`*.onnx.mlir`, `dialects/spatial0.mlir`, `dialects/spatial1_dcp_merged.mlir`,
|
||||
`dialects/pim0.mlir`, `dialects/pim1_buff.mlir`, `dialects/pim2_coalesced.mlir`,
|
||||
`dialects/pim3_folded.mlir`, `dialects/pim4_materialized.mlir`,
|
||||
`pim/config.json`, `pim/core_*.pim`, `pim/memory.bin`, and reports under
|
||||
`raptor/reports/` such as `dcp_merge_report.txt`,
|
||||
`memory_report.txt`, and `static_memory_coalescing_report.txt`.
|
||||
- `runner/` — generated reference runner source, build tree, and shared library.
|
||||
- `simulation/out.bin` — raw simulator output dump used for output comparison.
|
||||
|
||||
That means you usually do not need to rerun standalone `--EmitSpatial` or
|
||||
`--EmitPim` commands while debugging validation failures: the per-pass dialect
|
||||
dumps are already available under `raptor/dialects/`.
|
||||
|
||||
The validator does not currently expose a simulator tracing flag, but once a
|
||||
validation has produced `raptor/pim/` you can rerun the simulator manually with
|
||||
tracing enabled:
|
||||
|
||||
```bash
|
||||
cd backend-simulators/pim/pim-simulator
|
||||
cargo run --no-default-features --features tracing --release \
|
||||
--package pim-simulator --bin pim-simulator -- \
|
||||
-f /path/to/workspace/raptor/pim \
|
||||
-o /path/to/workspace/simulation/out.bin \
|
||||
-d <addr0>,<size0>,<addr1>,<size1>,...
|
||||
```
|
||||
|
||||
With `--features tracing`, the simulator writes per-core traces as
|
||||
`simulation/TraceCore0`, `simulation/TraceCore1`, ... next to `simulation/out.bin`.
|
||||
The validator normally computes the `-d` dump ranges from `raptor/pim/config.json`
|
||||
and the model output shapes. If you need a clean slate before rerunning, use:
|
||||
|
||||
```bash
|
||||
validate.py --clean
|
||||
```
|
||||
|
||||
Available networks under `validation/networks/`: `vgg16`, `yolo11n`.
|
||||
Available operations under `validation/operations/`: `add`, `conv`, `div`,
|
||||
`gather`, `gemm`, `gemv`, `mul`, `pool`, `reduce_mean`, `relu`, `resize`,
|
||||
|
||||
+19
@@ -1030,6 +1030,15 @@ version = "0.2.16"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b6d2cec3eae94f9f509c767b45932f1ada8350c4bdb85af2fcab4a3c14807981"
|
||||
|
||||
[[package]]
|
||||
name = "libmimalloc-sys"
|
||||
version = "0.1.47"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2d1eacfa31c33ec25e873c136ba5669f00f9866d0688bea7be4d3f7e43067df6"
|
||||
dependencies = [
|
||||
"cc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "linux-raw-sys"
|
||||
version = "0.12.1"
|
||||
@@ -1095,6 +1104,15 @@ version = "2.8.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79"
|
||||
|
||||
[[package]]
|
||||
name = "mimalloc"
|
||||
version = "0.1.50"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b3627c4272df786b9260cabaa46aec1d59c93ede723d4c3ef646c503816b0640"
|
||||
dependencies = [
|
||||
"libmimalloc-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mime"
|
||||
version = "0.3.17"
|
||||
@@ -1414,6 +1432,7 @@ dependencies = [
|
||||
"faer-traits",
|
||||
"glob",
|
||||
"hex",
|
||||
"mimalloc",
|
||||
"paste",
|
||||
"plotly",
|
||||
"rayon",
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
[package]
|
||||
name = "pim-simulator"
|
||||
version = "0.1.0"
|
||||
@@ -34,3 +33,4 @@ plotly = {version="0.8", optional=true}
|
||||
rayon = "1.12.0"
|
||||
faer = "0.24.0"
|
||||
faer-traits = "0.24.0"
|
||||
mimalloc = "0.1.50"
|
||||
|
||||
@@ -1,14 +1,20 @@
|
||||
use mimalloc::MiMalloc;
|
||||
|
||||
#[global_allocator]
|
||||
static GLOBAL: MiMalloc = MiMalloc;
|
||||
|
||||
use anyhow::{Context, Result, bail};
|
||||
use clap::Parser;
|
||||
use glob::glob;
|
||||
use pimcore::binary_to_instruction::binary_to_executor;
|
||||
use pimcore::cpu::crossbar::Crossbar;
|
||||
use pimcore::json_to_instruction::json_to_executor;
|
||||
use pimcore::memory_manager::CoreMemory;
|
||||
use pimcore::tracing::TRACER;
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use std::fs::{self, read_link};
|
||||
use std::io::Write;
|
||||
use std::fs::{self, File, read_link};
|
||||
use std::io::{BufReader, Write};
|
||||
use std::path::PathBuf;
|
||||
|
||||
/// Program to simulate core execution configuration
|
||||
@@ -44,18 +50,24 @@ fn main() -> Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
let config_json = retrive_config(&args)?;
|
||||
let core_jsons = retrive_cores(&args)?;
|
||||
let mut core_inputs = retrive_cores(&args)?;
|
||||
let memory = retrive_memory(&args)?;
|
||||
let global_crossbars = get_crossbars(&config_json, &args).unwrap();
|
||||
let crossbars = map_crossbars_to_cores(&config_json, &args, &global_crossbars);
|
||||
let mut executor =
|
||||
json_to_executor::json_to_executor(config_json, core_jsons.iter(), crossbars);
|
||||
let mut executor = match &mut core_inputs {
|
||||
CoreInputs::Json(core_jsons) => {
|
||||
json_to_executor::json_to_executor(config_json, core_jsons, crossbars)
|
||||
}
|
||||
CoreInputs::Binary(core_bins) => {
|
||||
binary_to_executor(config_json, core_bins.iter(), crossbars)?
|
||||
}
|
||||
};
|
||||
set_memory(&mut executor, memory);
|
||||
TRACER
|
||||
.lock()
|
||||
.unwrap()
|
||||
.init(executor.cpu().num_core(), args.output.clone());
|
||||
executor.execute();
|
||||
executor.execute()?;
|
||||
dump_memory(executor, &args)?;
|
||||
Ok(())
|
||||
}
|
||||
@@ -65,7 +77,7 @@ fn map_crossbars_to_cores<'c>(
|
||||
args: &Args,
|
||||
global_crossbars: &'c HashMap<String, Crossbar>,
|
||||
) -> Vec<Vec<&'c Crossbar>> {
|
||||
let mut res = Vec::new();
|
||||
let mut res = vec![Vec::new()];
|
||||
let num_cores = config.get("core_cnt").unwrap().as_i64().unwrap() as i32;
|
||||
|
||||
if let Some(folder) = args.folder.as_ref() {
|
||||
@@ -140,8 +152,7 @@ fn get_crossbars(config: &Value, args: &Args) -> anyhow::Result<HashMap<String,
|
||||
}
|
||||
|
||||
let bytes = std::fs::read(weight_file.path()).expect("Failed to read binary file");
|
||||
let mut crossbar =
|
||||
Crossbar::new(column_corssbar * 4, rows_crossbar, CoreMemory::new());
|
||||
let mut crossbar = Crossbar::new(column_corssbar * 4, rows_crossbar, CoreMemory::new());
|
||||
crossbar.execute_store(&bytes).unwrap();
|
||||
res.insert(
|
||||
weight_file
|
||||
@@ -214,45 +225,82 @@ fn retrive_memory(args: &Args) -> Result<Vec<u8>> {
|
||||
Ok(memory_vector)
|
||||
}
|
||||
|
||||
fn retrive_cores(args: &Args) -> Result<Vec<Value>, anyhow::Error> {
|
||||
let mut core_jsons: Vec<Value> = Vec::new();
|
||||
if let Some(cores_override) = &args.cores {
|
||||
for core in cores_override {
|
||||
let content = fs::read_to_string(core)
|
||||
.with_context(|| format!("Failed to read core file: {:?}", cores_override))?;
|
||||
let json: Value =
|
||||
serde_json::from_str(&content).context("Failed to parse core json override")?;
|
||||
core_jsons.push(json);
|
||||
}
|
||||
} else if let Some(folder) = args.folder.as_ref() {
|
||||
let pattern = folder.join("core*.json");
|
||||
let pattern_str = pattern.to_str().context("Invalid path encoding")?;
|
||||
let mut paths: Vec<_> = glob(pattern_str)?.map(|x| x.unwrap()).collect();
|
||||
paths.sort_by_cached_key(|x| {
|
||||
let mut x = x
|
||||
.file_stem()
|
||||
.expect("Extracting the stem")
|
||||
.to_str()
|
||||
.expect("File not utf-8");
|
||||
x = &x[5..];
|
||||
x.parse::<i32>().unwrap()
|
||||
});
|
||||
enum CoreInputs {
|
||||
Json(Vec<BufReader<File>>),
|
||||
Binary(Vec<Vec<u8>>),
|
||||
}
|
||||
|
||||
if paths.is_empty() {
|
||||
bail!("No core*.json files found in {:?}", folder);
|
||||
fn retrive_cores(args: &Args) -> Result<CoreInputs, anyhow::Error> {
|
||||
if let Some(cores_override) = &args.cores {
|
||||
let first_extension = cores_override
|
||||
.first()
|
||||
.and_then(|path| path.extension())
|
||||
.and_then(|ext| ext.to_str())
|
||||
.unwrap_or_default();
|
||||
if first_extension == "pim" {
|
||||
let mut core_bins = Vec::with_capacity(cores_override.len());
|
||||
for core in cores_override {
|
||||
core_bins.push(
|
||||
fs::read(core)
|
||||
.with_context(|| format!("Failed to read binary core file: {:?}", core))?,
|
||||
);
|
||||
}
|
||||
return Ok(CoreInputs::Binary(core_bins));
|
||||
}
|
||||
for entry in paths {
|
||||
let path = entry;
|
||||
let content = fs::read_to_string(&path)
|
||||
.with_context(|| format!("Failed to read core file: {:?}", path))?;
|
||||
let json: Value = serde_json::from_str(&content)
|
||||
.with_context(|| format!("Failed to parse JSON in {:?}", path))?;
|
||||
core_jsons.push(json);
|
||||
let mut core_jsons_reader: Vec<BufReader<File>> = Vec::with_capacity(cores_override.len());
|
||||
for core in cores_override {
|
||||
let file = File::open(core)?;
|
||||
let reader = BufReader::new(file);
|
||||
core_jsons_reader.push(reader);
|
||||
}
|
||||
} else {
|
||||
bail!("Either --core or --folder must be provided to find core definitions.");
|
||||
return Ok(CoreInputs::Json(core_jsons_reader));
|
||||
}
|
||||
Ok(core_jsons)
|
||||
|
||||
if let Some(folder) = args.folder.as_ref() {
|
||||
let binary_pattern = folder.join("core*.pim");
|
||||
let binary_pattern_str = binary_pattern.to_str().context("Invalid path encoding")?;
|
||||
let mut binary_paths: Vec<_> = glob(binary_pattern_str)?.map(|x| x.unwrap()).collect();
|
||||
binary_paths.sort_by_cached_key(core_sort_key);
|
||||
if !binary_paths.is_empty() {
|
||||
let mut core_bins = Vec::with_capacity(binary_paths.len());
|
||||
for path in binary_paths {
|
||||
core_bins.push(
|
||||
fs::read(&path)
|
||||
.with_context(|| format!("Failed to read core file: {:?}", path))?,
|
||||
);
|
||||
}
|
||||
return Ok(CoreInputs::Binary(core_bins));
|
||||
}
|
||||
|
||||
let json_pattern = folder.join("core*.json");
|
||||
let json_pattern_str = json_pattern.to_str().context("Invalid path encoding")?;
|
||||
let mut json_paths: Vec<_> = glob(json_pattern_str)?.map(|x| x.unwrap()).collect();
|
||||
json_paths.sort_by_cached_key(core_sort_key);
|
||||
|
||||
if json_paths.is_empty() {
|
||||
bail!("No core*.pim or core*.json files found in {:?}", folder);
|
||||
}
|
||||
|
||||
let mut core_json_reader: Vec<BufReader<File>> = Vec::with_capacity(json_paths.len());
|
||||
for path in json_paths {
|
||||
let file = File::open(path)?;
|
||||
let reader = BufReader::new(file);
|
||||
core_json_reader.push(reader);
|
||||
}
|
||||
return Ok(CoreInputs::Json(core_json_reader));
|
||||
}
|
||||
|
||||
bail!("Either --core or --folder must be provided to find core definitions.");
|
||||
}
|
||||
|
||||
fn core_sort_key(path: &PathBuf) -> i32 {
|
||||
let mut stem = path
|
||||
.file_stem()
|
||||
.expect("Extracting the stem")
|
||||
.to_str()
|
||||
.expect("File not utf-8");
|
||||
stem = &stem[5..];
|
||||
stem.parse::<i32>().unwrap()
|
||||
}
|
||||
|
||||
fn retrive_config(args: &Args) -> Result<Value, anyhow::Error> {
|
||||
|
||||
@@ -0,0 +1,497 @@
|
||||
use crate::{
|
||||
CoreInstructionsBuilder, Executable,
|
||||
cpu::{CPU, crossbar::Crossbar},
|
||||
instruction_set::{InstructionsBuilder, instruction_data::InstructionDataBuilder, isa::*},
|
||||
};
|
||||
use anyhow::{Context, Result, bail, ensure};
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use std::convert::TryFrom;
|
||||
use std::sync::LazyLock;
|
||||
|
||||
const MAGIC: &[u8; 4] = b"PIMB";
|
||||
const VERSION: u32 = 1;
|
||||
const HEADER_SIZE: usize = 12;
|
||||
const RECORD_SIZE: usize = 20;
|
||||
|
||||
macro_rules! add_name {
|
||||
($storage:ident, $opcode:literal, $name:literal) => {
|
||||
$storage.insert($opcode, $name);
|
||||
};
|
||||
}
|
||||
|
||||
static INSTRUCTIONS: LazyLock<HashMap<usize, &'static str>> = LazyLock::new(|| {
|
||||
let mut hash = HashMap::new();
|
||||
add_name!(hash, 0, "nop");
|
||||
add_name!(hash, 1, "sldi");
|
||||
add_name!(hash, 2, "sld");
|
||||
add_name!(hash, 3, "sadd");
|
||||
add_name!(hash, 4, "ssub");
|
||||
add_name!(hash, 5, "smul");
|
||||
add_name!(hash, 6, "saddi");
|
||||
add_name!(hash, 7, "smuli");
|
||||
add_name!(hash, 8, "setbw");
|
||||
add_name!(hash, 9, "mvmul");
|
||||
add_name!(hash, 10, "vvadd");
|
||||
add_name!(hash, 11, "vvsub");
|
||||
add_name!(hash, 12, "vvmul");
|
||||
add_name!(hash, 13, "vvdmul");
|
||||
add_name!(hash, 14, "vvmax");
|
||||
add_name!(hash, 15, "vvsll");
|
||||
add_name!(hash, 16, "vvsra");
|
||||
add_name!(hash, 17, "vavg");
|
||||
add_name!(hash, 18, "vrelu");
|
||||
add_name!(hash, 19, "vtanh");
|
||||
add_name!(hash, 20, "vsigm");
|
||||
add_name!(hash, 21, "vsoftmax");
|
||||
add_name!(hash, 22, "vmv");
|
||||
add_name!(hash, 23, "vrsu");
|
||||
add_name!(hash, 24, "vrsl");
|
||||
add_name!(hash, 25, "ld");
|
||||
add_name!(hash, 26, "st");
|
||||
add_name!(hash, 27, "lldi");
|
||||
add_name!(hash, 28, "lmv");
|
||||
add_name!(hash, 29, "send");
|
||||
add_name!(hash, 30, "recv");
|
||||
add_name!(hash, 31, "wait");
|
||||
add_name!(hash, 32, "sync");
|
||||
hash
|
||||
});
|
||||
|
||||
#[derive(Clone, Copy, Debug, Default)]
|
||||
struct InstructionRecord {
|
||||
opcode: u8,
|
||||
rd: u8,
|
||||
r1: u8,
|
||||
r2_or_imm: i32,
|
||||
generic1: i32,
|
||||
generic2: i32,
|
||||
generic3: i32,
|
||||
flags: u8,
|
||||
}
|
||||
|
||||
fn read_u32_le(bytes: &[u8], offset: usize) -> u32 {
|
||||
u32::from_le_bytes(bytes[offset..offset + 4].try_into().unwrap())
|
||||
}
|
||||
|
||||
fn read_i32_le(bytes: &[u8], offset: usize) -> i32 {
|
||||
i32::from_le_bytes(bytes[offset..offset + 4].try_into().unwrap())
|
||||
}
|
||||
|
||||
fn parse_binary_records(bytes: &[u8]) -> Result<Vec<InstructionRecord>> {
|
||||
ensure!(bytes.len() >= HEADER_SIZE, "binary core file too small");
|
||||
ensure!(&bytes[0..4] == MAGIC, "invalid PIM binary magic");
|
||||
|
||||
let version = read_u32_le(bytes, 4);
|
||||
ensure!(
|
||||
version == VERSION,
|
||||
"unsupported PIM binary version {version}"
|
||||
);
|
||||
|
||||
let instruction_count = read_u32_le(bytes, 8) as usize;
|
||||
let expected_len = HEADER_SIZE + instruction_count * RECORD_SIZE;
|
||||
ensure!(
|
||||
bytes.len() == expected_len,
|
||||
"PIM binary size mismatch: expected {expected_len} bytes, got {}",
|
||||
bytes.len()
|
||||
);
|
||||
|
||||
let mut records = Vec::with_capacity(instruction_count);
|
||||
for index in 0..instruction_count {
|
||||
let base = HEADER_SIZE + index * RECORD_SIZE;
|
||||
records.push(InstructionRecord {
|
||||
opcode: bytes[base],
|
||||
rd: bytes[base + 1],
|
||||
r1: bytes[base + 2],
|
||||
flags: bytes[base + 3],
|
||||
r2_or_imm: read_i32_le(bytes, base + 4),
|
||||
generic1: read_i32_le(bytes, base + 8),
|
||||
generic2: read_i32_le(bytes, base + 12),
|
||||
generic3: read_i32_le(bytes, base + 16),
|
||||
});
|
||||
}
|
||||
|
||||
Ok(records)
|
||||
}
|
||||
|
||||
fn append_record(
|
||||
inst_builder: &mut InstructionsBuilder,
|
||||
inst_data_builder: &mut InstructionDataBuilder,
|
||||
record: InstructionRecord,
|
||||
) -> Result<()> {
|
||||
let InstructionRecord {
|
||||
opcode,
|
||||
rd,
|
||||
r1,
|
||||
r2_or_imm,
|
||||
generic1,
|
||||
generic2,
|
||||
generic3,
|
||||
flags: _,
|
||||
} = record;
|
||||
|
||||
match opcode {
|
||||
0 => {}
|
||||
1 => {
|
||||
inst_data_builder.set_rd_u8(rd).set_imm(r2_or_imm);
|
||||
inst_builder.make_inst(sldi, inst_data_builder.build());
|
||||
}
|
||||
2 => {
|
||||
inst_data_builder
|
||||
.set_rd_u8(rd)
|
||||
.set_r1_u8(r1)
|
||||
.set_offset_select(generic1)
|
||||
.set_offset_value(generic2);
|
||||
inst_builder.make_inst(sld, inst_data_builder.build());
|
||||
}
|
||||
3 => {
|
||||
inst_data_builder.set_rdr1r2_u8(rd, r1, r2_or_imm);
|
||||
inst_builder.make_inst(sadd, inst_data_builder.build());
|
||||
}
|
||||
4 => {
|
||||
inst_data_builder.set_rdr1r2_u8(rd, r1, r2_or_imm);
|
||||
inst_builder.make_inst(ssub, inst_data_builder.build());
|
||||
}
|
||||
5 => {
|
||||
inst_data_builder.set_rdr1r2_u8(rd, r1, r2_or_imm);
|
||||
inst_builder.make_inst(smul, inst_data_builder.build());
|
||||
}
|
||||
6 => {
|
||||
inst_data_builder.set_rdr1imm_u8(rd, r1, r2_or_imm);
|
||||
inst_builder.make_inst(saddi, inst_data_builder.build());
|
||||
}
|
||||
7 => {
|
||||
inst_data_builder.set_rdr1imm_u8(rd, r1, r2_or_imm);
|
||||
inst_builder.make_inst(smuli, inst_data_builder.build());
|
||||
}
|
||||
8 => {
|
||||
inst_data_builder.set_ibiw_obiw(generic1, generic2);
|
||||
inst_builder.make_inst(setbw, inst_data_builder.build());
|
||||
}
|
||||
9 => {
|
||||
inst_data_builder
|
||||
.set_rd_u8(rd)
|
||||
.set_r1_u8(r1)
|
||||
.set_mbiw_immrelu_immgroup(r2_or_imm, generic1, generic2);
|
||||
inst_builder.make_inst(mvmul, inst_data_builder.build());
|
||||
}
|
||||
10 => {
|
||||
inst_data_builder
|
||||
.set_rdr1r2_u8(rd, r1, r2_or_imm)
|
||||
.set_imm_len(generic3)
|
||||
.set_offset_select_value(generic1, generic2);
|
||||
inst_builder.make_inst(vvadd, inst_data_builder.build());
|
||||
}
|
||||
11 => {
|
||||
inst_data_builder
|
||||
.set_rdr1r2_u8(rd, r1, r2_or_imm)
|
||||
.set_imm_len(generic3)
|
||||
.set_offset_select_value(generic1, generic2);
|
||||
inst_builder.make_inst(vvsub, inst_data_builder.build());
|
||||
}
|
||||
12 => {
|
||||
inst_data_builder
|
||||
.set_rdr1r2_u8(rd, r1, r2_or_imm)
|
||||
.set_imm_len(generic3)
|
||||
.set_offset_select_value(generic1, generic2);
|
||||
inst_builder.make_inst(vvmul, inst_data_builder.build());
|
||||
}
|
||||
13 => {
|
||||
inst_data_builder
|
||||
.set_rdr1r2_u8(rd, r1, r2_or_imm)
|
||||
.set_imm_len(generic3)
|
||||
.set_offset_select_value(generic1, generic2);
|
||||
inst_builder.make_inst(vvdmul, inst_data_builder.build());
|
||||
}
|
||||
14 => {
|
||||
inst_data_builder
|
||||
.set_rdr1r2_u8(rd, r1, r2_or_imm)
|
||||
.set_imm_len(generic3)
|
||||
.set_offset_select_value(generic1, generic2);
|
||||
inst_builder.make_inst(vvmax, inst_data_builder.build());
|
||||
}
|
||||
15 => {
|
||||
inst_data_builder
|
||||
.set_rdr1r2_u8(rd, r1, r2_or_imm)
|
||||
.set_imm_len(generic3)
|
||||
.set_offset_select_value(generic1, generic2);
|
||||
inst_builder.make_inst(vvsll, inst_data_builder.build());
|
||||
}
|
||||
16 => {
|
||||
inst_data_builder
|
||||
.set_rdr1r2_u8(rd, r1, r2_or_imm)
|
||||
.set_imm_len(generic3)
|
||||
.set_offset_select_value(generic1, generic2);
|
||||
inst_builder.make_inst(vvsra, inst_data_builder.build());
|
||||
}
|
||||
17 => {
|
||||
inst_data_builder
|
||||
.set_rdr1r2_u8(rd, r1, r2_or_imm)
|
||||
.set_imm_len(generic3)
|
||||
.set_offset_select_value(generic1, generic2);
|
||||
inst_builder.make_inst(vavg, inst_data_builder.build());
|
||||
}
|
||||
18 => {
|
||||
inst_data_builder
|
||||
.set_rdr1_u8(rd, r1)
|
||||
.set_imm_len(generic3)
|
||||
.set_offset_select_value(generic1, generic2);
|
||||
inst_builder.make_inst(vrelu, inst_data_builder.build());
|
||||
}
|
||||
19 => {
|
||||
inst_data_builder
|
||||
.set_rdr1_u8(rd, r1)
|
||||
.set_imm_len(generic3)
|
||||
.set_offset_select_value(generic1, generic2);
|
||||
inst_builder.make_inst(vtanh, inst_data_builder.build());
|
||||
}
|
||||
20 => {
|
||||
inst_data_builder
|
||||
.set_rdr1_u8(rd, r1)
|
||||
.set_imm_len(generic3)
|
||||
.set_offset_select_value(generic1, generic2);
|
||||
inst_builder.make_inst(vsigm, inst_data_builder.build());
|
||||
}
|
||||
21 => {
|
||||
inst_data_builder
|
||||
.set_rdr1_u8(rd, r1)
|
||||
.set_imm_len(generic3)
|
||||
.set_offset_select_value(generic1, generic2);
|
||||
inst_builder.make_inst(vsoftmax, inst_data_builder.build());
|
||||
}
|
||||
22 => {
|
||||
inst_data_builder
|
||||
.set_rdr1r2_u8(rd, r1, r2_or_imm)
|
||||
.set_imm_len(generic3)
|
||||
.set_offset_select_value(generic1, generic2);
|
||||
inst_builder.make_inst(vmv, inst_data_builder.build());
|
||||
}
|
||||
23 => {
|
||||
inst_data_builder
|
||||
.set_rdr1r2_u8(rd, r1, r2_or_imm)
|
||||
.set_imm_len(generic3)
|
||||
.set_offset_select_value(generic1, generic2);
|
||||
inst_builder.make_inst(vrsu, inst_data_builder.build());
|
||||
}
|
||||
24 => {
|
||||
inst_data_builder
|
||||
.set_rdr1r2_u8(rd, r1, r2_or_imm)
|
||||
.set_imm_len(generic3)
|
||||
.set_offset_select_value(generic1, generic2);
|
||||
inst_builder.make_inst(vrsl, inst_data_builder.build());
|
||||
}
|
||||
25 => {
|
||||
inst_data_builder
|
||||
.set_rdr1_u8(rd, r1)
|
||||
.set_imm_len(generic3)
|
||||
.set_offset_select_value(generic1, generic2);
|
||||
inst_builder.make_inst(ld, inst_data_builder.build());
|
||||
}
|
||||
26 => {
|
||||
inst_data_builder
|
||||
.set_rdr1_u8(rd, r1)
|
||||
.set_imm_len(generic3)
|
||||
.set_offset_select_value(generic1, generic2);
|
||||
inst_builder.make_inst(st, inst_data_builder.build());
|
||||
}
|
||||
27 => {
|
||||
inst_data_builder
|
||||
.set_rd_u8(rd)
|
||||
.set_imm(r2_or_imm)
|
||||
.set_imm_len(generic3)
|
||||
.set_offset_select_value(generic1, generic2);
|
||||
inst_builder.make_inst(lldi, inst_data_builder.build());
|
||||
}
|
||||
28 => {
|
||||
inst_data_builder
|
||||
.set_rdr1_u8(rd, r1)
|
||||
.set_imm_len(generic3)
|
||||
.set_offset_select_value(generic1, generic2);
|
||||
inst_builder.make_inst(lmv, inst_data_builder.build());
|
||||
}
|
||||
29 => {
|
||||
inst_data_builder
|
||||
.set_rd_u8(rd)
|
||||
.set_imm_core(r2_or_imm + 1)
|
||||
.set_imm_len(generic3)
|
||||
.set_offset_select_value(generic1, generic2);
|
||||
inst_builder.make_inst(send, inst_data_builder.build());
|
||||
}
|
||||
30 => {
|
||||
inst_data_builder
|
||||
.set_rd_u8(rd)
|
||||
.set_imm_core(r2_or_imm + 1)
|
||||
.set_imm_len(generic3)
|
||||
.set_offset_select_value(generic1, generic2);
|
||||
inst_builder.make_inst(recv, inst_data_builder.build());
|
||||
}
|
||||
31 => {
|
||||
inst_builder.make_inst(wait, inst_data_builder.build());
|
||||
}
|
||||
32 => {
|
||||
inst_builder.make_inst(sync, inst_data_builder.build());
|
||||
}
|
||||
_ => bail!("unsupported PIM binary opcode {opcode}"),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn binary_to_instructions(
|
||||
core_bytes: &[u8],
|
||||
core_index: i32,
|
||||
) -> Result<Vec<crate::instruction_set::Instruction>> {
|
||||
let records = parse_binary_records(core_bytes)?;
|
||||
let mut insts_builder = InstructionsBuilder::new();
|
||||
let mut inst_data_builder = InstructionDataBuilder::new();
|
||||
inst_data_builder
|
||||
.set_core_indx_u16(u16::try_from(core_index).expect("core index does not fit in u16"))
|
||||
.fix_core_indx();
|
||||
|
||||
for record in records {
|
||||
let opcode = record.opcode;
|
||||
let name = INSTRUCTIONS
|
||||
.get(&(opcode as usize))
|
||||
.copied()
|
||||
.unwrap_or("<unknown>");
|
||||
|
||||
append_record(&mut insts_builder, &mut inst_data_builder, record).with_context(|| {
|
||||
format!(
|
||||
"while decoding binary instruction for core {core_index}: opcode {opcode} ({name})"
|
||||
)
|
||||
})?;
|
||||
}
|
||||
|
||||
Ok(insts_builder.build())
|
||||
}
|
||||
|
||||
pub fn binary_to_executor<'a, 'b>(
|
||||
config: Value,
|
||||
cores: impl Iterator<Item = &'b Vec<u8>>,
|
||||
crossbars: Vec<Vec<&'a Crossbar>>,
|
||||
) -> Result<Executable<'a>> {
|
||||
let core_cnt = config
|
||||
.get("core_cnt")
|
||||
.context("missing core_cnt in config")?
|
||||
.as_i64()
|
||||
.context("core_cnt is not an integer")? as i32;
|
||||
|
||||
let cpu = CPU::new(core_cnt, crossbars);
|
||||
let mut core_insts_builder = CoreInstructionsBuilder::new(core_cnt as usize);
|
||||
for (external_core_indx, core_bytes) in cores.enumerate() {
|
||||
let core_indx = external_core_indx as i32 + 1;
|
||||
let instructions = binary_to_instructions(core_bytes, core_indx)?;
|
||||
core_insts_builder.set_core(core_indx, instructions);
|
||||
}
|
||||
|
||||
Ok(Executable::new(cpu, core_insts_builder.build()))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{
|
||||
HEADER_SIZE, InstructionRecord, MAGIC, RECORD_SIZE, VERSION, binary_to_instructions,
|
||||
};
|
||||
use crate::{
|
||||
functor_to_name,
|
||||
instruction_set::{InstructionsBuilder, instruction_data::InstructionDataBuilder},
|
||||
json_to_instruction::json_isa::json_to_instruction,
|
||||
};
|
||||
|
||||
fn encode_record(record: InstructionRecord, dst: &mut Vec<u8>) {
|
||||
dst.push(record.opcode);
|
||||
dst.push(record.rd);
|
||||
dst.push(record.r1);
|
||||
dst.push(record.flags);
|
||||
dst.extend_from_slice(&record.r2_or_imm.to_le_bytes());
|
||||
dst.extend_from_slice(&record.generic1.to_le_bytes());
|
||||
dst.extend_from_slice(&record.generic2.to_le_bytes());
|
||||
dst.extend_from_slice(&record.generic3.to_le_bytes());
|
||||
}
|
||||
|
||||
fn binary_blob(records: &[InstructionRecord]) -> Vec<u8> {
|
||||
let mut blob = Vec::with_capacity(HEADER_SIZE + records.len() * RECORD_SIZE);
|
||||
blob.extend_from_slice(MAGIC);
|
||||
blob.extend_from_slice(&VERSION.to_le_bytes());
|
||||
blob.extend_from_slice(&(records.len() as u32).to_le_bytes());
|
||||
for &record in records {
|
||||
encode_record(record, &mut blob);
|
||||
}
|
||||
blob
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn json_and_binary_decoders_match_for_representative_ops() {
|
||||
let json_program = [
|
||||
r#"{"imm":64,"op":"sldi","rd":0}"#,
|
||||
r#"{"imm":128,"op":"sldi","rd":1}"#,
|
||||
r#"{"len":16,"offset":{"offset_select":0,"offset_value":0},"op":"lmv","rd":0,"rs1":1}"#,
|
||||
r#"{"group":3,"mbiw":8,"op":"mvmul","rd":0,"relu":0,"rs1":1}"#,
|
||||
r#"{"len":16,"offset":{"offset_select":0,"offset_value":0},"op":"vvadd","rd":0,"rs1":1,"rs2":2}"#,
|
||||
r#"{"core":2,"offset":{"offset_select":0,"offset_value":0},"op":"send","rd":0,"size":16}"#,
|
||||
];
|
||||
|
||||
let binary_program = binary_blob(&[
|
||||
InstructionRecord {
|
||||
opcode: 1,
|
||||
rd: 0,
|
||||
r2_or_imm: 64,
|
||||
..Default::default()
|
||||
},
|
||||
InstructionRecord {
|
||||
opcode: 1,
|
||||
rd: 1,
|
||||
r2_or_imm: 128,
|
||||
..Default::default()
|
||||
},
|
||||
InstructionRecord {
|
||||
opcode: 28,
|
||||
rd: 0,
|
||||
r1: 1,
|
||||
generic3: 16,
|
||||
..Default::default()
|
||||
},
|
||||
InstructionRecord {
|
||||
opcode: 9,
|
||||
rd: 0,
|
||||
r1: 1,
|
||||
r2_or_imm: 8,
|
||||
generic2: 3,
|
||||
..Default::default()
|
||||
},
|
||||
InstructionRecord {
|
||||
opcode: 10,
|
||||
rd: 0,
|
||||
r1: 1,
|
||||
r2_or_imm: 2,
|
||||
generic3: 16,
|
||||
..Default::default()
|
||||
},
|
||||
InstructionRecord {
|
||||
opcode: 29,
|
||||
rd: 0,
|
||||
r2_or_imm: 2,
|
||||
generic3: 16,
|
||||
..Default::default()
|
||||
},
|
||||
]);
|
||||
|
||||
let mut json_builder = InstructionsBuilder::new();
|
||||
let mut json_data_builder = InstructionDataBuilder::new();
|
||||
json_data_builder.set_core_indx(1).fix_core_indx();
|
||||
for inst in json_program {
|
||||
let value = serde_json::from_str(inst).unwrap();
|
||||
json_to_instruction(&mut json_builder, &mut json_data_builder, &value);
|
||||
}
|
||||
let json_instructions = json_builder.build();
|
||||
let binary_instructions = binary_to_instructions(&binary_program, 1).unwrap();
|
||||
|
||||
assert_eq!(json_instructions.len(), binary_instructions.len());
|
||||
for (json_inst, binary_inst) in json_instructions.iter().zip(binary_instructions.iter()) {
|
||||
assert_eq!(
|
||||
functor_to_name(json_inst.functor as usize),
|
||||
functor_to_name(binary_inst.functor as usize)
|
||||
);
|
||||
assert_eq!(json_inst.data, binary_inst.data);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,10 +1,11 @@
|
||||
use paste::paste;
|
||||
use std::convert::TryFrom;
|
||||
|
||||
#[derive(Clone, Copy, Debug, Default)]
|
||||
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
|
||||
pub struct InstructionData {
|
||||
core_indx: i32,
|
||||
rd: i32,
|
||||
r1: i32,
|
||||
core_indx: u16,
|
||||
rd: u8,
|
||||
r1: u8,
|
||||
//r2 imm mbiw imm_core
|
||||
r2_or_imm: i32,
|
||||
//offset_select imm_relu ibiw
|
||||
@@ -16,18 +17,30 @@ pub struct InstructionData {
|
||||
}
|
||||
|
||||
impl InstructionData {
|
||||
pub fn core_indx(&self) -> i32 {
|
||||
pub fn core_indx_u16(&self) -> u16 {
|
||||
self.core_indx
|
||||
}
|
||||
|
||||
pub fn rd(&self) -> i32 {
|
||||
pub fn core_indx(&self) -> i32 {
|
||||
i32::from(self.core_indx)
|
||||
}
|
||||
|
||||
pub fn rd_u8(&self) -> u8 {
|
||||
self.rd
|
||||
}
|
||||
|
||||
pub fn r1(&self) -> i32 {
|
||||
pub fn rd(&self) -> i32 {
|
||||
i32::from(self.rd)
|
||||
}
|
||||
|
||||
pub fn r1_u8(&self) -> u8 {
|
||||
self.r1
|
||||
}
|
||||
|
||||
pub fn r1(&self) -> i32 {
|
||||
i32::from(self.r1)
|
||||
}
|
||||
|
||||
pub fn r2(&self) -> i32 {
|
||||
self.r2_or_imm
|
||||
}
|
||||
@@ -49,26 +62,26 @@ impl InstructionData {
|
||||
}
|
||||
|
||||
pub fn get_core_rd_r1(&self) -> (i32, i32, i32) {
|
||||
(self.core_indx, self.rd, self.r1)
|
||||
(self.core_indx(), self.rd(), self.r1())
|
||||
}
|
||||
|
||||
pub fn get_core_rd_r1_r2(&self) -> (i32, i32, i32, i32) {
|
||||
(self.core_indx, self.rd, self.r1, self.r2_or_imm)
|
||||
(self.core_indx(), self.rd(), self.r1(), self.r2_or_imm)
|
||||
}
|
||||
|
||||
pub fn get_core_rd_imm(&self) -> (i32, i32, i32) {
|
||||
(self.core_indx, self.rd, self.r2_or_imm)
|
||||
(self.core_indx(), self.rd(), self.r2_or_imm)
|
||||
}
|
||||
|
||||
pub fn get_core_rd_r1_imm(&self) -> (i32, i32, i32, i32) {
|
||||
(self.core_indx, self.rd, self.r1, self.r2_or_imm)
|
||||
(self.core_indx(), self.rd(), self.r1(), self.r2_or_imm)
|
||||
}
|
||||
|
||||
pub fn get_core_rd_r1_r2_immlen_offset(&self) -> (i32, i32, i32, i32, i32, i32, i32) {
|
||||
(
|
||||
self.core_indx,
|
||||
self.rd,
|
||||
self.r1,
|
||||
self.core_indx(),
|
||||
self.rd(),
|
||||
self.r1(),
|
||||
self.r2_or_imm,
|
||||
self.generic3,
|
||||
self.generic1,
|
||||
@@ -78,9 +91,9 @@ impl InstructionData {
|
||||
|
||||
pub fn get_core_rd_r1_mbiw_immrelu_immgroup(&self) -> (i32, i32, i32, i32, i32, i32) {
|
||||
(
|
||||
self.core_indx,
|
||||
self.rd,
|
||||
self.r1,
|
||||
self.core_indx(),
|
||||
self.rd(),
|
||||
self.r1(),
|
||||
self.r2_or_imm,
|
||||
self.generic1,
|
||||
self.generic2,
|
||||
@@ -100,7 +113,7 @@ impl InstructionData {
|
||||
}
|
||||
|
||||
pub(crate) fn get_core_immcore(&self) -> (i32, i32) {
|
||||
(self.core_indx, self.r2_or_imm)
|
||||
(self.core_indx(), self.r2_or_imm)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -216,6 +229,18 @@ impl InstructionDataBuilder {
|
||||
common_getter_setter![imm_group];
|
||||
common_getter_setter![imm_core];
|
||||
|
||||
pub fn set_core_indx_u16(&mut self, val: u16) -> &mut Self {
|
||||
self.set_core_indx(i32::from(val))
|
||||
}
|
||||
|
||||
pub fn set_rd_u8(&mut self, val: u8) -> &mut Self {
|
||||
self.set_rd(i32::from(val))
|
||||
}
|
||||
|
||||
pub fn set_r1_u8(&mut self, val: u8) -> &mut Self {
|
||||
self.set_r1(i32::from(val))
|
||||
}
|
||||
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
core_indx: Fixer::Edit(0),
|
||||
@@ -254,20 +279,16 @@ impl InstructionDataBuilder {
|
||||
|
||||
fn check_sanity(&self) {
|
||||
assert!(!(self.get_r2() != 0 && self.get_imm() != 0 && self.get_mbiw() != 0 && self.get_imm_core() != 0));
|
||||
assert!(
|
||||
!(self.get_ibiw() != 0 && self.get_offset_select() != 0 && self.get_imm_relu() != 0)
|
||||
);
|
||||
assert!(
|
||||
!(self.get_obiw() != 0 && self.get_offset_value() != 0 && self.get_imm_group() != 0)
|
||||
);
|
||||
assert!(!(self.get_ibiw() != 0 && self.get_offset_select() != 0 && self.get_imm_relu() != 0));
|
||||
assert!(!(self.get_obiw() != 0 && self.get_offset_value() != 0 && self.get_imm_group() != 0));
|
||||
}
|
||||
|
||||
pub fn build(&mut self) -> InstructionData {
|
||||
self.check_sanity();
|
||||
let inst_data = InstructionData {
|
||||
core_indx: self.get_core_indx(),
|
||||
rd: self.get_rd(),
|
||||
r1: self.get_r1(),
|
||||
core_indx: u16::try_from(self.get_core_indx()).expect("core index does not fit in u16"),
|
||||
rd: u8::try_from(self.get_rd()).expect("rd does not fit in u8"),
|
||||
r1: u8::try_from(self.get_r1()).expect("r1 does not fit in u8"),
|
||||
r2_or_imm: self.get_r2() + self.get_imm() + self.get_mbiw() + self.get_imm_core(),
|
||||
generic1: self.get_offset_select() + self.get_ibiw() + self.get_imm_relu(),
|
||||
generic2: self.get_offset_value() + self.get_obiw() + self.get_imm_group(),
|
||||
@@ -281,6 +302,10 @@ impl InstructionDataBuilder {
|
||||
self.set_rd(rd).set_r1(r1).set_r2(r2)
|
||||
}
|
||||
|
||||
pub fn set_rdr1r2_u8(&mut self, rd: u8, r1: u8, r2: i32) -> &mut Self {
|
||||
self.set_rd_u8(rd).set_r1_u8(r1).set_r2(r2)
|
||||
}
|
||||
|
||||
pub fn set_offset_select_value(&mut self, offset_select: i32, offset_value: i32) -> &mut Self {
|
||||
self.set_offset_select(offset_select)
|
||||
.set_offset_value(offset_value)
|
||||
@@ -290,14 +315,26 @@ impl InstructionDataBuilder {
|
||||
self.set_rd(rd).set_r1(r1).set_imm(imm)
|
||||
}
|
||||
|
||||
pub fn set_rdr1imm_u8(&mut self, rd: u8, r1: u8, imm: i32) -> &mut Self {
|
||||
self.set_rd_u8(rd).set_r1_u8(r1).set_imm(imm)
|
||||
}
|
||||
|
||||
pub fn set_rdr1(&mut self, rd: i32, r1: i32) -> &mut Self {
|
||||
self.set_rd(rd).set_r1(r1)
|
||||
}
|
||||
|
||||
pub fn set_rdr1_u8(&mut self, rd: u8, r1: u8) -> &mut Self {
|
||||
self.set_rd_u8(rd).set_r1_u8(r1)
|
||||
}
|
||||
|
||||
pub fn set_rdimm(&mut self, rd: i32, imm: i32) -> &mut Self {
|
||||
self.set_rd(rd).set_imm(imm)
|
||||
}
|
||||
|
||||
pub fn set_rdimm_u8(&mut self, rd: u8, imm: i32) -> &mut Self {
|
||||
self.set_rd_u8(rd).set_imm(imm)
|
||||
}
|
||||
|
||||
pub fn set_ibiw_obiw(&mut self, ibiw: i32, obiw: i32) -> &mut Self {
|
||||
self.set_ibiw(ibiw).set_obiw(obiw)
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@ use anyhow::{Context, Result, ensure};
|
||||
use rayon::prelude::*;
|
||||
|
||||
use paste::paste;
|
||||
use std::{borrow::Cow, cell::OnceCell, collections::HashMap};
|
||||
use std::{borrow::Cow, cell::OnceCell, collections::HashMap };
|
||||
use std::{collections::HashSet, sync::LazyLock};
|
||||
|
||||
macro_rules! add_name {
|
||||
@@ -35,7 +35,7 @@ macro_rules! add_name_simd {
|
||||
};
|
||||
}
|
||||
|
||||
static NAMES: LazyLock<HashMap<usize, &'static str>> = LazyLock::new(|| {
|
||||
pub static NAMES: LazyLock<HashMap<usize, &'static str>> = LazyLock::new(|| {
|
||||
let mut hash = HashMap::new();
|
||||
add_name!(hash, sldi);
|
||||
add_name!(hash, sld);
|
||||
@@ -81,6 +81,7 @@ pub fn functor_to_name(functor: usize) -> &'static str {
|
||||
///////////////////////////////////////////////////////////////
|
||||
/////////////////Scalar/register Instructions//////////////////
|
||||
///////////////////////////////////////////////////////////////
|
||||
#[inline(never)]
|
||||
pub fn sldi(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
TRACER.lock().unwrap().pre_sldi(cores, data);
|
||||
let (core_indx, rd, imm) = data.get_core_rd_imm();
|
||||
@@ -90,6 +91,7 @@ pub fn sldi(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||
Ok(InstructionStatus::Completed)
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn sld(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
TRACER.lock().unwrap().pre_sld(cores, data);
|
||||
let (core_indx, rd, r1) = data.get_core_rd_r1();
|
||||
@@ -104,6 +106,7 @@ pub fn sld(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||
Ok(InstructionStatus::Completed)
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn sadd(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
TRACER.lock().unwrap().pre_sadd(cores, data);
|
||||
let (core_indx, rd, r1, r2) = data.get_core_rd_r1_r2();
|
||||
@@ -114,6 +117,7 @@ pub fn sadd(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||
Ok(InstructionStatus::Completed)
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn ssub(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
TRACER.lock().unwrap().pre_ssub(cores, data);
|
||||
let (core_indx, rd, r1, r2) = data.get_core_rd_r1_r2();
|
||||
@@ -124,6 +128,7 @@ pub fn ssub(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||
Ok(InstructionStatus::Completed)
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn smul(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
TRACER.lock().unwrap().pre_smul(cores, data);
|
||||
let (core_indx, rd, r1, r2) = data.get_core_rd_r1_r2();
|
||||
@@ -134,6 +139,7 @@ pub fn smul(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||
Ok(InstructionStatus::Completed)
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn saddi(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
TRACER.lock().unwrap().pre_saddi(cores, data);
|
||||
let (core_indx, rd, r1, imm) = data.get_core_rd_r1_imm();
|
||||
@@ -143,6 +149,7 @@ pub fn saddi(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus
|
||||
Ok(InstructionStatus::Completed)
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn smuli(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
TRACER.lock().unwrap().pre_smuli(cores, data);
|
||||
let (core_indx, rd, r1, imm) = data.get_core_rd_r1_imm();
|
||||
@@ -217,14 +224,17 @@ pub fn is_setbw(functor: InstructionType) -> bool {
|
||||
functor as usize == setbw as *const () as usize
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn setbw(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
panic!("You are calling a placeholder, this instruction is resolved in the construction phase");
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn mvmul(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
panic!("You are calling a placeholder, the real call is the generic version");
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub(super) fn mvm_impl_internal<F, M, T>(
|
||||
cores: &mut CPU,
|
||||
data: InstructionData,
|
||||
@@ -309,6 +319,7 @@ where
|
||||
Ok(InstructionStatus::Completed)
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub(super) fn mvmul_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||
where
|
||||
[F]: UpcastSlice<T> + UpcastSlice<f32> + UpcastSlice<f64>,
|
||||
@@ -329,10 +340,12 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn vvadd(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
panic!("You are calling a placeholder, the real call is the generic version");
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub(super) fn vvadd_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||
where
|
||||
[F]: UpcastSlice<T>,
|
||||
@@ -371,10 +384,12 @@ where
|
||||
Ok(InstructionStatus::Completed)
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn vvsub(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
panic!("You are calling a placeholder, the real call is the generic version");
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub(super) fn vvsub_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||
where
|
||||
[F]: UpcastSlice<T>,
|
||||
@@ -416,6 +431,7 @@ pub fn vvmul(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus
|
||||
panic!("You are calling a placeholder, the real call is the generic version");
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub(super) fn vvmul_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||
where
|
||||
[F]: UpcastSlice<T>,
|
||||
@@ -452,10 +468,12 @@ where
|
||||
Ok(InstructionStatus::Completed)
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn vvdmul(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
panic!("You are calling a placeholder, the real call is the generic version");
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub(super) fn vvdmul_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||
where
|
||||
[F]: UpcastSlice<T>,
|
||||
@@ -488,10 +506,12 @@ where
|
||||
Ok(InstructionStatus::Completed)
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn vvmax(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
panic!("You are calling a placeholder, the real call is the generic version");
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub(super) fn vvmax_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||
where
|
||||
[F]: UpcastSlice<T>,
|
||||
@@ -525,22 +545,26 @@ where
|
||||
Ok(InstructionStatus::Completed)
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn vvsll(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
panic!(
|
||||
"Shift left on floating point what does it means? who has generated this instruction???"
|
||||
);
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn vvsra(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
panic!(
|
||||
"Shift right on floating point what does it means? who has generated this instruction???"
|
||||
);
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn vavg(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
panic!("You are calling a placeholder, the real call is the generic version");
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub(super) fn vavg_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||
where
|
||||
[F]: UpcastSlice<T>,
|
||||
@@ -570,10 +594,12 @@ where
|
||||
Ok(InstructionStatus::Completed)
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn vrelu(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
panic!("You are calling a placeholder, the real call is the generic version");
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub(super) fn vrelu_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||
where
|
||||
[F]: UpcastSlice<T>,
|
||||
@@ -600,10 +626,12 @@ where
|
||||
Ok(InstructionStatus::Completed)
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn vtanh(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
panic!("You are calling a placeholder, the real call is the generic version");
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub(super) fn vtanh_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||
where
|
||||
[F]: UpcastSlice<T>,
|
||||
@@ -628,10 +656,12 @@ where
|
||||
Ok(InstructionStatus::Completed)
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn vsigm(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
panic!("You are calling a placeholder, the real call is the generic version");
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub(super) fn vsigm_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||
where
|
||||
[F]: UpcastSlice<T>,
|
||||
@@ -654,10 +684,12 @@ where
|
||||
Ok(InstructionStatus::Completed)
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn vsoftmax(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
panic!("You are calling a placeholder, the real call is the generic version");
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub(super) fn vsoftmax_impl<F, T>(
|
||||
cores: &mut CPU,
|
||||
data: InstructionData,
|
||||
@@ -696,14 +728,17 @@ where
|
||||
Ok(InstructionStatus::Completed)
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn vmv(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn vrsu(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn vrsl(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
todo!()
|
||||
}
|
||||
@@ -711,6 +746,7 @@ pub fn vrsl(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||
///////////////////////////////////////////////////////////////
|
||||
///Communication/synchronization Instructions/////////////////
|
||||
///////////////////////////////////////////////////////////////
|
||||
#[inline(never)]
|
||||
pub fn ld(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
TRACER.lock().unwrap().pre_ld(cores, data);
|
||||
let (core, rd, r1, _, imm_len, offset_select, offset_value) =
|
||||
@@ -727,6 +763,7 @@ pub fn ld(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
Ok(InstructionStatus::Completed)
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn st(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
TRACER.lock().unwrap().pre_st(cores, data);
|
||||
let (core, rd, r1, _, imm_len, offset_select, offset_value) =
|
||||
@@ -743,6 +780,7 @@ pub fn st(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
Ok(InstructionStatus::Completed)
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn lldi(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
TRACER.lock().unwrap().pre_lldi(cores, data);
|
||||
let (core, rd, imm) = data.get_core_rd_imm();
|
||||
@@ -759,6 +797,7 @@ pub fn lldi(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||
Ok(InstructionStatus::Completed)
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn lmv(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
TRACER.lock().unwrap().pre_lmv(cores, data);
|
||||
let (core, rd, r1, _, imm_len, offset_select, offset_value) =
|
||||
@@ -775,18 +814,32 @@ pub fn lmv(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||
Ok(InstructionStatus::Completed)
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn isa_send(functor : usize) -> bool{
|
||||
(send as *const () as usize) == functor
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn send(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
Ok(InstructionStatus::Sending(data))
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn isa_recv(functor : usize) -> bool{
|
||||
(recv as *const () as usize) == functor
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn recv(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
Ok(InstructionStatus::Reciving(data))
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn wait(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
Ok(InstructionStatus::Waiting(data))
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn sync(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
Ok(InstructionStatus::Sync(data))
|
||||
}
|
||||
|
||||
@@ -14,7 +14,7 @@ pub mod helper;
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub struct Instruction {
|
||||
pub data: InstructionData,
|
||||
functor: InstructionType,
|
||||
pub functor: InstructionType,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Default)]
|
||||
|
||||
@@ -567,7 +567,7 @@ fn json_to_send(
|
||||
let (offset_select, offset_value) = json_to_offset(json.get("offset").unwrap());
|
||||
inst_data_builder
|
||||
.set_rd(rd)
|
||||
.set_imm_core(core)
|
||||
.set_imm_core(core + 1)
|
||||
.set_imm_len(size)
|
||||
.set_offset_select(offset_select)
|
||||
.set_offset_value(offset_value);
|
||||
@@ -588,7 +588,7 @@ fn json_to_recv(
|
||||
let (offset_select, offset_value) = json_to_offset(json.get("offset").unwrap());
|
||||
inst_data_builder
|
||||
.set_rd(rd)
|
||||
.set_imm_core(core)
|
||||
.set_imm_core(core + 1)
|
||||
.set_imm_len(size)
|
||||
.set_offset_select(offset_select)
|
||||
.set_offset_value(offset_value);
|
||||
|
||||
+14
-27
@@ -1,45 +1,32 @@
|
||||
use core::panic;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use serde_json::{Map, Value};
|
||||
use serde_json::Value;
|
||||
use std::{fs::File, io::BufReader};
|
||||
|
||||
use crate::{
|
||||
CoreInstructionsBuilder, Executable,
|
||||
cpu::{CPU, crossbar::{self, Crossbar}},
|
||||
instruction_set::{
|
||||
InstructionsBuilder,
|
||||
instruction_data::{self, InstructionData, InstructionDataBuilder},
|
||||
},
|
||||
json_to_instruction::{self, json_isa},
|
||||
memory_manager::type_traits::TryToUsize,
|
||||
cpu::{CPU, crossbar::Crossbar},
|
||||
instruction_set::{InstructionsBuilder, instruction_data::InstructionDataBuilder},
|
||||
json_to_instruction::json_isa,
|
||||
};
|
||||
|
||||
|
||||
pub fn json_to_executor<'a, 'b>(
|
||||
config: Value,
|
||||
mut cores: impl Iterator<Item = &'b Value>,
|
||||
crossbars : Vec<Vec<&'a Crossbar>>
|
||||
cores: &'b mut Vec<BufReader<File>>,
|
||||
crossbars: Vec<Vec<&'a Crossbar>>,
|
||||
) -> Executable<'a> {
|
||||
let cell_precision = config.get("cell_precision").unwrap().as_i64().unwrap() as i32;
|
||||
let core_cnt = config.get("core_cnt").unwrap().as_i64().unwrap() as i32 - 1;
|
||||
let xbar_count = config.get("xbar_array_count").unwrap().as_i64().unwrap() as i32;
|
||||
let xbar_size = config.get("xbar_size").unwrap().as_array().unwrap();
|
||||
let rows_crossbar = xbar_size[0].as_i64().unwrap() as i32;
|
||||
let column_corssbar = xbar_size[1].as_i64().unwrap() as i32;
|
||||
let core_cnt = config.get("core_cnt").unwrap().as_i64().unwrap() as i32;
|
||||
|
||||
let mut cpu = CPU::new(core_cnt, crossbars);
|
||||
let cpu = CPU::new(core_cnt, crossbars);
|
||||
let mut core_insts_builder = CoreInstructionsBuilder::new(core_cnt as usize);
|
||||
cores.next();
|
||||
for core_indx in 1..=core_cnt {
|
||||
for (external_core_indx, json_core_reader) in cores.iter_mut().enumerate() {
|
||||
let core_indx = external_core_indx as i32 + 1;
|
||||
let mut insts_builder = InstructionsBuilder::new();
|
||||
let mut inst_data_builder = InstructionDataBuilder::new();
|
||||
inst_data_builder.set_core_indx(core_indx).fix_core_indx();
|
||||
let json_core = cores
|
||||
.next()
|
||||
.unwrap_or_else(|| panic!("cores files less than {}", core_indx ));
|
||||
let json_core: Value = serde_json::from_reader(json_core_reader)
|
||||
.unwrap_or_else(|err| panic!("failed to parse core{}: {}", external_core_indx, err));
|
||||
let json_core_insts = json_core
|
||||
.as_array()
|
||||
.unwrap_or_else(|| panic!("core{} has not a list of instruction", core_indx));
|
||||
.unwrap_or_else(|| panic!("core{} has not a list of instruction", external_core_indx));
|
||||
for json_inst in json_core_insts {
|
||||
json_isa::json_to_instruction(&mut insts_builder, &mut inst_data_builder, json_inst);
|
||||
}
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
mod json_isa;
|
||||
pub(crate) mod json_isa;
|
||||
pub mod json_to_executor;
|
||||
|
||||
@@ -1,14 +1,22 @@
|
||||
#![allow(unused)]
|
||||
|
||||
use std::time::{Duration, SystemTime};
|
||||
use anyhow::{Result, bail};
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
time::{Duration, SystemTime},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
cpu::CPU,
|
||||
instruction_set::{Instruction, InstructionStatus, Instructions, isa::functor_to_name},
|
||||
instruction_set::{
|
||||
Instruction, InstructionStatus, Instructions,
|
||||
isa::{NAMES, functor_to_name, isa_recv, isa_send},
|
||||
},
|
||||
memory_manager::type_traits::TryToUsize,
|
||||
send_recv::{SendRecv, handle_send_recv},
|
||||
tracing::TRACER,
|
||||
};
|
||||
pub mod binary_to_instruction;
|
||||
pub mod cpu;
|
||||
pub mod instruction_set;
|
||||
pub mod json_to_instruction;
|
||||
@@ -80,6 +88,11 @@ pub struct Executable<'a> {
|
||||
send_recv: SendRecv,
|
||||
}
|
||||
|
||||
struct DeadlockInfo {
|
||||
cycle: String,
|
||||
states: String,
|
||||
}
|
||||
|
||||
fn print_status(core_instructions: &[CoreInstructions]) {
|
||||
let mut tot_instructions = 0;
|
||||
let mut progress = 0;
|
||||
@@ -111,7 +124,7 @@ impl<'a> Executable<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn execute<'b>(&'b mut self)
|
||||
pub fn execute<'b>(&'b mut self) -> Result<()>
|
||||
where
|
||||
'a: 'b,
|
||||
{
|
||||
@@ -144,8 +157,15 @@ impl<'a> Executable<'a> {
|
||||
cpu_progressed = 0;
|
||||
*program_counter += 1;
|
||||
}
|
||||
if (now.elapsed().unwrap() > Duration::from_secs(1)) {
|
||||
print_status(&cores_instructions);
|
||||
if (now.elapsed().unwrap() > Duration::from_secs(5)) {
|
||||
print_status(cores_instructions);
|
||||
if let Some(deadlock) = detect_deadlock(cores_instructions) {
|
||||
bail!(
|
||||
"Deadlock cycle detected: {} [{}]",
|
||||
deadlock.cycle,
|
||||
deadlock.states
|
||||
);
|
||||
}
|
||||
now = SystemTime::now();
|
||||
}
|
||||
}
|
||||
@@ -170,8 +190,23 @@ impl<'a> Executable<'a> {
|
||||
}
|
||||
print_status(cores_instructions);
|
||||
|
||||
#[cfg(feature = "profile_time")]
|
||||
if let Some(deadlock) = detect_deadlock(cores_instructions) {
|
||||
bail!(
|
||||
"Deadlock cycle detected: {} [{}]",
|
||||
deadlock.cycle,
|
||||
deadlock.states
|
||||
);
|
||||
}
|
||||
if cores_instructions
|
||||
.iter()
|
||||
.any(|core_inst| core_inst.program_counter < core_inst.instructions.len())
|
||||
{
|
||||
bail!("Execution stalled with unfinished instructions");
|
||||
}
|
||||
|
||||
#[cfg(feature = "profile_time")]
|
||||
TRACER.lock().unwrap().report();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn cpu(&self) -> &CPU<'a> {
|
||||
@@ -193,6 +228,125 @@ impl<'a> Executable<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
fn detect_deadlock(cores_instructions: &[CoreInstructions]) -> Option<DeadlockInfo> {
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
enum CoreState {
|
||||
SendingTo(i32, i32),
|
||||
ReceivingFrom(i32, i32),
|
||||
Working,
|
||||
Halted,
|
||||
}
|
||||
|
||||
let mut states = HashMap::new();
|
||||
|
||||
for core_inst in cores_instructions.iter() {
|
||||
if core_inst.program_counter >= core_inst.instructions.len() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let Instruction { data, functor } = core_inst.instructions[core_inst.program_counter];
|
||||
let functor_address = functor as usize;
|
||||
|
||||
let (this_core, target_core) = data.get_core_immcore();
|
||||
|
||||
if isa_recv(functor_address) {
|
||||
states.insert(this_core, CoreState::ReceivingFrom(target_core, data.imm_len()));
|
||||
} else if isa_send(functor_address) {
|
||||
states.insert(this_core, CoreState::SendingTo(target_core, data.imm_len()));
|
||||
} else {
|
||||
states.insert(this_core, CoreState::Working);
|
||||
}
|
||||
}
|
||||
|
||||
let mut wait_for = HashMap::new();
|
||||
|
||||
for (&core_id, state) in states.iter() {
|
||||
match state {
|
||||
CoreState::SendingTo(target_core, size) => {
|
||||
let target_state = states.get(target_core).unwrap_or(&CoreState::Halted);
|
||||
if target_state != &CoreState::ReceivingFrom(core_id, *size) {
|
||||
wait_for.insert(core_id, *target_core);
|
||||
}
|
||||
}
|
||||
CoreState::ReceivingFrom(target_core, size) => {
|
||||
let target_state = states.get(target_core).unwrap_or(&CoreState::Halted);
|
||||
if target_state != &CoreState::SendingTo(core_id, *size) {
|
||||
wait_for.insert(core_id, *target_core);
|
||||
}
|
||||
}
|
||||
CoreState::Working | CoreState::Halted => {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut visited = HashSet::new();
|
||||
|
||||
for &start_core in wait_for.keys() {
|
||||
if visited.contains(&start_core) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut path = Vec::new();
|
||||
let mut current_core = start_core;
|
||||
let mut in_path = HashSet::new();
|
||||
|
||||
while let Some(&waiting_for) = wait_for.get(¤t_core) {
|
||||
path.push(current_core);
|
||||
in_path.insert(current_core);
|
||||
visited.insert(current_core);
|
||||
|
||||
// Found a closed loop!
|
||||
if in_path.contains(&waiting_for) {
|
||||
let cycle_start = path.iter().position(|&c| c == waiting_for).unwrap();
|
||||
let cycle = &path[cycle_start..];
|
||||
let format_core = |core: &i32| (core - 1).to_string();
|
||||
|
||||
let cycle_str = cycle
|
||||
.iter()
|
||||
.map(format_core)
|
||||
.collect::<Vec<_>>()
|
||||
.join(" -> ");
|
||||
|
||||
let cycle = cycle
|
||||
.iter()
|
||||
.copied()
|
||||
.chain(std::iter::once(waiting_for))
|
||||
.collect::<Vec<_>>();
|
||||
let cycle_msg = format!("{} -> {}", cycle_str, waiting_for - 1);
|
||||
let states_msg = cycle
|
||||
.iter()
|
||||
.filter_map(|core| {
|
||||
states.get(core).map(|state| match state {
|
||||
CoreState::SendingTo(target, size) => {
|
||||
format!("core {} send {}B -> {}", core - 1, size, target - 1)
|
||||
}
|
||||
CoreState::ReceivingFrom(source, size) => {
|
||||
format!("core {} recv {}B <- {}", core - 1, size, source - 1)
|
||||
}
|
||||
CoreState::Working => format!("core {} working", core - 1),
|
||||
CoreState::Halted => format!("core {} halted", core - 1),
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ");
|
||||
|
||||
return Some(DeadlockInfo {
|
||||
cycle: cycle_msg,
|
||||
states: states_msg,
|
||||
});
|
||||
}
|
||||
|
||||
// Hit a known branch that didn't result in a cycle
|
||||
if visited.contains(&waiting_for) {
|
||||
break;
|
||||
}
|
||||
|
||||
current_core = waiting_for;
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn handle_wait_sync<'a, 'b, 'c>(
|
||||
cpu: &'b mut CPU<'a>,
|
||||
core_instructions: &'c mut [CoreInstructions],
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
use std::path::Path;
|
||||
|
||||
use pimcore::{Executable, cpu::CPU, instruction_set::{InstructionsBuilder, instruction_data::InstructionDataBuilder, isa::*}};
|
||||
use pimcore::{
|
||||
Executable,
|
||||
cpu::crossbar::Crossbar,
|
||||
instruction_set::{InstructionsBuilder, instruction_data::InstructionDataBuilder, isa::*},
|
||||
memory_manager::CoreMemory,
|
||||
};
|
||||
|
||||
fn simple_read(path: &Path) -> Vec<f32> {
|
||||
if !path.exists() {
|
||||
@@ -17,14 +22,12 @@ fn simple_read(path: &Path) -> Vec<f32> {
|
||||
fn mvmul_f32(err: &str)
|
||||
where
|
||||
{
|
||||
let mut cpu = CPU::new(0);
|
||||
cpu.reserve_crossbar(1, 1024 * size_of::<f32>(), 1024);
|
||||
let (memory, crossbars) = cpu.host().get_memory_crossbar();
|
||||
let matrix = simple_read(Path::new("B.txt")) ;
|
||||
|
||||
|
||||
crossbars.get_mut(0).unwrap().execute_store( &matrix).unwrap();
|
||||
let vector = simple_read(Path::new("A.txt"));
|
||||
let matrix = simple_read(Path::new("tests/B.txt"));
|
||||
let mut crossbar = Crossbar::new(1024 * size_of::<f32>(), 1024, CoreMemory::new());
|
||||
crossbar.execute_store(&matrix).unwrap();
|
||||
let mut cpu = pimcore::cpu::CPU::new(0, vec![vec![&crossbar]]);
|
||||
let (memory, _) = cpu.host().get_memory_crossbar();
|
||||
let vector = simple_read(Path::new("tests/A.txt"));
|
||||
memory.execute_store(0, &vector).unwrap();
|
||||
|
||||
let mut inst_builder = InstructionsBuilder::new();
|
||||
@@ -57,7 +60,7 @@ where
|
||||
.cpu_mut()
|
||||
.host()
|
||||
.load::<f32>(1024 * size_of::<f32>(), 1024*size_of::<f32>()).unwrap()[0].iter().zip(
|
||||
simple_read(Path::new("X.txt")) ).all(|(&a,b) : (&f32, f32)| {a-b < 0.001}),
|
||||
simple_read(Path::new("tests/X.txt")) ).all(|(&a,b) : (&f32, f32)| {a-b < 0.001}),
|
||||
"Wrong result for {}",
|
||||
err
|
||||
);
|
||||
@@ -69,5 +72,3 @@ fn mvmul_big_test() {
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
use pimcore::cpu::CPU;
|
||||
|
||||
pub fn empty_cpu(num_cores: usize) -> CPU<'static> {
|
||||
CPU::new(num_cores, vec![Vec::new(); num_cores + 1])
|
||||
}
|
||||
@@ -1,51 +1,103 @@
|
||||
use std::{fs, io::BufReader, path::Path};
|
||||
use std::{
|
||||
fs::{self, File},
|
||||
io::BufReader,
|
||||
path::{Path, PathBuf},
|
||||
};
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use pimcore::json_to_instruction::json_to_executor;
|
||||
use pimcore::{
|
||||
cpu::crossbar::Crossbar,
|
||||
json_to_instruction::json_to_executor,
|
||||
memory_manager::CoreMemory,
|
||||
};
|
||||
use serde_json::Value;
|
||||
|
||||
fn collect_json_from_subfolders<P: AsRef<Path>>(root: P) -> Result<Vec<(Value, Vec<Value>)>> {
|
||||
fn collect_examples<P: AsRef<Path>>(root: P) -> Result<Vec<PathBuf>> {
|
||||
let mut result = Vec::new();
|
||||
for entry in fs::read_dir(root)? {
|
||||
let entry = entry.context("Root not found")?;
|
||||
let path = entry.path();
|
||||
if path.is_dir() {
|
||||
let mut cores = Vec::new();
|
||||
let mut config: Option<Value> = None;
|
||||
for sub_entry in fs::read_dir(&path)
|
||||
.with_context(|| format!("File {} not readable", path.display()))?
|
||||
{
|
||||
let sub_entry =
|
||||
sub_entry.with_context(|| format!("File {} not readable", path.display()))?;
|
||||
let sub_path = sub_entry.path();
|
||||
if sub_path.is_file()
|
||||
&& sub_path.extension().and_then(|s| s.to_str()) == Some("json")
|
||||
{
|
||||
let file = fs::File::open(&sub_path)
|
||||
.with_context(|| format!("Subpath {} not opened", sub_path.display()))?;
|
||||
let reader = BufReader::new(file);
|
||||
let val: Value = serde_json::from_reader(reader).with_context(|| format!(
|
||||
"Serde reader fail for subpath {}",
|
||||
sub_path.display()
|
||||
))?;
|
||||
if sub_path.file_name().unwrap() == "config.json" {
|
||||
config = Some(val);
|
||||
} else {
|
||||
cores.push(val);
|
||||
}
|
||||
}
|
||||
}
|
||||
result.push((config.unwrap(), cores));
|
||||
result.push(path);
|
||||
}
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn core_sort_key(path: &Path) -> i32 {
|
||||
let stem = path.file_stem().unwrap().to_str().unwrap();
|
||||
stem[5..].parse::<i32>().unwrap()
|
||||
}
|
||||
|
||||
fn crossbar_sort_key(path: &Path) -> i32 {
|
||||
let stem = path.file_stem().unwrap().to_str().unwrap();
|
||||
stem[9..].parse::<i32>().unwrap()
|
||||
}
|
||||
|
||||
fn load_crossbars(folder: &Path, config: &Value) -> Result<Vec<Vec<Crossbar>>> {
|
||||
let xbar_size = config.get("xbar_size").unwrap().as_array().unwrap();
|
||||
let rows = xbar_size[0].as_i64().unwrap() as usize;
|
||||
let cols = xbar_size[1].as_i64().unwrap() as usize;
|
||||
let core_cnt = config.get("core_cnt").unwrap().as_i64().unwrap() as usize;
|
||||
let mut owned_crossbars = Vec::with_capacity(core_cnt + 1);
|
||||
owned_crossbars.push(Vec::new());
|
||||
|
||||
for core_idx in 0..core_cnt {
|
||||
let core_folder = folder.join(format!("core_{core_idx}"));
|
||||
let mut core_crossbars = Vec::new();
|
||||
if core_folder.is_dir() {
|
||||
let mut paths: Vec<_> = fs::read_dir(&core_folder)?
|
||||
.map(|entry| entry.map(|entry| entry.path()))
|
||||
.collect::<std::io::Result<Vec<_>>>()?;
|
||||
paths.sort_by_cached_key(|path| crossbar_sort_key(path));
|
||||
for path in paths {
|
||||
if path.extension().and_then(|ext| ext.to_str()) != Some("bin") {
|
||||
continue;
|
||||
}
|
||||
let bytes = fs::read(&path)
|
||||
.with_context(|| format!("failed to read crossbar {}", path.display()))?;
|
||||
let mut crossbar = Crossbar::new(cols * 4, rows, CoreMemory::new());
|
||||
crossbar.execute_store(&bytes)?;
|
||||
core_crossbars.push(crossbar);
|
||||
}
|
||||
}
|
||||
owned_crossbars.push(core_crossbars);
|
||||
}
|
||||
Ok(owned_crossbars)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn json_folder_tester() {
|
||||
let examples = collect_json_from_subfolders("data").unwrap();
|
||||
for example in examples {
|
||||
let (config, cores) = example;
|
||||
json_to_executor::json_to_executor(config, cores.iter()).execute();
|
||||
let examples = collect_examples("tests/data").unwrap();
|
||||
for folder in examples {
|
||||
let config_path = folder.join("config.json");
|
||||
let config_file = File::open(&config_path).unwrap();
|
||||
let config: Value = serde_json::from_reader(BufReader::new(config_file)).unwrap();
|
||||
|
||||
let core_cnt = config.get("core_cnt").unwrap().as_i64().unwrap() as usize;
|
||||
let mut core_paths: Vec<_> = fs::read_dir(&folder)
|
||||
.unwrap()
|
||||
.map(|entry| entry.unwrap().path())
|
||||
.filter(|path| path.extension().and_then(|ext| ext.to_str()) == Some("json"))
|
||||
.filter(|path| path.file_name().unwrap() != "config.json")
|
||||
.collect();
|
||||
core_paths.sort_by_cached_key(|path| core_sort_key(path));
|
||||
assert_eq!(core_paths.len(), core_cnt);
|
||||
|
||||
let mut core_readers: Vec<_> = core_paths
|
||||
.into_iter()
|
||||
.map(|path| BufReader::new(File::open(path).unwrap()))
|
||||
.collect();
|
||||
|
||||
let owned_crossbars = load_crossbars(&folder, &config).unwrap();
|
||||
let crossbars = owned_crossbars
|
||||
.iter()
|
||||
.map(|core_crossbars| core_crossbars.iter().collect())
|
||||
.collect();
|
||||
|
||||
let mut executable = json_to_executor::json_to_executor(config, &mut core_readers, crossbars);
|
||||
let memory = fs::read(folder.join("memory.bin")).unwrap();
|
||||
executable.cpu_mut().host().execute_store(0, &memory).unwrap();
|
||||
executable.execute();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,11 +1,17 @@
|
||||
mod common;
|
||||
|
||||
use pimcore::{Executable, cpu::CPU, instruction_set::{InstructionType, InstructionsBuilder, instruction_data::InstructionDataBuilder, isa::*}};
|
||||
use pimcore::{
|
||||
Executable,
|
||||
instruction_set::{
|
||||
InstructionType, InstructionsBuilder, instruction_data::InstructionDataBuilder, isa::*,
|
||||
},
|
||||
};
|
||||
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Function not found for the requested size") ]
|
||||
fn wrong_size_place_holder() {
|
||||
let cpu = CPU::new(0);
|
||||
let cpu = common::empty_cpu(0);
|
||||
let mut inst_builder = InstructionsBuilder::new();
|
||||
let mut idata_build = InstructionDataBuilder::new();
|
||||
idata_build.set_core_indx(0).fix_core_indx();
|
||||
@@ -30,7 +36,7 @@ fn wrong_size_place_holder() {
|
||||
|
||||
|
||||
fn place_holder(inst : InstructionType) {
|
||||
let mut cpu = CPU::new(0);
|
||||
let mut cpu = common::empty_cpu(0);
|
||||
let mut idata_build = InstructionDataBuilder::new();
|
||||
idata_build.set_core_indx(0).fix_core_indx();
|
||||
inst(&mut cpu, idata_build.build()).unwrap();
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
mod common;
|
||||
|
||||
use pimcore::{
|
||||
Executable,
|
||||
cpu::CPU,
|
||||
cpu::crossbar::Crossbar,
|
||||
instruction_set::{InstructionsBuilder, instruction_data::InstructionDataBuilder, isa::*},
|
||||
memory_manager::{MemoryStorable, type_traits::UpcastDestTraits},
|
||||
memory_manager::{CoreMemory, MemoryStorable, type_traits::UpcastDestTraits},
|
||||
};
|
||||
|
||||
/// VVADD Test
|
||||
@@ -11,7 +13,7 @@ where
|
||||
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
||||
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
|
||||
{
|
||||
let mut cpu = CPU::new(0);
|
||||
let mut cpu = common::empty_cpu(0);
|
||||
let buff: [F; _] = [
|
||||
1.0.into(),
|
||||
2.0.into(),
|
||||
@@ -115,7 +117,7 @@ where
|
||||
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
||||
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
|
||||
{
|
||||
let mut cpu = CPU::new(0);
|
||||
let mut cpu = common::empty_cpu(0);
|
||||
let buff: [F; _] = [
|
||||
1.0.into(),
|
||||
2.0.into(),
|
||||
@@ -219,7 +221,7 @@ where
|
||||
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
||||
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
|
||||
{
|
||||
let mut cpu = CPU::new(0);
|
||||
let mut cpu = common::empty_cpu(0);
|
||||
let buff: [F; _] = [
|
||||
1.0.into(),
|
||||
2.0.into(),
|
||||
@@ -323,7 +325,7 @@ where
|
||||
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
||||
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
|
||||
{
|
||||
let mut cpu = CPU::new(0);
|
||||
let mut cpu = common::empty_cpu(0);
|
||||
let buff: [F; _] = [
|
||||
1.0.into(),
|
||||
2.0.into(),
|
||||
@@ -420,7 +422,7 @@ where
|
||||
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
||||
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
|
||||
{
|
||||
let mut cpu = CPU::new(0);
|
||||
let mut cpu = common::empty_cpu(0);
|
||||
let buff: [F; _] = [
|
||||
9.0.into(),
|
||||
2.0.into(),
|
||||
@@ -524,7 +526,7 @@ where
|
||||
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
||||
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
|
||||
{
|
||||
let mut cpu = CPU::new(0);
|
||||
let mut cpu = common::empty_cpu(0);
|
||||
let buff: [F; _] = [
|
||||
9.0.into(),
|
||||
2.0.into(),
|
||||
@@ -562,6 +564,7 @@ where
|
||||
vavg,
|
||||
idata_build
|
||||
.set_rdr1r2(3, 1, 1)
|
||||
.set_offset_select(1)
|
||||
.set_imm_len(8 * size_of::<F>() as i32)
|
||||
.build(),
|
||||
);
|
||||
@@ -617,7 +620,7 @@ where
|
||||
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
||||
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
|
||||
{
|
||||
let mut cpu = CPU::new(0);
|
||||
let mut cpu = common::empty_cpu(0);
|
||||
let buff: [F; _] = [
|
||||
(-9.0).into(),
|
||||
2.0.into(),
|
||||
@@ -717,7 +720,7 @@ where
|
||||
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
||||
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable + UpcastDestTraits<T>,
|
||||
{
|
||||
let mut cpu = CPU::new(0);
|
||||
let mut cpu = common::empty_cpu(0);
|
||||
let buff: [F; _] = [
|
||||
0.1.into(),
|
||||
0.2.into(),
|
||||
@@ -819,7 +822,7 @@ where
|
||||
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
||||
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable + UpcastDestTraits<T>,
|
||||
{
|
||||
let mut cpu = CPU::new(0);
|
||||
let mut cpu = common::empty_cpu(0);
|
||||
let buff: [F; _] = [
|
||||
0.1.into(),
|
||||
0.2.into(),
|
||||
@@ -923,9 +926,6 @@ where
|
||||
M: From<f32> + std::fmt::Debug + PartialEq<M> + MemoryStorable,
|
||||
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable + UpcastDestTraits<T>,
|
||||
{
|
||||
let mut cpu = CPU::new(0);
|
||||
cpu.reserve_crossbar(1, 4 * size_of::<M>(), 4);
|
||||
let (memory, crossbars) = cpu.host().get_memory_crossbar();
|
||||
let matrix: [M; _] = [
|
||||
1.0.into(),
|
||||
2.0.into(),
|
||||
@@ -944,7 +944,10 @@ where
|
||||
15.0.into(),
|
||||
16.0.into(),
|
||||
];
|
||||
crossbars.get_mut(0).unwrap().execute_store( &matrix).unwrap();
|
||||
let mut crossbar = Crossbar::new(4 * size_of::<M>(), 4, CoreMemory::new());
|
||||
crossbar.execute_store(&matrix).unwrap();
|
||||
let mut cpu = pimcore::cpu::CPU::new(0, vec![vec![&crossbar]]);
|
||||
let (memory, _) = cpu.host().get_memory_crossbar();
|
||||
let vector: [F; _] = [
|
||||
1.0.into(),
|
||||
2.0.into(),
|
||||
@@ -1054,5 +1057,3 @@ fn mvmul_test() {
|
||||
mvmul_test_generic::<f64,f64,f64>("mvmul<f64,f64,f64>",1);
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
mod common;
|
||||
|
||||
use pimcore::{
|
||||
Executable, CoreInstructionsBuilder,
|
||||
cpu::CPU,
|
||||
instruction_set::{InstructionsBuilder, instruction_data::InstructionDataBuilder, isa::*},
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn ld_test() {
|
||||
let mut cpu = CPU::new(1);
|
||||
let mut cpu = common::empty_cpu(1);
|
||||
let mut core_instruction_builder = CoreInstructionsBuilder::new(1);
|
||||
let buff: [f32; _] = [
|
||||
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
|
||||
@@ -41,7 +42,7 @@ fn ld_test() {
|
||||
|
||||
#[test]
|
||||
fn st_test() {
|
||||
let mut cpu = CPU::new(1);
|
||||
let mut cpu = common::empty_cpu(1);
|
||||
let mut core_instruction_builder = CoreInstructionsBuilder::new(1);
|
||||
let buff: [f32; _] = [
|
||||
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
|
||||
@@ -76,7 +77,7 @@ fn st_test() {
|
||||
|
||||
#[test]
|
||||
fn lldi_test() {
|
||||
let cpu = CPU::new(1);
|
||||
let cpu = common::empty_cpu(1);
|
||||
let mut core_instruction_builder = CoreInstructionsBuilder::new(1);
|
||||
let mut inst_builder = InstructionsBuilder::new();
|
||||
let mut idata_build = InstructionDataBuilder::new();
|
||||
@@ -106,7 +107,7 @@ fn lldi_test() {
|
||||
|
||||
#[test]
|
||||
fn lmv_test() {
|
||||
let mut cpu = CPU::new(1);
|
||||
let mut cpu = common::empty_cpu(1);
|
||||
let mut core_instruction_builder = CoreInstructionsBuilder::new(1);
|
||||
let buff: [f32; _] = [
|
||||
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
|
||||
@@ -148,7 +149,7 @@ fn lmv_test() {
|
||||
|
||||
#[test]
|
||||
fn simple_send_recv_test() {
|
||||
let mut cpu = CPU::new(2);
|
||||
let mut cpu = common::empty_cpu(2);
|
||||
let mut core_instruction_builder = CoreInstructionsBuilder::new(2);
|
||||
let buff: [f32; _] = [
|
||||
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
|
||||
@@ -207,7 +208,7 @@ fn simple_send_recv_test() {
|
||||
|
||||
#[test]
|
||||
fn multiple_send_recv_test() {
|
||||
let mut cpu = CPU::new(4);
|
||||
let mut cpu = common::empty_cpu(4);
|
||||
let mut core_instruction_builder = CoreInstructionsBuilder::new(4);
|
||||
let buff: [f32; _] = [
|
||||
1.0, 1.0, 1.0, 1.0, 1.0
|
||||
@@ -226,7 +227,7 @@ fn multiple_send_recv_test() {
|
||||
];
|
||||
cpu.core(4).execute_store(0, &buff).unwrap();
|
||||
|
||||
let send_inst = |cpu :&mut CPU, core_instruction_builder: &mut CoreInstructionsBuilder, inst_builder: &mut InstructionsBuilder, from : i32, to : i32| {
|
||||
let send_inst = |inst_builder: &mut InstructionsBuilder, from: i32, to: i32| {
|
||||
let mut idata_build = InstructionDataBuilder::new();
|
||||
idata_build.set_core_indx(from).fix_core_indx();
|
||||
inst_builder.make_inst(sldi, idata_build.set_rdimm(1, from*size_of::<f32>() as i32).build());
|
||||
@@ -240,7 +241,7 @@ fn multiple_send_recv_test() {
|
||||
);
|
||||
};
|
||||
|
||||
let recv_inst = |cpu :&mut CPU, core_instruction_builder: &mut CoreInstructionsBuilder, mut inst_builder: &mut InstructionsBuilder, to : i32, from : i32| {
|
||||
let recv_inst = |inst_builder: &mut InstructionsBuilder, to: i32, from: i32| {
|
||||
let mut idata_build = InstructionDataBuilder::new();
|
||||
idata_build.set_core_indx(to).fix_core_indx();
|
||||
inst_builder.make_inst(sldi, idata_build.set_rdimm(1, from*size_of::<f32>() as i32).build());
|
||||
@@ -257,26 +258,26 @@ fn multiple_send_recv_test() {
|
||||
|
||||
|
||||
// 1 -> 3
|
||||
send_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,1, 3);
|
||||
send_inst(&mut inst_builder, 1, 3);
|
||||
core_instruction_builder.set_core(1, inst_builder.build());
|
||||
|
||||
// 2 -> 3
|
||||
// 2 <- 4
|
||||
send_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,2, 3);
|
||||
recv_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,2, 4);
|
||||
send_inst(&mut inst_builder, 2, 3);
|
||||
recv_inst(&mut inst_builder, 2, 4);
|
||||
core_instruction_builder.set_core(2, inst_builder.build());
|
||||
|
||||
// 3 <- 2
|
||||
// 3 <- 4
|
||||
// 3 <- 1
|
||||
recv_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,3, 2);
|
||||
recv_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,3, 4);
|
||||
recv_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,3, 1);
|
||||
recv_inst(&mut inst_builder, 3, 2);
|
||||
recv_inst(&mut inst_builder, 3, 4);
|
||||
recv_inst(&mut inst_builder, 3, 1);
|
||||
core_instruction_builder.set_core(3, inst_builder.build());
|
||||
// 4 -> 2
|
||||
// 4 -> 3
|
||||
send_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,4, 2);
|
||||
send_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,4, 3);
|
||||
send_inst(&mut inst_builder, 4, 2);
|
||||
send_inst(&mut inst_builder, 4, 3);
|
||||
core_instruction_builder.set_core(4, inst_builder.build());
|
||||
|
||||
let mut executable = Executable::new(cpu, core_instruction_builder.build());
|
||||
|
||||
Submodule backend-simulators/pim/pimsim-nn updated: 3e3442b663...6d3b898e6b
@@ -10,6 +10,56 @@ set(PIM_INCLUDE_PATH ${CMAKE_INCLUDE_OUTPUT_DIRECTORY})
|
||||
set(PIM_ONNX_MLIR_SRC_ROOT ${ONNX_MLIR_SRC_ROOT})
|
||||
set(PIM_ONNX_MLIR_BIN_ROOT ${ONNX_MLIR_BIN_ROOT})
|
||||
|
||||
set(PIM_GENERATED_PATH_SHIM_TARGET "")
|
||||
get_filename_component(PIM_BIN_ROOT_NAME "${PIM_BIN_ROOT}" NAME)
|
||||
if (PIM_BIN_ROOT_NAME STREQUAL "raptor-external")
|
||||
get_filename_component(PIM_GENERATED_PATH_SHIM_ROOT "${PIM_BIN_ROOT}" DIRECTORY)
|
||||
set(PIM_GENERATED_PATH_SHIM_OUTPUTS)
|
||||
|
||||
function(add_pim_generated_path_shim relative_path)
|
||||
set(real_file "${PIM_BIN_ROOT}/${relative_path}")
|
||||
set(shim_file "${PIM_GENERATED_PATH_SHIM_ROOT}/${relative_path}")
|
||||
get_filename_component(shim_dir "${shim_file}" DIRECTORY)
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT "${shim_file}"
|
||||
DEPENDS "${real_file}"
|
||||
COMMAND "${CMAKE_COMMAND}" -E make_directory "${shim_dir}"
|
||||
COMMAND "${CMAKE_COMMAND}" -E rm -f "${shim_file}"
|
||||
COMMAND "${CMAKE_COMMAND}" -E create_symlink "${real_file}" "${shim_file}"
|
||||
VERBATIM
|
||||
)
|
||||
|
||||
list(APPEND PIM_GENERATED_PATH_SHIM_OUTPUTS "${shim_file}")
|
||||
set(PIM_GENERATED_PATH_SHIM_OUTPUTS "${PIM_GENERATED_PATH_SHIM_OUTPUTS}" PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
file(GLOB_RECURSE pim_generated_path_scan_sources
|
||||
CONFIGURE_DEPENDS
|
||||
"${PIM_SRC_ROOT}/*.cpp"
|
||||
"${PIM_SRC_ROOT}/*.hpp"
|
||||
)
|
||||
|
||||
set(pim_generated_path_shims)
|
||||
foreach (source_file IN LISTS pim_generated_path_scan_sources)
|
||||
file(READ "${source_file}" source_contents)
|
||||
string(REGEX MATCHALL "#include \"src/Accelerators/PIM/[^\"]+\\.inc\"" source_inc_matches "${source_contents}")
|
||||
|
||||
foreach (inc_match IN LISTS source_inc_matches)
|
||||
string(REGEX REPLACE "^#include \"src/Accelerators/PIM/(.+)\"$" "\\1" relative_inc_path "${inc_match}")
|
||||
list(APPEND pim_generated_path_shims "${relative_inc_path}")
|
||||
endforeach ()
|
||||
endforeach ()
|
||||
|
||||
list(REMOVE_DUPLICATES pim_generated_path_shims)
|
||||
foreach (relative_inc_path IN LISTS pim_generated_path_shims)
|
||||
add_pim_generated_path_shim("${relative_inc_path}")
|
||||
endforeach ()
|
||||
|
||||
add_custom_target(OMPimGeneratedPathShims DEPENDS ${PIM_GENERATED_PATH_SHIM_OUTPUTS})
|
||||
set(PIM_GENERATED_PATH_SHIM_TARGET OMPimGeneratedPathShims)
|
||||
endif ()
|
||||
|
||||
set(PIM_PUBLIC_INCLUDE_DIRS
|
||||
${ONNX_MLIR_SRC_ROOT}/include
|
||||
${ONNX_MLIR_SRC_ROOT}
|
||||
@@ -37,6 +87,9 @@ set(PIM_GENERATED_INCLUDE_DIRS
|
||||
|
||||
function(add_pim_library name)
|
||||
add_onnx_mlir_library(${name} STATIC ${ARGN})
|
||||
if (PIM_GENERATED_PATH_SHIM_TARGET)
|
||||
add_dependencies(${name} ${PIM_GENERATED_PATH_SHIM_TARGET})
|
||||
endif ()
|
||||
endfunction()
|
||||
|
||||
add_subdirectory(Dialect)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
add_pim_library(OMPimCommon
|
||||
IR/AddressAnalysis.cpp
|
||||
IR/ConstantUtils.cpp
|
||||
IR/CoreBlockUtils.cpp
|
||||
IR/EntryPointUtils.cpp
|
||||
IR/ShapeUtils.cpp
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
|
||||
@@ -55,6 +57,47 @@ mlir::Value resolveLoopCarriedAliasImpl(mlir::Value value, const StaticValueKnow
|
||||
}
|
||||
|
||||
llvm::FailureOr<int64_t> resolveOpFoldResult(mlir::OpFoldResult ofr, const StaticValueKnowledge* knowledge);
|
||||
llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticValueKnowledge* knowledge);
|
||||
|
||||
static llvm::FailureOr<int64_t> resolveConstantGlobalLoad(mlir::memref::LoadOp loadOp,
|
||||
const StaticValueKnowledge* knowledge) {
|
||||
auto getGlobalOp = loadOp.getMemRef().getDefiningOp<mlir::memref::GetGlobalOp>();
|
||||
if (!getGlobalOp)
|
||||
return mlir::failure();
|
||||
|
||||
auto moduleOp = loadOp->getParentOfType<mlir::ModuleOp>();
|
||||
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||
if (!globalOp || !globalOp.getConstant() || !globalOp.getInitialValue())
|
||||
return mlir::failure();
|
||||
|
||||
auto denseAttr = mlir::dyn_cast<mlir::DenseElementsAttr>(*globalOp.getInitialValue());
|
||||
auto globalType = mlir::dyn_cast<mlir::MemRefType>(getGlobalOp.getType());
|
||||
if (!denseAttr || !globalType || !globalType.hasStaticShape())
|
||||
return mlir::failure();
|
||||
|
||||
auto elementType = denseAttr.getElementType();
|
||||
if (!elementType.isIndex() && !elementType.isInteger())
|
||||
return mlir::failure();
|
||||
|
||||
llvm::SmallVector<int64_t> indices;
|
||||
indices.reserve(loadOp.getIndices().size());
|
||||
for (mlir::Value index : loadOp.getIndices()) {
|
||||
auto resolvedIndex = resolveIndexValueImpl(index, knowledge);
|
||||
if (failed(resolvedIndex))
|
||||
return mlir::failure();
|
||||
indices.push_back(*resolvedIndex);
|
||||
}
|
||||
|
||||
if (indices.size() != static_cast<size_t>(globalType.getRank()))
|
||||
return mlir::failure();
|
||||
|
||||
auto strides = computeRowMajorStrides(globalType.getShape());
|
||||
int64_t linearIndex = linearizeIndex(indices, strides);
|
||||
if (linearIndex < 0 || linearIndex >= globalType.getNumElements())
|
||||
return mlir::failure();
|
||||
|
||||
return denseAttr.getValues<llvm::APInt>()[linearIndex].getSExtValue();
|
||||
}
|
||||
|
||||
llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticValueKnowledge* knowledge) {
|
||||
value = resolveAlias(value, knowledge);
|
||||
@@ -110,6 +153,14 @@ llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticVa
|
||||
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) / static_cast<uint64_t>(*rhs));
|
||||
}
|
||||
|
||||
if (auto minOp = mlir::dyn_cast<mlir::arith::MinUIOp>(definingOp)) {
|
||||
auto lhs = resolveIndexValueImpl(minOp.getLhs(), knowledge);
|
||||
auto rhs = resolveIndexValueImpl(minOp.getRhs(), knowledge);
|
||||
if (failed(lhs) || failed(rhs))
|
||||
return mlir::failure();
|
||||
return static_cast<int64_t>(std::min(static_cast<uint64_t>(*lhs), static_cast<uint64_t>(*rhs)));
|
||||
}
|
||||
|
||||
if (auto remOp = mlir::dyn_cast<mlir::arith::RemUIOp>(definingOp)) {
|
||||
auto lhs = resolveIndexValueImpl(remOp.getLhs(), knowledge);
|
||||
auto rhs = resolveIndexValueImpl(remOp.getRhs(), knowledge);
|
||||
@@ -118,6 +169,9 @@ llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticVa
|
||||
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) % static_cast<uint64_t>(*rhs));
|
||||
}
|
||||
|
||||
if (auto loadOp = mlir::dyn_cast<mlir::memref::LoadOp>(definingOp))
|
||||
return resolveConstantGlobalLoad(loadOp, knowledge);
|
||||
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
@@ -210,7 +264,7 @@ llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddressImpl(mlir::Va
|
||||
return mlir::failure();
|
||||
|
||||
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
||||
byteOffset += linearizeIndex(offsets, sourceStrides) * subviewType.getElementTypeBitWidth() / 8;
|
||||
byteOffset += linearizeIndex(offsets, sourceStrides) * getElementTypeSizeInBytes(subviewType.getElementType());
|
||||
value = resolveAlias(subviewOp.getSource(), knowledge);
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,62 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
|
||||
#include "ConstantUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
Block* getHostConstantBlock(Operation* anchorOp) {
|
||||
assert(anchorOp && "expected a valid anchor operation");
|
||||
|
||||
for (Operation* current = anchorOp; current; current = current->getParentOp())
|
||||
if (isa<spatial::SpatCompute, spatial::SpatComputeBatch, pim::PimCoreOp, pim::PimCoreBatchOp>(current))
|
||||
return current->getBlock();
|
||||
|
||||
if (auto funcOp = anchorOp->getParentOfType<func::FuncOp>())
|
||||
return &funcOp.getBody().front();
|
||||
if (auto moduleOp = anchorOp->getParentOfType<ModuleOp>())
|
||||
return moduleOp.getBody();
|
||||
return anchorOp->getBlock();
|
||||
}
|
||||
|
||||
Value getOrCreateHostConstant(Operation* anchorOp, Attribute value, Type type, OperationFolder& folder) {
|
||||
assert(anchorOp && "expected a valid anchor operation");
|
||||
Block* hostBlock = getHostConstantBlock(anchorOp);
|
||||
for (Operation& op : *hostBlock) {
|
||||
auto constantOp = dyn_cast<arith::ConstantOp>(&op);
|
||||
if (!constantOp || constantOp.getType() != type || constantOp.getValue() != value)
|
||||
continue;
|
||||
return constantOp.getResult();
|
||||
}
|
||||
|
||||
auto* arithDialect = anchorOp->getContext()->getOrLoadDialect<arith::ArithDialect>();
|
||||
return folder.getOrCreateConstant(hostBlock, arithDialect, value, type);
|
||||
}
|
||||
|
||||
Value getOrCreateHostConstantLike(arith::ConstantOp constantOp, OperationFolder& folder) {
|
||||
return getOrCreateHostConstant(constantOp.getOperation(), constantOp.getValue(), constantOp.getType(), folder);
|
||||
}
|
||||
|
||||
Value getOrCreateHostIndexConstant(Operation* anchorOp, int64_t value, OperationFolder& folder) {
|
||||
Builder builder(anchorOp->getContext());
|
||||
return getOrCreateHostConstant(anchorOp, builder.getIndexAttr(value), builder.getIndexType(), folder);
|
||||
}
|
||||
|
||||
Value getOrCreateHostI32Constant(Operation* anchorOp, int32_t value, OperationFolder& folder) {
|
||||
Builder builder(anchorOp->getContext());
|
||||
return getOrCreateHostConstant(anchorOp, builder.getI32IntegerAttr(value), builder.getI32Type(), folder);
|
||||
}
|
||||
|
||||
Value getOrCreateHostI64Constant(Operation* anchorOp, int64_t value, OperationFolder& folder) {
|
||||
Builder builder(anchorOp->getContext());
|
||||
return getOrCreateHostConstant(anchorOp, builder.getI64IntegerAttr(value), builder.getI64Type(), folder);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,28 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Transforms/FoldUtils.h"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
mlir::Block* getHostConstantBlock(mlir::Operation* anchorOp);
|
||||
|
||||
mlir::Value getOrCreateHostConstant(mlir::Operation* anchorOp,
|
||||
mlir::Attribute value,
|
||||
mlir::Type type,
|
||||
mlir::OperationFolder& folder);
|
||||
|
||||
mlir::Value getOrCreateHostConstantLike(mlir::arith::ConstantOp constantOp, mlir::OperationFolder& folder);
|
||||
|
||||
mlir::Value getOrCreateHostIndexConstant(mlir::Operation* anchorOp, int64_t value, mlir::OperationFolder& folder);
|
||||
|
||||
mlir::Value getOrCreateHostI32Constant(mlir::Operation* anchorOp, int32_t value, mlir::OperationFolder& folder);
|
||||
|
||||
mlir::Value getOrCreateHostI64Constant(mlir::Operation* anchorOp, int64_t value, mlir::OperationFolder& folder);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -1,4 +1,5 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp"
|
||||
@@ -12,6 +13,7 @@ bool isCoreStaticAddressOp(mlir::Operation* op) {
|
||||
mlir::arith::SubIOp,
|
||||
mlir::arith::MulIOp,
|
||||
mlir::arith::DivUIOp,
|
||||
mlir::arith::MinUIOp,
|
||||
mlir::arith::RemUIOp,
|
||||
mlir::arith::IndexCastOp,
|
||||
mlir::memref::AllocOp,
|
||||
@@ -29,6 +31,9 @@ walkPimCoreBlock(mlir::Block& block,
|
||||
for (mlir::Operation& op : block) {
|
||||
if (mlir::isa<pim::PimHaltOp, mlir::scf::YieldOp>(op) || isCoreStaticAddressOp(&op))
|
||||
continue;
|
||||
if (auto loadOp = mlir::dyn_cast<mlir::memref::LoadOp>(op);
|
||||
loadOp && succeeded(resolveIndexValue(loadOp.getResult(), knowledge)))
|
||||
continue;
|
||||
|
||||
if (auto forOp = mlir::dyn_cast<mlir::scf::ForOp>(op)) {
|
||||
mlir::Block& loopBody = forOp.getRegion().front();
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||
|
||||
@@ -35,6 +36,30 @@ int64_t getNumElements(llvm::ArrayRef<int64_t> shape) {
|
||||
return numElements;
|
||||
}
|
||||
|
||||
bool hasByteSizedElementType(mlir::Type elementType) {
|
||||
if (mlir::isa<mlir::IndexType>(elementType))
|
||||
return true;
|
||||
if (auto intType = mlir::dyn_cast<mlir::IntegerType>(elementType))
|
||||
return intType.getWidth() > 0 && intType.getWidth() % 8 == 0;
|
||||
if (auto floatType = mlir::dyn_cast<mlir::FloatType>(elementType))
|
||||
return floatType.getWidth() > 0 && floatType.getWidth() % 8 == 0;
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t getElementTypeSizeInBytes(mlir::Type elementType) {
|
||||
if (mlir::isa<mlir::IndexType>(elementType))
|
||||
return mlir::IndexType::kInternalStorageBitWidth / 8;
|
||||
if (auto intType = mlir::dyn_cast<mlir::IntegerType>(elementType))
|
||||
return static_cast<size_t>(intType.getWidth() / 8);
|
||||
if (auto floatType = mlir::dyn_cast<mlir::FloatType>(elementType))
|
||||
return static_cast<size_t>(floatType.getWidth() / 8);
|
||||
llvm_unreachable("expected byte-sized integer, float, or index element type");
|
||||
}
|
||||
|
||||
size_t getShapedTypeSizeInBytes(mlir::ShapedType shapedType) {
|
||||
return static_cast<size_t>(shapedType.getNumElements()) * getElementTypeSizeInBytes(shapedType.getElementType());
|
||||
}
|
||||
|
||||
bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
|
||||
llvm::ArrayRef<int64_t> offsets,
|
||||
llvm::ArrayRef<int64_t> sizes,
|
||||
|
||||
@@ -1,8 +1,13 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <cstddef>
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
llvm::SmallVector<int64_t> computeRowMajorStrides(llvm::ArrayRef<int64_t> shape);
|
||||
@@ -14,6 +19,12 @@ int64_t linearizeIndex(llvm::ArrayRef<int64_t> indices, llvm::ArrayRef<int64_t>
|
||||
|
||||
int64_t getNumElements(llvm::ArrayRef<int64_t> shape);
|
||||
|
||||
bool hasByteSizedElementType(mlir::Type elementType);
|
||||
|
||||
size_t getElementTypeSizeInBytes(mlir::Type elementType);
|
||||
|
||||
size_t getShapedTypeSizeInBytes(mlir::ShapedType shapedType);
|
||||
|
||||
bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
|
||||
llvm::ArrayRef<int64_t> offsets,
|
||||
llvm::ArrayRef<int64_t> sizes,
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
|
||||
|
||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
@@ -21,12 +21,15 @@ namespace {
|
||||
|
||||
template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy>
|
||||
bool hasMvmVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) {
|
||||
auto weightArg = parentOp.getWeightArgument(weightIndex);
|
||||
if (!weightArg)
|
||||
return false;
|
||||
bool found = false;
|
||||
parentOp.walk([&](mlir::Operation* op) {
|
||||
if (auto mvmOp = mlir::dyn_cast<MVMOpTy>(op))
|
||||
found |= mvmOp.getWeightIndex() == weightIndex;
|
||||
found |= mvmOp.getWeight() == *weightArg;
|
||||
else if (auto vmmOp = mlir::dyn_cast<VMMOpTy>(op))
|
||||
found |= vmmOp.getWeightIndex() == weightIndex;
|
||||
found |= vmmOp.getWeight() == *weightArg;
|
||||
});
|
||||
return found;
|
||||
}
|
||||
@@ -35,13 +38,19 @@ template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy>
|
||||
void walkMvmVmmWeightUses(ParentOpTy parentOp, llvm::function_ref<void(mlir::OpOperand&)> callback) {
|
||||
auto weights = parentOp.getWeights();
|
||||
llvm::SmallSet<unsigned, 8> visited;
|
||||
auto walkWeightIndex = [&](unsigned weightIndex) {
|
||||
if (weightIndex < weights.size() && visited.insert(weightIndex).second)
|
||||
callback(parentOp->getOpOperand(weightIndex));
|
||||
auto walkWeight = [&](mlir::Value weight) {
|
||||
for (unsigned weightIndex = 0; weightIndex < weights.size(); ++weightIndex) {
|
||||
auto weightArg = parentOp.getWeightArgument(weightIndex);
|
||||
if (!weightArg || *weightArg != weight)
|
||||
continue;
|
||||
if (visited.insert(weightIndex).second)
|
||||
callback(parentOp->getOpOperand(weightIndex));
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
parentOp.walk([&](MVMOpTy op) { walkWeightIndex(op.getWeightIndex()); });
|
||||
parentOp.walk([&](VMMOpTy op) { walkWeightIndex(op.getWeightIndex()); });
|
||||
parentOp.walk([&](MVMOpTy op) { walkWeight(op.getWeight()); });
|
||||
parentOp.walk([&](VMMOpTy op) { walkWeight(op.getWeight()); });
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@@ -90,18 +99,21 @@ void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir
|
||||
assert(root && "expected valid root op");
|
||||
root->walk([&](pim::PimCoreOp coreOp) {
|
||||
coreOp.walk([&](pim::PimVMMOp vmmOp) {
|
||||
auto weights = coreOp.getWeights();
|
||||
unsigned weightIndex = vmmOp.getWeightIndex();
|
||||
if (weightIndex < weights.size())
|
||||
callback(coreOp->getOpOperand(weightIndex));
|
||||
for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex)
|
||||
if (coreOp.getWeightArgument(weightIndex) == vmmOp.getWeight()) {
|
||||
callback(coreOp->getOpOperand(weightIndex));
|
||||
break;
|
||||
}
|
||||
});
|
||||
});
|
||||
root->walk([&](pim::PimCoreBatchOp coreBatchOp) {
|
||||
auto weights = coreBatchOp.getWeights();
|
||||
for (auto weight : weights)
|
||||
for (mlir::OpOperand& use : weight.getUses())
|
||||
if (use.getOwner() == coreBatchOp.getOperation())
|
||||
callback(use);
|
||||
coreBatchOp.walk([&](pim::PimVMMOp vmmOp) {
|
||||
for (unsigned weightIndex = 0; weightIndex < coreBatchOp.getWeights().size(); ++weightIndex)
|
||||
if (coreBatchOp.getWeightArgument(weightIndex) == vmmOp.getWeight()) {
|
||||
callback(coreBatchOp->getOpOperand(weightIndex));
|
||||
break;
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/EntryPointUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||
|
||||
@@ -7,10 +7,34 @@
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <system_error>
|
||||
|
||||
namespace onnx_mlir::pim {
|
||||
|
||||
struct CappedDiagnosticReporter {
|
||||
explicit CappedDiagnosticReporter(int64_t maxReportedFailures = 8)
|
||||
: maxReportedFailures(maxReportedFailures) {}
|
||||
|
||||
template <typename EmitFn>
|
||||
void report(mlir::Operation* op, EmitFn&& emit) {
|
||||
numFailures++;
|
||||
if (numFailures <= maxReportedFailures)
|
||||
emit(op);
|
||||
}
|
||||
|
||||
void emitSuppressedSummary(mlir::Operation* op, llvm::StringRef failureDescription) const {
|
||||
if (numFailures > maxReportedFailures)
|
||||
op->emitError() << "suppressed " << (numFailures - maxReportedFailures) << " additional " << failureDescription;
|
||||
}
|
||||
|
||||
bool hasFailure() const { return numFailures != 0; }
|
||||
|
||||
private:
|
||||
int64_t maxReportedFailures;
|
||||
int64_t numFailures = 0;
|
||||
};
|
||||
|
||||
/// Emits a consistent diagnostic for target paths that require static shapes.
|
||||
mlir::InFlightDiagnostic emitUnsupportedStaticShapeDiagnostic(mlir::Operation* op, llvm::StringRef valueDescription);
|
||||
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
#include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp"
|
||||
|
||||
#include "llvm/Support/Format.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/Support/FileSystemUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
#pragma once
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <fstream>
|
||||
#include <limits>
|
||||
#include <string>
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimArtifactWriter.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimBinaryFormat.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
|
||||
@@ -19,21 +20,6 @@ using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
OnnxMlirCompilerErrorCodes writeHostCoreJson(StringRef outputDirPath) {
|
||||
std::error_code errorCode;
|
||||
std::string outputHostCorePath = outputDirPath.str() + "/core_0.json";
|
||||
raw_fd_ostream hostFileStream(outputHostCorePath, errorCode);
|
||||
if (errorCode) {
|
||||
errs() << "Error while opening host core file `" << outputHostCorePath << "`: " << errorCode.message() << '\n';
|
||||
return InvalidOutputFileAccess;
|
||||
}
|
||||
|
||||
// The host core json contains two no-op-like instructions to satisfy pimsim-nn.
|
||||
hostFileStream << "[{\"imm\":0,\"op\":\"sldi\",\"rd\":0},{\"imm\":0,\"op\":\"sldi\",\"rd\":0}]";
|
||||
hostFileStream.close();
|
||||
return CompilerSuccess;
|
||||
}
|
||||
|
||||
OnnxMlirCompilerErrorCodes
|
||||
writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory& memory, StringRef outputDirPath) {
|
||||
auto memoryFilePath = (outputDirPath + "/memory.bin").str();
|
||||
@@ -91,9 +77,6 @@ OnnxMlirCompilerErrorCodes writeConfigJson(func::FuncOp funcOp,
|
||||
json::Object configJson;
|
||||
|
||||
configJson["core_cnt"] = maxCoreId + 1;
|
||||
configJson["adc_count"] = 16;
|
||||
configJson["cell_precision"] = 2;
|
||||
configJson["xbar_array_count"] = crossbarCountInCore.getValue();
|
||||
configJson["xbar_size"] = {crossbarSize.getValue(), crossbarSize.getValue()};
|
||||
configJson["array_group_map"] = std::move(xbarsPerArrayGroup);
|
||||
|
||||
|
||||
@@ -12,7 +12,6 @@ namespace onnx_mlir {
|
||||
|
||||
class PimAcceleratorMemory;
|
||||
|
||||
OnnxMlirCompilerErrorCodes writeHostCoreJson(llvm::StringRef outputDirPath);
|
||||
OnnxMlirCompilerErrorCodes writeMemoryBinary(mlir::ModuleOp moduleOp,
|
||||
mlir::func::FuncOp funcOp,
|
||||
PimAcceleratorMemory& memory,
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
|
||||
|
||||
@@ -24,113 +28,166 @@ static SmallVector<int32_t> getLaneChunkCoreIds(ArrayRef<int32_t> coreIds, size_
|
||||
return laneCoreIds;
|
||||
}
|
||||
|
||||
static void scalarizeBatchOpsInCore(pim::PimCoreOp scalarCore, size_t laneCount, unsigned lane) {
|
||||
IRRewriter rewriter(scalarCore.getContext());
|
||||
SmallVector<Operation*> batchOps;
|
||||
scalarCore.walk([&](Operation* op) {
|
||||
if (isa<pim::PimSendBatchOp,
|
||||
pim::PimSendTensorBatchOp,
|
||||
pim::PimReceiveBatchOp,
|
||||
pim::PimReceiveTensorBatchOp,
|
||||
pim::PimMemCopyHostToDevBatchOp>(op)) {
|
||||
batchOps.push_back(op);
|
||||
}
|
||||
});
|
||||
static Value getOrCloneCapturedValue(OpBuilder& builder, Block& oldBlock, Value value, IRMapping& mapper) {
|
||||
if (Value mapped = mapper.lookupOrNull(value))
|
||||
return mapped;
|
||||
|
||||
for (Operation* op : batchOps) {
|
||||
rewriter.setInsertionPoint(op);
|
||||
if (auto blockArgument = dyn_cast<BlockArgument>(value)) {
|
||||
assert(blockArgument.getOwner() != &oldBlock && "expected block argument to be mapped before cloning");
|
||||
assert(false && "unexpected captured block argument while scalarizing pim.core_batch");
|
||||
}
|
||||
|
||||
Operation* definingOp = value.getDefiningOp();
|
||||
assert(definingOp && "expected captured value to be defined by an operation");
|
||||
assert(definingOp->getBlock() != &oldBlock && "expected in-block value to be mapped before cloning");
|
||||
|
||||
for (Value operand : definingOp->getOperands())
|
||||
(void) getOrCloneCapturedValue(builder, oldBlock, operand, mapper);
|
||||
|
||||
Operation* cloned = builder.clone(*definingOp, mapper);
|
||||
for (auto [originalResult, clonedResult] : llvm::zip(definingOp->getResults(), cloned->getResults()))
|
||||
mapper.map(originalResult, clonedResult);
|
||||
return mapper.lookup(value);
|
||||
}
|
||||
|
||||
static void cloneScalarizedLaneBody(OpBuilder& builder,
|
||||
pim::PimCoreBatchOp coreBatchOp,
|
||||
unsigned lane,
|
||||
OperationFolder& constantFolder) {
|
||||
Block& oldBlock = coreBatchOp.getBody().front();
|
||||
Operation* anchorOp = builder.getInsertionBlock()->getParentOp();
|
||||
size_t laneCount = static_cast<size_t>(coreBatchOp.getLaneCount());
|
||||
size_t weightCount = coreBatchOp.getWeights().size();
|
||||
|
||||
IRMapping mapper;
|
||||
for (auto [argIndex, blockArg] : llvm::enumerate(oldBlock.getArguments())) {
|
||||
if (blockArg.getType().isIndex()) {
|
||||
mapper.map(blockArg, getOrCreateHostIndexConstant(anchorOp, static_cast<int64_t>(lane), constantFolder));
|
||||
continue;
|
||||
}
|
||||
|
||||
if (argIndex <= weightCount) {
|
||||
auto scalarCoreOp = cast<pim::PimCoreOp>(anchorOp);
|
||||
mapper.map(blockArg, scalarCoreOp.getWeightArgument(argIndex - 1));
|
||||
continue;
|
||||
}
|
||||
|
||||
size_t inputIndex = argIndex - 1 - weightCount;
|
||||
assert(inputIndex < coreBatchOp.getInputs().size() && "pim.core_batch block input index out of range");
|
||||
mapper.map(blockArg, coreBatchOp.getInputs()[inputIndex]);
|
||||
}
|
||||
|
||||
for (Operation& op : oldBlock) {
|
||||
if (isa<pim::PimHaltOp>(op))
|
||||
continue;
|
||||
|
||||
for (Value operand : op.getOperands())
|
||||
(void) getOrCloneCapturedValue(builder, oldBlock, operand, mapper);
|
||||
|
||||
if (auto sendBatchOp = dyn_cast<pim::PimSendBatchOp>(op)) {
|
||||
pim::PimSendOp::create(rewriter,
|
||||
sendBatchOp.getLoc(),
|
||||
sendBatchOp.getInput(),
|
||||
sendBatchOp.getSizeAttr(),
|
||||
rewriter.getI32IntegerAttr(sendBatchOp.getTargetCoreIds()[lane]));
|
||||
rewriter.eraseOp(op);
|
||||
pim::PimSendOp::create(
|
||||
builder,
|
||||
sendBatchOp.getLoc(),
|
||||
mapper.lookup(sendBatchOp.getInput()),
|
||||
sendBatchOp.getSizeAttr(),
|
||||
getOrCreateHostIndexConstant(anchorOp, sendBatchOp.getTargetCoreIds()[lane], constantFolder));
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto sendTensorBatchOp = dyn_cast<pim::PimSendTensorBatchOp>(op)) {
|
||||
pim::PimSendTensorOp::create(
|
||||
rewriter,
|
||||
builder,
|
||||
sendTensorBatchOp.getLoc(),
|
||||
sendTensorBatchOp.getInput(),
|
||||
rewriter.getDenseI32ArrayAttr(getLaneChunkCoreIds(sendTensorBatchOp.getTargetCoreIds(), laneCount, lane)));
|
||||
rewriter.eraseOp(op);
|
||||
mapper.lookup(sendTensorBatchOp.getInput()),
|
||||
builder.getDenseI32ArrayAttr(getLaneChunkCoreIds(sendTensorBatchOp.getTargetCoreIds(), laneCount, lane)));
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto receiveBatchOp = dyn_cast<pim::PimReceiveBatchOp>(op)) {
|
||||
auto scalarReceive =
|
||||
pim::PimReceiveOp::create(rewriter,
|
||||
receiveBatchOp.getLoc(),
|
||||
receiveBatchOp.getOutput().getType(),
|
||||
receiveBatchOp.getOutputBuffer(),
|
||||
receiveBatchOp.getSizeAttr(),
|
||||
rewriter.getI32IntegerAttr(receiveBatchOp.getSourceCoreIds()[lane]));
|
||||
rewriter.replaceOp(op, scalarReceive->getResults());
|
||||
auto scalarReceive = pim::PimReceiveOp::create(
|
||||
builder,
|
||||
receiveBatchOp.getLoc(),
|
||||
receiveBatchOp.getOutput().getType(),
|
||||
mapper.lookup(receiveBatchOp.getOutputBuffer()),
|
||||
receiveBatchOp.getSizeAttr(),
|
||||
getOrCreateHostIndexConstant(anchorOp, receiveBatchOp.getSourceCoreIds()[lane], constantFolder));
|
||||
mapper.map(receiveBatchOp.getOutput(), scalarReceive.getOutput());
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto receiveTensorBatchOp = dyn_cast<pim::PimReceiveTensorBatchOp>(op)) {
|
||||
auto scalarReceive = pim::PimReceiveTensorOp::create(
|
||||
rewriter,
|
||||
builder,
|
||||
receiveTensorBatchOp.getLoc(),
|
||||
receiveTensorBatchOp.getOutput().getType(),
|
||||
receiveTensorBatchOp.getOutputBuffer(),
|
||||
rewriter.getDenseI32ArrayAttr(getLaneChunkCoreIds(receiveTensorBatchOp.getSourceCoreIds(), laneCount, lane)));
|
||||
rewriter.replaceOp(op, scalarReceive->getResults());
|
||||
mapper.lookup(receiveTensorBatchOp.getOutputBuffer()),
|
||||
builder.getDenseI32ArrayAttr(getLaneChunkCoreIds(receiveTensorBatchOp.getSourceCoreIds(), laneCount, lane)));
|
||||
mapper.map(receiveTensorBatchOp.getOutput(), scalarReceive.getOutput());
|
||||
continue;
|
||||
}
|
||||
|
||||
auto memcpBatchOp = cast<pim::PimMemCopyHostToDevBatchOp>(op);
|
||||
auto scalarCopy = pim::PimMemCopyHostToDevOp::create(rewriter,
|
||||
memcpBatchOp.getLoc(),
|
||||
memcpBatchOp.getOutput().getType(),
|
||||
memcpBatchOp.getDeviceTarget(),
|
||||
memcpBatchOp.getHostSource(),
|
||||
memcpBatchOp.getDeviceTargetOffsetAttr(),
|
||||
memcpBatchOp.getHostSourceOffsetAttr(),
|
||||
memcpBatchOp.getSizeAttr());
|
||||
rewriter.replaceOp(op, scalarCopy->getResults());
|
||||
if (auto memcpBatchOp = dyn_cast<pim::PimMemCopyHostToDevBatchOp>(op)) {
|
||||
auto scalarCopy = pim::PimMemCopyHostToDevOp::create(
|
||||
builder,
|
||||
memcpBatchOp.getLoc(),
|
||||
memcpBatchOp.getOutput().getType(),
|
||||
getOrCreateHostIndexConstant(anchorOp, memcpBatchOp.getDeviceTargetOffset(), constantFolder),
|
||||
getOrCreateHostIndexConstant(anchorOp, memcpBatchOp.getHostSourceOffset(), constantFolder),
|
||||
mapper.lookup(memcpBatchOp.getDeviceTarget()),
|
||||
mapper.lookup(memcpBatchOp.getHostSource()),
|
||||
memcpBatchOp.getSizeAttr());
|
||||
mapper.map(memcpBatchOp.getOutput(), scalarCopy.getOutput());
|
||||
continue;
|
||||
}
|
||||
|
||||
Operation* cloned = builder.clone(op, mapper);
|
||||
for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults()))
|
||||
mapper.map(originalResult, clonedResult);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult withScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp,
|
||||
unsigned lane,
|
||||
llvm::function_ref<LogicalResult(pim::PimCoreOp)> callback) {
|
||||
LogicalResult withScalarCoreFromBatchLanes(pim::PimCoreBatchOp coreBatchOp,
|
||||
ArrayRef<unsigned> lanes,
|
||||
llvm::function_ref<LogicalResult(pim::PimCoreOp)> callback) {
|
||||
assert(!lanes.empty() && "expected at least one batch lane");
|
||||
|
||||
OwningOpRef<ModuleOp> scratchModule = ModuleOp::create(coreBatchOp.getLoc());
|
||||
OpBuilder builder(scratchModule->getContext());
|
||||
OperationFolder constantFolder(scratchModule->getContext());
|
||||
builder.setInsertionPointToStart(scratchModule->getBody());
|
||||
|
||||
size_t laneCount = static_cast<size_t>(coreBatchOp.getLaneCount());
|
||||
size_t weightsPerLane = coreBatchOp.getWeights().size() / laneCount;
|
||||
SmallVector<Value> laneWeights;
|
||||
laneWeights.reserve(weightsPerLane);
|
||||
for (size_t weightIndex = 0; weightIndex < weightsPerLane; ++weightIndex)
|
||||
laneWeights.push_back(coreBatchOp.getWeights()[lane * weightsPerLane + weightIndex]);
|
||||
|
||||
SmallVector<Value> weights(coreBatchOp.getWeights().begin(), coreBatchOp.getWeights().end());
|
||||
auto coreIds = getBatchCoreIds(coreBatchOp);
|
||||
auto scalarCore = pim::PimCoreOp::create(
|
||||
builder, coreBatchOp.getLoc(), ValueRange(laneWeights), builder.getI32IntegerAttr(coreIds[lane]));
|
||||
Block* block = builder.createBlock(&scalarCore.getBody(), scalarCore.getBody().end());
|
||||
IRMapping mapper;
|
||||
if (coreBatchOp.getBody().front().getNumArguments() == 1)
|
||||
mapper.map(coreBatchOp.getBody().front().getArgument(0), coreBatchOp.getInputs()[lane]);
|
||||
int32_t coreId = coreIds[lanes.front()];
|
||||
for (unsigned lane : lanes)
|
||||
assert(coreIds[lane] == coreId && "all grouped lanes must target the same core");
|
||||
|
||||
builder.setInsertionPointToEnd(block);
|
||||
for (Operation& op : coreBatchOp.getBody().front()) {
|
||||
Operation* cloned = builder.clone(op, mapper);
|
||||
for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults()))
|
||||
mapper.map(originalResult, clonedResult);
|
||||
auto scalarCore =
|
||||
pim::PimCoreOp::create(builder, coreBatchOp.getLoc(), ValueRange(weights), builder.getI32IntegerAttr(coreId));
|
||||
SmallVector<Type> weightTypes;
|
||||
SmallVector<Location> weightLocs;
|
||||
weightTypes.reserve(weights.size());
|
||||
weightLocs.reserve(weights.size());
|
||||
for (Value weight : weights) {
|
||||
weightTypes.push_back(weight.getType());
|
||||
weightLocs.push_back(weight.getLoc());
|
||||
}
|
||||
|
||||
Block* block =
|
||||
builder.createBlock(&scalarCore.getBody(), scalarCore.getBody().end(), TypeRange(weightTypes), weightLocs);
|
||||
builder.setInsertionPointToEnd(block);
|
||||
for (unsigned lane : lanes)
|
||||
cloneScalarizedLaneBody(builder, coreBatchOp, lane, constantFolder);
|
||||
if (block->empty() || !isa<pim::PimHaltOp>(block->back()))
|
||||
pim::PimHaltOp::create(builder, coreBatchOp.getLoc());
|
||||
scalarizeBatchOpsInCore(scalarCore, laneCount, lane);
|
||||
return callback(scalarCore);
|
||||
}
|
||||
|
||||
LogicalResult withScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp,
|
||||
unsigned lane,
|
||||
llvm::function_ref<LogicalResult(pim::PimCoreOp)> callback) {
|
||||
return withScalarCoreFromBatchLanes(coreBatchOp, ArrayRef<unsigned> {lane}, callback);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -9,5 +9,8 @@ namespace onnx_mlir {
|
||||
mlir::LogicalResult withScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp,
|
||||
unsigned lane,
|
||||
llvm::function_ref<mlir::LogicalResult(pim::PimCoreOp)> callback);
|
||||
mlir::LogicalResult withScalarCoreFromBatchLanes(pim::PimCoreBatchOp coreBatchOp,
|
||||
llvm::ArrayRef<unsigned> lanes,
|
||||
llvm::function_ref<mlir::LogicalResult(pim::PimCoreOp)> callback);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -0,0 +1,374 @@
|
||||
#pragma once
|
||||
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/Endian.h"
|
||||
#include "llvm/Support/JSON.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include <array>
|
||||
#include <cassert>
|
||||
#include <limits>
|
||||
|
||||
namespace onnx_mlir::pim_binary {
|
||||
|
||||
inline constexpr char kMagic[4] = {'P', 'I', 'M', 'B'};
|
||||
inline constexpr uint32_t kVersion = 1;
|
||||
inline constexpr uint64_t kCountOffset = 8;
|
||||
inline constexpr size_t kHeaderSize = 12;
|
||||
inline constexpr size_t kRecordSize = 20;
|
||||
|
||||
enum class Opcode : uint32_t {
|
||||
nop = 0,
|
||||
sldi = 1,
|
||||
sld = 2,
|
||||
sadd = 3,
|
||||
ssub = 4,
|
||||
smul = 5,
|
||||
saddi = 6,
|
||||
smuli = 7,
|
||||
setbw = 8,
|
||||
mvmul = 9,
|
||||
vvadd = 10,
|
||||
vvsub = 11,
|
||||
vvmul = 12,
|
||||
vvdmul = 13,
|
||||
vvmax = 14,
|
||||
vvsll = 15,
|
||||
vvsra = 16,
|
||||
vavg = 17,
|
||||
vrelu = 18,
|
||||
vtanh = 19,
|
||||
vsigm = 20,
|
||||
vsoftmax = 21,
|
||||
vmv = 22,
|
||||
vrsu = 23,
|
||||
vrsl = 24,
|
||||
ld = 25,
|
||||
st = 26,
|
||||
lldi = 27,
|
||||
lmv = 28,
|
||||
send = 29,
|
||||
recv = 30,
|
||||
wait = 31,
|
||||
sync = 32,
|
||||
};
|
||||
|
||||
struct InstructionRecord {
|
||||
Opcode opcode = Opcode::nop;
|
||||
uint8_t rd = 0;
|
||||
uint8_t r1 = 0;
|
||||
int32_t r2OrImm = 0;
|
||||
int32_t generic1 = 0;
|
||||
int32_t generic2 = 0;
|
||||
int32_t generic3 = 0;
|
||||
uint8_t flags = 0;
|
||||
};
|
||||
|
||||
inline void writeUint32LE(llvm::raw_ostream& os, uint32_t value) {
|
||||
std::array<char, sizeof(uint32_t)> bytes;
|
||||
llvm::support::endian::write32le(bytes.data(), value);
|
||||
os.write(bytes.data(), bytes.size());
|
||||
}
|
||||
|
||||
inline void writeInt32LE(llvm::raw_ostream& os, int32_t value) { writeUint32LE(os, static_cast<uint32_t>(value)); }
|
||||
|
||||
inline void writeHeader(llvm::raw_ostream& os) {
|
||||
os.write(kMagic, sizeof(kMagic));
|
||||
writeUint32LE(os, kVersion);
|
||||
writeUint32LE(os, 0);
|
||||
}
|
||||
|
||||
inline void patchInstructionCount(llvm::raw_pwrite_stream& os, uint32_t instructionCount) {
|
||||
std::array<char, sizeof(uint32_t)> bytes;
|
||||
llvm::support::endian::write32le(bytes.data(), instructionCount);
|
||||
os.pwrite(bytes.data(), bytes.size(), kCountOffset);
|
||||
}
|
||||
|
||||
inline void writeInstructionRecord(llvm::raw_ostream& os, const InstructionRecord& record) {
|
||||
os << static_cast<char>(static_cast<uint8_t>(record.opcode));
|
||||
os << static_cast<char>(record.rd);
|
||||
os << static_cast<char>(record.r1);
|
||||
os << static_cast<char>(record.flags);
|
||||
writeInt32LE(os, record.r2OrImm);
|
||||
writeInt32LE(os, record.generic1);
|
||||
writeInt32LE(os, record.generic2);
|
||||
writeInt32LE(os, record.generic3);
|
||||
}
|
||||
|
||||
inline int32_t toI32(int64_t value) {
|
||||
assert(value >= std::numeric_limits<int32_t>::min() && value <= std::numeric_limits<int32_t>::max()
|
||||
&& "PIM binary field out of int32 range");
|
||||
return static_cast<int32_t>(value);
|
||||
}
|
||||
|
||||
inline uint8_t toU8(int64_t value) {
|
||||
assert(value >= 0 && value <= std::numeric_limits<uint8_t>::max() && "PIM binary field out of uint8 range");
|
||||
return static_cast<uint8_t>(value);
|
||||
}
|
||||
|
||||
inline int32_t getOptionalInt(const llvm::json::Object& object, llvm::StringRef key, int32_t defaultValue = 0) {
|
||||
if (std::optional<int64_t> value = object.getInteger(key))
|
||||
return toI32(*value);
|
||||
return defaultValue;
|
||||
}
|
||||
|
||||
inline Opcode opcodeFromString(llvm::StringRef opName) {
|
||||
if (opName == "nop")
|
||||
return Opcode::nop;
|
||||
if (opName == "sldi")
|
||||
return Opcode::sldi;
|
||||
if (opName == "sld")
|
||||
return Opcode::sld;
|
||||
if (opName == "sadd")
|
||||
return Opcode::sadd;
|
||||
if (opName == "ssub")
|
||||
return Opcode::ssub;
|
||||
if (opName == "smul")
|
||||
return Opcode::smul;
|
||||
if (opName == "saddi")
|
||||
return Opcode::saddi;
|
||||
if (opName == "smuli")
|
||||
return Opcode::smuli;
|
||||
if (opName == "setbw")
|
||||
return Opcode::setbw;
|
||||
if (opName == "mvmul")
|
||||
return Opcode::mvmul;
|
||||
if (opName == "vvadd")
|
||||
return Opcode::vvadd;
|
||||
if (opName == "vvsub")
|
||||
return Opcode::vvsub;
|
||||
if (opName == "vvmul")
|
||||
return Opcode::vvmul;
|
||||
if (opName == "vvdmul")
|
||||
return Opcode::vvdmul;
|
||||
if (opName == "vvmax")
|
||||
return Opcode::vvmax;
|
||||
if (opName == "vvsll")
|
||||
return Opcode::vvsll;
|
||||
if (opName == "vvsra")
|
||||
return Opcode::vvsra;
|
||||
if (opName == "vavg")
|
||||
return Opcode::vavg;
|
||||
if (opName == "vrelu")
|
||||
return Opcode::vrelu;
|
||||
if (opName == "vtanh")
|
||||
return Opcode::vtanh;
|
||||
if (opName == "vsigm")
|
||||
return Opcode::vsigm;
|
||||
if (opName == "vsoftmax")
|
||||
return Opcode::vsoftmax;
|
||||
if (opName == "vmv")
|
||||
return Opcode::vmv;
|
||||
if (opName == "vrsu")
|
||||
return Opcode::vrsu;
|
||||
if (opName == "vrsl")
|
||||
return Opcode::vrsl;
|
||||
if (opName == "ld")
|
||||
return Opcode::ld;
|
||||
if (opName == "st")
|
||||
return Opcode::st;
|
||||
if (opName == "lldi")
|
||||
return Opcode::lldi;
|
||||
if (opName == "lmv")
|
||||
return Opcode::lmv;
|
||||
if (opName == "send")
|
||||
return Opcode::send;
|
||||
if (opName == "recv")
|
||||
return Opcode::recv;
|
||||
if (opName == "wait")
|
||||
return Opcode::wait;
|
||||
if (opName == "sync")
|
||||
return Opcode::sync;
|
||||
llvm_unreachable("Unsupported PIM binary opcode");
|
||||
}
|
||||
|
||||
inline llvm::StringRef opcodeToString(Opcode opcode) {
|
||||
switch (opcode) {
|
||||
case Opcode::nop: return "nop";
|
||||
case Opcode::sldi: return "sldi";
|
||||
case Opcode::sld: return "sld";
|
||||
case Opcode::sadd: return "sadd";
|
||||
case Opcode::ssub: return "ssub";
|
||||
case Opcode::smul: return "smul";
|
||||
case Opcode::saddi: return "saddi";
|
||||
case Opcode::smuli: return "smuli";
|
||||
case Opcode::setbw: return "setbw";
|
||||
case Opcode::mvmul: return "mvmul";
|
||||
case Opcode::vvadd: return "vvadd";
|
||||
case Opcode::vvsub: return "vvsub";
|
||||
case Opcode::vvmul: return "vvmul";
|
||||
case Opcode::vvdmul: return "vvdmul";
|
||||
case Opcode::vvmax: return "vvmax";
|
||||
case Opcode::vvsll: return "vvsll";
|
||||
case Opcode::vvsra: return "vvsra";
|
||||
case Opcode::vavg: return "vavg";
|
||||
case Opcode::vrelu: return "vrelu";
|
||||
case Opcode::vtanh: return "vtanh";
|
||||
case Opcode::vsigm: return "vsigm";
|
||||
case Opcode::vsoftmax: return "vsoftmax";
|
||||
case Opcode::vmv: return "vmv";
|
||||
case Opcode::vrsu: return "vrsu";
|
||||
case Opcode::vrsl: return "vrsl";
|
||||
case Opcode::ld: return "ld";
|
||||
case Opcode::st: return "st";
|
||||
case Opcode::lldi: return "lldi";
|
||||
case Opcode::lmv: return "lmv";
|
||||
case Opcode::send: return "send";
|
||||
case Opcode::recv: return "recv";
|
||||
case Opcode::wait: return "wait";
|
||||
case Opcode::sync: return "sync";
|
||||
}
|
||||
llvm_unreachable("Unsupported PIM binary opcode");
|
||||
}
|
||||
|
||||
inline InstructionRecord makeInstructionRecord(const llvm::json::Object& instruction) {
|
||||
InstructionRecord record;
|
||||
std::optional<llvm::StringRef> opName = instruction.getString("op");
|
||||
assert(opName && "Missing op field in PIM instruction");
|
||||
record.opcode = opcodeFromString(*opName);
|
||||
record.rd = toU8(getOptionalInt(instruction, "rd"));
|
||||
record.r1 = toU8(getOptionalInt(instruction, "rs1"));
|
||||
|
||||
switch (record.opcode) {
|
||||
case Opcode::sldi:
|
||||
case Opcode::saddi:
|
||||
case Opcode::smuli:
|
||||
case Opcode::lldi: record.r2OrImm = getOptionalInt(instruction, "imm"); break;
|
||||
case Opcode::mvmul:
|
||||
record.r2OrImm = getOptionalInt(instruction, "mbiw");
|
||||
record.generic1 = getOptionalInt(instruction, "relu");
|
||||
record.generic2 = getOptionalInt(instruction, "group");
|
||||
break;
|
||||
case Opcode::setbw:
|
||||
record.generic1 = getOptionalInt(instruction, "ibiw");
|
||||
record.generic2 = getOptionalInt(instruction, "obiw");
|
||||
break;
|
||||
case Opcode::send:
|
||||
case Opcode::recv:
|
||||
record.r2OrImm = getOptionalInt(instruction, "core");
|
||||
record.generic3 = getOptionalInt(instruction, "size");
|
||||
break;
|
||||
default: record.r2OrImm = getOptionalInt(instruction, "rs2"); break;
|
||||
}
|
||||
|
||||
if (record.opcode != Opcode::mvmul && record.opcode != Opcode::setbw) {
|
||||
if (auto* offsetValue = instruction.getObject("offset")) {
|
||||
record.generic1 = getOptionalInt(*offsetValue, "offset_select");
|
||||
record.generic2 = getOptionalInt(*offsetValue, "offset_value");
|
||||
}
|
||||
}
|
||||
|
||||
if (instruction.get("len"))
|
||||
record.generic3 = getOptionalInt(instruction, "len");
|
||||
else if (instruction.get("size") && record.opcode != Opcode::send && record.opcode != Opcode::recv)
|
||||
record.generic3 = getOptionalInt(instruction, "size");
|
||||
|
||||
return record;
|
||||
}
|
||||
|
||||
inline llvm::json::Object makeInstructionJson(const InstructionRecord& record) {
|
||||
llvm::json::Object instruction;
|
||||
instruction["op"] = opcodeToString(record.opcode).str();
|
||||
|
||||
auto addOffset = [&](int32_t offsetSelect, int32_t offsetValue) {
|
||||
llvm::json::Object offset;
|
||||
offset["offset_select"] = offsetSelect;
|
||||
offset["offset_value"] = offsetValue;
|
||||
instruction["offset"] = std::move(offset);
|
||||
};
|
||||
|
||||
switch (record.opcode) {
|
||||
case Opcode::sldi:
|
||||
instruction["rd"] = static_cast<int64_t>(record.rd);
|
||||
instruction["imm"] = record.r2OrImm;
|
||||
break;
|
||||
case Opcode::sld:
|
||||
instruction["rd"] = static_cast<int64_t>(record.rd);
|
||||
instruction["rs1"] = static_cast<int64_t>(record.r1);
|
||||
addOffset(record.generic1, record.generic2);
|
||||
break;
|
||||
case Opcode::sadd:
|
||||
case Opcode::ssub:
|
||||
case Opcode::smul:
|
||||
instruction["rd"] = static_cast<int64_t>(record.rd);
|
||||
instruction["rs1"] = static_cast<int64_t>(record.r1);
|
||||
instruction["rs2"] = record.r2OrImm;
|
||||
break;
|
||||
case Opcode::saddi:
|
||||
case Opcode::smuli:
|
||||
instruction["rd"] = static_cast<int64_t>(record.rd);
|
||||
instruction["rs1"] = static_cast<int64_t>(record.r1);
|
||||
instruction["imm"] = record.r2OrImm;
|
||||
break;
|
||||
case Opcode::setbw:
|
||||
instruction["ibiw"] = record.generic1;
|
||||
instruction["obiw"] = record.generic2;
|
||||
break;
|
||||
case Opcode::mvmul:
|
||||
instruction["rd"] = static_cast<int64_t>(record.rd);
|
||||
instruction["rs1"] = static_cast<int64_t>(record.r1);
|
||||
instruction["mbiw"] = record.r2OrImm;
|
||||
instruction["relu"] = record.generic1;
|
||||
instruction["group"] = record.generic2;
|
||||
break;
|
||||
case Opcode::vvadd:
|
||||
case Opcode::vvsub:
|
||||
case Opcode::vvmul:
|
||||
case Opcode::vvdmul:
|
||||
case Opcode::vvmax:
|
||||
case Opcode::vvsll:
|
||||
case Opcode::vvsra:
|
||||
case Opcode::vavg:
|
||||
case Opcode::vmv:
|
||||
case Opcode::vrsu:
|
||||
case Opcode::vrsl:
|
||||
instruction["rd"] = static_cast<int64_t>(record.rd);
|
||||
instruction["rs1"] = static_cast<int64_t>(record.r1);
|
||||
instruction["rs2"] = record.r2OrImm;
|
||||
addOffset(record.generic1, record.generic2);
|
||||
instruction["len"] = record.generic3;
|
||||
break;
|
||||
case Opcode::vrelu:
|
||||
case Opcode::vtanh:
|
||||
case Opcode::vsigm:
|
||||
case Opcode::vsoftmax:
|
||||
instruction["rd"] = static_cast<int64_t>(record.rd);
|
||||
instruction["rs1"] = static_cast<int64_t>(record.r1);
|
||||
addOffset(record.generic1, record.generic2);
|
||||
instruction["len"] = record.generic3;
|
||||
break;
|
||||
case Opcode::ld:
|
||||
case Opcode::st:
|
||||
instruction["rd"] = static_cast<int64_t>(record.rd);
|
||||
instruction["rs1"] = static_cast<int64_t>(record.r1);
|
||||
addOffset(record.generic1, record.generic2);
|
||||
instruction["size"] = record.generic3;
|
||||
break;
|
||||
case Opcode::lldi:
|
||||
instruction["rd"] = static_cast<int64_t>(record.rd);
|
||||
instruction["imm"] = record.r2OrImm;
|
||||
addOffset(record.generic1, record.generic2);
|
||||
instruction["len"] = record.generic3;
|
||||
break;
|
||||
case Opcode::lmv:
|
||||
instruction["rd"] = static_cast<int64_t>(record.rd);
|
||||
instruction["rs1"] = static_cast<int64_t>(record.r1);
|
||||
addOffset(record.generic1, record.generic2);
|
||||
instruction["len"] = record.generic3;
|
||||
break;
|
||||
case Opcode::send:
|
||||
case Opcode::recv:
|
||||
instruction["rd"] = static_cast<int64_t>(record.rd);
|
||||
instruction["core"] = record.r2OrImm;
|
||||
addOffset(record.generic1, record.generic2);
|
||||
instruction["size"] = record.generic3;
|
||||
break;
|
||||
case Opcode::wait:
|
||||
case Opcode::sync:
|
||||
case Opcode::nop: break;
|
||||
}
|
||||
|
||||
return instruction;
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir::pim_binary
|
||||
+258
-186
@@ -30,6 +30,7 @@
|
||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimArtifactWriter.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimBinaryFormat.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimWeightEmitter.hpp"
|
||||
@@ -40,15 +41,10 @@ using namespace mlir;
|
||||
using namespace onnx_mlir;
|
||||
using namespace onnx_mlir::compact_asm;
|
||||
|
||||
static size_t getValueSizeInBytes(mlir::Value value) {
|
||||
auto type = cast<ShapedType>(value.getType());
|
||||
return type.getNumElements() * type.getElementTypeBitWidth() / 8;
|
||||
}
|
||||
|
||||
MemEntry* PimMemory::gatherMemEntry(mlir::Value value) {
|
||||
auto type = cast<ShapedType>(value.getType());
|
||||
assert("Only static shape is supported" && type.hasStaticShape());
|
||||
size_t allocSize = type.getNumElements() * type.getElementType().getIntOrFloatBitWidth() / 8;
|
||||
size_t allocSize = getShapedTypeSizeInBytes(type);
|
||||
MemEntry memEntry = {0, allocSize};
|
||||
return &memEntries.emplace_back(memEntry, value).first;
|
||||
}
|
||||
@@ -116,25 +112,29 @@ void PimMemory::allocateCore(Operation* op) {
|
||||
|
||||
static void printHostMemoryReportRow(raw_ostream& os, const MemoryReportRow& row) {
|
||||
llvm::SmallVector<ReportField, 2> fields = {
|
||||
{"Number of globals", std::to_string(row.numGlobal)},
|
||||
{"Global memory", formatReportMemory(row.sizeGlobal)}};
|
||||
{"Number of globals", std::to_string(row.numGlobal) },
|
||||
{"Global memory", formatReportMemory(row.sizeGlobal)}
|
||||
};
|
||||
printReportFlatFields(os, fields);
|
||||
}
|
||||
|
||||
static void printCoreMemoryReportRow(raw_ostream& os, const MemoryReportEntry& entry) {
|
||||
llvm::SmallVector<ReportField, 2> fields = {
|
||||
{"Number of allocas", std::to_string(entry.row.numAlloca)},
|
||||
{"Allocated memory", formatReportMemory(entry.row.sizeAlloca)}};
|
||||
{"Number of allocas", std::to_string(entry.row.numAlloca) },
|
||||
{"Allocated memory", formatReportMemory(entry.row.sizeAlloca)}
|
||||
};
|
||||
printReportFlatFields(os, fields);
|
||||
}
|
||||
|
||||
static void printBatchMemoryReportRow(raw_ostream& os, const MemoryReportEntry& entry) {
|
||||
llvm::SmallVector<ReportField, 2> perCoreFields = {
|
||||
{"Number of allocas", std::to_string(entry.row.numAlloca)},
|
||||
{"Allocated memory", formatReportMemory(entry.row.sizeAlloca)}};
|
||||
{"Number of allocas", std::to_string(entry.row.numAlloca) },
|
||||
{"Allocated memory", formatReportMemory(entry.row.sizeAlloca)}
|
||||
};
|
||||
llvm::SmallVector<ReportField, 2> totalFields = {
|
||||
{"Number of allocas", std::to_string(entry.totalAllocaCount)},
|
||||
{"Batch memory", formatReportMemory(entry.totalAllocaBytes)}};
|
||||
{"Number of allocas", std::to_string(entry.totalAllocaCount) },
|
||||
{"Batch memory", formatReportMemory(entry.totalAllocaBytes)}
|
||||
};
|
||||
printReportPerCoreAndTotalFields(os, perCoreFields, totalFields);
|
||||
}
|
||||
|
||||
@@ -215,12 +215,8 @@ size_t PimAcceleratorMemory::getValueAddress(mlir::Value value, const StaticValu
|
||||
void PimAcceleratorMemory::reportHost() { hostReportRow = hostMem.getReportRow(); }
|
||||
|
||||
void PimAcceleratorMemory::recordCoreReport(size_t coreId, const MemoryReportRow& row) {
|
||||
reportEntries.push_back({MemoryReportEntry::Kind::Core,
|
||||
coreId,
|
||||
{static_cast<int32_t>(coreId)},
|
||||
row,
|
||||
row.numAlloca,
|
||||
row.sizeAlloca});
|
||||
reportEntries.push_back(
|
||||
{MemoryReportEntry::Kind::Core, coreId, {static_cast<int32_t>(coreId)}, row, row.numAlloca, row.sizeAlloca});
|
||||
}
|
||||
|
||||
void PimAcceleratorMemory::recordBatchReport(uint64_t batchId,
|
||||
@@ -250,7 +246,8 @@ void PimAcceleratorMemory::flushReport() {
|
||||
|
||||
llvm::SmallVector<ReportField, 2> totalFields = {
|
||||
{"Global memory", formatReportMemory(totalGlobalMemory)},
|
||||
{"Cores memory", formatReportMemory(totalCoresMemory)}};
|
||||
{"Cores memory", formatReportMemory(totalCoresMemory) }
|
||||
};
|
||||
printReportTotalsBlock(os, totalFields);
|
||||
|
||||
if (hostReportRow.has_value()) {
|
||||
@@ -312,36 +309,25 @@ void PimAcceleratorMemory::clean(mlir::Operation* op) {
|
||||
}
|
||||
}
|
||||
|
||||
json::Object PimCodeGen::createEmptyOffset() {
|
||||
json::Object offset;
|
||||
offset["offset_select"] = 0;
|
||||
offset["offset_value"] = 0;
|
||||
return offset;
|
||||
}
|
||||
|
||||
size_t PimCodeGen::remapCoreId(size_t coreId) const {
|
||||
auto it = emittedCoreIds.find(coreId);
|
||||
assert(it != emittedCoreIds.end() && "Missing emitted core id remapping");
|
||||
return it->second;
|
||||
}
|
||||
|
||||
static json::Object createRs1OnlyOffset() {
|
||||
json::Object offset;
|
||||
offset["offset_select"] = 1;
|
||||
offset["offset_value"] = 0;
|
||||
return offset;
|
||||
}
|
||||
|
||||
void PimCodeGen::emitInstruction(json::Object instruction) const {
|
||||
coreFileStream << json::Value(std::move(instruction)) << ',';
|
||||
void PimCodeGen::emitInstruction(const pim_binary::InstructionRecord& instruction) const {
|
||||
pim_binary::writeInstructionRecord(coreBinaryStream, instruction);
|
||||
++emittedInstructionCount;
|
||||
if (coreJsonStream)
|
||||
*coreJsonStream << json::Value(pim_binary::makeInstructionJson(instruction)) << ',';
|
||||
}
|
||||
|
||||
void PimCodeGen::genSetRegisterImmediateUnsigned(size_t registerNumber, size_t immediate) const {
|
||||
json::Object json;
|
||||
json["op"] = "sldi";
|
||||
json["rd"] = registerNumber;
|
||||
json["imm"] = immediate;
|
||||
emitInstruction(std::move(json));
|
||||
pim_binary::InstructionRecord instruction;
|
||||
instruction.opcode = pim_binary::Opcode::sldi;
|
||||
instruction.rd = static_cast<uint8_t>(registerNumber);
|
||||
instruction.r2OrImm = static_cast<int32_t>(immediate);
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
void PimCodeGen::setupRd(size_t rdAddress, size_t rdOffset) const {
|
||||
@@ -369,55 +355,66 @@ void PimCodeGen::emitMemCopyOp(StringRef opName,
|
||||
StringRef sizeFieldName) const {
|
||||
setupRdRs1(rdAddr, rdOffset, rs1Addr, rs1Offset);
|
||||
|
||||
json::Object json;
|
||||
json["op"] = opName;
|
||||
json["rd"] = 0;
|
||||
json["rs1"] = 1;
|
||||
json[sizeFieldName] = size;
|
||||
json["offset"] = createEmptyOffset();
|
||||
emitInstruction(std::move(json));
|
||||
pim_binary::InstructionRecord instruction;
|
||||
instruction.opcode = pim_binary::opcodeFromString(opName);
|
||||
instruction.rd = 0;
|
||||
instruction.r1 = 1;
|
||||
instruction.generic1 = 0;
|
||||
instruction.generic2 = 0;
|
||||
instruction.generic3 = static_cast<int32_t>(size);
|
||||
(void) sizeFieldName;
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
void PimCodeGen::emitCommunicationOp(StringRef opName, size_t bufferAddr, size_t coreId, size_t size) const {
|
||||
setupRd(bufferAddr, 0);
|
||||
|
||||
json::Object json;
|
||||
json["op"] = opName;
|
||||
json["rd"] = 0;
|
||||
json["core"] = remapCoreId(coreId);
|
||||
json["size"] = size;
|
||||
json["offset"] = createEmptyOffset();
|
||||
emitInstruction(std::move(json));
|
||||
pim_binary::InstructionRecord instruction;
|
||||
instruction.opcode = pim_binary::opcodeFromString(opName);
|
||||
instruction.rd = 0;
|
||||
instruction.r2OrImm = static_cast<int32_t>(remapCoreId(coreId));
|
||||
instruction.generic1 = 0;
|
||||
instruction.generic2 = 0;
|
||||
instruction.generic3 = static_cast<int32_t>(size);
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
void PimCodeGen::emitMvmOp(size_t groupId, size_t rdAddr, size_t rdOffset, size_t rs1Addr, size_t rs1Offset) const {
|
||||
setupRdRs1(rdAddr, rdOffset, rs1Addr, rs1Offset);
|
||||
|
||||
json::Object json;
|
||||
json["op"] = "mvmul";
|
||||
json["rd"] = 0;
|
||||
json["rs1"] = 1;
|
||||
json["group"] = groupId;
|
||||
json["relu"] = 0;
|
||||
json["mbiw"] = 8;
|
||||
emitInstruction(std::move(json));
|
||||
pim_binary::InstructionRecord instruction;
|
||||
instruction.opcode = pim_binary::Opcode::mvmul;
|
||||
instruction.rd = 0;
|
||||
instruction.r1 = 1;
|
||||
instruction.r2OrImm = 8;
|
||||
instruction.generic1 = 0;
|
||||
instruction.generic2 = static_cast<int32_t>(groupId);
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp, const StaticValueKnowledge& knowledge) const {
|
||||
auto deviceTargetOffset = resolveIndexValue(loadOp.getDeviceTargetOffset(), knowledge);
|
||||
auto hostSourceOffset = resolveIndexValue(loadOp.getHostSourceOffset(), knowledge);
|
||||
assert(succeeded(deviceTargetOffset) && succeeded(hostSourceOffset)
|
||||
&& "pim.memcp_hd offsets must be statically resolvable during codegen");
|
||||
emitMemCopyOp("ld",
|
||||
addressOf(loadOp.getDeviceTarget(), knowledge),
|
||||
loadOp.getDeviceTargetOffset(),
|
||||
*deviceTargetOffset,
|
||||
addressOf(loadOp.getHostSource(), knowledge),
|
||||
loadOp.getHostSourceOffset(),
|
||||
*hostSourceOffset,
|
||||
loadOp.getSize());
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp, const StaticValueKnowledge& knowledge) const {
|
||||
auto hostTargetOffset = resolveIndexValue(storeOp.getHostTargetOffset(), knowledge);
|
||||
auto deviceSourceOffset = resolveIndexValue(storeOp.getDeviceSourceOffset(), knowledge);
|
||||
assert(succeeded(hostTargetOffset) && succeeded(deviceSourceOffset)
|
||||
&& "pim.memcp_dh offsets must be statically resolvable during codegen");
|
||||
emitMemCopyOp("st",
|
||||
addressOf(storeOp.getHostTarget(), knowledge),
|
||||
storeOp.getHostTargetOffset(),
|
||||
*hostTargetOffset,
|
||||
addressOf(storeOp.getDeviceSource(), knowledge),
|
||||
storeOp.getDeviceSourceOffset(),
|
||||
*deviceSourceOffset,
|
||||
storeOp.getSize());
|
||||
}
|
||||
|
||||
@@ -432,25 +429,30 @@ void PimCodeGen::codeGenLmvOp(pim::PimMemCopyOp lmvOp, const StaticValueKnowledg
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValueKnowledge& knowledge) const {
|
||||
emitCommunicationOp(
|
||||
"recv", addressOf(receiveOp.getOutputBuffer(), knowledge), receiveOp.getSourceCoreId(), receiveOp.getSize());
|
||||
auto sourceCoreId = resolveIndexValue(receiveOp.getSourceCoreId(), knowledge);
|
||||
assert(succeeded(sourceCoreId) && "pim.receive source core id must be statically resolvable during codegen");
|
||||
emitCommunicationOp("recv", addressOf(receiveOp.getOutputBuffer(), knowledge), *sourceCoreId, receiveOp.getSize());
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenReceiveTensorOp(pim::PimReceiveTensorOp receiveTensorOp,
|
||||
const StaticValueKnowledge& knowledge) const {
|
||||
size_t outputAddr = addressOf(receiveTensorOp.getOutputBuffer(), knowledge);
|
||||
size_t chunkSize = getValueSizeInBytes(receiveTensorOp.getOutputBuffer()) / receiveTensorOp.getSourceCoreIds().size();
|
||||
size_t chunkSize = getShapedTypeSizeInBytes(cast<ShapedType>(receiveTensorOp.getOutputBuffer().getType()))
|
||||
/ receiveTensorOp.getSourceCoreIds().size();
|
||||
for (auto [chunkIndex, sourceCoreId] : llvm::enumerate(receiveTensorOp.getSourceCoreIds()))
|
||||
emitCommunicationOp("recv", outputAddr + chunkIndex * chunkSize, sourceCoreId, chunkSize);
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const {
|
||||
emitCommunicationOp("send", addressOf(sendOp.getInput(), knowledge), sendOp.getTargetCoreId(), sendOp.getSize());
|
||||
auto targetCoreId = resolveIndexValue(sendOp.getTargetCoreId(), knowledge);
|
||||
assert(succeeded(targetCoreId) && "pim.send target core id must be statically resolvable during codegen");
|
||||
emitCommunicationOp("send", addressOf(sendOp.getInput(), knowledge), *targetCoreId, sendOp.getSize());
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenSendTensorOp(pim::PimSendTensorOp sendTensorOp, const StaticValueKnowledge& knowledge) const {
|
||||
size_t inputAddr = addressOf(sendTensorOp.getInput(), knowledge);
|
||||
size_t chunkSize = getValueSizeInBytes(sendTensorOp.getInput()) / sendTensorOp.getTargetCoreIds().size();
|
||||
size_t chunkSize = getShapedTypeSizeInBytes(cast<ShapedType>(sendTensorOp.getInput().getType()))
|
||||
/ sendTensorOp.getTargetCoreIds().size();
|
||||
for (auto [chunkIndex, targetCoreId] : llvm::enumerate(sendTensorOp.getTargetCoreIds()))
|
||||
emitCommunicationOp("send", inputAddr + chunkIndex * chunkSize, targetCoreId, chunkSize);
|
||||
}
|
||||
@@ -461,7 +463,7 @@ void PimCodeGen::codeGenConcatOp(pim::PimConcatOp concatOp, const StaticValueKno
|
||||
|
||||
int64_t axis = concatOp.getAxis();
|
||||
ArrayRef<int64_t> outputShape = outputType.getShape();
|
||||
size_t elementSize = outputType.getElementTypeBitWidth() / 8;
|
||||
size_t elementSize = getElementTypeSizeInBytes(outputType.getElementType());
|
||||
size_t outputAddr = addressOf(concatOp.getOutputBuffer(), knowledge);
|
||||
|
||||
size_t outerCount = 1;
|
||||
@@ -508,14 +510,13 @@ void PimCodeGen::codeGenVVAddOp(pim::PimVVAddOp vvaddOp, const StaticValueKnowle
|
||||
auto rhsAddr = addressOf(vvaddOp.getRhs(), knowledge);
|
||||
setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0);
|
||||
|
||||
json::Object json;
|
||||
json["op"] = "vvadd";
|
||||
json["rd"] = 0;
|
||||
json["rs1"] = 1;
|
||||
json["rs2"] = 2;
|
||||
json["offset"] = createEmptyOffset();
|
||||
json["len"] = getValueSizeInBytes(vvaddOp.getLhs());
|
||||
emitInstruction(std::move(json));
|
||||
pim_binary::InstructionRecord instruction;
|
||||
instruction.opcode = pim_binary::Opcode::vvadd;
|
||||
instruction.rd = 0;
|
||||
instruction.r1 = 1;
|
||||
instruction.r2OrImm = 2;
|
||||
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvaddOp.getLhs().getType())));
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenVVSubOp(pim::PimVVSubOp vvsubOp, const StaticValueKnowledge& knowledge) const {
|
||||
@@ -524,14 +525,13 @@ void PimCodeGen::codeGenVVSubOp(pim::PimVVSubOp vvsubOp, const StaticValueKnowle
|
||||
auto rhsAddr = addressOf(vvsubOp.getRhs(), knowledge);
|
||||
setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0);
|
||||
|
||||
json::Object json;
|
||||
json["op"] = "vvsub";
|
||||
json["rd"] = 0;
|
||||
json["rs1"] = 1;
|
||||
json["rs2"] = 2;
|
||||
json["offset"] = createEmptyOffset();
|
||||
json["len"] = getValueSizeInBytes(vvsubOp.getLhs());
|
||||
emitInstruction(std::move(json));
|
||||
pim_binary::InstructionRecord instruction;
|
||||
instruction.opcode = pim_binary::Opcode::vvsub;
|
||||
instruction.rd = 0;
|
||||
instruction.r1 = 1;
|
||||
instruction.r2OrImm = 2;
|
||||
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvsubOp.getLhs().getType())));
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenVVMulOp(pim::PimVVMulOp vvmulOp, const StaticValueKnowledge& knowledge) const {
|
||||
@@ -540,14 +540,13 @@ void PimCodeGen::codeGenVVMulOp(pim::PimVVMulOp vvmulOp, const StaticValueKnowle
|
||||
auto rhsAddr = addressOf(vvmulOp.getRhs(), knowledge);
|
||||
setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0);
|
||||
|
||||
json::Object json;
|
||||
json["op"] = "vvmul";
|
||||
json["rd"] = 0;
|
||||
json["rs1"] = 1;
|
||||
json["rs2"] = 2;
|
||||
json["offset"] = createEmptyOffset();
|
||||
json["len"] = getValueSizeInBytes(vvmulOp.getLhs());
|
||||
emitInstruction(std::move(json));
|
||||
pim_binary::InstructionRecord instruction;
|
||||
instruction.opcode = pim_binary::Opcode::vvmul;
|
||||
instruction.rd = 0;
|
||||
instruction.r1 = 1;
|
||||
instruction.r2OrImm = 2;
|
||||
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvmulOp.getLhs().getType())));
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenVVMaxOp(pim::PimVVMaxOp vvmaxOp, const StaticValueKnowledge& knowledge) const {
|
||||
@@ -556,14 +555,13 @@ void PimCodeGen::codeGenVVMaxOp(pim::PimVVMaxOp vvmaxOp, const StaticValueKnowle
|
||||
auto rhsAddr = addressOf(vvmaxOp.getRhs(), knowledge);
|
||||
setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0);
|
||||
|
||||
json::Object json;
|
||||
json["op"] = "vvmax";
|
||||
json["rd"] = 0;
|
||||
json["rs1"] = 1;
|
||||
json["rs2"] = 2;
|
||||
json["offset"] = createEmptyOffset();
|
||||
json["len"] = getValueSizeInBytes(vvmaxOp.getLhs());
|
||||
emitInstruction(std::move(json));
|
||||
pim_binary::InstructionRecord instruction;
|
||||
instruction.opcode = pim_binary::Opcode::vvmax;
|
||||
instruction.rd = 0;
|
||||
instruction.r1 = 1;
|
||||
instruction.r2OrImm = 2;
|
||||
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvmaxOp.getLhs().getType())));
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenVVDMulOp(pim::PimVVDMulOp vvdmulOp, const StaticValueKnowledge& knowledge) const {
|
||||
@@ -572,14 +570,13 @@ void PimCodeGen::codeGenVVDMulOp(pim::PimVVDMulOp vvdmulOp, const StaticValueKno
|
||||
auto rhsAddr = addressOf(vvdmulOp.getRhs(), knowledge);
|
||||
setupRdRs1Rs2(outputBufferAddr, 0, lhsAddr, 0, rhsAddr, 0);
|
||||
|
||||
json::Object json;
|
||||
json["op"] = "vvdmul";
|
||||
json["rd"] = 0;
|
||||
json["rs1"] = 1;
|
||||
json["rs2"] = 2;
|
||||
json["offset"] = createEmptyOffset();
|
||||
json["len"] = getValueSizeInBytes(vvdmulOp.getLhs());
|
||||
emitInstruction(std::move(json));
|
||||
pim_binary::InstructionRecord instruction;
|
||||
instruction.opcode = pim_binary::Opcode::vvdmul;
|
||||
instruction.rd = 0;
|
||||
instruction.r1 = 1;
|
||||
instruction.r2OrImm = 2;
|
||||
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vvdmulOp.getLhs().getType())));
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenVAvgOp(pim::PimVAvgOp vavgOp, const StaticValueKnowledge& knowledge) const {
|
||||
@@ -587,14 +584,14 @@ void PimCodeGen::codeGenVAvgOp(pim::PimVAvgOp vavgOp, const StaticValueKnowledge
|
||||
auto inputAddr = addressOf(vavgOp.getInput(), knowledge);
|
||||
setupRdRs1(outputBufferAddr, 0, inputAddr, 0);
|
||||
|
||||
json::Object json;
|
||||
json["op"] = "vavg";
|
||||
json["rd"] = 0;
|
||||
json["rs1"] = 1;
|
||||
json["rs2"] = 1;
|
||||
json["offset"] = createRs1OnlyOffset();
|
||||
json["len"] = getValueSizeInBytes(vavgOp.getInput());
|
||||
emitInstruction(std::move(json));
|
||||
pim_binary::InstructionRecord instruction;
|
||||
instruction.opcode = pim_binary::Opcode::vavg;
|
||||
instruction.rd = 0;
|
||||
instruction.r1 = 1;
|
||||
instruction.r2OrImm = 1;
|
||||
instruction.generic1 = 1;
|
||||
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vavgOp.getInput().getType())));
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenVReluOp(pim::PimVReluOp vreluOp, const StaticValueKnowledge& knowledge) const {
|
||||
@@ -602,13 +599,12 @@ void PimCodeGen::codeGenVReluOp(pim::PimVReluOp vreluOp, const StaticValueKnowle
|
||||
auto inputAddr = addressOf(vreluOp.getInput(), knowledge);
|
||||
setupRdRs1(outputBufferAddr, 0, inputAddr, 0);
|
||||
|
||||
json::Object json;
|
||||
json["op"] = "vrelu";
|
||||
json["rd"] = 0;
|
||||
json["rs1"] = 1;
|
||||
json["offset"] = createEmptyOffset();
|
||||
json["len"] = getValueSizeInBytes(vreluOp.getInput());
|
||||
emitInstruction(std::move(json));
|
||||
pim_binary::InstructionRecord instruction;
|
||||
instruction.opcode = pim_binary::Opcode::vrelu;
|
||||
instruction.rd = 0;
|
||||
instruction.r1 = 1;
|
||||
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vreluOp.getInput().getType())));
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenVTanhOp(pim::PimVTanhOp vtanhOp, const StaticValueKnowledge& knowledge) const {
|
||||
@@ -616,13 +612,12 @@ void PimCodeGen::codeGenVTanhOp(pim::PimVTanhOp vtanhOp, const StaticValueKnowle
|
||||
auto inputAddr = addressOf(vtanhOp.getInput(), knowledge);
|
||||
setupRdRs1(outputBufferAddr, 0, inputAddr, 0);
|
||||
|
||||
json::Object json;
|
||||
json["op"] = "vtanh";
|
||||
json["rd"] = 0;
|
||||
json["rs1"] = 1;
|
||||
json["offset"] = createEmptyOffset();
|
||||
json["len"] = getValueSizeInBytes(vtanhOp.getInput());
|
||||
emitInstruction(std::move(json));
|
||||
pim_binary::InstructionRecord instruction;
|
||||
instruction.opcode = pim_binary::Opcode::vtanh;
|
||||
instruction.rd = 0;
|
||||
instruction.r1 = 1;
|
||||
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vtanhOp.getInput().getType())));
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenVSigmOp(pim::PimVSigmOp vsigmOp, const StaticValueKnowledge& knowledge) const {
|
||||
@@ -630,13 +625,12 @@ void PimCodeGen::codeGenVSigmOp(pim::PimVSigmOp vsigmOp, const StaticValueKnowle
|
||||
auto inputAddr = addressOf(vsigmOp.getInput(), knowledge);
|
||||
setupRdRs1(outputBufferAddr, 0, inputAddr, 0);
|
||||
|
||||
json::Object json;
|
||||
json["op"] = "vsigm";
|
||||
json["rd"] = 0;
|
||||
json["rs1"] = 1;
|
||||
json["offset"] = createEmptyOffset();
|
||||
json["len"] = getValueSizeInBytes(vsigmOp.getInput());
|
||||
emitInstruction(std::move(json));
|
||||
pim_binary::InstructionRecord instruction;
|
||||
instruction.opcode = pim_binary::Opcode::vsigm;
|
||||
instruction.rd = 0;
|
||||
instruction.r1 = 1;
|
||||
instruction.generic3 = static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vsigmOp.getInput().getType())));
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp, const StaticValueKnowledge& knowledge) const {
|
||||
@@ -644,13 +638,13 @@ void PimCodeGen::codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp, const StaticVa
|
||||
auto inputAddr = addressOf(vsoftmaxOp.getInput(), knowledge);
|
||||
setupRdRs1(outputBufferAddr, 0, inputAddr, 0);
|
||||
|
||||
json::Object json;
|
||||
json["op"] = "vsoftmax";
|
||||
json["rd"] = 0;
|
||||
json["rs1"] = 1;
|
||||
json["offset"] = createEmptyOffset();
|
||||
json["len"] = getValueSizeInBytes(vsoftmaxOp.getInput());
|
||||
emitInstruction(std::move(json));
|
||||
pim_binary::InstructionRecord instruction;
|
||||
instruction.opcode = pim_binary::Opcode::vsoftmax;
|
||||
instruction.rd = 0;
|
||||
instruction.r1 = 1;
|
||||
instruction.generic3 =
|
||||
static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(vsoftmaxOp.getInput().getType())));
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
void PimCodeGen::codeGetGlobalOp(memref::GetGlobalOp getGlobalOp, const StaticValueKnowledge& knowledge) const {}
|
||||
@@ -662,7 +656,7 @@ void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp, const Stati
|
||||
auto srcType = cast<ShapedType>(transposeOp.getInput().getType());
|
||||
auto srcShape = srcType.getShape();
|
||||
size_t rank = srcShape.size();
|
||||
size_t elementSize = srcType.getElementTypeBitWidth() / 8;
|
||||
size_t elementSize = getElementTypeSizeInBytes(srcType.getElementType());
|
||||
size_t totalElements = srcType.getNumElements();
|
||||
|
||||
// Read permutation. Destination dim i corresponds to source dim perm[i].
|
||||
@@ -682,6 +676,30 @@ void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp, const Stati
|
||||
dstStrides[i] = dstStrides[i + 1] * dstShape[i + 1];
|
||||
}
|
||||
|
||||
bool storagePreserving = true;
|
||||
for (size_t srcFlat = 0; srcFlat < totalElements; srcFlat++) {
|
||||
SmallVector<size_t> srcIdx(rank);
|
||||
size_t remaining = srcFlat;
|
||||
for (size_t d = 0; d < rank; d++) {
|
||||
srcIdx[d] = remaining / srcStrides[d];
|
||||
remaining %= srcStrides[d];
|
||||
}
|
||||
|
||||
size_t dstFlat = 0;
|
||||
for (size_t d = 0; d < rank; d++)
|
||||
dstFlat += srcIdx[perm[d]] * dstStrides[d];
|
||||
|
||||
if (dstFlat != srcFlat) {
|
||||
storagePreserving = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (storagePreserving) {
|
||||
emitMemCopyOp("lmv", dstAddr, 0, srcAddr, 0, totalElements * elementSize, "len");
|
||||
return;
|
||||
}
|
||||
|
||||
// Emit element-by-element copy with transposed addressing
|
||||
for (size_t srcFlat = 0; srcFlat < totalElements; srcFlat++) {
|
||||
// Decompose flat source index into multi-dimensional index
|
||||
@@ -719,12 +737,19 @@ std::string getMemorySizeAsString(size_t size) {
|
||||
|
||||
static SmallVector<unsigned, 8> getUsedWeightIndices(Block& block) {
|
||||
SmallVector<unsigned, 8> indices;
|
||||
auto addIndex = [&](unsigned weightIndex) {
|
||||
if (!llvm::is_contained(indices, weightIndex))
|
||||
indices.push_back(weightIndex);
|
||||
auto coreOp = dyn_cast<pim::PimCoreOp>(block.getParentOp());
|
||||
auto addWeight = [&](mlir::Value weight) {
|
||||
if (!coreOp)
|
||||
return;
|
||||
for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex) {
|
||||
if (coreOp.getWeightArgument(weightIndex) != weight)
|
||||
continue;
|
||||
if (!llvm::is_contained(indices, weightIndex))
|
||||
indices.push_back(weightIndex);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
block.walk([&](pim::PimVMMOp vmmOp) { addIndex(vmmOp.getWeightIndex()); });
|
||||
block.walk([&](pim::PimVMMOp vmmOp) { addWeight(vmmOp.getWeight()); });
|
||||
llvm::sort(indices);
|
||||
return indices;
|
||||
}
|
||||
@@ -747,9 +772,25 @@ static SmallVector<Operation*> collectTopLevelCoreLikeOps(func::FuncOp funcOp) {
|
||||
return coreLikeOps;
|
||||
}
|
||||
|
||||
static SmallDenseMap<memref::GlobalOp, MemEntry, 16>
|
||||
collectMaterializedHostGlobals(ModuleOp moduleOp, func::FuncOp funcOp, const PimAcceleratorMemory& memory) {
|
||||
SmallDenseMap<memref::GlobalOp, MemEntry, 16> materializedHostGlobals;
|
||||
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
|
||||
if (hasWeightAlways(getGlobalOp))
|
||||
return;
|
||||
auto targetGlobal = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||
if (!targetGlobal || materializedHostGlobals.contains(targetGlobal))
|
||||
return;
|
||||
auto it = memory.memEntriesMap.find(getGlobalOp.getResult());
|
||||
if (it != memory.memEntriesMap.end())
|
||||
materializedHostGlobals[targetGlobal] = it->second;
|
||||
});
|
||||
return materializedHostGlobals;
|
||||
}
|
||||
|
||||
static void aliasMaterializedHostGlobals(ModuleOp moduleOp,
|
||||
func::FuncOp funcOp,
|
||||
pim::PimCoreOp coreOp,
|
||||
const SmallDenseMap<memref::GlobalOp, MemEntry, 16>& materializedHostGlobals,
|
||||
PimAcceleratorMemory& memory) {
|
||||
coreOp.walk([&](memref::GetGlobalOp getGlobalOp) {
|
||||
if (hasWeightAlways(getGlobalOp) || memory.memEntriesMap.contains(getGlobalOp.getResult()))
|
||||
@@ -759,16 +800,9 @@ static void aliasMaterializedHostGlobals(ModuleOp moduleOp,
|
||||
if (!targetGlobal)
|
||||
return;
|
||||
|
||||
mlir::Value aliasedValue;
|
||||
funcOp.walk([&](memref::GetGlobalOp candidate) {
|
||||
if (aliasedValue || candidate == getGlobalOp || !memory.memEntriesMap.contains(candidate.getResult()))
|
||||
return;
|
||||
if (lookupGlobalForGetGlobal(moduleOp, candidate) == targetGlobal)
|
||||
aliasedValue = candidate.getResult();
|
||||
});
|
||||
|
||||
if (aliasedValue)
|
||||
memory.memEntriesMap[getGlobalOp.getResult()] = memory.memEntriesMap[aliasedValue];
|
||||
auto it = materializedHostGlobals.find(targetGlobal);
|
||||
if (it != materializedHostGlobals.end())
|
||||
memory.memEntriesMap[getGlobalOp.getResult()] = it->second;
|
||||
});
|
||||
}
|
||||
|
||||
@@ -777,6 +811,15 @@ static void aliasMaterializedHostGlobals(ModuleOp moduleOp,
|
||||
/// fully resolved before the JSON instructions are emitted.
|
||||
/// Returns the number of emitted instructions, or -1 on failure.
|
||||
static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
|
||||
auto resolveWeightIndex = [&](pim::PimVMMOp vmmOp) -> std::optional<unsigned> {
|
||||
auto coreOp = vmmOp->getParentOfType<pim::PimCoreOp>();
|
||||
if (!coreOp)
|
||||
return std::nullopt;
|
||||
for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex)
|
||||
if (coreOp.getWeightArgument(weightIndex) == vmmOp.getWeight())
|
||||
return weightIndex;
|
||||
return std::nullopt;
|
||||
};
|
||||
size_t processedOperations = 0;
|
||||
auto result =
|
||||
walkPimCoreBlock(block, StaticValueKnowledge {}, [&](Operation& op, const StaticValueKnowledge& knowledge) {
|
||||
@@ -796,8 +839,12 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
|
||||
coreCodeGen.codeGenSendTensorOp(sendTensorOp, knowledge);
|
||||
else if (auto concatOp = dyn_cast<pim::PimConcatOp>(op))
|
||||
coreCodeGen.codeGenConcatOp(concatOp, knowledge);
|
||||
else if (auto vmmOp = dyn_cast<pim::PimVMMOp>(op))
|
||||
coreCodeGen.codeGenMVMLikeOp<pim::PimVMMOp>(vmmOp.getWeightIndex(), vmmOp, true, knowledge);
|
||||
else if (auto vmmOp = dyn_cast<pim::PimVMMOp>(op)) {
|
||||
auto weightIndex = resolveWeightIndex(vmmOp);
|
||||
if (!weightIndex)
|
||||
return failure();
|
||||
coreCodeGen.codeGenMVMLikeOp<pim::PimVMMOp>(*weightIndex, vmmOp, true, knowledge);
|
||||
}
|
||||
else if (auto transposeOp = dyn_cast<pim::PimTransposeOp>(op))
|
||||
coreCodeGen.codeGenTransposeOp(transposeOp, knowledge);
|
||||
else if (auto vvaddOp = dyn_cast<pim::PimVVAddOp>(op))
|
||||
@@ -837,7 +884,7 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
|
||||
return failed(result) ? -1 : static_cast<int64_t>(processedOperations);
|
||||
}
|
||||
|
||||
OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::string& outputDirPath) {
|
||||
OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimCode(ModuleOp& moduleOp, std::string& outputDirPath) {
|
||||
if (!outputDirPath.empty()) {
|
||||
if (auto error = sys::fs::create_directory(outputDirPath)) {
|
||||
errs() << "Error creating output directory: " << outputDirPath << ": " << error.message() << '\n';
|
||||
@@ -857,11 +904,6 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
|
||||
if (auto err = writeMemoryBinary(moduleOp, funcOp, memory, outputDirPath))
|
||||
return err;
|
||||
|
||||
if (auto err = writeHostCoreJson(outputDirPath))
|
||||
return err;
|
||||
|
||||
// For each core, specify the number of crossbar per array group.
|
||||
// This implementation always assigns one crossbar per group.
|
||||
json::Object xbarsPerArrayGroup;
|
||||
size_t maxCoreId = 0;
|
||||
uint64_t nextBatchReportId = 0;
|
||||
@@ -870,8 +912,10 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
|
||||
auto mapCoreWeightToFileName = createAndPopulateWeightFolder(funcOp, outputDirPath);
|
||||
|
||||
SmallVector<Operation*> coreLikeOps = collectTopLevelCoreLikeOps(funcOp);
|
||||
SmallDenseMap<memref::GlobalOp, MemEntry, 16> materializedHostGlobals =
|
||||
collectMaterializedHostGlobals(moduleOp, funcOp, memory);
|
||||
llvm::DenseMap<size_t, size_t> emittedCoreIds;
|
||||
size_t nextEmittedCoreId = 1;
|
||||
size_t nextEmittedCoreId = 0;
|
||||
|
||||
for (Operation* op : coreLikeOps) {
|
||||
if (auto coreOp = dyn_cast<pim::PimCoreOp>(op)) {
|
||||
@@ -899,16 +943,30 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
|
||||
maxCoreId = std::max(maxCoreId, coreId);
|
||||
|
||||
std::error_code errorCode;
|
||||
auto outputCorePath = outputDirPath + "/core_" + std::to_string(coreId) + ".json";
|
||||
raw_fd_ostream coreFileStream(outputCorePath, errorCode);
|
||||
auto outputCorePath = outputDirPath + "/core_" + std::to_string(coreId) + ".pim";
|
||||
raw_fd_ostream coreBinaryStream(outputCorePath, errorCode, sys::fs::OF_None);
|
||||
if (errorCode) {
|
||||
errs() << "Error while opening core file `" << outputCorePath << "`: " << errorCode.message() << '\n';
|
||||
return InvalidOutputFileAccess;
|
||||
}
|
||||
coreFileStream << '[';
|
||||
|
||||
PimCodeGen coreCodeGen(memory, coreFileStream, emittedCoreIds);
|
||||
aliasMaterializedHostGlobals(moduleOp, funcOp, coreOp, memory);
|
||||
std::unique_ptr<raw_fd_ostream> coreJsonStream;
|
||||
if (pimEmitJson.getValue()) {
|
||||
std::string outputCoreJsonPath = outputDirPath + "/core_" + std::to_string(coreId) + ".json";
|
||||
errorCode = std::error_code();
|
||||
coreJsonStream = std::make_unique<raw_fd_ostream>(outputCoreJsonPath, errorCode);
|
||||
if (errorCode) {
|
||||
errs() << "Error while opening core json file `" << outputCoreJsonPath << "`: " << errorCode.message()
|
||||
<< '\n';
|
||||
return InvalidOutputFileAccess;
|
||||
}
|
||||
*coreJsonStream << '[';
|
||||
}
|
||||
|
||||
pim_binary::writeHeader(coreBinaryStream);
|
||||
|
||||
PimCodeGen coreCodeGen(memory, coreBinaryStream, coreJsonStream.get(), emittedCoreIds);
|
||||
aliasMaterializedHostGlobals(moduleOp, coreOp, materializedHostGlobals, memory);
|
||||
auto& deviceMemory = memory.getOrCreateDeviceMem(coreId);
|
||||
deviceMemory.allocateCore(coreOp);
|
||||
|
||||
@@ -920,9 +978,14 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
|
||||
if (reportRow)
|
||||
*reportRow = deviceMemory.getReportRow();
|
||||
|
||||
coreFileStream.seek(coreFileStream.tell() - 1);
|
||||
coreFileStream << ']';
|
||||
coreFileStream.close();
|
||||
pim_binary::patchInstructionCount(coreBinaryStream, coreCodeGen.getEmittedInstructionCount());
|
||||
coreBinaryStream.close();
|
||||
|
||||
if (coreJsonStream) {
|
||||
coreJsonStream->seek(coreJsonStream->tell() - 1);
|
||||
*coreJsonStream << ']';
|
||||
coreJsonStream->close();
|
||||
}
|
||||
|
||||
auto coreWeightsDirPath = outputDirPath + "/core_" + std::to_string(coreId);
|
||||
if (auto error = sys::fs::create_directory(coreWeightsDirPath)) {
|
||||
@@ -970,10 +1033,19 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
|
||||
reportedCoreIds.reserve(batchCoreIds.size());
|
||||
MemoryReportRow batchRow;
|
||||
std::optional<MemoryReportRow> batchPerCoreRow;
|
||||
llvm::DenseMap<size_t, SmallVector<unsigned>> lanesByCoreId;
|
||||
SmallVector<size_t> orderedOriginalCoreIds;
|
||||
for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane) {
|
||||
size_t originalCoreId = static_cast<size_t>(batchCoreIds[lane]);
|
||||
auto [it, inserted] = lanesByCoreId.try_emplace(originalCoreId);
|
||||
if (inserted)
|
||||
orderedOriginalCoreIds.push_back(originalCoreId);
|
||||
it->second.push_back(lane);
|
||||
}
|
||||
|
||||
for (size_t originalCoreId : orderedOriginalCoreIds) {
|
||||
OnnxMlirCompilerErrorCodes laneResult = CompilerSuccess;
|
||||
if (failed(withScalarCoreFromBatchLane(coreBatchOp, lane, [&](pim::PimCoreOp coreOp) {
|
||||
size_t originalCoreId = static_cast<size_t>(batchCoreIds[lane]);
|
||||
if (failed(withScalarCoreFromBatchLanes(coreBatchOp, lanesByCoreId[originalCoreId], [&](pim::PimCoreOp coreOp) {
|
||||
size_t coreId = emittedCoreIds.lookup(originalCoreId);
|
||||
reportedCoreIds.push_back(static_cast<int32_t>(coreId));
|
||||
MemoryReportRow laneRow;
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
#include "onnx-mlir/Compiler/OMCompilerTypes.h"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimBinaryFormat.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
@@ -104,16 +105,17 @@ public:
|
||||
|
||||
class PimCodeGen {
|
||||
PimAcceleratorMemory& memory;
|
||||
llvm::raw_fd_ostream& coreFileStream;
|
||||
llvm::raw_fd_ostream& coreBinaryStream;
|
||||
llvm::raw_fd_ostream* coreJsonStream;
|
||||
const llvm::DenseMap<size_t, size_t>& emittedCoreIds;
|
||||
mutable uint32_t emittedInstructionCount = 0;
|
||||
|
||||
size_t addressOf(mlir::Value value, const StaticValueKnowledge& knowledge) const {
|
||||
return memory.getValueAddress(value, knowledge);
|
||||
}
|
||||
size_t remapCoreId(size_t coreId) const;
|
||||
|
||||
static llvm::json::Object createEmptyOffset();
|
||||
void emitInstruction(llvm::json::Object instruction) const;
|
||||
void emitInstruction(const pim_binary::InstructionRecord& instruction) const;
|
||||
|
||||
void genSetRegisterImmediateUnsigned(size_t registerNumber, size_t immediate) const;
|
||||
void setupRd(size_t rdAddress, size_t rdOffset) const;
|
||||
@@ -133,9 +135,12 @@ class PimCodeGen {
|
||||
|
||||
public:
|
||||
PimCodeGen(PimAcceleratorMemory& memory,
|
||||
llvm::raw_fd_ostream& coreJson,
|
||||
llvm::raw_fd_ostream& coreBinary,
|
||||
llvm::raw_fd_ostream* coreJson,
|
||||
const llvm::DenseMap<size_t, size_t>& emittedCoreIds)
|
||||
: memory(memory), coreFileStream(coreJson), emittedCoreIds(emittedCoreIds) {}
|
||||
: memory(memory), coreBinaryStream(coreBinary), coreJsonStream(coreJson), emittedCoreIds(emittedCoreIds) {}
|
||||
|
||||
uint32_t getEmittedInstructionCount() const { return emittedInstructionCount; }
|
||||
|
||||
void codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp, const StaticValueKnowledge& knowledge) const;
|
||||
void codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp, const StaticValueKnowledge& knowledge) const;
|
||||
@@ -164,6 +169,6 @@ public:
|
||||
void codeGenTransposeOp(pim::PimTransposeOp transposeOp, const StaticValueKnowledge& knowledge) const;
|
||||
};
|
||||
|
||||
OnnxMlirCompilerErrorCodes compileToPimJson(mlir::ModuleOp& moduleOpRef, std::string& outputDirName);
|
||||
OnnxMlirCompilerErrorCodes compileToPimCode(mlir::ModuleOp& moduleOpRef, std::string& outputDirName);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
|
||||
#define DEBUG_TYPE "PimCompilerOptions"
|
||||
@@ -13,6 +15,14 @@ llvm::cl::opt<PimEmissionTargetType> pimEmissionTarget(
|
||||
llvm::cl::init(EmitPimCodegen),
|
||||
llvm::cl::cat(OnnxMlirOptions));
|
||||
|
||||
llvm::cl::opt<PimMergeSchedulerType>
|
||||
pimMergeScheduler("pim-merge-scheduler",
|
||||
llvm::cl::desc("Scheduler used by the Spatial merge-compute-nodes pass"),
|
||||
llvm::cl::values(clEnumValN(MergeSchedulerPeft, "peft", "Use PEFT scheduling")),
|
||||
llvm::cl::values(clEnumValN(MergeSchedulerDcp, "dcp", "Use the legacy DCP-inspired scheduler")),
|
||||
llvm::cl::init(MergeSchedulerPeft),
|
||||
llvm::cl::cat(OnnxMlirOptions));
|
||||
|
||||
llvm::cl::opt<bool>
|
||||
pimOnlyCodegen("pim-only-codegen",
|
||||
llvm::cl::desc("Only generate code for PIM (assume input is already in bufferized PIM IR)"),
|
||||
@@ -24,20 +34,25 @@ llvm::cl::opt<bool> useExperimentalConvImpl("use-experimental-conv-impl",
|
||||
llvm::cl::init(false),
|
||||
llvm::cl::cat(OnnxMlirOptions));
|
||||
|
||||
llvm::cl::opt<bool> pimEmitJson("pim-emit-json",
|
||||
llvm::cl::desc("Also emit per-core JSON instruction files alongside binary .pim files"),
|
||||
llvm::cl::init(false),
|
||||
llvm::cl::cat(OnnxMlirOptions));
|
||||
|
||||
llvm::cl::opt<size_t>
|
||||
crossbarSize("crossbar-size", llvm::cl::desc("Width and heigth of a single crossbar"), llvm::cl::init(2));
|
||||
crossbarSize("crossbar-size", llvm::cl::desc("Width and height of a single crossbar"), llvm::cl::init(2));
|
||||
|
||||
llvm::cl::opt<size_t>
|
||||
crossbarCountInCore("crossbar-count", llvm::cl::desc("Number of crossbars in each core"), llvm::cl::init(256));
|
||||
|
||||
llvm::cl::opt<long> coresCount("core-count",
|
||||
llvm::cl::desc("Number of cores in the chip. `-1` to use the minimum amount of cores."),
|
||||
llvm::cl::desc("Number of cores in the chip. Required for PIM compilation."),
|
||||
llvm::cl::init(-1));
|
||||
|
||||
llvm::cl::opt<size_t> dcpCriticalWindowSize(
|
||||
"dcp-critical-window-size",
|
||||
llvm::cl::desc("Number of lowest-slack virtual nodes considered by each DCP coarsening iteration. "
|
||||
"Use 0 to run the legacy full-graph DCP analysis."),
|
||||
"Use 0 to run the legacy full-graph DCP analysis. Only used by the DCP scheduler."),
|
||||
llvm::cl::init(4000));
|
||||
|
||||
llvm::cl::opt<bool>
|
||||
@@ -45,4 +60,13 @@ llvm::cl::opt<bool>
|
||||
llvm::cl::desc("Ignore ConcatOp corner case: do not assert and do a simplification"),
|
||||
llvm::cl::init(false));
|
||||
|
||||
bool hasExplicitPimCoreCount() { return coresCount.getNumOccurrences() != 0; }
|
||||
|
||||
void verifyExplicitPimCoreCount() {
|
||||
if (!hasExplicitPimCoreCount())
|
||||
llvm::report_fatal_error("PIM compilation requires an explicit --core-count=<positive integer>");
|
||||
if (coresCount.getValue() <= 0)
|
||||
llvm::report_fatal_error("PIM compilation requires --core-count to be a positive integer");
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -20,17 +20,27 @@ typedef enum {
|
||||
EmitPimCodegen = 3
|
||||
} PimEmissionTargetType;
|
||||
|
||||
typedef enum {
|
||||
MergeSchedulerPeft = 0,
|
||||
MergeSchedulerDcp = 1,
|
||||
} PimMergeSchedulerType;
|
||||
|
||||
extern llvm::cl::OptionCategory OnnxMlirOptions;
|
||||
extern llvm::cl::opt<PimEmissionTargetType> pimEmissionTarget;
|
||||
extern llvm::cl::opt<PimMergeSchedulerType> pimMergeScheduler;
|
||||
|
||||
extern llvm::cl::opt<bool> pimOnlyCodegen;
|
||||
extern llvm::cl::opt<bool> useExperimentalConvImpl;
|
||||
extern llvm::cl::opt<bool> pimEmitJson;
|
||||
|
||||
extern llvm::cl::opt<size_t> crossbarSize;
|
||||
extern llvm::cl::opt<size_t> crossbarCountInCore;
|
||||
extern llvm::cl::opt<long> coresCount;
|
||||
extern llvm::cl::opt<size_t> dcpCriticalWindowSize;
|
||||
|
||||
bool hasExplicitPimCoreCount();
|
||||
void verifyExplicitPimCoreCount();
|
||||
|
||||
// This option, by default set to false, will ignore an error when resolving a
|
||||
// specific tiles of the operands of a concat. This specific case is when the
|
||||
// wanted tile is generated by two separate operands of the concat. If this is
|
||||
|
||||
@@ -17,6 +17,7 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
|
||||
PassManager& pm,
|
||||
EmissionTargetType& emissionTarget,
|
||||
std::string outputNameNoExt) {
|
||||
verifyExplicitPimCoreCount();
|
||||
|
||||
if (pimOnlyCodegen) {
|
||||
// Skip all the lowering passes and directly generate code for PIM.
|
||||
@@ -52,9 +53,9 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
|
||||
pm.addPass(createPimMaterializeHostConstantsPass());
|
||||
pm.addPass(createPimVerificationPass());
|
||||
pm.addPass(createMessagePass("Pim verified"));
|
||||
pm.addPass(createEmitPimJsonPass());
|
||||
pm.addPass(createEmitPimCodePass());
|
||||
// pm.addPass(createCountInstructionPass());
|
||||
pm.addPass(createMessagePass("Pim json code emitted"));
|
||||
pm.addPass(createMessagePass("Pim code emitted"));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ struct DenseWeightView {
|
||||
};
|
||||
|
||||
FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value weight) {
|
||||
SmallVector<memref::SubViewOp> subviews;
|
||||
SmallVector<Operation*> viewOps;
|
||||
mlir::Value current = weight;
|
||||
memref::GetGlobalOp getGlobalOp;
|
||||
|
||||
@@ -46,7 +46,7 @@ FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value
|
||||
if (auto subview = dyn_cast<memref::SubViewOp>(defOp)) {
|
||||
if (!hasAllStaticSubviewParts(subview))
|
||||
return failure();
|
||||
subviews.push_back(subview);
|
||||
viewOps.push_back(subview);
|
||||
current = subview.getSource();
|
||||
continue;
|
||||
}
|
||||
@@ -54,6 +54,24 @@ FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value
|
||||
current = cast.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto collapse = dyn_cast<memref::CollapseShapeOp>(defOp)) {
|
||||
auto srcType = dyn_cast<MemRefType>(collapse.getSrc().getType());
|
||||
auto resultType = dyn_cast<MemRefType>(collapse.getResult().getType());
|
||||
if (!srcType || !resultType || !srcType.hasStaticShape() || !resultType.hasStaticShape())
|
||||
return failure();
|
||||
viewOps.push_back(collapse);
|
||||
current = collapse.getSrc();
|
||||
continue;
|
||||
}
|
||||
if (auto expand = dyn_cast<memref::ExpandShapeOp>(defOp)) {
|
||||
auto srcType = dyn_cast<MemRefType>(expand.getSrc().getType());
|
||||
auto resultType = dyn_cast<MemRefType>(expand.getResult().getType());
|
||||
if (!srcType || !resultType || !srcType.hasStaticShape() || !resultType.hasStaticShape())
|
||||
return failure();
|
||||
viewOps.push_back(expand);
|
||||
current = expand.getSrc();
|
||||
continue;
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
@@ -70,16 +88,39 @@ FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value
|
||||
view.shape.assign(denseAttr.getType().getShape().begin(), denseAttr.getType().getShape().end());
|
||||
view.strides = computeRowMajorStrides(view.shape);
|
||||
|
||||
for (memref::SubViewOp subview : llvm::reverse(subviews)) {
|
||||
SmallVector<int64_t> nextStrides;
|
||||
nextStrides.reserve(subview.getStaticStrides().size());
|
||||
for (auto [offset, stride, sourceStride] :
|
||||
llvm::zip_equal(subview.getStaticOffsets(), subview.getStaticStrides(), view.strides)) {
|
||||
view.offset += offset * sourceStride;
|
||||
nextStrides.push_back(stride * sourceStride);
|
||||
for (Operation* viewOp : llvm::reverse(viewOps)) {
|
||||
if (auto subview = dyn_cast<memref::SubViewOp>(viewOp)) {
|
||||
SmallVector<int64_t> nextStrides;
|
||||
nextStrides.reserve(subview.getStaticStrides().size());
|
||||
for (auto [offset, stride, sourceStride] :
|
||||
llvm::zip_equal(subview.getStaticOffsets(), subview.getStaticStrides(), view.strides)) {
|
||||
view.offset += offset * sourceStride;
|
||||
nextStrides.push_back(stride * sourceStride);
|
||||
}
|
||||
view.shape.assign(subview.getStaticSizes().begin(), subview.getStaticSizes().end());
|
||||
view.strides = std::move(nextStrides);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Collapse/expand are accepted only as contiguous static reshapes of a
|
||||
// dense global view, so a row-major stride recomputation preserves layout.
|
||||
if (auto collapse = dyn_cast<memref::CollapseShapeOp>(viewOp)) {
|
||||
if (view.strides != computeRowMajorStrides(view.shape))
|
||||
return failure();
|
||||
auto resultType = cast<MemRefType>(collapse.getResult().getType());
|
||||
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
|
||||
view.strides = computeRowMajorStrides(view.shape);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto expand = dyn_cast<memref::ExpandShapeOp>(viewOp)) {
|
||||
if (view.strides != computeRowMajorStrides(view.shape))
|
||||
return failure();
|
||||
auto resultType = cast<MemRefType>(expand.getResult().getType());
|
||||
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
|
||||
view.strides = computeRowMajorStrides(view.shape);
|
||||
continue;
|
||||
}
|
||||
view.shape.assign(subview.getStaticSizes().begin(), subview.getStaticSizes().end());
|
||||
view.strides = std::move(nextStrides);
|
||||
}
|
||||
|
||||
return view;
|
||||
@@ -87,12 +128,20 @@ FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value
|
||||
|
||||
SmallVector<unsigned, 8> getUsedWeightIndices(Block& block) {
|
||||
SmallVector<unsigned, 8> indices;
|
||||
auto addIndex = [&](unsigned weightIndex) {
|
||||
if (!llvm::is_contained(indices, weightIndex))
|
||||
indices.push_back(weightIndex);
|
||||
auto coreOp = dyn_cast<pim::PimCoreOp>(block.getParentOp());
|
||||
auto addWeight = [&](mlir::Value weight) {
|
||||
if (!coreOp)
|
||||
return;
|
||||
for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex) {
|
||||
if (coreOp.getWeightArgument(weightIndex) != weight)
|
||||
continue;
|
||||
if (!llvm::is_contained(indices, weightIndex))
|
||||
indices.push_back(weightIndex);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
block.walk([&](pim::PimVMMOp vmmOp) { addIndex(vmmOp.getWeightIndex()); });
|
||||
block.walk([&](pim::PimVMMOp vmmOp) { addWeight(vmmOp.getWeight()); });
|
||||
llvm::sort(indices);
|
||||
return indices;
|
||||
}
|
||||
@@ -159,7 +208,7 @@ createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) {
|
||||
int64_t numCols = shape[1];
|
||||
assert(numRows <= xbarSize && numCols <= xbarSize && "Weight dimensions must not exceed crossbar size");
|
||||
|
||||
size_t elementByteWidth = denseAttr.getElementType().getIntOrFloatBitWidth() / 8;
|
||||
size_t elementByteWidth = getElementTypeSizeInBytes(denseAttr.getElementType());
|
||||
|
||||
std::string newFileName = "crossbar_" + std::to_string(indexFileName++) + ".bin";
|
||||
auto weightFilePath = (coreWeightsDirPath + "/" + newFileName).str();
|
||||
|
||||
@@ -18,13 +18,17 @@ namespace detail {
|
||||
|
||||
inline mlir::ValueRange getBlockArgs(mlir::Block* block) { return mlir::ValueRange(block->getArguments()); }
|
||||
|
||||
inline mlir::ValueRange getInputBlockArgs(mlir::Block* block, size_t weightCount) {
|
||||
return mlir::ValueRange(block->getArguments()).drop_front(weightCount);
|
||||
}
|
||||
|
||||
template <typename Fn, size_t... Is>
|
||||
decltype(auto) invokeWithBlockArgs(Fn&& fn, mlir::Block* block, std::index_sequence<Is...>) {
|
||||
return std::forward<Fn>(fn)(block->getArgument(Is)...);
|
||||
}
|
||||
|
||||
template <typename Fn, size_t... Is>
|
||||
decltype(auto) invokeWithValues(Fn&& fn, mlir::ArrayRef<mlir::Value> values, std::index_sequence<Is...>) {
|
||||
decltype(auto) invokeWithValues(Fn&& fn, mlir::ValueRange values, std::index_sequence<Is...>) {
|
||||
return std::forward<Fn>(fn)(values[Is]...);
|
||||
}
|
||||
|
||||
@@ -85,6 +89,8 @@ auto createSpatCompute(RewriterT& rewriter,
|
||||
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
||||
|
||||
auto* block = new mlir::Block();
|
||||
for (mlir::Value weight : weights)
|
||||
block->addArgument(weight.getType(), loc);
|
||||
for (mlir::Value input : inputs)
|
||||
block->addArgument(input.getType(), loc);
|
||||
|
||||
@@ -93,14 +99,17 @@ auto createSpatCompute(RewriterT& rewriter,
|
||||
|
||||
using BodyResult = detail::InvokeWithBlockArgsResultT<std::decay_t<BodyFn>, std::make_index_sequence<NumInputs>>;
|
||||
if constexpr (std::is_same_v<BodyResult, void>) {
|
||||
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
|
||||
detail::invokeWithValues(std::forward<BodyFn>(body),
|
||||
detail::getInputBlockArgs(block, weights.size()),
|
||||
std::make_index_sequence<NumInputs> {});
|
||||
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
return computeOp;
|
||||
}
|
||||
else {
|
||||
auto bodyResult =
|
||||
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
|
||||
auto bodyResult = detail::invokeWithValues(std::forward<BodyFn>(body),
|
||||
detail::getInputBlockArgs(block, weights.size()),
|
||||
std::make_index_sequence<NumInputs> {});
|
||||
if (mlir::failed(bodyResult)) {
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
rewriter.eraseOp(computeOp);
|
||||
@@ -123,6 +132,8 @@ auto createSpatCompute(RewriterT& rewriter,
|
||||
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
||||
|
||||
auto* block = new mlir::Block();
|
||||
for (mlir::Value weight : weights)
|
||||
block->addArgument(weight.getType(), loc);
|
||||
for (mlir::Value input : inputs)
|
||||
block->addArgument(input.getType(), loc);
|
||||
|
||||
@@ -131,13 +142,13 @@ auto createSpatCompute(RewriterT& rewriter,
|
||||
|
||||
using BodyResult = detail::InvokeWithValueRangeResultT<std::decay_t<BodyFn>>;
|
||||
if constexpr (std::is_same_v<BodyResult, void>) {
|
||||
std::forward<BodyFn>(body)(detail::getBlockArgs(block));
|
||||
std::forward<BodyFn>(body)(detail::getInputBlockArgs(block, weights.size()));
|
||||
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
return computeOp;
|
||||
}
|
||||
else {
|
||||
auto bodyResult = std::forward<BodyFn>(body)(detail::getBlockArgs(block));
|
||||
auto bodyResult = std::forward<BodyFn>(body)(detail::getInputBlockArgs(block, weights.size()));
|
||||
if (mlir::failed(bodyResult)) {
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
rewriter.eraseOp(computeOp);
|
||||
|
||||
@@ -100,18 +100,27 @@ DenseMap<HSliceId, DenseMap<CoreId, SmallVector<Value>>> tileMatrix(
|
||||
return tiles;
|
||||
}
|
||||
|
||||
tensor::SplatOp
|
||||
broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
Value broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
auto oldType = cast<RankedTensorType>(scalarToBroadcast.getType());
|
||||
Type elementType = oldType.getElementType();
|
||||
int64_t shape[2] = {1, length};
|
||||
Type type = oldType.cloneWith(ArrayRef(shape), elementType);
|
||||
|
||||
auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0).getResult();
|
||||
SmallVector<Value> index(oldType.getRank(), zero);
|
||||
auto elementValue = tensor::ExtractOp::create(rewriter, loc, scalarToBroadcast, index).getResult();
|
||||
auto buildBroadcast = [&](Value input) -> Value {
|
||||
auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0).getResult();
|
||||
SmallVector<Value> index(oldType.getRank(), zero);
|
||||
auto elementValue = tensor::ExtractOp::create(rewriter, loc, input, index).getResult();
|
||||
return tensor::SplatOp::create(rewriter, loc, type, elementValue);
|
||||
};
|
||||
|
||||
return tensor::SplatOp::create(rewriter, loc, type, elementValue);
|
||||
if (isHostFoldableValue(scalarToBroadcast))
|
||||
return buildBroadcast(scalarToBroadcast);
|
||||
|
||||
auto broadcastCompute =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {type}, {}, ValueRange {scalarToBroadcast}, [&](Value input) {
|
||||
spatial::SpatYieldOp::create(rewriter, loc, buildBroadcast(input));
|
||||
});
|
||||
return broadcastCompute.getResult(0);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -136,9 +136,9 @@ tileMatrix(mlir::Value& matrixToTile,
|
||||
mlir::ConversionPatternRewriter& rewriter,
|
||||
mlir::Location& loc);
|
||||
|
||||
mlir::tensor::SplatOp broadcastToVector(mlir::Value scalarToBroadcast,
|
||||
int64_t length,
|
||||
mlir::ConversionPatternRewriter& rewriter,
|
||||
mlir::Location loc);
|
||||
mlir::Value broadcastToVector(mlir::Value scalarToBroadcast,
|
||||
int64_t length,
|
||||
mlir::ConversionPatternRewriter& rewriter,
|
||||
mlir::Location loc);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallBitVector.h"
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
@@ -18,6 +22,11 @@ static bool hasStaticUnitStrides(tensor::ExtractSliceOp extractSliceOp) {
|
||||
return llvm::all_of(extractSliceOp.getStaticStrides(), [](int64_t stride) { return stride == 1; });
|
||||
}
|
||||
|
||||
static bool hasConstantIndices(tensor::ExtractOp extractOp) {
|
||||
return llvm::all_of(extractOp.getIndices(),
|
||||
[](Value index) { return isa_and_nonnull<arith::ConstantIndexOp>(index.getDefiningOp()); });
|
||||
}
|
||||
|
||||
static bool isStaticTensorResult(Operation* op) {
|
||||
return llvm::all_of(op->getResultTypes(), [](Type type) {
|
||||
auto shapedType = dyn_cast<ShapedType>(type);
|
||||
@@ -25,6 +34,167 @@ static bool isStaticTensorResult(Operation* op) {
|
||||
});
|
||||
}
|
||||
|
||||
static SmallVector<int64_t> computeRowMajorStrides(ArrayRef<int64_t> shape) {
|
||||
SmallVector<int64_t> strides(shape.size(), 1);
|
||||
for (int64_t dim = static_cast<int64_t>(shape.size()) - 2; dim >= 0; --dim)
|
||||
strides[dim] = strides[dim + 1] * shape[dim + 1];
|
||||
return strides;
|
||||
}
|
||||
|
||||
static FailureOr<DenseElementsAttr> transposeDenseElements(DenseElementsAttr denseAttr, ArrayRef<int64_t> perms) {
|
||||
auto tensorType = dyn_cast<RankedTensorType>(denseAttr.getType());
|
||||
if (!tensorType)
|
||||
return failure();
|
||||
|
||||
int64_t rank = tensorType.getRank();
|
||||
if (static_cast<int64_t>(perms.size()) != rank)
|
||||
return failure();
|
||||
|
||||
llvm::SmallBitVector seen(rank);
|
||||
SmallVector<int64_t> transposedShape;
|
||||
transposedShape.reserve(rank);
|
||||
for (int64_t perm : perms) {
|
||||
if (perm < 0 || perm >= rank || seen.test(perm))
|
||||
return failure();
|
||||
seen.set(perm);
|
||||
transposedShape.push_back(tensorType.getShape()[perm]);
|
||||
}
|
||||
|
||||
auto transposedType = RankedTensorType::get(transposedShape, tensorType.getElementType(), tensorType.getEncoding());
|
||||
if (denseAttr.isSplat())
|
||||
return DenseElementsAttr::get(transposedType, denseAttr.getSplatValue<Attribute>());
|
||||
|
||||
SmallVector<Attribute> originalValues(denseAttr.getValues<Attribute>());
|
||||
SmallVector<Attribute> transposedValues(originalValues.size());
|
||||
SmallVector<int64_t> originalStrides = computeRowMajorStrides(tensorType.getShape());
|
||||
SmallVector<int64_t> transposedStrides = computeRowMajorStrides(transposedShape);
|
||||
SmallVector<int64_t> originalIndices(rank);
|
||||
|
||||
for (auto [linearIndex, value] : llvm::enumerate(originalValues)) {
|
||||
int64_t remaining = static_cast<int64_t>(linearIndex);
|
||||
for (int64_t dim = 0; dim < rank; ++dim) {
|
||||
originalIndices[dim] = remaining / originalStrides[dim];
|
||||
remaining %= originalStrides[dim];
|
||||
}
|
||||
|
||||
int64_t transposedLinearIndex = 0;
|
||||
for (int64_t dim = 0; dim < rank; ++dim)
|
||||
transposedLinearIndex += originalIndices[perms[dim]] * transposedStrides[dim];
|
||||
|
||||
transposedValues[transposedLinearIndex] = value;
|
||||
}
|
||||
|
||||
return DenseElementsAttr::get(transposedType, transposedValues);
|
||||
}
|
||||
|
||||
static FailureOr<DenseElementsAttr> reshapeDenseElements(DenseElementsAttr denseAttr, RankedTensorType resultType) {
|
||||
auto sourceType = dyn_cast<RankedTensorType>(denseAttr.getType());
|
||||
if (!sourceType || !resultType || sourceType.getNumElements() != resultType.getNumElements())
|
||||
return failure();
|
||||
|
||||
if (denseAttr.isSplat())
|
||||
return DenseElementsAttr::get(resultType, denseAttr.getSplatValue<Attribute>());
|
||||
|
||||
SmallVector<Attribute> values(denseAttr.getValues<Attribute>());
|
||||
return DenseElementsAttr::get(resultType, values);
|
||||
}
|
||||
|
||||
static FailureOr<DenseElementsAttr> extractSliceDenseElements(DenseElementsAttr denseAttr,
|
||||
tensor::ExtractSliceOp extractSliceOp) {
|
||||
auto sourceType = dyn_cast<RankedTensorType>(denseAttr.getType());
|
||||
auto resultType = dyn_cast<RankedTensorType>(extractSliceOp.getType());
|
||||
if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
ArrayRef<int64_t> offsets = extractSliceOp.getStaticOffsets();
|
||||
ArrayRef<int64_t> sizes = extractSliceOp.getStaticSizes();
|
||||
ArrayRef<int64_t> strides = extractSliceOp.getStaticStrides();
|
||||
if (llvm::any_of(offsets, [](int64_t value) { return ShapedType::isDynamic(value); })
|
||||
|| llvm::any_of(sizes, [](int64_t value) { return ShapedType::isDynamic(value); })
|
||||
|| llvm::any_of(strides, [](int64_t stride) { return ShapedType::isDynamic(stride) || stride != 1; }))
|
||||
return failure();
|
||||
|
||||
if (denseAttr.isSplat())
|
||||
return DenseElementsAttr::get(resultType, denseAttr.getSplatValue<Attribute>());
|
||||
|
||||
SmallVector<Attribute> sourceValues(denseAttr.getValues<Attribute>());
|
||||
SmallVector<int64_t> sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
||||
SmallVector<int64_t> resultStrides = computeRowMajorStrides(resultType.getShape());
|
||||
SmallVector<Attribute> resultValues;
|
||||
resultValues.reserve(resultType.getNumElements());
|
||||
|
||||
for (int64_t linearIndex = 0; linearIndex < resultType.getNumElements(); ++linearIndex) {
|
||||
int64_t remaining = linearIndex;
|
||||
int64_t sourceLinearIndex = 0;
|
||||
for (int64_t dim = 0; dim < resultType.getRank(); ++dim) {
|
||||
const int64_t resultIndex = resultStrides.empty() ? 0 : remaining / resultStrides[dim];
|
||||
remaining = resultStrides.empty() ? 0 : remaining % resultStrides[dim];
|
||||
sourceLinearIndex += (offsets[dim] + resultIndex) * sourceStrides[dim];
|
||||
}
|
||||
resultValues.push_back(sourceValues[sourceLinearIndex]);
|
||||
}
|
||||
|
||||
return DenseElementsAttr::get(resultType, resultValues);
|
||||
}
|
||||
|
||||
static DenseElementsAttr getDirectDenseConstantAttr(Value value) {
|
||||
if (auto constantOp = value.getDefiningOp<arith::ConstantOp>())
|
||||
return dyn_cast<DenseElementsAttr>(constantOp.getValue());
|
||||
if (auto constantOp = value.getDefiningOp<ONNXConstantOp>())
|
||||
return dyn_cast_or_null<DenseElementsAttr>(constantOp.getValueAttr());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static DenseElementsAttr getHostFoldableDenseElementsAttrImpl(Value value, llvm::SmallPtrSetImpl<Operation*>& visited) {
|
||||
auto* definingOp = value.getDefiningOp();
|
||||
if (!definingOp || !visited.insert(definingOp).second)
|
||||
return nullptr;
|
||||
|
||||
// Rebuild dense attributes through view-only host-foldable chains so later
|
||||
// lowering stages can still recognize grouped/sliced constants.
|
||||
if (auto denseAttr = getDirectDenseConstantAttr(value))
|
||||
return denseAttr;
|
||||
|
||||
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(definingOp)) {
|
||||
auto inputAttr = getHostFoldableDenseElementsAttrImpl(transposeOp.getData(), visited);
|
||||
if (!inputAttr)
|
||||
return nullptr;
|
||||
|
||||
SmallVector<int64_t> perm;
|
||||
perm.reserve(transposeOp.getPermAttr().size());
|
||||
for (IntegerAttr attr : transposeOp.getPermAttr().getAsRange<IntegerAttr>())
|
||||
perm.push_back(attr.getInt());
|
||||
auto transposedAttr = transposeDenseElements(inputAttr, perm);
|
||||
return succeeded(transposedAttr) ? *transposedAttr : nullptr;
|
||||
}
|
||||
|
||||
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(definingOp)) {
|
||||
auto inputAttr = getHostFoldableDenseElementsAttrImpl(collapseShapeOp.getSrc(), visited);
|
||||
if (!inputAttr)
|
||||
return nullptr;
|
||||
auto reshapedAttr = reshapeDenseElements(inputAttr, cast<RankedTensorType>(collapseShapeOp.getType()));
|
||||
return succeeded(reshapedAttr) ? *reshapedAttr : nullptr;
|
||||
}
|
||||
|
||||
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(definingOp)) {
|
||||
auto inputAttr = getHostFoldableDenseElementsAttrImpl(expandShapeOp.getSrc(), visited);
|
||||
if (!inputAttr)
|
||||
return nullptr;
|
||||
auto reshapedAttr = reshapeDenseElements(inputAttr, cast<RankedTensorType>(expandShapeOp.getType()));
|
||||
return succeeded(reshapedAttr) ? *reshapedAttr : nullptr;
|
||||
}
|
||||
|
||||
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) {
|
||||
auto inputAttr = getHostFoldableDenseElementsAttrImpl(extractSliceOp.getSource(), visited);
|
||||
if (!inputAttr)
|
||||
return nullptr;
|
||||
auto slicedAttr = extractSliceDenseElements(inputAttr, extractSliceOp);
|
||||
return succeeded(slicedAttr) ? *slicedAttr : nullptr;
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static bool isHostFoldableOpImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*>& visited) {
|
||||
if (!op || !visited.insert(op).second)
|
||||
return false;
|
||||
@@ -32,6 +202,9 @@ static bool isHostFoldableOpImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*
|
||||
if (isa<arith::ConstantOp, ONNXConstantOp, ONNXNoneOp>(op))
|
||||
return true;
|
||||
|
||||
if (auto extractOp = dyn_cast<tensor::ExtractOp>(op))
|
||||
return hasConstantIndices(extractOp) && isHostFoldableValue(extractOp.getTensor());
|
||||
|
||||
if (!isStaticTensorResult(op))
|
||||
return false;
|
||||
|
||||
@@ -47,6 +220,9 @@ static bool isHostFoldableOpImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*
|
||||
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op))
|
||||
return hasStaticUnitStrides(extractSliceOp) && isHostFoldableValue(extractSliceOp.getSource());
|
||||
|
||||
if (auto splatOp = dyn_cast<tensor::SplatOp>(op))
|
||||
return isHostFoldableValue(splatOp.getInput());
|
||||
|
||||
if (auto extractRowsOp = dyn_cast<spatial::SpatExtractRowsOp>(op))
|
||||
return isHostFoldableValue(extractRowsOp.getInput());
|
||||
|
||||
@@ -72,4 +248,9 @@ bool isHostFoldableOp(Operation* op) {
|
||||
return isHostFoldableOpImpl(op, visited);
|
||||
}
|
||||
|
||||
DenseElementsAttr getHostFoldableDenseElementsAttr(Value value) {
|
||||
llvm::SmallPtrSet<Operation*, 8> visited;
|
||||
return getHostFoldableDenseElementsAttrImpl(value, visited);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
@@ -9,4 +10,6 @@ bool isHostFoldableValue(mlir::Value value);
|
||||
|
||||
bool isHostFoldableOp(mlir::Operation* op);
|
||||
|
||||
mlir::DenseElementsAttr getHostFoldableDenseElementsAttr(mlir::Value value);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostLegality.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
@@ -11,7 +12,7 @@ using namespace mlir;
|
||||
namespace onnx_mlir {
|
||||
|
||||
LogicalResult verifyONNXToSpatialHostLegality(func::FuncOp funcOp) {
|
||||
bool hasFailure = false;
|
||||
pim::CappedDiagnosticReporter diagnostics;
|
||||
|
||||
for (Operation& op : funcOp.getFunctionBody().front()) {
|
||||
if (isa<func::ReturnOp, spatial::SpatCompute, spatial::SpatComputeBatch>(&op))
|
||||
@@ -19,11 +20,15 @@ LogicalResult verifyONNXToSpatialHostLegality(func::FuncOp funcOp) {
|
||||
if (isHostFoldableOp(&op))
|
||||
continue;
|
||||
|
||||
op.emitOpError("non-foldable top-level runtime op remains after ONNX-to-Spatial; lower it inside spat.compute");
|
||||
hasFailure = true;
|
||||
diagnostics.report(&op, [](Operation* illegalOp) {
|
||||
illegalOp->emitOpError("non-foldable top-level runtime op remains after ONNX-to-Spatial; lower it inside "
|
||||
"spat.compute");
|
||||
});
|
||||
}
|
||||
|
||||
return success(!hasFailure);
|
||||
diagnostics.emitSuppressedSummary(funcOp, "ONNX-to-Spatial host legality failures");
|
||||
|
||||
return success(!diagnostics.hasFailure());
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -5,17 +5,15 @@
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
||||
#include "Common/Common.hpp"
|
||||
#include "Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostLegality.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp"
|
||||
@@ -46,7 +44,8 @@ static void populateEmptyFunction(func::FuncOp funcOp) {
|
||||
IRRewriter rewriter(funcOp.getContext());
|
||||
IRMapping mapper;
|
||||
SmallVector<spatial::SpatCompute> computes(funcOp.getOps<spatial::SpatCompute>());
|
||||
if (!computes.empty())
|
||||
SmallVector<spatial::SpatComputeBatch> computeBatches(funcOp.getOps<spatial::SpatComputeBatch>());
|
||||
if (!computes.empty() || !computeBatches.empty())
|
||||
return;
|
||||
|
||||
auto returnOp = cast<func::ReturnOp>(funcOp.getFunctionBody().front().getTerminator());
|
||||
@@ -87,17 +86,68 @@ static void populateEmptyFunction(func::FuncOp funcOp) {
|
||||
returnOp.setOperand(index, computeResult);
|
||||
}
|
||||
|
||||
static void wrapTopLevelRuntimeTransposes(func::FuncOp funcOp) {
|
||||
IRRewriter rewriter(funcOp.getContext());
|
||||
Block& entryBlock = funcOp.getFunctionBody().front();
|
||||
|
||||
for (Operation& op : llvm::make_early_inc_range(entryBlock)) {
|
||||
auto transposeOp = dyn_cast<ONNXTransposeOp>(&op);
|
||||
if (!transposeOp || isHostFoldableOp(transposeOp))
|
||||
continue;
|
||||
|
||||
// Transpose stays globally legal because constant/view-only cases are
|
||||
// allowed on the host. Any residual runtime transpose must be sunk into
|
||||
// spat.compute before the host legality check.
|
||||
auto resultType = transposeOp.getResult().getType();
|
||||
rewriter.setInsertionPoint(transposeOp);
|
||||
auto computeOp = createSpatCompute<1>(
|
||||
rewriter, transposeOp.getLoc(), TypeRange {resultType}, {}, ValueRange {transposeOp.getData()}, [&](Value input) {
|
||||
Value transposed =
|
||||
ONNXTransposeOp::create(rewriter, transposeOp.getLoc(), resultType, input, transposeOp.getPermAttr());
|
||||
spatial::SpatYieldOp::create(rewriter, transposeOp.getLoc(), transposed);
|
||||
});
|
||||
rewriter.replaceOp(transposeOp, computeOp.getResult(0));
|
||||
}
|
||||
}
|
||||
|
||||
void ONNXToSpatialPass::runOnOperation() {
|
||||
ModuleOp moduleOp = getOperation();
|
||||
MLIRContext* ctx = &getContext();
|
||||
|
||||
ConversionTarget preTarget(*ctx);
|
||||
preTarget.addLegalDialect<spatial::SpatialDialect,
|
||||
ONNXDialect,
|
||||
tensor::TensorDialect,
|
||||
arith::ArithDialect,
|
||||
scf::SCFDialect>();
|
||||
preTarget.addIllegalOp<ONNXConstantOp, ONNXFlattenOp>();
|
||||
|
||||
RewritePatternSet prePatterns(ctx);
|
||||
populatePrePatterns(prePatterns, ctx);
|
||||
if (failed(applyPatternsGreedily(moduleOp, std::move(prePatterns))))
|
||||
moduleOp.emitWarning("failed to apply ONNX-to-Spatial pre-patterns; continuing");
|
||||
if (failed(applyPartialConversion(moduleOp, preTarget, std::move(prePatterns)))) {
|
||||
moduleOp.emitError("failed to apply ONNX-to-Spatial pre-rewrites");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
auto entryFunc = getPimEntryFunc(moduleOp);
|
||||
if (failed(entryFunc)) {
|
||||
moduleOp.emitError("failed to locate the PIM entry function during ONNX-to-Spatial lowering");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
RewritePatternSet matmulPatterns(ctx);
|
||||
populateMatMulRewritePatterns(matmulPatterns, ctx);
|
||||
walkAndApplyPatterns(moduleOp, std::move(matmulPatterns));
|
||||
|
||||
bool hasUnloweredMatMul = false;
|
||||
moduleOp.walk([&](ONNXMatMulOp matmulOp) {
|
||||
hasUnloweredMatMul = true;
|
||||
matmulOp.emitOpError("remaining ONNX MatMul before the required ONNX-to-Spatial conversion");
|
||||
});
|
||||
if (hasUnloweredMatMul) {
|
||||
moduleOp.emitError("failed to lower all ONNX MatMul ops before ONNX-to-Spatial conversion");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
@@ -130,30 +180,17 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
RewritePatternSet conversionPatterns(ctx);
|
||||
populateConversionPatterns(conversionPatterns, ctx);
|
||||
if (failed(applyPartialConversion(moduleOp, target, std::move(conversionPatterns)))) {
|
||||
moduleOp.emitError("failed to convert required ONNX ops to Spatial ops");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
RewritePatternSet earlyPostPatterns(ctx);
|
||||
populateEarlyPostPatterns(earlyPostPatterns, ctx);
|
||||
if (failed(applyPatternsGreedily(*entryFunc, std::move(earlyPostPatterns)))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
if (coresCount != -1) {
|
||||
int computeOpsCount = 0;
|
||||
for (Operation& op : entryFunc->getFunctionBody().front().getOperations())
|
||||
if (isa<spatial::SpatCompute>(op))
|
||||
computeOpsCount++;
|
||||
|
||||
if (computeOpsCount > coresCount) {
|
||||
entryFunc->emitError() << "number of compute ops (" << computeOpsCount << ") exceeds the core count ("
|
||||
<< coresCount << ")";
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
}
|
||||
ConversionTarget earlyPostTarget(*ctx);
|
||||
earlyPostTarget.addLegalDialect<spatial::SpatialDialect,
|
||||
ONNXDialect,
|
||||
tensor::TensorDialect,
|
||||
arith::ArithDialect,
|
||||
scf::SCFDialect>();
|
||||
|
||||
PassManager cleanupPM(ctx);
|
||||
cleanupPM.addPass(createCanonicalizerPass());
|
||||
@@ -162,14 +199,29 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
|
||||
annotateWeightsConstants(*entryFunc);
|
||||
|
||||
ConversionTarget postTarget(*ctx);
|
||||
postTarget.addLegalDialect<spatial::SpatialDialect,
|
||||
ONNXDialect,
|
||||
tensor::TensorDialect,
|
||||
arith::ArithDialect,
|
||||
scf::SCFDialect>();
|
||||
postTarget.addDynamicallyLegalOp<spatial::SpatCompute>(
|
||||
[](spatial::SpatCompute computeOp) { return !requiresPostRewrite(computeOp); });
|
||||
postTarget.addDynamicallyLegalOp<spatial::SpatComputeBatch>(
|
||||
[](spatial::SpatComputeBatch computeOp) { return !requiresPostRewrite(computeOp); });
|
||||
|
||||
RewritePatternSet postPatterns(ctx);
|
||||
populatePostPatterns(postPatterns, ctx);
|
||||
if (failed(applyPatternsGreedily(*entryFunc, std::move(postPatterns)))) {
|
||||
if (failed(applyPartialConversion(*entryFunc, postTarget, std::move(postPatterns)))) {
|
||||
moduleOp.emitError("failed to normalize weight-like Spatial compute operands before Spatial-to-PIM lowering");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
wrapTopLevelRuntimeTransposes(*entryFunc);
|
||||
|
||||
if (failed(verifyONNXToSpatialHostLegality(*entryFunc))) {
|
||||
moduleOp.emitError("ONNX-to-Spatial host legality verification failed");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -27,16 +28,6 @@ struct ConvToGemm : OpConversionPattern<ONNXConvOp> {
|
||||
ConversionPatternRewriter& rewriter) const override;
|
||||
};
|
||||
|
||||
static DenseElementsAttr getDenseConstantAttr(Value value) {
|
||||
if (auto constantOp = value.getDefiningOp<arith::ConstantOp>())
|
||||
return dyn_cast<DenseElementsAttr>(constantOp.getValue());
|
||||
|
||||
if (auto constantOp = value.getDefiningOp<ONNXConstantOp>())
|
||||
return dyn_cast_or_null<DenseElementsAttr>(constantOp.getValueAttr());
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static int64_t getI64FromArrayAttr(ArrayAttr arr, size_t idx) { return cast<IntegerAttr>(arr[idx]).getInt(); }
|
||||
|
||||
static Value expandBiasIfNeeded(Value bias, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
@@ -355,49 +346,22 @@ static Value createCollectedConvOutput(ValueRange gemmRows,
|
||||
return collectComputeOp.getResult(0);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
ONNXConvOpAdaptor convOpAdaptor,
|
||||
ConversionPatternRewriter& rewriter) const {
|
||||
Location loc = convOp.getLoc();
|
||||
Value x = convOpAdaptor.getX();
|
||||
Value w = convOpAdaptor.getW();
|
||||
Value b = convOpAdaptor.getB();
|
||||
|
||||
auto xType = cast<RankedTensorType>(x.getType());
|
||||
auto wType = cast<RankedTensorType>(w.getType());
|
||||
auto outType = cast<RankedTensorType>(convOp.getY().getType());
|
||||
|
||||
if (!xType.hasStaticShape()) {
|
||||
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv input");
|
||||
return failure();
|
||||
}
|
||||
if (!wType.hasStaticShape()) {
|
||||
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv weight");
|
||||
return failure();
|
||||
}
|
||||
if (!outType.hasStaticShape()) {
|
||||
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv result");
|
||||
return failure();
|
||||
}
|
||||
if (xType.getRank() != 4) {
|
||||
pim::emitUnsupportedRankDiagnostic(convOp, "conv input", xType.getRank(), {4});
|
||||
return failure();
|
||||
}
|
||||
if (wType.getRank() != 4) {
|
||||
pim::emitUnsupportedRankDiagnostic(convOp, "conv weight", wType.getRank(), {4});
|
||||
return failure();
|
||||
}
|
||||
if (outType.getRank() != 4) {
|
||||
pim::emitUnsupportedRankDiagnostic(convOp, "conv result", outType.getRank(), {4});
|
||||
return failure();
|
||||
}
|
||||
if (convOp.getGroup() != 1) {
|
||||
convOp.emitOpError("only group=1 convolution is supported for Spatial lowering");
|
||||
return failure();
|
||||
}
|
||||
|
||||
static Value lowerSingleConvGroup(Value x,
|
||||
Value w,
|
||||
Value b,
|
||||
RankedTensorType xType,
|
||||
RankedTensorType wType,
|
||||
RankedTensorType outType,
|
||||
int64_t padHeightBegin,
|
||||
int64_t padHeightEnd,
|
||||
int64_t padWidthBegin,
|
||||
int64_t padWidthEnd,
|
||||
int64_t strideHeight,
|
||||
int64_t strideWidth,
|
||||
int64_t dilationHeight,
|
||||
int64_t dilationWidth,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
const int64_t batchSize = xType.getDimSize(0);
|
||||
const int64_t numChannelsIn = xType.getDimSize(1);
|
||||
const int64_t xHeight = xType.getDimSize(2);
|
||||
@@ -408,71 +372,6 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
const int64_t outHeight = outType.getDimSize(2);
|
||||
const int64_t outWidth = outType.getDimSize(3);
|
||||
|
||||
// Read optional conv attributes (ONNX defaults: stride=1, dilation=1, pad=0)
|
||||
const auto stridesAttr = convOp.getStrides();
|
||||
const auto dilationsAttr = convOp.getDilations();
|
||||
const auto padsAttr = convOp.getPads();
|
||||
|
||||
if (stridesAttr && stridesAttr->size() != 2) {
|
||||
convOp.emitOpError("requires exactly two stride values for Spatial lowering");
|
||||
return failure();
|
||||
}
|
||||
if (dilationsAttr && dilationsAttr->size() != 2) {
|
||||
convOp.emitOpError("requires exactly two dilation values for Spatial lowering");
|
||||
return failure();
|
||||
}
|
||||
if (padsAttr && padsAttr->size() != 4) {
|
||||
convOp.emitOpError("requires exactly four pad values for 2D Spatial lowering");
|
||||
return failure();
|
||||
}
|
||||
|
||||
const int64_t strideHeight = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 0) : 1;
|
||||
const int64_t strideWidth = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 1) : 1;
|
||||
const int64_t dilationHeight = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 0) : 1;
|
||||
const int64_t dilationWidth = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 1) : 1;
|
||||
|
||||
int64_t padHeightBegin = 0;
|
||||
int64_t padHeightEnd = 0;
|
||||
int64_t padWidthBegin = 0;
|
||||
int64_t padWidthEnd = 0;
|
||||
|
||||
if (padsAttr) {
|
||||
padHeightBegin = getI64FromArrayAttr(*padsAttr, 0);
|
||||
padWidthBegin = getI64FromArrayAttr(*padsAttr, 1);
|
||||
padHeightEnd = getI64FromArrayAttr(*padsAttr, 2);
|
||||
padWidthEnd = getI64FromArrayAttr(*padsAttr, 3);
|
||||
}
|
||||
else {
|
||||
// Compute padding from auto_pad attribute
|
||||
const auto autoPad = convOp.getAutoPad();
|
||||
if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
|
||||
const int64_t effectiveKernelH = (wHeight - 1) * dilationHeight + 1;
|
||||
const int64_t effectiveKernelW = (wWidth - 1) * dilationWidth + 1;
|
||||
const int64_t totalPadH =
|
||||
std::max(static_cast<int64_t>(0), (outHeight - 1) * strideHeight + effectiveKernelH - xHeight);
|
||||
const int64_t totalPadW =
|
||||
std::max(static_cast<int64_t>(0), (outWidth - 1) * strideWidth + effectiveKernelW - xWidth);
|
||||
|
||||
if (autoPad == "SAME_UPPER") {
|
||||
padHeightBegin = totalPadH / 2;
|
||||
padHeightEnd = totalPadH - padHeightBegin;
|
||||
padWidthBegin = totalPadW / 2;
|
||||
padWidthEnd = totalPadW - padWidthBegin;
|
||||
}
|
||||
else { // SAME_LOWER
|
||||
padHeightEnd = totalPadH / 2;
|
||||
padHeightBegin = totalPadH - padHeightEnd;
|
||||
padWidthEnd = totalPadW / 2;
|
||||
padWidthBegin = totalPadW - padWidthEnd;
|
||||
}
|
||||
}
|
||||
else if (autoPad != "NOTSET" && autoPad != "VALID") {
|
||||
convOp.emitOpError() << "unsupported auto_pad value `" << autoPad << "` for Spatial lowering";
|
||||
return failure();
|
||||
}
|
||||
// "NOTSET" or "VALID" -> all pads stay 0
|
||||
}
|
||||
|
||||
// im2col layout (flipped with respect to the standard, so filters sit in B = crossbar):
|
||||
// A (im2col): [numPatches, patchSize] -- one row per output spatial position
|
||||
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
|
||||
@@ -492,7 +391,7 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
const int64_t xbarSize = static_cast<int64_t>(crossbarSize.getValue());
|
||||
const int64_t wMaxDim = std::max(patchSize, numChannelsOut);
|
||||
const int64_t maxParallelPixels = std::max<int64_t>(1, xbarSize / wMaxDim);
|
||||
auto wDenseAttr = getDenseConstantAttr(w);
|
||||
auto wDenseAttr = getHostFoldableDenseElementsAttr(w);
|
||||
|
||||
// Prepare weight matrix W for crossbar storage:
|
||||
// W: [numChannelsOut, numChannelsIn, wHeight, wWidth] -> [numChannelsOut, patchSize] -> [patchSize, numChannelsOut]
|
||||
@@ -513,7 +412,7 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
DenseElementsAttr biasDenseAttr;
|
||||
if (hasB) {
|
||||
gemmBias = b;
|
||||
biasDenseAttr = getDenseConstantAttr(b);
|
||||
biasDenseAttr = getHostFoldableDenseElementsAttr(b);
|
||||
biasMatrix = expandBiasIfNeeded(b, rewriter, loc);
|
||||
}
|
||||
const bool canPackWeightsAsConstants = static_cast<bool>(wDenseAttr);
|
||||
@@ -589,17 +488,246 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
rewriter.getBoolAttr(false))
|
||||
.getY();
|
||||
|
||||
rewriter.replaceOp(convOp,
|
||||
createCollectedConvOutput(ValueRange {gemmRows},
|
||||
convOp.getType(),
|
||||
gemmOutType,
|
||||
nhwcType,
|
||||
outType,
|
||||
numPatches,
|
||||
numChannelsOut,
|
||||
effectiveMaxParallelPixels,
|
||||
rewriter,
|
||||
loc));
|
||||
return createCollectedConvOutput(ValueRange {gemmRows},
|
||||
outType,
|
||||
gemmOutType,
|
||||
nhwcType,
|
||||
outType,
|
||||
numPatches,
|
||||
numChannelsOut,
|
||||
effectiveMaxParallelPixels,
|
||||
rewriter,
|
||||
loc);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
ONNXConvOpAdaptor convOpAdaptor,
|
||||
ConversionPatternRewriter& rewriter) const {
|
||||
Location loc = convOp.getLoc();
|
||||
Value x = convOpAdaptor.getX();
|
||||
Value w = convOpAdaptor.getW();
|
||||
Value b = convOpAdaptor.getB();
|
||||
|
||||
auto xType = cast<RankedTensorType>(x.getType());
|
||||
auto wType = cast<RankedTensorType>(w.getType());
|
||||
auto outType = cast<RankedTensorType>(convOp.getY().getType());
|
||||
|
||||
if (!xType.hasStaticShape()) {
|
||||
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv input");
|
||||
return failure();
|
||||
}
|
||||
if (!wType.hasStaticShape()) {
|
||||
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv weight");
|
||||
return failure();
|
||||
}
|
||||
if (!outType.hasStaticShape()) {
|
||||
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv result");
|
||||
return failure();
|
||||
}
|
||||
if (xType.getRank() != 4) {
|
||||
pim::emitUnsupportedRankDiagnostic(convOp, "conv input", xType.getRank(), {4});
|
||||
return failure();
|
||||
}
|
||||
if (wType.getRank() != 4) {
|
||||
pim::emitUnsupportedRankDiagnostic(convOp, "conv weight", wType.getRank(), {4});
|
||||
return failure();
|
||||
}
|
||||
if (outType.getRank() != 4) {
|
||||
pim::emitUnsupportedRankDiagnostic(convOp, "conv result", outType.getRank(), {4});
|
||||
return failure();
|
||||
}
|
||||
if (convOp.getGroup() < 1) {
|
||||
convOp.emitOpError("requires group >= 1 for Spatial lowering");
|
||||
return failure();
|
||||
}
|
||||
|
||||
const int64_t batchSize = xType.getDimSize(0);
|
||||
const int64_t numChannelsIn = xType.getDimSize(1);
|
||||
const int64_t xHeight = xType.getDimSize(2);
|
||||
const int64_t xWidth = xType.getDimSize(3);
|
||||
const int64_t numChannelsOut = wType.getDimSize(0);
|
||||
const int64_t wHeight = wType.getDimSize(2);
|
||||
const int64_t wWidth = wType.getDimSize(3);
|
||||
const int64_t outHeight = outType.getDimSize(2);
|
||||
const int64_t outWidth = outType.getDimSize(3);
|
||||
const int64_t group = convOp.getGroup();
|
||||
const bool hasB = !isa<ONNXNoneOp>(b.getDefiningOp());
|
||||
|
||||
if (numChannelsIn % group != 0) {
|
||||
convOp.emitOpError() << "requires input channels " << numChannelsIn << " to be divisible by group " << group
|
||||
<< " for Spatial lowering";
|
||||
return failure();
|
||||
}
|
||||
if (numChannelsOut % group != 0) {
|
||||
convOp.emitOpError() << "requires output channels " << numChannelsOut << " to be divisible by group " << group
|
||||
<< " for Spatial lowering";
|
||||
return failure();
|
||||
}
|
||||
|
||||
const int64_t numChannelsInPerGroup = numChannelsIn / group;
|
||||
const int64_t numChannelsOutPerGroup = numChannelsOut / group;
|
||||
if (wType.getDimSize(1) != numChannelsInPerGroup) {
|
||||
convOp.emitOpError() << "requires grouped conv weight input channels " << wType.getDimSize(1)
|
||||
<< " to match input channels per group " << numChannelsInPerGroup << " for Spatial lowering";
|
||||
return failure();
|
||||
}
|
||||
if (wType.getDimSize(0) != numChannelsOut) {
|
||||
convOp.emitOpError() << "requires weight output channels " << wType.getDimSize(0) << " to match result channels "
|
||||
<< numChannelsOut << " for Spatial lowering";
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Read optional conv attributes (ONNX defaults: stride=1, dilation=1, pad=0)
|
||||
const auto stridesAttr = convOp.getStrides();
|
||||
const auto dilationsAttr = convOp.getDilations();
|
||||
const auto padsAttr = convOp.getPads();
|
||||
|
||||
if (stridesAttr && stridesAttr->size() != 2) {
|
||||
convOp.emitOpError("requires exactly two stride values for Spatial lowering");
|
||||
return failure();
|
||||
}
|
||||
if (dilationsAttr && dilationsAttr->size() != 2) {
|
||||
convOp.emitOpError("requires exactly two dilation values for Spatial lowering");
|
||||
return failure();
|
||||
}
|
||||
if (padsAttr && padsAttr->size() != 4) {
|
||||
convOp.emitOpError("requires exactly four pad values for 2D Spatial lowering");
|
||||
return failure();
|
||||
}
|
||||
|
||||
const int64_t strideHeight = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 0) : 1;
|
||||
const int64_t strideWidth = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 1) : 1;
|
||||
const int64_t dilationHeight = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 0) : 1;
|
||||
const int64_t dilationWidth = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 1) : 1;
|
||||
|
||||
int64_t padHeightBegin = 0;
|
||||
int64_t padHeightEnd = 0;
|
||||
int64_t padWidthBegin = 0;
|
||||
int64_t padWidthEnd = 0;
|
||||
|
||||
if (padsAttr) {
|
||||
padHeightBegin = getI64FromArrayAttr(*padsAttr, 0);
|
||||
padWidthBegin = getI64FromArrayAttr(*padsAttr, 1);
|
||||
padHeightEnd = getI64FromArrayAttr(*padsAttr, 2);
|
||||
padWidthEnd = getI64FromArrayAttr(*padsAttr, 3);
|
||||
}
|
||||
else {
|
||||
// Compute padding from auto_pad attribute
|
||||
const auto autoPad = convOp.getAutoPad();
|
||||
if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
|
||||
const int64_t effectiveKernelH = (wHeight - 1) * dilationHeight + 1;
|
||||
const int64_t effectiveKernelW = (wWidth - 1) * dilationWidth + 1;
|
||||
const int64_t totalPadH =
|
||||
std::max(static_cast<int64_t>(0), (outHeight - 1) * strideHeight + effectiveKernelH - xHeight);
|
||||
const int64_t totalPadW =
|
||||
std::max(static_cast<int64_t>(0), (outWidth - 1) * strideWidth + effectiveKernelW - xWidth);
|
||||
|
||||
if (autoPad == "SAME_UPPER") {
|
||||
padHeightBegin = totalPadH / 2;
|
||||
padHeightEnd = totalPadH - padHeightBegin;
|
||||
padWidthBegin = totalPadW / 2;
|
||||
padWidthEnd = totalPadW - padWidthBegin;
|
||||
}
|
||||
else { // SAME_LOWER
|
||||
padHeightEnd = totalPadH / 2;
|
||||
padHeightBegin = totalPadH - padHeightEnd;
|
||||
padWidthEnd = totalPadW / 2;
|
||||
padWidthBegin = totalPadW - padWidthEnd;
|
||||
}
|
||||
}
|
||||
else if (autoPad != "NOTSET" && autoPad != "VALID") {
|
||||
convOp.emitOpError() << "unsupported auto_pad value `" << autoPad << "` for Spatial lowering";
|
||||
return failure();
|
||||
}
|
||||
// "NOTSET" or "VALID" -> all pads stay 0
|
||||
}
|
||||
|
||||
if (group == 1) {
|
||||
rewriter.replaceOp(convOp,
|
||||
lowerSingleConvGroup(x,
|
||||
w,
|
||||
b,
|
||||
xType,
|
||||
wType,
|
||||
outType,
|
||||
padHeightBegin,
|
||||
padHeightEnd,
|
||||
padWidthBegin,
|
||||
padWidthEnd,
|
||||
strideHeight,
|
||||
strideWidth,
|
||||
dilationHeight,
|
||||
dilationWidth,
|
||||
rewriter,
|
||||
loc));
|
||||
return success();
|
||||
}
|
||||
|
||||
SmallVector<Value> xSlices = sliceTensor(x, /*axis=*/1, numChannelsInPerGroup, rewriter, loc);
|
||||
SmallVector<Value> wSlices = sliceTensor(w, /*axis=*/0, numChannelsOutPerGroup, rewriter, loc);
|
||||
SmallVector<Value> bSlices;
|
||||
if (hasB) {
|
||||
auto biasType = cast<RankedTensorType>(b.getType());
|
||||
int64_t biasAxis = -1;
|
||||
if (biasType.getRank() == 1)
|
||||
biasAxis = 0;
|
||||
else if (biasType.getRank() == 2)
|
||||
biasAxis = biasType.getDimSize(0) != 1 ? 0 : 1;
|
||||
else {
|
||||
convOp.emitOpError() << "requires rank-1 or rank-2 bias for grouped convolution Spatial lowering, but got rank "
|
||||
<< biasType.getRank();
|
||||
return failure();
|
||||
}
|
||||
bSlices = sliceTensor(b, biasAxis, numChannelsOutPerGroup, rewriter, loc);
|
||||
}
|
||||
|
||||
if (xSlices.size() != static_cast<size_t>(group) || wSlices.size() != static_cast<size_t>(group)
|
||||
|| (hasB && bSlices.size() != static_cast<size_t>(group))) {
|
||||
convOp.emitOpError("failed to partition grouped convolution operands for Spatial lowering");
|
||||
return failure();
|
||||
}
|
||||
|
||||
SmallVector<Value> groupResults;
|
||||
groupResults.reserve(group);
|
||||
auto groupOutType =
|
||||
RankedTensorType::get({batchSize, numChannelsOutPerGroup, outHeight, outWidth}, outType.getElementType());
|
||||
Value noBias = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
|
||||
for (int64_t groupId = 0; groupId < group; groupId++) {
|
||||
Value groupX = xSlices[groupId];
|
||||
Value groupW = wSlices[groupId];
|
||||
Value groupB = hasB ? bSlices[groupId] : noBias;
|
||||
groupResults.push_back(lowerSingleConvGroup(groupX,
|
||||
groupW,
|
||||
groupB,
|
||||
cast<RankedTensorType>(groupX.getType()),
|
||||
cast<RankedTensorType>(groupW.getType()),
|
||||
groupOutType,
|
||||
padHeightBegin,
|
||||
padHeightEnd,
|
||||
padWidthBegin,
|
||||
padWidthEnd,
|
||||
strideHeight,
|
||||
strideWidth,
|
||||
dilationHeight,
|
||||
dilationWidth,
|
||||
rewriter,
|
||||
loc));
|
||||
}
|
||||
|
||||
Value result;
|
||||
if (llvm::all_of(groupResults, isHostFoldableValue)) {
|
||||
result = createSpatConcat(rewriter, loc, /*axis=*/1, groupResults);
|
||||
}
|
||||
else {
|
||||
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {outType}, {}, groupResults, [&](ValueRange args) {
|
||||
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, /*axis=*/1, args));
|
||||
});
|
||||
result = concatCompute.getResult(0);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(convOp, result);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
||||
@@ -402,24 +402,41 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
for (size_t aSliceId = 0; aSliceId < aHSlices[coreId].size(); aSliceId++)
|
||||
weights.push_back(bTiles[outSliceId][coreId][aSliceId]);
|
||||
|
||||
auto computeOp = createSpatCompute(
|
||||
rewriter, gemmLoc, currOutHSliceType, weights, aHSlices[coreId], [&](ValueRange aHSlicesArgs) -> LogicalResult {
|
||||
SmallVector<Value> vmmOutputs;
|
||||
vmmOutputs.reserve(aHSlicesArgs.size());
|
||||
for (auto [aHSliceId, computeArg] : llvm::enumerate(aHSlicesArgs))
|
||||
vmmOutputs.push_back(
|
||||
spatial::SpatVMMOp::create(rewriter, gemmLoc, currOutHSliceType, aHSliceId, computeArg));
|
||||
if (vmmOutputs.empty()) {
|
||||
gemmOp.emitOpError("requires at least one non-empty slice when lowering tiled Gemm to Spatial VMMs");
|
||||
return failure();
|
||||
}
|
||||
auto computeOp =
|
||||
spatial::SpatCompute::create(rewriter, gemmLoc, TypeRange {currOutHSliceType}, weights, aHSlices[coreId]);
|
||||
SmallVector<Type> blockArgTypes;
|
||||
SmallVector<Location> blockArgLocs;
|
||||
blockArgTypes.reserve(weights.size() + aHSlices[coreId].size());
|
||||
blockArgLocs.reserve(weights.size() + aHSlices[coreId].size());
|
||||
for (Value weight : weights) {
|
||||
blockArgTypes.push_back(weight.getType());
|
||||
blockArgLocs.push_back(gemmLoc);
|
||||
}
|
||||
for (Value input : aHSlices[coreId]) {
|
||||
blockArgTypes.push_back(input.getType());
|
||||
blockArgLocs.push_back(gemmLoc);
|
||||
}
|
||||
Block* body =
|
||||
rewriter.createBlock(&computeOp.getBody(), computeOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
||||
rewriter.setInsertionPointToEnd(body);
|
||||
|
||||
Value partialVmmSum = sumTensors(vmmOutputs, rewriter);
|
||||
spatial::SpatYieldOp::create(rewriter, gemmLoc, partialVmmSum);
|
||||
return success();
|
||||
});
|
||||
if (failed(computeOp))
|
||||
SmallVector<Value> vmmOutputs;
|
||||
vmmOutputs.reserve(aHSlices[coreId].size());
|
||||
for (auto aHSliceId : llvm::seq<size_t>(0, aHSlices[coreId].size())) {
|
||||
auto weightArg = computeOp.getWeightArgument(aHSliceId);
|
||||
auto inputArg = computeOp.getInputArgument(aHSliceId);
|
||||
if (!weightArg || !inputArg)
|
||||
return failure();
|
||||
vmmOutputs.push_back(spatial::SpatVMMOp::create(rewriter, gemmLoc, currOutHSliceType, *weightArg, *inputArg));
|
||||
}
|
||||
if (vmmOutputs.empty()) {
|
||||
gemmOp.emitOpError("requires at least one non-empty slice when lowering tiled Gemm to Spatial VMMs");
|
||||
return failure();
|
||||
}
|
||||
|
||||
Value partialVmmSum = sumTensors(vmmOutputs, rewriter);
|
||||
spatial::SpatYieldOp::create(rewriter, gemmLoc, partialVmmSum);
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
|
||||
partialResults.push_back(computeOp->getResult(0));
|
||||
}
|
||||
@@ -502,9 +519,6 @@ LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
}
|
||||
(void) bType;
|
||||
|
||||
if (!isHostFoldableValue(b))
|
||||
return failure();
|
||||
|
||||
Value sharedBias;
|
||||
if (hasC) {
|
||||
auto scaledC = materializeScaledConstantTensor(c, gemmOpAdaptor.getBeta().convertToFloat(), rewriter, loc);
|
||||
@@ -533,37 +547,49 @@ LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
sharedBias = c;
|
||||
}
|
||||
|
||||
SmallVector<Value> aSlices = materializeBatchRowSlices(a, aType, rewriter, loc);
|
||||
auto aSliceType = cast<RankedTensorType>(aSlices.front().getType());
|
||||
|
||||
auto outRowType = RankedTensorType::get({1, outType.getDimSize(1)}, outType.getElementType());
|
||||
SmallVector<Type> resultTypes(static_cast<size_t>(numOutRows), outRowType);
|
||||
SmallVector<Value> weights(static_cast<size_t>(numOutRows), b);
|
||||
|
||||
auto aRowType = RankedTensorType::get({1, aType.getDimSize(1)}, aType.getElementType());
|
||||
auto batchOp = spatial::SpatComputeBatch::create(rewriter,
|
||||
loc,
|
||||
TypeRange(resultTypes),
|
||||
TypeRange {outType},
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(numOutRows)),
|
||||
ValueRange(weights),
|
||||
ValueRange(aSlices));
|
||||
ValueRange {b},
|
||||
ValueRange {a});
|
||||
|
||||
Block* body = rewriter.createBlock(
|
||||
&batchOp.getBody(), batchOp.getBody().end(), TypeRange {aSliceType}, SmallVector<Location>(1, loc));
|
||||
SmallVector<Type> blockArgTypes {rewriter.getIndexType(), bType, aType, outType};
|
||||
SmallVector<Location> blockArgLocs(4, loc);
|
||||
Block* body =
|
||||
rewriter.createBlock(&batchOp.getBody(), batchOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
||||
rewriter.setInsertionPointToEnd(body);
|
||||
|
||||
Value vmmResult = spatial::SpatVMMOp::create(rewriter, loc, outRowType, 0, body->getArgument(0)).getResult();
|
||||
auto lane = batchOp.getLaneArgument();
|
||||
auto weight = batchOp.getWeightArgument(0);
|
||||
auto packedInput = batchOp.getInputArgument(0);
|
||||
auto packedOutput = batchOp.getOutputArgument(0);
|
||||
if (!lane || !weight || !packedInput || !packedOutput)
|
||||
return failure();
|
||||
|
||||
SmallVector<OpFoldResult> inputOffsets {*lane, rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> inputSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(aType.getDimSize(1))};
|
||||
SmallVector<OpFoldResult> unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
Value row =
|
||||
tensor::ExtractSliceOp::create(rewriter, loc, aRowType, *packedInput, inputOffsets, inputSizes, unitStrides)
|
||||
.getResult();
|
||||
|
||||
Value vmmResult = spatial::SpatVMMOp::create(rewriter, loc, outRowType, *weight, row).getResult();
|
||||
Value laneResult = vmmResult;
|
||||
if (sharedBias)
|
||||
laneResult = spatial::SpatVAddOp::create(rewriter, loc, outRowType, vmmResult, sharedBias).getResult();
|
||||
spatial::SpatYieldOp::create(rewriter, loc, laneResult);
|
||||
|
||||
auto inParallelOp = spatial::SpatInParallelOp::create(rewriter, loc);
|
||||
rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front());
|
||||
SmallVector<OpFoldResult> outputOffsets {*lane, rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> outputSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(outType.getDimSize(1))};
|
||||
tensor::ParallelInsertSliceOp::create(
|
||||
rewriter, loc, laneResult, *packedOutput, outputOffsets, outputSizes, unitStrides);
|
||||
rewriter.setInsertionPointAfter(batchOp);
|
||||
SmallVector<Value> laneResults(batchOp->result_begin(), batchOp->result_end());
|
||||
auto concatComputeOp = createSpatCompute(rewriter, loc, gemmOp.getType(), {}, laneResults, [&](ValueRange args) {
|
||||
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, /*axis=*/0, args));
|
||||
});
|
||||
|
||||
rewriter.replaceOp(gemmOp, concatComputeOp);
|
||||
rewriter.replaceOp(gemmOp, batchOp.getResults());
|
||||
return success();
|
||||
}
|
||||
|
||||
|
||||
@@ -2,8 +2,12 @@
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <functional>
|
||||
#include <numeric>
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
@@ -19,6 +23,67 @@ static bool haveStaticPositiveShape(ArrayRef<int64_t> shape) {
|
||||
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
|
||||
}
|
||||
|
||||
static int64_t getStaticShapeElementCount(ArrayRef<int64_t> shape) {
|
||||
return std::accumulate(shape.begin(), shape.end(), int64_t {1}, std::multiplies<int64_t> {});
|
||||
}
|
||||
|
||||
static FailureOr<SmallVector<int64_t>> inferSupportedBatchShape(ArrayRef<int64_t> lhsBatchShape,
|
||||
ArrayRef<int64_t> rhsBatchShape) {
|
||||
if (lhsBatchShape.empty())
|
||||
return SmallVector<int64_t>(rhsBatchShape.begin(), rhsBatchShape.end());
|
||||
if (rhsBatchShape.empty())
|
||||
return SmallVector<int64_t>(lhsBatchShape.begin(), lhsBatchShape.end());
|
||||
if (!llvm::equal(lhsBatchShape, rhsBatchShape))
|
||||
return failure();
|
||||
return SmallVector<int64_t>(lhsBatchShape.begin(), lhsBatchShape.end());
|
||||
}
|
||||
|
||||
static Value
|
||||
collapseBatchDims(Value value, int64_t batchSize, int64_t rows, int64_t cols, PatternRewriter& rewriter, Location loc) {
|
||||
auto type = cast<RankedTensorType>(value.getType());
|
||||
if (type.getRank() == 2 || type.getRank() == 3)
|
||||
return value;
|
||||
|
||||
auto collapsedType = RankedTensorType::get({batchSize, rows, cols}, type.getElementType(), type.getEncoding());
|
||||
SmallVector<ReassociationIndices> reassociation = {ReassociationIndices {},
|
||||
ReassociationIndices {static_cast<int64_t>(type.getRank() - 2)},
|
||||
ReassociationIndices {static_cast<int64_t>(type.getRank() - 1)}};
|
||||
for (int64_t dim = 0; dim < type.getRank() - 2; ++dim)
|
||||
reassociation.front().push_back(dim);
|
||||
|
||||
auto buildCollapsed = [&](Value input) -> Value {
|
||||
return tensor::CollapseShapeOp::create(rewriter, loc, collapsedType, input, reassociation);
|
||||
};
|
||||
|
||||
if (isHostFoldableValue(value))
|
||||
return buildCollapsed(value);
|
||||
|
||||
auto collapseCompute =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {collapsedType}, {}, ValueRange {value}, [&](Value input) {
|
||||
spatial::SpatYieldOp::create(rewriter, loc, buildCollapsed(input));
|
||||
});
|
||||
return collapseCompute.getResult(0);
|
||||
}
|
||||
|
||||
static Value
|
||||
expandBatchDims(Value value, RankedTensorType outputType, size_t batchRank, PatternRewriter& rewriter, Location loc) {
|
||||
if (cast<RankedTensorType>(value.getType()) == outputType)
|
||||
return value;
|
||||
|
||||
SmallVector<ReassociationIndices> reassociation = {ReassociationIndices {},
|
||||
ReassociationIndices {static_cast<int64_t>(batchRank)},
|
||||
ReassociationIndices {static_cast<int64_t>(batchRank + 1)}};
|
||||
for (size_t dim = 0; dim < batchRank; ++dim)
|
||||
reassociation.front().push_back(static_cast<int64_t>(dim));
|
||||
|
||||
auto expandCompute =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {outputType}, {}, ValueRange {value}, [&](Value input) {
|
||||
Value expanded = tensor::ExpandShapeOp::create(rewriter, loc, outputType, input, reassociation);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, expanded);
|
||||
});
|
||||
return expandCompute.getResult(0);
|
||||
}
|
||||
|
||||
static Value extractBatchMatrix(Value value,
|
||||
int64_t batchIndex,
|
||||
int64_t batchSize,
|
||||
@@ -62,13 +127,29 @@ static Value extractBatchMatrix(Value value,
|
||||
static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Location loc) {
|
||||
auto type = cast<RankedTensorType>(value.getType());
|
||||
auto shape = type.getShape();
|
||||
RankedTensorType transposedType;
|
||||
SmallVector<int64_t> perm;
|
||||
if (type.getRank() == 2) {
|
||||
auto transposedType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType());
|
||||
return ONNXTransposeOp::create(rewriter, loc, transposedType, value, rewriter.getI64ArrayAttr({1, 0}));
|
||||
transposedType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType());
|
||||
perm = {1, 0};
|
||||
}
|
||||
else {
|
||||
transposedType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType());
|
||||
perm = {0, 2, 1};
|
||||
}
|
||||
|
||||
auto transposedType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType());
|
||||
return ONNXTransposeOp::create(rewriter, loc, transposedType, value, rewriter.getI64ArrayAttr({0, 2, 1}));
|
||||
auto buildTranspose = [&](Value input) -> Value {
|
||||
return ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm));
|
||||
};
|
||||
|
||||
if (isHostFoldableValue(value))
|
||||
return buildTranspose(value);
|
||||
|
||||
auto transposeCompute =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {transposedType}, {}, ValueRange {value}, [&](Value input) {
|
||||
spatial::SpatYieldOp::create(rewriter, loc, buildTranspose(input));
|
||||
});
|
||||
return transposeCompute.getResult(0);
|
||||
}
|
||||
|
||||
static Value transposeLastTwoDimsInCompute(Value value, PatternRewriter& rewriter, Location loc) {
|
||||
@@ -120,24 +201,25 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||
if (!lhsType || !rhsType || !outType || !lhsType.hasStaticShape() || !rhsType.hasStaticShape()
|
||||
|| !outType.hasStaticShape())
|
||||
return failure();
|
||||
if ((lhsType.getRank() != 2 && lhsType.getRank() != 3) || (rhsType.getRank() != 2 && rhsType.getRank() != 3)
|
||||
|| (outType.getRank() != 2 && outType.getRank() != 3))
|
||||
if (lhsType.getRank() < 2 || rhsType.getRank() < 2 || outType.getRank() < 2)
|
||||
return failure();
|
||||
if (!haveStaticPositiveShape(lhsType.getShape()) || !haveStaticPositiveShape(rhsType.getShape())
|
||||
|| !haveStaticPositiveShape(outType.getShape()))
|
||||
return failure();
|
||||
|
||||
const int64_t lhsBatch = lhsType.getRank() == 3 ? lhsType.getDimSize(0) : 1;
|
||||
const int64_t rhsBatch = rhsType.getRank() == 3 ? rhsType.getDimSize(0) : 1;
|
||||
const int64_t batch = std::max(lhsBatch, rhsBatch);
|
||||
|
||||
if ((lhsBatch != 1 && lhsBatch != batch) || (rhsBatch != 1 && rhsBatch != batch))
|
||||
SmallVector<int64_t> lhsBatchShape(lhsType.getShape().begin(), lhsType.getShape().end() - 2);
|
||||
SmallVector<int64_t> rhsBatchShape(rhsType.getShape().begin(), rhsType.getShape().end() - 2);
|
||||
auto batchShape = inferSupportedBatchShape(lhsBatchShape, rhsBatchShape);
|
||||
if (failed(batchShape))
|
||||
return failure();
|
||||
const int64_t lhsBatch = lhsBatchShape.empty() ? 1 : getStaticShapeElementCount(lhsBatchShape);
|
||||
const int64_t rhsBatch = rhsBatchShape.empty() ? 1 : getStaticShapeElementCount(rhsBatchShape);
|
||||
const int64_t batch = batchShape->empty() ? 1 : getStaticShapeElementCount(*batchShape);
|
||||
|
||||
const int64_t m = lhsType.getRank() == 3 ? lhsType.getDimSize(1) : lhsType.getDimSize(0);
|
||||
const int64_t k = lhsType.getRank() == 3 ? lhsType.getDimSize(2) : lhsType.getDimSize(1);
|
||||
const int64_t rhsK = rhsType.getRank() == 3 ? rhsType.getDimSize(1) : rhsType.getDimSize(0);
|
||||
const int64_t n = rhsType.getRank() == 3 ? rhsType.getDimSize(2) : rhsType.getDimSize(1);
|
||||
const int64_t m = lhsType.getDimSize(lhsType.getRank() - 2);
|
||||
const int64_t k = lhsType.getDimSize(lhsType.getRank() - 1);
|
||||
const int64_t rhsK = rhsType.getDimSize(rhsType.getRank() - 2);
|
||||
const int64_t n = rhsType.getDimSize(rhsType.getRank() - 1);
|
||||
if (k != rhsK)
|
||||
return failure();
|
||||
|
||||
@@ -146,15 +228,17 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||
return failure();
|
||||
}
|
||||
else {
|
||||
if (outType.getDimSize(0) != batch || outType.getDimSize(1) != m || outType.getDimSize(2) != n)
|
||||
SmallVector<int64_t> outBatchShape(outType.getShape().begin(), outType.getShape().end() - 2);
|
||||
if (!llvm::equal(outBatchShape, *batchShape) || outType.getDimSize(outType.getRank() - 2) != m
|
||||
|| outType.getDimSize(outType.getRank() - 1) != n)
|
||||
return failure();
|
||||
}
|
||||
|
||||
Location loc = matmulOp.getLoc();
|
||||
bool useTransposedForm = isHostFoldableValue(matmulOp.getA()) && !isHostFoldableValue(matmulOp.getB());
|
||||
|
||||
Value lhs = matmulOp.getA();
|
||||
Value rhs = matmulOp.getB();
|
||||
Value lhs = collapseBatchDims(matmulOp.getA(), lhsBatch, m, k, rewriter, loc);
|
||||
Value rhs = collapseBatchDims(matmulOp.getB(), rhsBatch, k, n, rewriter, loc);
|
||||
int64_t lhsBatchForGemm = lhsBatch;
|
||||
int64_t rhsBatchForGemm = rhsBatch;
|
||||
int64_t gemmM = m;
|
||||
@@ -239,6 +323,7 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||
}
|
||||
|
||||
Value result = concatValues(batchResults, /*axis=*/0, rewriter, loc);
|
||||
result = expandBatchDims(result, outType, batchShape->size(), rewriter, loc);
|
||||
rewriter.replaceOp(matmulOp, result);
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -22,53 +23,83 @@ static SmallVector<int64_t> permuteShape(ArrayRef<int64_t> shape, ArrayRef<int64
|
||||
return permutedShape;
|
||||
}
|
||||
|
||||
static Value createSoftmaxCompute(Value input, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
static Value buildLoopSoftmaxSlice(Value input,
|
||||
Value accumulator,
|
||||
RankedTensorType inputType,
|
||||
ArrayRef<Value> outerIndices,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
int64_t rank = inputType.getRank();
|
||||
SmallVector<int64_t> sliceShape(static_cast<size_t>(rank - 1), 1);
|
||||
sliceShape.push_back(inputType.getDimSize(rank - 1));
|
||||
auto sliceType = RankedTensorType::get(sliceShape, inputType.getElementType(), inputType.getEncoding());
|
||||
|
||||
SmallVector<OpFoldResult> offsets;
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
|
||||
offsets.reserve(rank);
|
||||
sizes.reserve(rank);
|
||||
|
||||
for (Value outerIndex : outerIndices) {
|
||||
offsets.push_back(outerIndex);
|
||||
sizes.push_back(rewriter.getIndexAttr(1));
|
||||
}
|
||||
offsets.push_back(rewriter.getIndexAttr(0));
|
||||
sizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(rank - 1)));
|
||||
|
||||
Value inputSlice = tensor::ExtractSliceOp::create(rewriter, loc, sliceType, input, offsets, sizes, strides);
|
||||
Value softmaxSlice = spatial::SpatSoftmaxOp::create(rewriter, loc, sliceType, inputSlice).getResult();
|
||||
return tensor::InsertSliceOp::create(rewriter, loc, softmaxSlice, accumulator, offsets, sizes, strides);
|
||||
}
|
||||
|
||||
static Value buildLoopSoftmaxNest(Value input,
|
||||
Value accumulator,
|
||||
RankedTensorType inputType,
|
||||
int64_t axis,
|
||||
SmallVectorImpl<Value>& outerIndices,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
if (axis == inputType.getRank() - 1)
|
||||
return buildLoopSoftmaxSlice(input, accumulator, inputType, outerIndices, rewriter, loc);
|
||||
|
||||
Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
|
||||
Value c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
|
||||
Value cUpper = arith::ConstantIndexOp::create(rewriter, loc, inputType.getDimSize(axis));
|
||||
|
||||
auto loop = scf::ForOp::create(rewriter, loc, c0, cUpper, c1, ValueRange {accumulator});
|
||||
rewriter.setInsertionPointToStart(loop.getBody());
|
||||
|
||||
Value loopIndex = loop.getInductionVar();
|
||||
Value loopAccumulator = loop.getRegionIterArgs().front();
|
||||
outerIndices.push_back(loopIndex);
|
||||
Value updatedAccumulator =
|
||||
buildLoopSoftmaxNest(input, loopAccumulator, inputType, axis + 1, outerIndices, rewriter, loc);
|
||||
outerIndices.pop_back();
|
||||
|
||||
scf::YieldOp::create(rewriter, loc, updatedAccumulator);
|
||||
rewriter.setInsertionPointAfter(loop);
|
||||
return loop.getResult(0);
|
||||
}
|
||||
|
||||
static Value createLoopSoftmaxCompute(Value input, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
auto inputType = cast<RankedTensorType>(input.getType());
|
||||
constexpr size_t numInputs = 1;
|
||||
auto computeOp =
|
||||
createSpatCompute<numInputs>(rewriter, loc, TypeRange {inputType}, {}, ValueRange {input}, [&](Value x) {
|
||||
auto softmaxOp = spatial::SpatSoftmaxOp::create(rewriter, loc, inputType, x);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, softmaxOp.getResult());
|
||||
if (inputType.getRank() == 1) {
|
||||
Value softmax = spatial::SpatSoftmaxOp::create(rewriter, loc, inputType, x).getResult();
|
||||
spatial::SpatYieldOp::create(rewriter, loc, softmax);
|
||||
return;
|
||||
}
|
||||
|
||||
Value outputInit = tensor::EmptyOp::create(rewriter, loc, inputType.getShape(), inputType.getElementType());
|
||||
SmallVector<Value> outerIndices;
|
||||
Value result = buildLoopSoftmaxNest(x, outputInit, inputType, /*axis=*/0, outerIndices, rewriter, loc);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, result);
|
||||
});
|
||||
return computeOp.getResult(0);
|
||||
}
|
||||
|
||||
static Value concatValues(ValueRange inputs, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
auto firstType = cast<RankedTensorType>(inputs.front().getType());
|
||||
SmallVector<int64_t> outputShape(firstType.getShape().begin(), firstType.getShape().end());
|
||||
int64_t concatDimSize = 0;
|
||||
for (Value input : inputs)
|
||||
concatDimSize += cast<RankedTensorType>(input.getType()).getDimSize(axis);
|
||||
outputShape[axis] = concatDimSize;
|
||||
auto resultType = RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
|
||||
|
||||
if (llvm::all_of(inputs, isHostFoldableValue))
|
||||
return createSpatConcat(rewriter, loc, axis, inputs);
|
||||
|
||||
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {resultType}, {}, inputs, [&](ValueRange args) {
|
||||
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, axis, args));
|
||||
});
|
||||
return concatCompute.getResult(0);
|
||||
}
|
||||
|
||||
static Value
|
||||
buildSoftmax(Value input, int64_t softmaxAxis, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
auto inputType = cast<RankedTensorType>(input.getType());
|
||||
if (axis == inputType.getRank())
|
||||
return createSoftmaxCompute(input, rewriter, loc);
|
||||
|
||||
if (axis == softmaxAxis)
|
||||
return buildSoftmax(input, softmaxAxis, axis + 1, rewriter, loc);
|
||||
|
||||
SmallVector<Value> slices = sliceTensor(input, axis, /*sliceSize=*/1, rewriter, loc);
|
||||
SmallVector<Value> rebuiltSlices;
|
||||
rebuiltSlices.reserve(slices.size());
|
||||
for (Value slice : slices)
|
||||
rebuiltSlices.push_back(buildSoftmax(slice, softmaxAxis, axis + 1, rewriter, loc));
|
||||
|
||||
return concatValues(rebuiltSlices, axis, rewriter, loc);
|
||||
}
|
||||
|
||||
struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
@@ -86,7 +117,7 @@ struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
|
||||
Value input = adaptor.getInput();
|
||||
Value result;
|
||||
if (axis == inputType.getRank() - 1) {
|
||||
result = buildSoftmax(input, axis, /*axis=*/0, rewriter, softmaxOp.getLoc());
|
||||
result = createLoopSoftmaxCompute(input, rewriter, softmaxOp.getLoc());
|
||||
}
|
||||
else {
|
||||
SmallVector<int64_t> permutation;
|
||||
@@ -109,8 +140,7 @@ struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
|
||||
spatial::SpatYieldOp::create(rewriter, softmaxOp.getLoc(), transposed);
|
||||
});
|
||||
Value transposedInput = preTransposeCompute.getResult(0);
|
||||
Value transposedResult = buildSoftmax(
|
||||
transposedInput, /*softmaxAxis=*/inputType.getRank() - 1, /*axis=*/0, rewriter, softmaxOp.getLoc());
|
||||
Value transposedResult = createLoopSoftmaxCompute(transposedInput, rewriter, softmaxOp.getLoc());
|
||||
auto postTransposeCompute =
|
||||
createSpatCompute<1>(rewriter, softmaxOp.getLoc(), TypeRange {inputType}, {}, transposedResult, [&](Value x) {
|
||||
Value transposed = ONNXTransposeOp::create(
|
||||
|
||||
@@ -80,6 +80,22 @@ static bool inferExpandReassociation(ArrayRef<int64_t> sourceShape,
|
||||
return sourceIdx == sourceShape.size() && resultIdx == resultShape.size();
|
||||
}
|
||||
|
||||
static SmallVector<ReassociationIndices> getCollapseTo1DReassociation(size_t rank) {
|
||||
SmallVector<ReassociationIndices> reassociation(1);
|
||||
reassociation.front().reserve(rank);
|
||||
for (size_t dim = 0; dim < rank; ++dim)
|
||||
reassociation.front().push_back(static_cast<int64_t>(dim));
|
||||
return reassociation;
|
||||
}
|
||||
|
||||
static SmallVector<ReassociationIndices> getExpandFrom1DReassociation(size_t rank) {
|
||||
SmallVector<ReassociationIndices> reassociation(1);
|
||||
reassociation.front().reserve(rank);
|
||||
for (size_t dim = 0; dim < rank; ++dim)
|
||||
reassociation.front().push_back(static_cast<int64_t>(dim));
|
||||
return reassociation;
|
||||
}
|
||||
|
||||
struct Reshape : OpConversionPattern<ONNXReshapeOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
@@ -126,6 +142,23 @@ struct Reshape : OpConversionPattern<ONNXReshapeOp> {
|
||||
return tensor::ExpandShapeOp::create(rewriter, reshapeOp.getLoc(), resultType, data, reassociation);
|
||||
});
|
||||
|
||||
if (sourceType.getNumElements() != resultType.getNumElements())
|
||||
return failure();
|
||||
|
||||
return replaceWithReshape([&](Value data) -> Value {
|
||||
Value reshaped = data;
|
||||
if (sourceType.getRank() != 1) {
|
||||
auto flatType = RankedTensorType::get({sourceType.getNumElements()}, sourceType.getElementType());
|
||||
reshaped = tensor::CollapseShapeOp::create(
|
||||
rewriter, reshapeOp.getLoc(), flatType, reshaped, getCollapseTo1DReassociation(sourceType.getRank()));
|
||||
}
|
||||
if (resultType.getRank() == 1)
|
||||
return reshaped;
|
||||
return tensor::ExpandShapeOp::create(
|
||||
rewriter, reshapeOp.getLoc(), resultType, reshaped, getExpandFrom1DReassociation(resultType.getRank()))
|
||||
.getResult();
|
||||
});
|
||||
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
@@ -15,42 +15,85 @@ using namespace mlir;
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
static Value
|
||||
extractSliceAt(Value input, int64_t axis, int64_t offset, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
auto inputType = cast<RankedTensorType>(input.getType());
|
||||
SmallVector<OpFoldResult> offsets(inputType.getRank(), rewriter.getIndexAttr(0));
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
SmallVector<OpFoldResult> strides(inputType.getRank(), rewriter.getIndexAttr(1));
|
||||
sizes.reserve(inputType.getRank());
|
||||
for (int64_t dim : inputType.getShape())
|
||||
sizes.push_back(rewriter.getIndexAttr(dim));
|
||||
offsets[axis] = rewriter.getIndexAttr(offset);
|
||||
sizes[axis] = rewriter.getIndexAttr(1);
|
||||
return tensor::ExtractSliceOp::create(rewriter, loc, input, offsets, sizes, strides);
|
||||
static Value buildNearestAsymmetricIndex(
|
||||
Value outputIndex, int64_t inputDim, int64_t outputDim, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
Value cInputDim = arith::ConstantIndexOp::create(rewriter, loc, inputDim);
|
||||
Value cOutputDim = arith::ConstantIndexOp::create(rewriter, loc, outputDim);
|
||||
Value cInputDimLast = arith::ConstantIndexOp::create(rewriter, loc, inputDim - 1);
|
||||
Value scaledIndex = arith::MulIOp::create(rewriter, loc, outputIndex, cInputDim);
|
||||
Value inputIndex = arith::DivUIOp::create(rewriter, loc, scaledIndex, cOutputDim);
|
||||
return arith::MinUIOp::create(rewriter, loc, inputIndex, cInputDimLast);
|
||||
}
|
||||
|
||||
static int64_t nearestAsymmetricIndex(int64_t outputIndex, int64_t inputDim, int64_t outputDim) {
|
||||
return std::min<int64_t>((outputIndex * inputDim) / outputDim, inputDim - 1);
|
||||
}
|
||||
static Value buildNearestResizeLoop(Value input,
|
||||
RankedTensorType inputType,
|
||||
RankedTensorType resultType,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
auto elemType = resultType.getElementType();
|
||||
SmallVector<int64_t> unitShape(resultType.getRank(), 1);
|
||||
auto unitTensorType = RankedTensorType::get(unitShape, elemType);
|
||||
|
||||
static Value buildNearestResize(Value input,
|
||||
ArrayRef<int64_t> inputShape,
|
||||
ArrayRef<int64_t> outputShape,
|
||||
int64_t axis,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
if (axis == static_cast<int64_t>(outputShape.size()))
|
||||
return input;
|
||||
SmallVector<OpFoldResult> unitSizes(resultType.getRank(), rewriter.getIndexAttr(1));
|
||||
SmallVector<OpFoldResult> unitStrides(resultType.getRank(), rewriter.getIndexAttr(1));
|
||||
|
||||
SmallVector<Value> slices;
|
||||
slices.reserve(outputShape[axis]);
|
||||
for (int64_t outputIndex = 0; outputIndex < outputShape[axis]; ++outputIndex) {
|
||||
int64_t inputIndex = nearestAsymmetricIndex(outputIndex, inputShape[axis], outputShape[axis]);
|
||||
Value slice = extractSliceAt(input, axis, inputIndex, rewriter, loc);
|
||||
slices.push_back(buildNearestResize(slice, inputShape, outputShape, axis + 1, rewriter, loc));
|
||||
}
|
||||
Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
|
||||
Value c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
|
||||
Value cOutputN = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(0));
|
||||
Value cOutputC = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(1));
|
||||
Value cOutputH = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(2));
|
||||
Value cOutputW = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(3));
|
||||
|
||||
return createSpatConcat(rewriter, loc, axis, slices);
|
||||
Value outputInit = tensor::EmptyOp::create(rewriter, loc, resultType.getShape(), elemType);
|
||||
|
||||
auto batchLoop = scf::ForOp::create(rewriter, loc, c0, cOutputN, c1, ValueRange {outputInit});
|
||||
rewriter.setInsertionPointToStart(batchLoop.getBody());
|
||||
|
||||
Value outputN = batchLoop.getInductionVar();
|
||||
Value outputBatchAcc = batchLoop.getRegionIterArgs().front();
|
||||
Value inputN = buildNearestAsymmetricIndex(outputN, inputType.getDimSize(0), resultType.getDimSize(0), rewriter, loc);
|
||||
|
||||
auto channelLoop = scf::ForOp::create(rewriter, loc, c0, cOutputC, c1, ValueRange {outputBatchAcc});
|
||||
rewriter.setInsertionPointToStart(channelLoop.getBody());
|
||||
|
||||
Value outputC = channelLoop.getInductionVar();
|
||||
Value outputChannelAcc = channelLoop.getRegionIterArgs().front();
|
||||
Value inputC = buildNearestAsymmetricIndex(outputC, inputType.getDimSize(1), resultType.getDimSize(1), rewriter, loc);
|
||||
|
||||
auto heightLoop = scf::ForOp::create(rewriter, loc, c0, cOutputH, c1, ValueRange {outputChannelAcc});
|
||||
rewriter.setInsertionPointToStart(heightLoop.getBody());
|
||||
|
||||
Value outputH = heightLoop.getInductionVar();
|
||||
Value outputHeightAcc = heightLoop.getRegionIterArgs().front();
|
||||
Value inputH = buildNearestAsymmetricIndex(outputH, inputType.getDimSize(2), resultType.getDimSize(2), rewriter, loc);
|
||||
|
||||
auto widthLoop = scf::ForOp::create(rewriter, loc, c0, cOutputW, c1, ValueRange {outputHeightAcc});
|
||||
rewriter.setInsertionPointToStart(widthLoop.getBody());
|
||||
|
||||
Value outputW = widthLoop.getInductionVar();
|
||||
Value outputWidthAcc = widthLoop.getRegionIterArgs().front();
|
||||
Value inputW = buildNearestAsymmetricIndex(outputW, inputType.getDimSize(3), resultType.getDimSize(3), rewriter, loc);
|
||||
|
||||
SmallVector<OpFoldResult> inputOffsets = {inputN, inputC, inputH, inputW};
|
||||
Value inputSlice =
|
||||
tensor::ExtractSliceOp::create(rewriter, loc, unitTensorType, input, inputOffsets, unitSizes, unitStrides);
|
||||
|
||||
SmallVector<OpFoldResult> outputOffsets = {outputN, outputC, outputH, outputW};
|
||||
Value updatedOutput =
|
||||
tensor::InsertSliceOp::create(rewriter, loc, inputSlice, outputWidthAcc, outputOffsets, unitSizes, unitStrides);
|
||||
scf::YieldOp::create(rewriter, loc, updatedOutput);
|
||||
|
||||
rewriter.setInsertionPointAfter(widthLoop);
|
||||
scf::YieldOp::create(rewriter, loc, widthLoop.getResult(0));
|
||||
|
||||
rewriter.setInsertionPointAfter(heightLoop);
|
||||
scf::YieldOp::create(rewriter, loc, heightLoop.getResult(0));
|
||||
|
||||
rewriter.setInsertionPointAfter(channelLoop);
|
||||
scf::YieldOp::create(rewriter, loc, channelLoop.getResult(0));
|
||||
|
||||
rewriter.setInsertionPointAfter(batchLoop);
|
||||
return batchLoop.getResult(0);
|
||||
}
|
||||
|
||||
struct Resize : OpConversionPattern<ONNXResizeOp> {
|
||||
@@ -62,20 +105,22 @@ struct Resize : OpConversionPattern<ONNXResizeOp> {
|
||||
auto inputType = dyn_cast<RankedTensorType>(adaptor.getX().getType());
|
||||
auto resultType = dyn_cast<RankedTensorType>(resizeOp.getY().getType());
|
||||
if (!inputType || !resultType || !inputType.hasStaticShape() || !resultType.hasStaticShape())
|
||||
return failure();
|
||||
return rewriter.notifyMatchFailure(resizeOp, "resize lowering requires static ranked tensor types.");
|
||||
if (inputType.getRank() != 4 || resultType.getRank() != 4)
|
||||
return rewriter.notifyMatchFailure(resizeOp, "resize lowering currently supports only rank-4 NCHW tensors.");
|
||||
|
||||
if (resizeOp.getMode() != "nearest" || resizeOp.getCoordinateTransformationMode() != "asymmetric"
|
||||
|| resizeOp.getNearestMode() != "floor")
|
||||
return failure();
|
||||
return rewriter.notifyMatchFailure(resizeOp,
|
||||
"resize lowering currently supports only nearest + asymmetric + floor.");
|
||||
|
||||
if (llvm::any_of(inputType.getShape(), [](int64_t dim) { return dim <= 0; })
|
||||
|| llvm::any_of(resultType.getShape(), [](int64_t dim) { return dim <= 0; }))
|
||||
return failure();
|
||||
return rewriter.notifyMatchFailure(resizeOp, "resize lowering requires positive static dimensions.");
|
||||
|
||||
auto computeOp =
|
||||
createSpatCompute<1>(rewriter, resizeOp.getLoc(), TypeRange {resultType}, {}, adaptor.getX(), [&](Value x) {
|
||||
Value result =
|
||||
buildNearestResize(x, inputType.getShape(), resultType.getShape(), /*axis=*/0, rewriter, resizeOp.getLoc());
|
||||
Value result = buildNearestResizeLoop(x, inputType, resultType, rewriter, resizeOp.getLoc());
|
||||
spatial::SpatYieldOp::create(rewriter, resizeOp.getLoc(), result);
|
||||
});
|
||||
rewriter.replaceOp(resizeOp, computeOp.getResults());
|
||||
|
||||
@@ -27,50 +27,25 @@ static bool canPromoteInputBlockArgument(BlockArgument arg) {
|
||||
return !arg.use_empty() && llvm::all_of(arg.getUsers(), isWeightMaterializationHelperUser);
|
||||
}
|
||||
|
||||
static bool canPromoteInputBlockArgument(std::optional<BlockArgument> arg) {
|
||||
return arg && canPromoteInputBlockArgument(*arg);
|
||||
}
|
||||
|
||||
static bool isDirectConstantValue(Value value) {
|
||||
return isa_and_nonnull<arith::ConstantOp, ONNXConstantOp>(value.getDefiningOp());
|
||||
}
|
||||
|
||||
// Collapses one-lane batches so later phases do not carry batch-only structure unnecessarily.
|
||||
struct FoldSingleLaneComputeBatchPattern : OpRewritePattern<spatial::SpatComputeBatch> {
|
||||
using OpRewritePattern<spatial::SpatComputeBatch>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(spatial::SpatComputeBatch batchOp, PatternRewriter& rewriter) const override {
|
||||
if (batchOp.getLaneCount() != 1)
|
||||
return rewriter.notifyMatchFailure(batchOp, "requires a single lane");
|
||||
|
||||
auto loc = batchOp.getLoc();
|
||||
rewriter.setInsertionPoint(batchOp);
|
||||
auto computeOp =
|
||||
spatial::SpatCompute::create(rewriter, loc, batchOp.getResultTypes(), batchOp.getWeights(), batchOp.getInputs());
|
||||
computeOp.getProperties().setOperandSegmentSizes(
|
||||
{static_cast<int>(batchOp.getWeights().size()), static_cast<int>(batchOp.getInputs().size())});
|
||||
|
||||
Block& templateBlock = batchOp.getBody().front();
|
||||
SmallVector<Type> blockArgTypes;
|
||||
SmallVector<Location> blockArgLocs;
|
||||
blockArgTypes.reserve(templateBlock.getNumArguments());
|
||||
blockArgLocs.reserve(templateBlock.getNumArguments());
|
||||
for (BlockArgument arg : templateBlock.getArguments()) {
|
||||
blockArgTypes.push_back(arg.getType());
|
||||
blockArgLocs.push_back(loc);
|
||||
}
|
||||
|
||||
auto* newBlock =
|
||||
rewriter.createBlock(&computeOp.getBody(), computeOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
||||
IRMapping mapper;
|
||||
for (auto [oldArg, newArg] : llvm::zip(templateBlock.getArguments(), newBlock->getArguments()))
|
||||
mapper.map(oldArg, newArg);
|
||||
|
||||
rewriter.setInsertionPointToEnd(newBlock);
|
||||
for (Operation& op : templateBlock)
|
||||
rewriter.clone(op, mapper);
|
||||
|
||||
batchOp->replaceAllUsesWith(computeOp->getResults());
|
||||
rewriter.eraseOp(batchOp);
|
||||
return success();
|
||||
template <typename ComputeOpTy>
|
||||
static bool hasPromotableWeightLikeInputs(ComputeOpTy compute) {
|
||||
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||
if (!isWeightLikeComputeOperand(input))
|
||||
continue;
|
||||
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(compute.getInputArgument(inputIdx)))
|
||||
continue;
|
||||
return true;
|
||||
}
|
||||
};
|
||||
return false;
|
||||
}
|
||||
|
||||
// Promotes foldable helper chains from runtime inputs to weights to avoid artificial compute inputs.
|
||||
struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCompute> {
|
||||
@@ -81,11 +56,9 @@ struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCom
|
||||
bool needsRewrite = false;
|
||||
Block& oldBlock = compute.getBody().front();
|
||||
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||
if (inputIdx >= oldBlock.getNumArguments())
|
||||
continue;
|
||||
if (!isWeightLikeComputeOperand(input))
|
||||
continue;
|
||||
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(oldBlock.getArgument(inputIdx)))
|
||||
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(compute.getInputArgument(inputIdx)))
|
||||
continue;
|
||||
promoteInput[inputIdx] = true;
|
||||
needsRewrite = true;
|
||||
@@ -116,8 +89,16 @@ struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCom
|
||||
|
||||
auto newCompute =
|
||||
spatial::SpatCompute::create(rewriter, compute.getLoc(), compute.getResultTypes(), newWeights, newInputs);
|
||||
auto* newBlock =
|
||||
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), newInputTypes, newInputLocs);
|
||||
SmallVector<Type> newBlockArgTypes;
|
||||
SmallVector<Location> newBlockArgLocs;
|
||||
for (Value weight : newWeights) {
|
||||
newBlockArgTypes.push_back(weight.getType());
|
||||
newBlockArgLocs.push_back(weight.getLoc());
|
||||
}
|
||||
llvm::append_range(newBlockArgTypes, newInputTypes);
|
||||
llvm::append_range(newBlockArgLocs, newInputLocs);
|
||||
auto* newBlock = rewriter.createBlock(
|
||||
&newCompute.getBody(), newCompute.getBody().end(), TypeRange(newBlockArgTypes), newBlockArgLocs);
|
||||
newCompute.getProperties().setOperandSegmentSizes(
|
||||
{static_cast<int>(newWeights.size()), static_cast<int>(newInputs.size())});
|
||||
rewriter.setInsertionPointToStart(newBlock);
|
||||
@@ -126,17 +107,30 @@ struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCom
|
||||
bodyRewriter.setInsertionPointToStart(newBlock);
|
||||
|
||||
IRMapping mapper;
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(compute.getWeights())) {
|
||||
auto oldWeightArg = compute.getWeightArgument(weightIndex);
|
||||
auto newWeightArg = newCompute.getWeightArgument(weightIndex);
|
||||
if (!oldWeightArg || !newWeightArg)
|
||||
return rewriter.notifyMatchFailure(compute, "missing compute weight block argument during rewrite");
|
||||
mapper.map(*oldWeightArg, *newWeightArg);
|
||||
}
|
||||
size_t newInputIdx = 0;
|
||||
for (auto [oldInputIdx, oldArg] : llvm::enumerate(oldBlock.getArguments())) {
|
||||
for (auto [oldInputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||
auto oldArg = compute.getInputArgument(oldInputIdx);
|
||||
if (!oldArg)
|
||||
return rewriter.notifyMatchFailure(compute, "missing compute input block argument during rewrite");
|
||||
if (!promoteInput[oldInputIdx]) {
|
||||
mapper.map(oldArg, newBlock->getArgument(newInputIdx++));
|
||||
auto newInputArg = newCompute.getInputArgument(newInputIdx++);
|
||||
if (!newInputArg)
|
||||
return rewriter.notifyMatchFailure(compute, "missing rewritten compute input block argument");
|
||||
mapper.map(*oldArg, *newInputArg);
|
||||
continue;
|
||||
}
|
||||
|
||||
auto clonedValue = materializeWeightLikeValueInBlock(compute.getInputs()[oldInputIdx], bodyRewriter, mapper);
|
||||
auto clonedValue = materializeWeightLikeValueInBlock(input, bodyRewriter, mapper);
|
||||
if (failed(clonedValue))
|
||||
return rewriter.notifyMatchFailure(compute, "failed to materialize promoted weight-like operand");
|
||||
mapper.map(oldArg, *clonedValue);
|
||||
mapper.map(*oldArg, *clonedValue);
|
||||
}
|
||||
|
||||
for (Operation& op : oldBlock.without_terminator())
|
||||
@@ -165,11 +159,9 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
|
||||
bool needsRewrite = false;
|
||||
Block& oldBlock = compute.getBody().front();
|
||||
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||
if (inputIdx >= oldBlock.getNumArguments())
|
||||
continue;
|
||||
if (!isWeightLikeComputeOperand(input))
|
||||
continue;
|
||||
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(oldBlock.getArgument(inputIdx)))
|
||||
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(compute.getInputArgument(inputIdx)))
|
||||
continue;
|
||||
promoteInput[inputIdx] = true;
|
||||
needsRewrite = true;
|
||||
@@ -205,8 +197,31 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(compute.getLaneCount())),
|
||||
newWeights,
|
||||
newInputs);
|
||||
auto* newBlock =
|
||||
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), newInputTypes, newInputLocs);
|
||||
auto laneArg = compute.getLaneArgument();
|
||||
if (!laneArg)
|
||||
return rewriter.notifyMatchFailure(compute, "missing compute_batch lane block argument");
|
||||
SmallVector<Type> newBlockArgTypes;
|
||||
SmallVector<Location> newBlockArgLocs;
|
||||
newBlockArgTypes.reserve(1 + newWeights.size() + newInputTypes.size() + compute.getNumResults());
|
||||
newBlockArgLocs.reserve(1 + newWeights.size() + newInputLocs.size() + compute.getNumResults());
|
||||
newBlockArgTypes.push_back(laneArg->getType());
|
||||
newBlockArgLocs.push_back(laneArg->getLoc());
|
||||
for (Value weight : newWeights) {
|
||||
newBlockArgTypes.push_back(weight.getType());
|
||||
newBlockArgLocs.push_back(weight.getLoc());
|
||||
}
|
||||
llvm::append_range(newBlockArgTypes, newInputTypes);
|
||||
llvm::append_range(newBlockArgLocs, newInputLocs);
|
||||
for (auto [resultIndex, resultType] : llvm::enumerate(compute.getResultTypes())) {
|
||||
auto outputArg = compute.getOutputArgument(resultIndex);
|
||||
if (!outputArg)
|
||||
return rewriter.notifyMatchFailure(compute, "missing compute_batch output block argument");
|
||||
newBlockArgTypes.push_back(resultType);
|
||||
newBlockArgLocs.push_back(outputArg->getLoc());
|
||||
}
|
||||
|
||||
auto* newBlock = rewriter.createBlock(
|
||||
&newCompute.getBody(), newCompute.getBody().end(), TypeRange(newBlockArgTypes), newBlockArgLocs);
|
||||
newCompute.getProperties().setOperandSegmentSizes(
|
||||
{static_cast<int>(newWeights.size()), static_cast<int>(newInputs.size())});
|
||||
rewriter.setInsertionPointToStart(newBlock);
|
||||
@@ -215,31 +230,45 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
|
||||
bodyRewriter.setInsertionPointToStart(newBlock);
|
||||
|
||||
IRMapping mapper;
|
||||
auto newLaneArg = newCompute.getLaneArgument();
|
||||
if (!newLaneArg)
|
||||
return rewriter.notifyMatchFailure(compute, "missing rewritten compute_batch lane block argument");
|
||||
mapper.map(*laneArg, *newLaneArg);
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(compute.getWeights())) {
|
||||
auto oldWeightArg = compute.getWeightArgument(weightIndex);
|
||||
auto newWeightArg = newCompute.getWeightArgument(weightIndex);
|
||||
if (!oldWeightArg || !newWeightArg)
|
||||
return rewriter.notifyMatchFailure(compute, "missing compute_batch weight block argument during rewrite");
|
||||
mapper.map(*oldWeightArg, *newWeightArg);
|
||||
}
|
||||
size_t newInputIdx = 0;
|
||||
for (auto [oldInputIdx, oldArg] : llvm::enumerate(oldBlock.getArguments())) {
|
||||
for (auto [oldInputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||
auto oldArg = compute.getInputArgument(oldInputIdx);
|
||||
if (!oldArg)
|
||||
return rewriter.notifyMatchFailure(compute, "missing compute_batch input block argument during rewrite");
|
||||
if (!promoteInput[oldInputIdx]) {
|
||||
mapper.map(oldArg, newBlock->getArgument(newInputIdx++));
|
||||
auto newInputArg = newCompute.getInputArgument(newInputIdx++);
|
||||
if (!newInputArg)
|
||||
return rewriter.notifyMatchFailure(compute, "missing rewritten compute_batch input block argument");
|
||||
mapper.map(*oldArg, *newInputArg);
|
||||
continue;
|
||||
}
|
||||
|
||||
auto clonedValue = materializeWeightLikeValueInBlock(compute.getInputs()[oldInputIdx], bodyRewriter, mapper);
|
||||
auto clonedValue = materializeWeightLikeValueInBlock(input, bodyRewriter, mapper);
|
||||
if (failed(clonedValue))
|
||||
return rewriter.notifyMatchFailure(compute, "failed to materialize promoted batch weight-like operand");
|
||||
mapper.map(oldArg, *clonedValue);
|
||||
mapper.map(*oldArg, *clonedValue);
|
||||
}
|
||||
for (auto resultIndex : llvm::seq<size_t>(0, compute.getNumResults())) {
|
||||
auto outputArg = compute.getOutputArgument(resultIndex);
|
||||
if (!outputArg)
|
||||
return rewriter.notifyMatchFailure(compute, "missing compute_batch output block argument during rewrite");
|
||||
mapper.map(*outputArg, newBlock->getArgument(1 + newWeights.size() + newInputs.size() + resultIndex));
|
||||
}
|
||||
|
||||
for (Operation& op : oldBlock.without_terminator())
|
||||
for (Operation& op : oldBlock)
|
||||
rewriter.clone(op, mapper);
|
||||
|
||||
auto oldYield = cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
|
||||
SmallVector<Value> newYieldOperands;
|
||||
newYieldOperands.reserve(oldYield.getOutputs().size());
|
||||
for (Value operand : oldYield.getOutputs()) {
|
||||
auto mapped = mapper.lookupOrNull(operand);
|
||||
newYieldOperands.push_back(mapped ? cast<Value>(mapped) : operand);
|
||||
}
|
||||
spatial::SpatYieldOp::create(rewriter, oldYield.getLoc(), newYieldOperands);
|
||||
|
||||
rewriter.replaceOp(compute, newCompute.getResults());
|
||||
return success();
|
||||
}
|
||||
@@ -247,10 +276,6 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
|
||||
|
||||
} // namespace
|
||||
|
||||
void populateEarlyPostPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.add<FoldSingleLaneComputeBatchPattern>(ctx);
|
||||
}
|
||||
|
||||
void populatePostPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.add<PromoteWeightLikeComputeInputsPattern, PromoteWeightLikeComputeBatchInputsPattern>(ctx);
|
||||
}
|
||||
@@ -262,4 +287,8 @@ void annotateWeightsConstants(func::FuncOp funcOp) {
|
||||
});
|
||||
}
|
||||
|
||||
bool requiresPostRewrite(spatial::SpatCompute computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
|
||||
|
||||
bool requiresPostRewrite(spatial::SpatComputeBatch computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -3,9 +3,13 @@
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
void populateEarlyPostPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
bool requiresPostRewrite(spatial::SpatCompute computeOp);
|
||||
|
||||
bool requiresPostRewrite(spatial::SpatComputeBatch computeOp);
|
||||
|
||||
void populatePostPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
|
||||
@@ -17,9 +17,7 @@ void populatePrePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* c
|
||||
patterns.add<convAddToConvWithBiasLeft>(ctx);
|
||||
patterns.add<convAddToConvWithBiasRight>(ctx);
|
||||
patterns.add<matMulAddToGemm>(ctx);
|
||||
patterns.add<matMulToGemm>(ctx);
|
||||
patterns.add<removeFlattenSameShape>(ctx);
|
||||
populateMatMulRewritePatterns(patterns, ctx);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -202,6 +202,7 @@ void SpatialToGraphvizPass::runOnOperation() {
|
||||
|
||||
auto entryFunc = getPimEntryFunc(module);
|
||||
if (failed(entryFunc)) {
|
||||
module.emitError("failed to locate the PIM entry function for Spatial graph visualization");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
|
||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "Conversion/SpatialToPim/SpatialToPimPass.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
@@ -17,6 +20,37 @@ namespace {
|
||||
|
||||
static int32_t translateSpatialCoreIdToPimCoreId(size_t spatialCoreId) { return static_cast<int32_t>(spatialCoreId); }
|
||||
|
||||
static FailureOr<int32_t> getConstantI32Value(Value value) {
|
||||
APInt constantValue;
|
||||
if (!matchPattern(value, m_ConstantInt(&constantValue)))
|
||||
return failure();
|
||||
return static_cast<int32_t>(constantValue.getSExtValue());
|
||||
}
|
||||
|
||||
static FailureOr<SmallVector<int32_t>> getConstantI32Values(ValueRange values) {
|
||||
SmallVector<int32_t> constants;
|
||||
constants.reserve(values.size());
|
||||
for (Value value : values) {
|
||||
FailureOr<int32_t> constantValue = getConstantI32Value(value);
|
||||
if (failed(constantValue))
|
||||
return failure();
|
||||
constants.push_back(*constantValue);
|
||||
}
|
||||
return constants;
|
||||
}
|
||||
|
||||
static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
|
||||
if (isa<pim::PimMemCopyDevToHostOp>(op))
|
||||
return operandIndex == 2;
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool isUsedOnlyAsExplicitHostOperand(Value value) {
|
||||
return !value.use_empty() && llvm::all_of(value.getUses(), [](OpOperand& use) {
|
||||
return isExplicitHostOperand(use.getOwner(), use.getOperandNumber());
|
||||
});
|
||||
}
|
||||
|
||||
static SmallVector<int32_t> getPimCoreIdsForBatchOp(spatial::SpatComputeBatch computeBatchOp, size_t& fallbackCoreId) {
|
||||
if (auto coreIdsAttr = computeBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
|
||||
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
|
||||
@@ -28,27 +62,30 @@ static SmallVector<int32_t> getPimCoreIdsForBatchOp(spatial::SpatComputeBatch co
|
||||
return coreIds;
|
||||
}
|
||||
|
||||
static void lowerChannelSendTensorBatch(spatial::SpatChannelSendTensorBatchOp sendTensorBatchOp,
|
||||
IRMapping& mapper,
|
||||
IRRewriter& rewriter) {
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
targetCoreIds.reserve(sendTensorBatchOp.getTargetCoreIds().size());
|
||||
for (int32_t targetCoreId : sendTensorBatchOp.getTargetCoreIds())
|
||||
targetCoreIds.push_back(translateSpatialCoreIdToPimCoreId(targetCoreId));
|
||||
static LogicalResult lowerChannelSendTensorBatch(spatial::SpatChannelSendTensorBatchOp sendTensorBatchOp,
|
||||
IRMapping& mapper,
|
||||
IRRewriter& rewriter) {
|
||||
FailureOr<SmallVector<int32_t>> targetCoreIds = getConstantI32Values(sendTensorBatchOp.getTargetCoreIds());
|
||||
if (failed(targetCoreIds))
|
||||
return sendTensorBatchOp.emitOpError("expected constant targetCoreIds");
|
||||
for (int32_t& targetCoreId : *targetCoreIds)
|
||||
targetCoreId = translateSpatialCoreIdToPimCoreId(targetCoreId);
|
||||
|
||||
pim::PimSendTensorBatchOp::create(rewriter,
|
||||
sendTensorBatchOp.getLoc(),
|
||||
mapper.lookup(sendTensorBatchOp.getInput()),
|
||||
rewriter.getDenseI32ArrayAttr(targetCoreIds));
|
||||
rewriter.getDenseI32ArrayAttr(*targetCoreIds));
|
||||
return success();
|
||||
}
|
||||
|
||||
static void lowerChannelReceiveTensorBatch(spatial::SpatChannelReceiveTensorBatchOp receiveTensorBatchOp,
|
||||
IRMapping& mapper,
|
||||
IRRewriter& rewriter) {
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
sourceCoreIds.reserve(receiveTensorBatchOp.getSourceCoreIds().size());
|
||||
for (int32_t sourceCoreId : receiveTensorBatchOp.getSourceCoreIds())
|
||||
sourceCoreIds.push_back(translateSpatialCoreIdToPimCoreId(sourceCoreId));
|
||||
static LogicalResult lowerChannelReceiveTensorBatch(spatial::SpatChannelReceiveTensorBatchOp receiveTensorBatchOp,
|
||||
IRMapping& mapper,
|
||||
IRRewriter& rewriter) {
|
||||
FailureOr<SmallVector<int32_t>> sourceCoreIds = getConstantI32Values(receiveTensorBatchOp.getSourceCoreIds());
|
||||
if (failed(sourceCoreIds))
|
||||
return receiveTensorBatchOp.emitOpError("expected constant sourceCoreIds");
|
||||
for (int32_t& sourceCoreId : *sourceCoreIds)
|
||||
sourceCoreId = translateSpatialCoreIdToPimCoreId(sourceCoreId);
|
||||
|
||||
auto outputType = cast<ShapedType>(receiveTensorBatchOp.getOutput().getType());
|
||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveTensorBatchOp.getLoc(), outputType);
|
||||
@@ -56,26 +93,81 @@ static void lowerChannelReceiveTensorBatch(spatial::SpatChannelReceiveTensorBatc
|
||||
receiveTensorBatchOp.getLoc(),
|
||||
outputBuffer.getType(),
|
||||
outputBuffer,
|
||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds))
|
||||
rewriter.getDenseI32ArrayAttr(*sourceCoreIds))
|
||||
.getOutput();
|
||||
mapper.map(receiveTensorBatchOp.getOutput(), received);
|
||||
return success();
|
||||
}
|
||||
|
||||
static FailureOr<unsigned> getDirectReturnOperandIndex(OpResult result) {
|
||||
if (!result.hasOneUse())
|
||||
return failure();
|
||||
|
||||
auto returnOp = dyn_cast<func::ReturnOp>(*result.getUsers().begin());
|
||||
if (!returnOp)
|
||||
return failure();
|
||||
return result.getUses().begin()->getOperandNumber();
|
||||
}
|
||||
|
||||
static Value createScaledIndexValue(IRRewriter& rewriter, Location loc, Value base, int64_t scale) {
|
||||
if (scale == 1)
|
||||
return base;
|
||||
|
||||
auto scaleValue = arith::ConstantIndexOp::create(rewriter, loc, scale).getResult();
|
||||
return arith::MulIOp::create(rewriter, loc, base, scaleValue).getResult();
|
||||
}
|
||||
|
||||
static Value createHostTargetOffset(IRRewriter& rewriter,
|
||||
tensor::ParallelInsertSliceOp insertSlice,
|
||||
ShapedType destinationType,
|
||||
IRMapping& mapper) {
|
||||
int64_t elementBytes = static_cast<int64_t>(getElementTypeSizeInBytes(destinationType.getElementType()));
|
||||
SmallVector<int64_t> strides(destinationType.getRank(), 1);
|
||||
ArrayRef<int64_t> shape = destinationType.getShape();
|
||||
for (int64_t dim = destinationType.getRank() - 2; dim >= 0; --dim)
|
||||
strides[dim] = strides[dim + 1] * shape[dim + 1];
|
||||
|
||||
Value totalOffset;
|
||||
Location loc = insertSlice.getLoc();
|
||||
for (auto [dim, offset] : llvm::enumerate(insertSlice.getMixedOffsets())) {
|
||||
int64_t scale = strides[dim] * elementBytes;
|
||||
Value scaledOffset;
|
||||
if (auto attr = dyn_cast<Attribute>(offset)) {
|
||||
auto intAttr = dyn_cast<IntegerAttr>(attr);
|
||||
assert(intAttr && "expected integer offset attribute");
|
||||
scaledOffset = arith::ConstantIndexOp::create(rewriter, loc, intAttr.getInt() * scale).getResult();
|
||||
}
|
||||
else {
|
||||
scaledOffset = createScaledIndexValue(rewriter, loc, mapper.lookupOrDefault(cast<Value>(offset)), scale);
|
||||
}
|
||||
|
||||
totalOffset =
|
||||
totalOffset ? arith::AddIOp::create(rewriter, loc, totalOffset, scaledOffset).getResult() : scaledOffset;
|
||||
}
|
||||
|
||||
if (!totalOffset)
|
||||
totalOffset = arith::ConstantIndexOp::create(rewriter, loc, 0).getResult();
|
||||
return totalOffset;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult
|
||||
lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState& state, IRRewriter& rewriter) {
|
||||
if (computeBatchOp.getNumResults() != 0)
|
||||
return computeBatchOp.emitOpError(
|
||||
"batched Spatial-to-PIM lowering currently requires channelized compute_batch with no results");
|
||||
|
||||
LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp,
|
||||
IRRewriter& rewriter) {
|
||||
Location loc = computeBatchOp.getLoc();
|
||||
Block& oldBlock = computeBatchOp.getBody().front();
|
||||
auto oldYield = cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
|
||||
if (oldYield.getNumOperands() != 0)
|
||||
return computeBatchOp.emitOpError("batched Spatial-to-PIM lowering currently requires empty spat.yield");
|
||||
auto oldYield = dyn_cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
|
||||
auto inParallelOp = dyn_cast<spatial::SpatInParallelOp>(oldBlock.getTerminator());
|
||||
if (computeBatchOp.getNumResults() == 0) {
|
||||
if (!oldYield || oldYield.getNumOperands() != 0)
|
||||
return computeBatchOp.emitOpError("resultless compute_batch lowering requires empty spat.yield");
|
||||
}
|
||||
else if (!inParallelOp) {
|
||||
return computeBatchOp.emitOpError(
|
||||
"resultful compute_batch lowering currently requires a spat.in_parallel terminator");
|
||||
}
|
||||
|
||||
SmallVector<int32_t> coreIds = getPimCoreIdsForBatchOp(computeBatchOp, state.nextCoreId);
|
||||
SmallVector<int32_t> coreIds = getPimCoreIdsForBatchOp(computeBatchOp, coreId);
|
||||
SmallVector<Value> batchWeights(computeBatchOp.getWeights().begin(), computeBatchOp.getWeights().end());
|
||||
SmallVector<Value> batchInputs;
|
||||
if (!computeBatchOp.getInputs().empty())
|
||||
@@ -91,9 +183,22 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState&
|
||||
{static_cast<int>(batchWeights.size()), static_cast<int>(batchInputs.size())});
|
||||
coreBatchOp->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds));
|
||||
|
||||
SmallVector<unsigned> returnOperandIndices;
|
||||
if (computeBatchOp.getNumResults() != 0) {
|
||||
returnOperandIndices.resize(computeBatchOp.getNumResults());
|
||||
for (auto [resultIndex, result] : llvm::enumerate(computeBatchOp.getResults())) {
|
||||
FailureOr<unsigned> returnOperandIndex = getDirectReturnOperandIndex(cast<OpResult>(result));
|
||||
if (failed(returnOperandIndex))
|
||||
return computeBatchOp.emitOpError(
|
||||
"resultful compute_batch lowering currently requires each result to be used directly by func.return");
|
||||
returnOperandIndices[resultIndex] = *returnOperandIndex;
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<Type> blockArgTypes;
|
||||
SmallVector<Location> blockArgLocs;
|
||||
for (BlockArgument arg : oldBlock.getArguments()) {
|
||||
unsigned inputArgLimit = 1 + computeBatchOp.getWeights().size() + computeBatchOp.getInputs().size();
|
||||
for (BlockArgument arg : oldBlock.getArguments().take_front(inputArgLimit)) {
|
||||
blockArgTypes.push_back(arg.getType());
|
||||
blockArgLocs.push_back(arg.getLoc());
|
||||
}
|
||||
@@ -102,7 +207,21 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState&
|
||||
|
||||
IRMapping mapper;
|
||||
rewriter.setInsertionPointToStart(newBlock);
|
||||
for (auto [oldArg, newArg] : llvm::zip(oldBlock.getArguments(), newBlock->getArguments())) {
|
||||
auto oldLaneArg = computeBatchOp.getLaneArgument();
|
||||
if (!oldLaneArg)
|
||||
return computeBatchOp.emitOpError("expected compute_batch lane block argument before lowering");
|
||||
mapper.map(*oldLaneArg, coreBatchOp.getLaneArgument());
|
||||
for (unsigned weightIndex = 0; weightIndex < computeBatchOp.getWeights().size(); ++weightIndex) {
|
||||
auto oldWeightArg = computeBatchOp.getWeightArgument(weightIndex);
|
||||
if (!oldWeightArg)
|
||||
return computeBatchOp.emitOpError("expected compute_batch weight block arguments before lowering");
|
||||
mapper.map(*oldWeightArg, coreBatchOp.getWeightArgument(weightIndex));
|
||||
}
|
||||
for (unsigned inputIndex = 0; inputIndex < computeBatchOp.getInputs().size(); ++inputIndex) {
|
||||
auto oldArg = computeBatchOp.getInputArgument(inputIndex);
|
||||
if (!oldArg)
|
||||
return computeBatchOp.emitOpError("expected compute_batch input block arguments before lowering");
|
||||
BlockArgument newArg = coreBatchOp.getInputArgument(inputIndex);
|
||||
auto newArgType = cast<ShapedType>(newArg.getType());
|
||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, newArgType);
|
||||
auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter,
|
||||
@@ -114,7 +233,7 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState&
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
getTensorSizeInBytesAttr(rewriter, newArg))
|
||||
.getOutput();
|
||||
mapper.map(oldArg, copied);
|
||||
mapper.map(*oldArg, copied);
|
||||
}
|
||||
|
||||
auto materializeCapturedTensor = [&](Value capturedTensor) -> Value {
|
||||
@@ -136,26 +255,81 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState&
|
||||
return copied;
|
||||
};
|
||||
|
||||
SmallVector<Value> hostOutputTensors(returnOperandIndices.size());
|
||||
auto getOrCreateHostOutputTensor = [&](unsigned resultIndex, Location resultLoc) -> Value {
|
||||
Value& hostOutputTensor = hostOutputTensors[resultIndex];
|
||||
if (hostOutputTensor)
|
||||
return hostOutputTensor;
|
||||
|
||||
hostOutputTensor = outputTensors[returnOperandIndices[resultIndex]](rewriter, resultLoc);
|
||||
return hostOutputTensor;
|
||||
};
|
||||
|
||||
rewriter.setInsertionPointToEnd(newBlock);
|
||||
for (Operation& op : oldBlock) {
|
||||
if (isa<spatial::SpatYieldOp>(op))
|
||||
continue;
|
||||
|
||||
if (auto parallelOp = dyn_cast<spatial::SpatInParallelOp>(op)) {
|
||||
auto firstOutputArg = computeBatchOp.getOutputArgument(0);
|
||||
if (!firstOutputArg)
|
||||
return computeBatchOp.emitOpError("expected compute_batch output block arguments before lowering");
|
||||
for (Operation& nestedOp : parallelOp.getRegion().front()) {
|
||||
auto insertSlice = dyn_cast<tensor::ParallelInsertSliceOp>(&nestedOp);
|
||||
if (!insertSlice)
|
||||
return parallelOp.emitOpError("expected only tensor.parallel_insert_slice in spat.in_parallel");
|
||||
|
||||
auto outputArg = dyn_cast<BlockArgument>(insertSlice.getDest());
|
||||
if (!outputArg || outputArg.getOwner() != &oldBlock)
|
||||
return insertSlice.emitOpError("expected compute_batch output block argument destination");
|
||||
|
||||
unsigned resultIndex = outputArg.getArgNumber() - firstOutputArg->getArgNumber();
|
||||
if (resultIndex >= returnOperandIndices.size())
|
||||
return insertSlice.emitOpError("result index out of range while lowering host batch output");
|
||||
|
||||
Value mappedSource = mapper.lookup(insertSlice.getSource());
|
||||
Value hostTarget = getOrCreateHostOutputTensor(resultIndex, insertSlice.getLoc());
|
||||
auto hostTargetType = cast<ShapedType>(hostTarget.getType());
|
||||
Value hostTargetOffset = createHostTargetOffset(rewriter, insertSlice, hostTargetType, mapper);
|
||||
Value zeroOffset = arith::ConstantIndexOp::create(rewriter, insertSlice.getLoc(), 0).getResult();
|
||||
pim::PimMemCopyDevToHostOp::create(rewriter,
|
||||
insertSlice.getLoc(),
|
||||
hostTarget.getType(),
|
||||
hostTargetOffset,
|
||||
zeroOffset,
|
||||
hostTarget,
|
||||
mappedSource,
|
||||
getTensorSizeInBytesAttr(rewriter, mappedSource));
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto sendBatchOp = dyn_cast<spatial::SpatChannelSendBatchOp>(op)) {
|
||||
FailureOr<SmallVector<int32_t>> targetCoreIds = getConstantI32Values(sendBatchOp.getTargetCoreIds());
|
||||
if (failed(targetCoreIds))
|
||||
return sendBatchOp.emitOpError("expected constant targetCoreIds");
|
||||
for (int32_t& targetCoreId : *targetCoreIds)
|
||||
targetCoreId = translateSpatialCoreIdToPimCoreId(targetCoreId);
|
||||
pim::PimSendBatchOp::create(rewriter,
|
||||
loc,
|
||||
mapper.lookup(sendBatchOp.getInput()),
|
||||
getTensorSizeInBytesAttr(rewriter, mapper.lookup(sendBatchOp.getInput())),
|
||||
sendBatchOp.getTargetCoreIdsAttr());
|
||||
rewriter.getDenseI32ArrayAttr(*targetCoreIds));
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto sendTensorBatchOp = dyn_cast<spatial::SpatChannelSendTensorBatchOp>(op)) {
|
||||
lowerChannelSendTensorBatch(sendTensorBatchOp, mapper, rewriter);
|
||||
if (failed(lowerChannelSendTensorBatch(sendTensorBatchOp, mapper, rewriter)))
|
||||
return failure();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto receiveBatchOp = dyn_cast<spatial::SpatChannelReceiveBatchOp>(op)) {
|
||||
FailureOr<SmallVector<int32_t>> sourceCoreIds = getConstantI32Values(receiveBatchOp.getSourceCoreIds());
|
||||
if (failed(sourceCoreIds))
|
||||
return receiveBatchOp.emitOpError("expected constant sourceCoreIds");
|
||||
for (int32_t& sourceCoreId : *sourceCoreIds)
|
||||
sourceCoreId = translateSpatialCoreIdToPimCoreId(sourceCoreId);
|
||||
auto outputType = cast<ShapedType>(receiveBatchOp.getOutput().getType());
|
||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, outputType);
|
||||
auto received = pim::PimReceiveBatchOp::create(rewriter,
|
||||
@@ -163,14 +337,15 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState&
|
||||
outputBuffer.getType(),
|
||||
outputBuffer,
|
||||
getTensorSizeInBytesAttr(rewriter, receiveBatchOp.getOutput()),
|
||||
receiveBatchOp.getSourceCoreIdsAttr())
|
||||
rewriter.getDenseI32ArrayAttr(*sourceCoreIds))
|
||||
.getOutput();
|
||||
mapper.map(receiveBatchOp.getOutput(), received);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto receiveTensorBatchOp = dyn_cast<spatial::SpatChannelReceiveTensorBatchOp>(op)) {
|
||||
lowerChannelReceiveTensorBatch(receiveTensorBatchOp, mapper, rewriter);
|
||||
if (failed(lowerChannelReceiveTensorBatch(receiveTensorBatchOp, mapper, rewriter)))
|
||||
return failure();
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -178,6 +353,10 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState&
|
||||
if (isa_and_present<memref::GetGlobalOp>(toTensorOp.getBuffer().getDefiningOp())) {
|
||||
Operation* cloned = rewriter.clone(op, mapper);
|
||||
auto clonedTensor = cloned->getResult(0);
|
||||
if (isUsedOnlyAsExplicitHostOperand(toTensorOp.getResult())) {
|
||||
mapper.map(toTensorOp.getResult(), clonedTensor);
|
||||
continue;
|
||||
}
|
||||
auto clonedType = cast<ShapedType>(clonedTensor.getType());
|
||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, clonedType);
|
||||
auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter,
|
||||
@@ -194,9 +373,11 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState&
|
||||
}
|
||||
}
|
||||
|
||||
for (Value operand : op.getOperands()) {
|
||||
for (auto [operandIndex, operand] : llvm::enumerate(op.getOperands())) {
|
||||
if (!isa<TensorType>(operand.getType()) || mapper.contains(operand))
|
||||
continue;
|
||||
if (isExplicitHostOperand(&op, operandIndex))
|
||||
continue;
|
||||
|
||||
Operation* definingOp = operand.getDefiningOp();
|
||||
if (definingOp && definingOp->getBlock() == &oldBlock)
|
||||
|
||||
@@ -1,10 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState& state, mlir::IRRewriter& rewriter);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -6,7 +6,6 @@ add_pim_library(OMSpatialToPim
|
||||
SpatialToPimPass.cpp
|
||||
BatchCoreLoweringPatterns.cpp
|
||||
ChannelLoweringPatterns.cpp
|
||||
Cleanup.cpp
|
||||
Common.cpp
|
||||
ComputeLikeRegionUtils.cpp
|
||||
CoreLoweringPatterns.cpp
|
||||
@@ -22,6 +21,8 @@ add_pim_library(OMSpatialToPim
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRSCFDialect
|
||||
MLIRSCFUtils
|
||||
MLIRTransformUtils
|
||||
MLIRTosaDialect
|
||||
OMCompilerOptions
|
||||
OMPimCommon
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||
@@ -12,15 +13,24 @@ namespace {
|
||||
|
||||
static int32_t toPimCoreId(int32_t spatialCoreId) { return spatialCoreId; }
|
||||
|
||||
static FailureOr<SmallVector<int32_t>> getConstantI32Values(ValueRange values) {
|
||||
SmallVector<int32_t> constants;
|
||||
constants.reserve(values.size());
|
||||
for (Value value : values) {
|
||||
APInt constantValue;
|
||||
if (!matchPattern(value, m_ConstantInt(&constantValue)))
|
||||
return failure();
|
||||
constants.push_back(static_cast<int32_t>(constantValue.getSExtValue()));
|
||||
}
|
||||
return constants;
|
||||
}
|
||||
|
||||
struct ChannelSendLowering : OpRewritePattern<spatial::SpatChannelSendOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(spatial::SpatChannelSendOp op, PatternRewriter& rewriter) const override {
|
||||
pim::PimSendOp::create(rewriter,
|
||||
op.getLoc(),
|
||||
op.getInput(),
|
||||
getTensorSizeInBytesAttr(rewriter, op.getInput()),
|
||||
rewriter.getI32IntegerAttr(toPimCoreId(op.getTargetCoreId())));
|
||||
pim::PimSendOp::create(
|
||||
rewriter, op.getLoc(), op.getInput(), getTensorSizeInBytesAttr(rewriter, op.getInput()), op.getTargetCoreId());
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
@@ -42,7 +52,7 @@ struct ChannelReceiveLowering : OpRewritePattern<spatial::SpatChannelReceiveOp>
|
||||
op.getResult().getType(),
|
||||
outputBuffer,
|
||||
getTensorSizeInBytesAttr(rewriter, op.getResult()),
|
||||
rewriter.getI32IntegerAttr(toPimCoreId(op.getSourceCoreId())))
|
||||
op.getSourceCoreId())
|
||||
.getOutput();
|
||||
rewriter.replaceOp(op, received);
|
||||
return success();
|
||||
@@ -53,11 +63,12 @@ struct ChannelSendTensorLowering : OpRewritePattern<spatial::SpatChannelSendTens
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(spatial::SpatChannelSendTensorOp op, PatternRewriter& rewriter) const override {
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
targetCoreIds.reserve(op.getTargetCoreIds().size());
|
||||
for (int32_t targetCoreId : op.getTargetCoreIds())
|
||||
targetCoreIds.push_back(toPimCoreId(targetCoreId));
|
||||
pim::PimSendTensorOp::create(rewriter, op.getLoc(), op.getInput(), rewriter.getDenseI32ArrayAttr(targetCoreIds));
|
||||
FailureOr<SmallVector<int32_t>> targetCoreIds = getConstantI32Values(op.getTargetCoreIds());
|
||||
if (failed(targetCoreIds))
|
||||
return rewriter.notifyMatchFailure(op, "expected constant targetCoreIds");
|
||||
for (int32_t& targetCoreId : *targetCoreIds)
|
||||
targetCoreId = toPimCoreId(targetCoreId);
|
||||
pim::PimSendTensorOp::create(rewriter, op.getLoc(), op.getInput(), rewriter.getDenseI32ArrayAttr(*targetCoreIds));
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
@@ -67,16 +78,17 @@ struct ChannelReceiveTensorLowering : OpRewritePattern<spatial::SpatChannelRecei
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(spatial::SpatChannelReceiveTensorOp op, PatternRewriter& rewriter) const override {
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
sourceCoreIds.reserve(op.getSourceCoreIds().size());
|
||||
for (int32_t sourceCoreId : op.getSourceCoreIds())
|
||||
sourceCoreIds.push_back(toPimCoreId(sourceCoreId));
|
||||
FailureOr<SmallVector<int32_t>> sourceCoreIds = getConstantI32Values(op.getSourceCoreIds());
|
||||
if (failed(sourceCoreIds))
|
||||
return rewriter.notifyMatchFailure(op, "expected constant sourceCoreIds");
|
||||
for (int32_t& sourceCoreId : *sourceCoreIds)
|
||||
sourceCoreId = toPimCoreId(sourceCoreId);
|
||||
auto outputType = cast<ShapedType>(op.getOutput().getType());
|
||||
Value outputBuffer =
|
||||
tensor::EmptyOp::create(rewriter, op.getLoc(), outputType.getShape(), outputType.getElementType()).getResult();
|
||||
Value received =
|
||||
pim::PimReceiveTensorOp::create(
|
||||
rewriter, op.getLoc(), op.getOutput().getType(), outputBuffer, rewriter.getDenseI32ArrayAttr(sourceCoreIds))
|
||||
rewriter, op.getLoc(), op.getOutput().getType(), outputBuffer, rewriter.getDenseI32ArrayAttr(*sourceCoreIds))
|
||||
.getOutput();
|
||||
rewriter.replaceOp(op, received);
|
||||
return success();
|
||||
|
||||
@@ -1,42 +0,0 @@
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Cleanup.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
LogicalResult erasePendingOps(SmallVectorImpl<Operation*>& pendingOps, IRRewriter& rewriter) {
|
||||
while (!pendingOps.empty()) {
|
||||
bool erasedAnyOp = false;
|
||||
for (auto it = pendingOps.begin(); it != pendingOps.end();) {
|
||||
Operation* opToRemove = *it;
|
||||
if (!opToRemove->use_empty()) {
|
||||
++it;
|
||||
continue;
|
||||
}
|
||||
|
||||
rewriter.eraseOp(opToRemove);
|
||||
it = pendingOps.erase(it);
|
||||
erasedAnyOp = true;
|
||||
}
|
||||
|
||||
if (erasedAnyOp)
|
||||
continue;
|
||||
|
||||
for (Operation* opToRemove : pendingOps) {
|
||||
InFlightDiagnostic diag = opToRemove->emitError("pending Spatial-to-PIM cleanup could not erase operation");
|
||||
diag << "; op has " << llvm::range_size(opToRemove->getUsers()) << " remaining user(s)";
|
||||
for (Operation* user : opToRemove->getUsers()) {
|
||||
bool userPendingRemoval = llvm::is_contained(pendingOps, user);
|
||||
opToRemove->emitRemark() << "remaining user `" << user->getName() << "`"
|
||||
<< (userPendingRemoval ? " is also pending removal" : " is not pending removal");
|
||||
}
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -1,11 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
mlir::LogicalResult erasePendingOps(llvm::SmallVectorImpl<mlir::Operation*>& pendingOps, mlir::IRRewriter& rewriter);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -55,10 +55,6 @@ size_t getSliceActualOffset(tensor::ExtractSliceOp& sliceOp, ShapedType& inputSh
|
||||
return returnValue;
|
||||
}
|
||||
|
||||
size_t getShapedTypeSizeInBytes(ShapedType shapedType) {
|
||||
return shapedType.getNumElements() * shapedType.getElementTypeBitWidth() / 8;
|
||||
}
|
||||
|
||||
IntegerAttr getTensorSizeInBytesAttr(Builder& builder, mlir::Value value) {
|
||||
return builder.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(value.getType()))));
|
||||
}
|
||||
|
||||
@@ -20,8 +20,6 @@ namespace onnx_mlir {
|
||||
*/
|
||||
size_t getSliceActualOffset(mlir::tensor::ExtractSliceOp& sliceOp, mlir::ShapedType& inputShape);
|
||||
|
||||
size_t getShapedTypeSizeInBytes(mlir::ShapedType shapedType);
|
||||
|
||||
mlir::IntegerAttr getTensorSizeInBytesAttr(mlir::Builder& builder, mlir::Value value);
|
||||
|
||||
template <class T>
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
#include <cassert>
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ComputeLikeRegionUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
@@ -29,7 +31,18 @@ void replaceAndEraseDirectComputeLikeInput(PatternRewriter& rewriter,
|
||||
unsigned inputIndex,
|
||||
Value replacement) {
|
||||
Block& body = owner->getRegion(0).front();
|
||||
BlockArgument bodyArgument = body.getArgument(inputIndex);
|
||||
BlockArgument bodyArgument;
|
||||
if (auto compute = dyn_cast<spatial::SpatCompute>(owner)) {
|
||||
auto computeArg = compute.getInputArgument(inputIndex);
|
||||
assert(computeArg && "expected compute input block argument");
|
||||
bodyArgument = *computeArg;
|
||||
}
|
||||
else {
|
||||
auto batchArg = cast<spatial::SpatComputeBatch>(owner).getInputArgument(inputIndex);
|
||||
assert(batchArg && "expected compute_batch input block argument");
|
||||
bodyArgument = *batchArg;
|
||||
}
|
||||
unsigned bodyArgIndex = bodyArgument.getArgNumber();
|
||||
|
||||
rewriter.startOpModification(owner);
|
||||
bodyArgument.replaceAllUsesWith(replacement);
|
||||
@@ -37,7 +50,7 @@ void replaceAndEraseDirectComputeLikeInput(PatternRewriter& rewriter,
|
||||
compute.getInputsMutable().erase(inputIndex);
|
||||
else
|
||||
cast<spatial::SpatComputeBatch>(owner).getInputsMutable().erase(inputIndex);
|
||||
body.eraseArgument(inputIndex);
|
||||
body.eraseArgument(bodyArgIndex);
|
||||
rewriter.finalizeOpModification(owner);
|
||||
}
|
||||
|
||||
|
||||
@@ -3,11 +3,12 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
|
||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "Conversion/SpatialToPim/SpatialToPimPass.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -27,7 +28,8 @@ static bool isChannelUseChainOp(Operation* op) {
|
||||
pim::PimTransposeOp>(op);
|
||||
}
|
||||
|
||||
static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter) {
|
||||
static void
|
||||
cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter, OperationFolder& constantFolder) {
|
||||
for (Value operand : op->getOperands()) {
|
||||
if (mapping.lookupOrNull(operand))
|
||||
continue;
|
||||
@@ -36,7 +38,12 @@ static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewri
|
||||
if (!definingOp)
|
||||
continue;
|
||||
|
||||
if (!isa<tensor::EmptyOp, arith::ConstantOp>(definingOp))
|
||||
if (auto constantOp = dyn_cast<arith::ConstantOp>(definingOp)) {
|
||||
mapping.map(operand, getOrCreateHostConstantLike(constantOp, constantFolder));
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!isa<tensor::EmptyOp>(definingOp))
|
||||
continue;
|
||||
|
||||
Operation* clonedOp = rewriter.clone(*definingOp, mapping);
|
||||
@@ -48,6 +55,18 @@ static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewri
|
||||
|
||||
static int32_t translateSpatialCoreIdToPimCoreId(size_t spatialCoreId) { return static_cast<int32_t>(spatialCoreId); }
|
||||
|
||||
static FailureOr<SmallVector<int32_t>> getConstantI32Values(ValueRange values) {
|
||||
SmallVector<int32_t> constants;
|
||||
constants.reserve(values.size());
|
||||
for (Value value : values) {
|
||||
APInt constantValue;
|
||||
if (!matchPattern(value, m_ConstantInt(&constantValue)))
|
||||
return failure();
|
||||
constants.push_back(static_cast<int32_t>(constantValue.getSExtValue()));
|
||||
}
|
||||
return constants;
|
||||
}
|
||||
|
||||
static int32_t getPimCoreIdForComputeOp(spatial::SpatCompute computeOp, size_t& fallbackCoreId) {
|
||||
if (auto spatialCoreIdAttr = computeOp->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
|
||||
return static_cast<int32_t>(spatialCoreIdAttr.getInt());
|
||||
@@ -92,7 +111,9 @@ static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp,
|
||||
return success();
|
||||
}
|
||||
|
||||
static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute computeOp, IRRewriter& rewriter) {
|
||||
static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute computeOp,
|
||||
IRRewriter& rewriter,
|
||||
OperationFolder& constantFolder) {
|
||||
if (!computeOp.getInputs().empty() || computeOp.getNumResults() != 1)
|
||||
return false;
|
||||
if (!llvm::all_of(computeOp.getResult(0).getUsers(), [](Operation* user) {
|
||||
@@ -101,7 +122,7 @@ static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute
|
||||
return false;
|
||||
|
||||
Block& block = computeOp.getBody().front();
|
||||
if (block.getNumArguments() != 0)
|
||||
if (block.getNumArguments() != computeOp.getWeights().size())
|
||||
return false;
|
||||
|
||||
auto yieldOp = dyn_cast<spatial::SpatYieldOp>(block.getTerminator());
|
||||
@@ -110,8 +131,14 @@ static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute
|
||||
|
||||
rewriter.setInsertionPoint(computeOp);
|
||||
IRMapping mapping;
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(computeOp.getWeights())) {
|
||||
auto weightArg = computeOp.getWeightArgument(weightIndex);
|
||||
if (!weightArg)
|
||||
return false;
|
||||
mapping.map(*weightArg, weight);
|
||||
}
|
||||
for (Operation& op : block.without_terminator()) {
|
||||
cloneMappedHelperOperands(&op, mapping, rewriter);
|
||||
cloneMappedHelperOperands(&op, mapping, rewriter, constantFolder);
|
||||
Operation* clonedOp = rewriter.clone(op, mapping);
|
||||
for (auto [originalResult, newResult] : llvm::zip(op.getResults(), clonedOp->getResults()))
|
||||
mapping.map(originalResult, newResult);
|
||||
@@ -125,15 +152,12 @@ static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute
|
||||
|
||||
} // namespace
|
||||
|
||||
void markOpToRemove(CoreLoweringState& state, Operation* op) {
|
||||
if (!llvm::is_contained(state.operationsToRemove, op))
|
||||
state.operationsToRemove.push_back(op);
|
||||
}
|
||||
|
||||
LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState& state, IRRewriter& rewriter) {
|
||||
LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatCompute computeOp,
|
||||
IRRewriter& rewriter,
|
||||
OperationFolder& constantFolder) {
|
||||
Location loc = computeOp->getLoc();
|
||||
|
||||
if (inlineInputlessHelperComputeForWeightLikeUsers(computeOp, rewriter))
|
||||
if (inlineInputlessHelperComputeForWeightLikeUsers(computeOp, rewriter, constantFolder))
|
||||
return success();
|
||||
|
||||
SmallVector<Operation*> helperChain;
|
||||
@@ -143,21 +167,44 @@ LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState&
|
||||
auto& block = computeOp.getRegion().front();
|
||||
auto yieldOp = cast<spatial::SpatYieldOp>(block.getTerminator());
|
||||
|
||||
for (auto [argIndex, blockArg] : llvm::enumerate(block.getArguments())) {
|
||||
auto receiveOp = dyn_cast_or_null<spatial::SpatChannelReceiveOp>(computeOp.getInputs()[argIndex].getDefiningOp());
|
||||
if (!receiveOp || blockArg.use_empty())
|
||||
for (auto [inputIndex, input] : llvm::enumerate(computeOp.getInputs())) {
|
||||
auto blockArg = computeOp.getInputArgument(inputIndex);
|
||||
if (!blockArg)
|
||||
return computeOp.emitOpError("expected compute input block arguments during lowering");
|
||||
auto receiveOp = dyn_cast_or_null<spatial::SpatChannelReceiveOp>(input.getDefiningOp());
|
||||
if (receiveOp && !blockArg->use_empty()) {
|
||||
rewriter.setInsertionPoint(getEarliestUserWithinBlock(*blockArg));
|
||||
auto outputType = cast<ShapedType>(blockArg->getType());
|
||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveOp.getLoc(), outputType);
|
||||
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, *blockArg);
|
||||
Value received =
|
||||
PimReceiveOp::create(
|
||||
rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, receiveOp.getSourceCoreId())
|
||||
.getOutput();
|
||||
blockArg->replaceAllUsesWith(received);
|
||||
markOpToRemove(receiveOp);
|
||||
continue;
|
||||
}
|
||||
|
||||
rewriter.setInsertionPoint(getEarliestUserWithinBlock(blockArg));
|
||||
auto outputType = cast<ShapedType>(blockArg.getType());
|
||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveOp.getLoc(), outputType);
|
||||
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, blockArg);
|
||||
auto sourceCoreIdAttr = rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(receiveOp.getSourceCoreId()));
|
||||
Value received = PimReceiveOp::create(
|
||||
rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr)
|
||||
.getOutput();
|
||||
blockArg.replaceAllUsesWith(received);
|
||||
markOpToRemove(state, receiveOp);
|
||||
auto receiveTensorOp = dyn_cast_or_null<spatial::SpatChannelReceiveTensorOp>(input.getDefiningOp());
|
||||
if (receiveTensorOp && !blockArg->use_empty()) {
|
||||
FailureOr<SmallVector<int32_t>> sourceCoreIds = getConstantI32Values(receiveTensorOp.getSourceCoreIds());
|
||||
if (failed(sourceCoreIds))
|
||||
return receiveTensorOp.emitOpError("expected constant sourceCoreIds");
|
||||
for (int32_t& sourceCoreId : *sourceCoreIds)
|
||||
sourceCoreId = translateSpatialCoreIdToPimCoreId(sourceCoreId);
|
||||
rewriter.setInsertionPoint(getEarliestUserWithinBlock(*blockArg));
|
||||
auto outputType = cast<ShapedType>(blockArg->getType());
|
||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveTensorOp.getLoc(), outputType);
|
||||
Value received = PimReceiveTensorOp::create(rewriter,
|
||||
receiveTensorOp.getLoc(),
|
||||
outputBuffer.getType(),
|
||||
outputBuffer,
|
||||
rewriter.getDenseI32ArrayAttr(*sourceCoreIds))
|
||||
.getOutput();
|
||||
blockArg->replaceAllUsesWith(received);
|
||||
markOpToRemove(receiveTensorOp);
|
||||
}
|
||||
}
|
||||
|
||||
if (computeOp.getNumResults() != yieldOp.getNumOperands())
|
||||
@@ -167,9 +214,8 @@ LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState&
|
||||
if (result.use_empty())
|
||||
continue;
|
||||
|
||||
ReturnPathState returnPathState {state.outputTensors, state.operationsToRemove};
|
||||
ReturnPathLoweringResult returnPathResult =
|
||||
lowerComputeResultReturnPath(computeOp, cast<OpResult>(result), yieldValue, returnPathState, rewriter);
|
||||
lowerComputeResultReturnPath(computeOp, cast<OpResult>(result), yieldValue, rewriter);
|
||||
if (returnPathResult == ReturnPathLoweringResult::Failure)
|
||||
return failure();
|
||||
if (returnPathResult == ReturnPathLoweringResult::Handled)
|
||||
@@ -193,15 +239,40 @@ LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState&
|
||||
if (!computeOp.getWeights().empty())
|
||||
computeWeights.append(computeOp.getWeights().begin(), computeOp.getWeights().end());
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
auto coreOp = PimCoreOp::create(rewriter,
|
||||
loc,
|
||||
ValueRange(computeWeights),
|
||||
rewriter.getI32IntegerAttr(getPimCoreIdForComputeOp(computeOp, state.nextCoreId)));
|
||||
auto coreOp = PimCoreOp::create(
|
||||
rewriter, loc, ValueRange(computeWeights), rewriter.getI32IntegerAttr(getPimCoreIdForComputeOp(computeOp, coreId)));
|
||||
rewriter.setInsertionPointToStart(&block);
|
||||
auto& coreOpBlocks = coreOp.getBody().getBlocks();
|
||||
for (auto [argIndex, blockArg] : llvm::enumerate(block.getArguments()))
|
||||
if (!blockArg.use_empty())
|
||||
blockArg.replaceAllUsesWith(computeOp.getInputs()[argIndex]);
|
||||
block.eraseArguments(0, block.getNumArguments());
|
||||
for (auto [inputIndex, input] : llvm::enumerate(computeOp.getInputs())) {
|
||||
auto blockArg = computeOp.getInputArgument(inputIndex);
|
||||
if (!blockArg)
|
||||
return computeOp.emitOpError("expected compute input block arguments during input materialization");
|
||||
if (blockArg->use_empty())
|
||||
continue;
|
||||
|
||||
if (auto constantOp = input.getDefiningOp<arith::ConstantOp>()) {
|
||||
blockArg->replaceAllUsesWith(getOrCreateHostConstantLike(constantOp, constantFolder));
|
||||
continue;
|
||||
}
|
||||
|
||||
auto inputType = dyn_cast<ShapedType>(input.getType());
|
||||
if (!inputType)
|
||||
return computeOp.emitOpError("expected shaped compute input during pim.core lowering");
|
||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, inputType);
|
||||
auto copied =
|
||||
PimMemCopyHostToDevOp::create(rewriter,
|
||||
loc,
|
||||
outputBuffer.getType(),
|
||||
getOrCreateHostIndexConstant(outputBuffer.getOperation(), 0, constantFolder),
|
||||
getOrCreateHostIndexConstant(outputBuffer.getOperation(), 0, constantFolder),
|
||||
outputBuffer,
|
||||
input,
|
||||
getTensorSizeInBytesAttr(rewriter, input))
|
||||
.getOutput();
|
||||
blockArg->replaceAllUsesWith(copied);
|
||||
}
|
||||
if (!computeOp.getInputs().empty())
|
||||
block.eraseArguments(computeOp.getWeights().size(), computeOp.getInputs().size());
|
||||
coreOpBlocks.splice(coreOpBlocks.begin(), computeOp.getBody().getBlocks());
|
||||
Block* tempComputeBlock = new Block();
|
||||
computeOp.getBody().push_back(tempComputeBlock);
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
struct CoreLoweringState {
|
||||
size_t& nextCoreId;
|
||||
llvm::SmallVectorImpl<OutputTensorFactory>& outputTensors;
|
||||
llvm::SmallVectorImpl<mlir::Operation*>& operationsToRemove;
|
||||
};
|
||||
|
||||
void markOpToRemove(CoreLoweringState& state, mlir::Operation* op);
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState& state, mlir::IRRewriter& rewriter);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -76,10 +76,11 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
|
||||
auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, uses.getOperandNumber());
|
||||
if (!inputIndex)
|
||||
return failure();
|
||||
auto BBArgIndex = *inputIndex;
|
||||
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
|
||||
auto BBArgValue = spatCompute.getInputArgument(*inputIndex);
|
||||
if (!BBArgValue)
|
||||
return failure();
|
||||
|
||||
if (BBArgValue.use_empty())
|
||||
if (BBArgValue->use_empty())
|
||||
continue;
|
||||
|
||||
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
||||
@@ -89,16 +90,17 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
|
||||
}
|
||||
|
||||
replaceAndEraseDirectComputeLikeInput(
|
||||
rewriter, spatCompute.getOperation(), BBArgIndex, mapSpatToExtract[spatCompute.getOperation()]);
|
||||
rewriter, spatCompute.getOperation(), *inputIndex, mapSpatToExtract[spatCompute.getOperation()]);
|
||||
}
|
||||
else if (auto spatComputeBatch = dyn_cast<spatial::SpatComputeBatch>(uses.getOwner())) {
|
||||
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, uses.getOperandNumber());
|
||||
if (!inputIndex)
|
||||
return failure();
|
||||
auto BBArgIndex = *inputIndex;
|
||||
auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex);
|
||||
auto BBArgValue = spatComputeBatch.getInputArgument(*inputIndex);
|
||||
if (!BBArgValue)
|
||||
return failure();
|
||||
|
||||
if (BBArgValue.use_empty())
|
||||
if (BBArgValue->use_empty())
|
||||
continue;
|
||||
|
||||
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
||||
@@ -108,7 +110,7 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
|
||||
}
|
||||
|
||||
replaceAndEraseDirectComputeLikeInput(
|
||||
rewriter, spatComputeBatch.getOperation(), BBArgIndex, mapSpatToExtract[spatComputeBatch.getOperation()]);
|
||||
rewriter, spatComputeBatch.getOperation(), *inputIndex, mapSpatToExtract[spatComputeBatch.getOperation()]);
|
||||
}
|
||||
else {
|
||||
{
|
||||
@@ -143,170 +145,6 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
|
||||
}
|
||||
};
|
||||
|
||||
// Turns runtime constants consumed by compute regions into private globals and local loads.
|
||||
struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::ConstantOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(mlir::arith::ConstantOp constantOp, PatternRewriter& rewriter) const override {
|
||||
Location loc = constantOp.getLoc();
|
||||
|
||||
if (hasWeightAlways(constantOp))
|
||||
return failure();
|
||||
|
||||
if (!isa<func::FuncOp>(constantOp->getParentOp()))
|
||||
return failure();
|
||||
|
||||
if (llvm::all_of(constantOp->getUsers(), [](Operation* op) {
|
||||
if (isa<spatial::SpatCompute>(op))
|
||||
return false;
|
||||
if (isa<func::FuncOp>(op->getParentOp()))
|
||||
return true;
|
||||
return false;
|
||||
}))
|
||||
return failure();
|
||||
|
||||
rewriter.setInsertionPoint(constantOp->getParentOfType<func::FuncOp>());
|
||||
|
||||
auto constRankedTensorType = llvm::dyn_cast<mlir::RankedTensorType>(constantOp.getType());
|
||||
|
||||
if (constRankedTensorType) {
|
||||
mlir::MemRefType memRefType =
|
||||
mlir::MemRefType::get(constRankedTensorType.getShape(), constRankedTensorType.getElementType());
|
||||
auto globalOp = createPrivateMemrefGlobalWithUniqueName(rewriter,
|
||||
loc,
|
||||
constantOp->getParentOfType<ModuleOp>(),
|
||||
"const",
|
||||
memRefType,
|
||||
constantOp.getValueAttr(),
|
||||
rewriter.getUnitAttr());
|
||||
std::string argName = globalOp.getSymName().str();
|
||||
|
||||
llvm::DenseMap<Operation*, Value> mapSpatComputeToConst;
|
||||
|
||||
for (auto& constUses : llvm::make_early_inc_range(constantOp->getUses())) {
|
||||
auto constUsers = constUses.getOwner();
|
||||
|
||||
if (auto spatCompute = llvm::dyn_cast<spatial::SpatCompute>(constUsers)) {
|
||||
auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, constUses.getOperandNumber());
|
||||
if (!inputIndex)
|
||||
return failure();
|
||||
auto BBArgIndex = *inputIndex;
|
||||
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
||||
if (!mapSpatComputeToConst.contains(spatCompute.getOperation())) {
|
||||
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
||||
auto toTensor = bufferization::ToTensorOp::create(
|
||||
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
|
||||
mapSpatComputeToConst.insert({spatCompute.getOperation(), toTensor.getResult()});
|
||||
}
|
||||
|
||||
replaceAndEraseDirectComputeLikeInput(
|
||||
rewriter, spatCompute.getOperation(), BBArgIndex, mapSpatComputeToConst[spatCompute.getOperation()]);
|
||||
}
|
||||
else if (auto spatComputeBatch = llvm::dyn_cast<spatial::SpatComputeBatch>(constUsers)) {
|
||||
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, constUses.getOperandNumber());
|
||||
if (!inputIndex)
|
||||
return failure();
|
||||
auto BBArgIndex = *inputIndex;
|
||||
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
||||
if (!mapSpatComputeToConst.contains(spatComputeBatch.getOperation())) {
|
||||
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
||||
auto toTensor = bufferization::ToTensorOp::create(
|
||||
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
|
||||
mapSpatComputeToConst.insert({spatComputeBatch.getOperation(), toTensor.getResult()});
|
||||
}
|
||||
|
||||
replaceAndEraseDirectComputeLikeInput(rewriter,
|
||||
spatComputeBatch.getOperation(),
|
||||
BBArgIndex,
|
||||
mapSpatComputeToConst[spatComputeBatch.getOperation()]);
|
||||
}
|
||||
else {
|
||||
{
|
||||
|
||||
if (auto spatCompute = constUses.getOwner()->getParentOfType<spatial::SpatCompute>()) {
|
||||
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
||||
if (!mapSpatComputeToConst.contains(spatCompute.getOperation())) {
|
||||
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
||||
auto toTensor = bufferization::ToTensorOp::create(
|
||||
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
|
||||
mapSpatComputeToConst.insert({spatCompute.getOperation(), toTensor.getResult()});
|
||||
}
|
||||
|
||||
rewriter.startOpModification(spatCompute.getOperation());
|
||||
constUses.set(mapSpatComputeToConst[spatCompute.getOperation()]);
|
||||
rewriter.finalizeOpModification(spatCompute.getOperation());
|
||||
}
|
||||
else if (auto spatComputeBatch = constUses.getOwner()->getParentOfType<spatial::SpatComputeBatch>()) {
|
||||
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
||||
if (!mapSpatComputeToConst.contains(spatComputeBatch.getOperation())) {
|
||||
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
||||
auto toTensor = bufferization::ToTensorOp::create(
|
||||
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
|
||||
mapSpatComputeToConst.insert({spatComputeBatch.getOperation(), toTensor.getResult()});
|
||||
}
|
||||
|
||||
rewriter.startOpModification(spatComputeBatch.getOperation());
|
||||
constUses.set(mapSpatComputeToConst[spatComputeBatch.getOperation()]);
|
||||
rewriter.finalizeOpModification(spatComputeBatch.getOperation());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (constantOp.getType().isIntOrIndexOrFloat()) {
|
||||
llvm::DenseMap<Operation*, Value> mapSpatComputeToConst;
|
||||
|
||||
for (auto& constUses : llvm::make_early_inc_range(constantOp->getUses())) {
|
||||
auto constUsers = constUses.getOwner();
|
||||
|
||||
if (auto spatCompute = llvm::dyn_cast<spatial::SpatCompute>(constUsers)) {
|
||||
auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, constUses.getOperandNumber());
|
||||
if (!inputIndex)
|
||||
return failure();
|
||||
auto BBArgIndex = *inputIndex;
|
||||
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
||||
auto newConst = rewriter.clone(*constantOp);
|
||||
|
||||
replaceAndEraseDirectComputeLikeInput(
|
||||
rewriter, spatCompute.getOperation(), BBArgIndex, newConst->getResult(0));
|
||||
}
|
||||
else if (auto spatComputeBatch = llvm::dyn_cast<spatial::SpatComputeBatch>(constUsers)) {
|
||||
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, constUses.getOperandNumber());
|
||||
if (!inputIndex)
|
||||
return failure();
|
||||
auto BBArgIndex = *inputIndex;
|
||||
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
||||
auto newConst = rewriter.clone(*constantOp);
|
||||
|
||||
replaceAndEraseDirectComputeLikeInput(
|
||||
rewriter, spatComputeBatch.getOperation(), BBArgIndex, newConst->getResult(0));
|
||||
}
|
||||
else if (auto parent = constUsers->getParentOfType<spatial::SpatCompute>()) {
|
||||
if (!mapSpatComputeToConst.contains(parent)) {
|
||||
rewriter.setInsertionPoint(&parent.getBody().front().front());
|
||||
auto newConst = rewriter.clone(*constantOp);
|
||||
mapSpatComputeToConst.insert({parent.getOperation(), newConst->getResult(0)});
|
||||
}
|
||||
constUses.set(mapSpatComputeToConst[parent.getOperation()]);
|
||||
}
|
||||
else {
|
||||
auto batchParent = constUsers->getParentOfType<spatial::SpatComputeBatch>();
|
||||
assert(batchParent && "Global Constant used direcly not within a compute");
|
||||
if (!mapSpatComputeToConst.contains(batchParent.getOperation())) {
|
||||
rewriter.setInsertionPoint(&batchParent.getBody().front().front());
|
||||
auto newConst = rewriter.clone(*constantOp);
|
||||
mapSpatComputeToConst.insert({batchParent.getOperation(), newConst->getResult(0)});
|
||||
}
|
||||
constUses.set(mapSpatComputeToConst[batchParent.getOperation()]);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (constantOp->use_empty())
|
||||
rewriter.eraseOp(constantOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Materializes public function tensor inputs as globals so compute bodies can load them uniformly.
|
||||
struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
@@ -383,8 +221,7 @@ struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncO
|
||||
|
||||
} // namespace
|
||||
void populateGlobalTensorMaterializationPatterns(RewritePatternSet& patterns) {
|
||||
patterns.add<MoveExtractSliceIntoCompute, FuncOpArgToGlobalMemoryPattern, ArithConstToGlobalMemoryPattern>(
|
||||
patterns.getContext());
|
||||
patterns.add<MoveExtractSliceIntoCompute, FuncOpArgToGlobalMemoryPattern>(patterns.getContext());
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -6,10 +6,12 @@
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/SymbolTable.h"
|
||||
#include "mlir/Transforms/FoldUtils.h"
|
||||
|
||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "Conversion/SpatialToPim/SpatialToPimPass.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -42,11 +44,6 @@ static bool isReturnHelperChainOp(Operation* op) {
|
||||
pim::PimTransposeOp>(op);
|
||||
}
|
||||
|
||||
static void markOpToRemove(ReturnPathState& state, Operation* op) {
|
||||
if (!llvm::is_contained(state.operationsToRemove, op))
|
||||
state.operationsToRemove.push_back(op);
|
||||
}
|
||||
|
||||
static std::string makeUniqueSymbolName(Operation* symbolTableOp, StringRef baseName) {
|
||||
std::string name = baseName.str();
|
||||
unsigned suffix = 0;
|
||||
@@ -318,7 +315,8 @@ static LogicalResult mapIndicesThroughHelperChain(ArrayRef<int64_t> sourceIndice
|
||||
return success();
|
||||
}
|
||||
|
||||
static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter) {
|
||||
static void
|
||||
cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter, OperationFolder& constantFolder) {
|
||||
for (Value operand : op->getOperands()) {
|
||||
if (mapping.lookupOrNull(operand))
|
||||
continue;
|
||||
@@ -327,7 +325,12 @@ static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewri
|
||||
if (!definingOp)
|
||||
continue;
|
||||
|
||||
if (!isa<tensor::EmptyOp, arith::ConstantOp>(definingOp))
|
||||
if (auto constantOp = dyn_cast<arith::ConstantOp>(definingOp)) {
|
||||
mapping.map(operand, getOrCreateHostConstantLike(constantOp, constantFolder));
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!isa<tensor::EmptyOp>(definingOp))
|
||||
continue;
|
||||
|
||||
Operation* clonedOp = rewriter.clone(*definingOp, mapping);
|
||||
@@ -337,15 +340,18 @@ static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewri
|
||||
}
|
||||
}
|
||||
|
||||
static void
|
||||
cloneHelperChain(Value sourceValue, ArrayRef<Operation*> helperChain, IRRewriter& rewriter, Value& clonedValue) {
|
||||
static void cloneHelperChain(Value sourceValue,
|
||||
ArrayRef<Operation*> helperChain,
|
||||
IRRewriter& rewriter,
|
||||
OperationFolder& constantFolder,
|
||||
Value& clonedValue) {
|
||||
IRMapping mapping;
|
||||
mapping.map(sourceValue, sourceValue);
|
||||
clonedValue = sourceValue;
|
||||
|
||||
rewriter.setInsertionPointAfterValue(sourceValue);
|
||||
for (Operation* op : helperChain) {
|
||||
cloneMappedHelperOperands(op, mapping, rewriter);
|
||||
cloneMappedHelperOperands(op, mapping, rewriter, constantFolder);
|
||||
Operation* clonedOp = rewriter.clone(*op, mapping);
|
||||
for (auto [originalResult, newResult] : llvm::zip(op->getResults(), clonedOp->getResults()))
|
||||
mapping.map(originalResult, newResult);
|
||||
@@ -360,23 +366,26 @@ static Value emitHostCopy(IRRewriter& rewriter,
|
||||
Value sourceValue,
|
||||
int32_t hostTargetOffset,
|
||||
int32_t deviceSourceOffset,
|
||||
int32_t sizeInBytes) {
|
||||
int32_t sizeInBytes,
|
||||
OperationFolder& constantFolder) {
|
||||
Operation* anchorOp = sourceValue.getDefiningOp() ? sourceValue.getDefiningOp() : outputTensor.getDefiningOp();
|
||||
assert(anchorOp && "expected a concrete op anchor for return-path host copy constants");
|
||||
Value hostTargetOffsetValue = getOrCreateHostIndexConstant(anchorOp, hostTargetOffset, constantFolder);
|
||||
Value deviceSourceOffsetValue = getOrCreateHostIndexConstant(anchorOp, deviceSourceOffset, constantFolder);
|
||||
return PimMemCopyDevToHostOp::create(rewriter,
|
||||
loc,
|
||||
outputTensor.getType(),
|
||||
hostTargetOffsetValue,
|
||||
deviceSourceOffsetValue,
|
||||
outputTensor,
|
||||
sourceValue,
|
||||
rewriter.getI32IntegerAttr(hostTargetOffset),
|
||||
rewriter.getI32IntegerAttr(deviceSourceOffset),
|
||||
rewriter.getI32IntegerAttr(sizeInBytes))
|
||||
.getOutput();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void addReturnOutputBuffers(func::ReturnOp returnOp,
|
||||
IRRewriter& rewriter,
|
||||
SmallVectorImpl<OutputTensorFactory>& outputTensors) {
|
||||
void raptor::SpatialToPimPass::addReturnOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewriter) {
|
||||
outputTensors.reserve(returnOp->getNumOperands());
|
||||
for (auto [index, returnValue] : llvm::enumerate(returnOp->getOperands())) {
|
||||
Value currentReturnValue = returnValue;
|
||||
@@ -411,70 +420,85 @@ void addReturnOutputBuffers(func::ReturnOp returnOp,
|
||||
}
|
||||
}
|
||||
|
||||
ReturnPathLoweringResult lowerComputeResultReturnPath(
|
||||
spatial::SpatCompute computeOp, OpResult result, Value yieldValue, ReturnPathState& state, IRRewriter& rewriter) {
|
||||
Location loc = computeOp->getLoc();
|
||||
auto yieldType = cast<TensorType>(yieldValue.getType());
|
||||
raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::lowerProducedValueReturnPath(
|
||||
Operation* producerOp, Value producedValue, Value storedValue, IRRewriter& rewriter) {
|
||||
Location loc = producerOp->getLoc();
|
||||
OperationFolder constantFolder(producerOp->getContext());
|
||||
auto storedTensorType = cast<TensorType>(storedValue.getType());
|
||||
|
||||
if (auto returnUse = analyzeReturnUse(result)) {
|
||||
Value storedValue = yieldValue;
|
||||
cloneHelperChain(yieldValue, returnUse->helperChain, rewriter, storedValue);
|
||||
if (auto returnUse = analyzeReturnUse(producedValue)) {
|
||||
Value currentStoredValue = storedValue;
|
||||
cloneHelperChain(storedValue, returnUse->helperChain, rewriter, constantFolder, currentStoredValue);
|
||||
for (Operation* op : returnUse->helperChain)
|
||||
markOpToRemove(state, op);
|
||||
markOpToRemove(op);
|
||||
|
||||
auto storedType = cast<ShapedType>(storedValue.getType());
|
||||
size_t elementSize = storedType.getElementTypeBitWidth() / 8;
|
||||
if (auto storedOp = storedValue.getDefiningOp())
|
||||
auto storedType = cast<ShapedType>(currentStoredValue.getType());
|
||||
size_t elementSize = getElementTypeSizeInBytes(storedType.getElementType());
|
||||
if (auto storedOp = currentStoredValue.getDefiningOp())
|
||||
rewriter.setInsertionPointAfter(storedOp);
|
||||
Value outputTensor = state.outputTensors[returnUse->returnIndex](rewriter, loc);
|
||||
emitHostCopy(
|
||||
rewriter, loc, outputTensor, storedValue, 0, 0, static_cast<int32_t>(storedType.getNumElements() * elementSize));
|
||||
Value outputTensor = outputTensors[returnUse->returnIndex](rewriter, loc);
|
||||
emitHostCopy(rewriter,
|
||||
loc,
|
||||
outputTensor,
|
||||
currentStoredValue,
|
||||
0,
|
||||
0,
|
||||
static_cast<int32_t>(storedType.getNumElements() * elementSize),
|
||||
constantFolder);
|
||||
return ReturnPathLoweringResult::Handled;
|
||||
}
|
||||
|
||||
auto resultUses = result.getUses();
|
||||
auto resultUses = producedValue.getUses();
|
||||
if (rangeLength(resultUses) == 1) {
|
||||
OpOperand& resultUse = *resultUses.begin();
|
||||
Operation* resultUser = resultUse.getOwner();
|
||||
|
||||
if (isa<func::ReturnOp>(resultUser)) {
|
||||
size_t resultIndexInReturn = resultUse.getOperandNumber();
|
||||
size_t elementSize = yieldType.getElementType().getIntOrFloatBitWidth() / 8;
|
||||
rewriter.setInsertionPointAfterValue(yieldValue);
|
||||
Value outputTensor = state.outputTensors[resultIndexInReturn](rewriter, loc);
|
||||
emitHostCopy(
|
||||
rewriter, loc, outputTensor, yieldValue, 0, 0, static_cast<int32_t>(yieldType.getNumElements() * elementSize));
|
||||
size_t elementSize = getElementTypeSizeInBytes(storedTensorType.getElementType());
|
||||
rewriter.setInsertionPointAfterValue(storedValue);
|
||||
Value outputTensor = outputTensors[resultIndexInReturn](rewriter, loc);
|
||||
emitHostCopy(rewriter,
|
||||
loc,
|
||||
outputTensor,
|
||||
storedValue,
|
||||
0,
|
||||
0,
|
||||
static_cast<int32_t>(storedTensorType.getNumElements() * elementSize),
|
||||
constantFolder);
|
||||
return ReturnPathLoweringResult::Handled;
|
||||
}
|
||||
}
|
||||
|
||||
if (auto concatReturnUse = analyzeConcatReturnUse(result)) {
|
||||
size_t elementSize = yieldType.getElementTypeBitWidth() / 8;
|
||||
if (auto concatReturnUse = analyzeConcatReturnUse(producedValue)) {
|
||||
size_t elementSize = getElementTypeSizeInBytes(storedTensorType.getElementType());
|
||||
for (Operation* concatOp : concatReturnUse->concatChain)
|
||||
markOpToRemove(state, concatOp);
|
||||
markOpToRemove(concatOp);
|
||||
|
||||
if (concatReturnUse->helperChain.empty()) {
|
||||
rewriter.setInsertionPointAfterValue(yieldValue);
|
||||
Value outputTensor = state.outputTensors[concatReturnUse->returnIndex](rewriter, loc);
|
||||
rewriter.setInsertionPointAfterValue(storedValue);
|
||||
Value outputTensor = outputTensors[concatReturnUse->returnIndex](rewriter, loc);
|
||||
auto outputType = cast<ShapedType>(outputTensor.getType());
|
||||
int64_t flatOffset = computeFlatElementIndex(concatReturnUse->sliceOffsets, outputType.getShape());
|
||||
emitHostCopy(rewriter,
|
||||
loc,
|
||||
outputTensor,
|
||||
yieldValue,
|
||||
storedValue,
|
||||
static_cast<int32_t>(flatOffset * elementSize),
|
||||
0,
|
||||
static_cast<int32_t>(yieldType.getNumElements() * elementSize));
|
||||
static_cast<int32_t>(storedTensorType.getNumElements() * elementSize),
|
||||
constantFolder);
|
||||
return ReturnPathLoweringResult::Handled;
|
||||
}
|
||||
|
||||
auto storedType = dyn_cast<RankedTensorType>(yieldValue.getType());
|
||||
auto storedType = dyn_cast<RankedTensorType>(storedValue.getType());
|
||||
if (!storedType) {
|
||||
computeOp.emitOpError("has an unsupported non-ranked concat-return helper yield during Spatial-to-PIM lowering");
|
||||
producerOp->emitOpError(
|
||||
"has an unsupported non-ranked concat-return helper yield during Spatial-to-PIM lowering");
|
||||
return ReturnPathLoweringResult::Failure;
|
||||
}
|
||||
rewriter.setInsertionPointAfterValue(yieldValue);
|
||||
Value outputTensor = state.outputTensors[concatReturnUse->returnIndex](rewriter, loc);
|
||||
rewriter.setInsertionPointAfterValue(storedValue);
|
||||
Value outputTensor = outputTensors[concatReturnUse->returnIndex](rewriter, loc);
|
||||
auto outputType = cast<ShapedType>(outputTensor.getType());
|
||||
for (int64_t linearIndex = 0; linearIndex < storedType.getNumElements(); ++linearIndex) {
|
||||
SmallVector<int64_t> sourceIndices = expandFlatElementIndex(linearIndex, storedType.getShape());
|
||||
@@ -484,7 +508,7 @@ ReturnPathLoweringResult lowerComputeResultReturnPath(
|
||||
SmallVector<int64_t> destinationIndices;
|
||||
if (failed(mapIndicesThroughHelperChain(
|
||||
sourceIndices, concatReturnUse->concatShape, concatReturnUse->helperChain, destinationIndices))) {
|
||||
computeOp.emitOpError("has an unsupported concat-return helper chain during Spatial-to-PIM lowering");
|
||||
producerOp->emitOpError("has an unsupported concat-return helper chain during Spatial-to-PIM lowering");
|
||||
return ReturnPathLoweringResult::Failure;
|
||||
}
|
||||
|
||||
@@ -503,7 +527,7 @@ ReturnPathLoweringResult lowerComputeResultReturnPath(
|
||||
auto scalarTensorType =
|
||||
RankedTensorType::get(SmallVector<int64_t>(storedType.getRank(), 1), storedType.getElementType());
|
||||
auto elementSlice = tensor::ExtractSliceOp::create(
|
||||
rewriter, loc, scalarTensorType, yieldValue, extractOffsets, extractSizes, extractStrides);
|
||||
rewriter, loc, scalarTensorType, storedValue, extractOffsets, extractSizes, extractStrides);
|
||||
rewriter.setInsertionPointAfter(elementSlice);
|
||||
|
||||
int64_t destinationFlatOffset = computeFlatElementIndex(destinationIndices, outputType.getShape());
|
||||
@@ -513,7 +537,8 @@ ReturnPathLoweringResult lowerComputeResultReturnPath(
|
||||
elementSlice.getResult(),
|
||||
static_cast<int32_t>(destinationFlatOffset * elementSize),
|
||||
0,
|
||||
static_cast<int32_t>(elementSize));
|
||||
static_cast<int32_t>(elementSize),
|
||||
constantFolder);
|
||||
}
|
||||
return ReturnPathLoweringResult::Handled;
|
||||
}
|
||||
@@ -521,7 +546,12 @@ ReturnPathLoweringResult lowerComputeResultReturnPath(
|
||||
return ReturnPathLoweringResult::NotReturnPath;
|
||||
}
|
||||
|
||||
void replaceReturnWithOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewriter, ReturnPathState& state) {
|
||||
raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::lowerComputeResultReturnPath(
|
||||
spatial::SpatCompute computeOp, OpResult result, Value yieldValue, IRRewriter& rewriter) {
|
||||
return lowerProducedValueReturnPath(computeOp.getOperation(), result, yieldValue, rewriter);
|
||||
}
|
||||
|
||||
void raptor::SpatialToPimPass::replaceReturnWithOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewriter) {
|
||||
auto markOwnedReturnChain = [&](Operation* op, auto&& markOwnedReturnChain) -> void {
|
||||
if (!op)
|
||||
return;
|
||||
@@ -538,13 +568,13 @@ void replaceReturnWithOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewrite
|
||||
|
||||
if (isReturnHelperChainOp(op)) {
|
||||
Value source = op->getOperand(0);
|
||||
markOpToRemove(state, op);
|
||||
markOpToRemove(op);
|
||||
markOwnedReturnChain(source.getDefiningOp(), markOwnedReturnChain);
|
||||
return;
|
||||
}
|
||||
|
||||
if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) {
|
||||
markOpToRemove(state, computeOp);
|
||||
markOpToRemove(computeOp);
|
||||
if (!computeOp.getInputs().empty())
|
||||
for (Value input : computeOp.getInputs())
|
||||
markOwnedReturnChain(input.getDefiningOp(), markOwnedReturnChain);
|
||||
@@ -552,24 +582,33 @@ void replaceReturnWithOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewrite
|
||||
}
|
||||
|
||||
if (auto concatOp = dyn_cast<tensor::ConcatOp>(op)) {
|
||||
markOpToRemove(state, concatOp);
|
||||
markOpToRemove(concatOp);
|
||||
for (Value operand : concatOp.getOperands())
|
||||
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
|
||||
return;
|
||||
}
|
||||
|
||||
if (auto concatOp = dyn_cast<spatial::SpatConcatOp>(op)) {
|
||||
markOpToRemove(state, concatOp);
|
||||
markOpToRemove(concatOp);
|
||||
for (Value operand : concatOp.getInputs())
|
||||
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
|
||||
return;
|
||||
}
|
||||
|
||||
if (auto concatOp = dyn_cast<pim::PimConcatOp>(op)) {
|
||||
markOpToRemove(state, concatOp);
|
||||
markOpToRemove(concatOp);
|
||||
for (Value operand : concatOp.getInputs())
|
||||
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
|
||||
return;
|
||||
}
|
||||
|
||||
if (auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(op)) {
|
||||
markOpToRemove(receiveOp);
|
||||
return;
|
||||
}
|
||||
|
||||
if (auto receiveTensorOp = dyn_cast<spatial::SpatChannelReceiveTensorOp>(op))
|
||||
markOpToRemove(receiveTensorOp);
|
||||
};
|
||||
|
||||
SmallVector<Value> originalOperands(returnOp.getOperands().begin(), returnOp.getOperands().end());
|
||||
@@ -578,7 +617,7 @@ void replaceReturnWithOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewrite
|
||||
size_t orderWithinReturn = it.index();
|
||||
Operation* returnOperand = it.value().getDefiningOp();
|
||||
rewriter.setInsertionPoint(returnOp);
|
||||
Value outputTensor = state.outputTensors[orderWithinReturn](rewriter, loc);
|
||||
Value outputTensor = outputTensors[orderWithinReturn](rewriter, loc);
|
||||
rewriter.modifyOpInPlace(returnOp, [&] { returnOp.setOperand(orderWithinReturn, outputTensor); });
|
||||
markOwnedReturnChain(returnOperand, markOwnedReturnChain);
|
||||
}
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include <functional>
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
using OutputTensorFactory = std::function<mlir::Value(mlir::IRRewriter& rewriter, mlir::Location loc)>;
|
||||
|
||||
struct ReturnPathState {
|
||||
llvm::SmallVectorImpl<OutputTensorFactory>& outputTensors;
|
||||
llvm::SmallVectorImpl<mlir::Operation*>& operationsToRemove;
|
||||
};
|
||||
|
||||
enum class ReturnPathLoweringResult {
|
||||
Handled,
|
||||
NotReturnPath,
|
||||
Failure
|
||||
};
|
||||
|
||||
void addReturnOutputBuffers(mlir::func::ReturnOp returnOp,
|
||||
mlir::IRRewriter& rewriter,
|
||||
llvm::SmallVectorImpl<OutputTensorFactory>& outputTensors);
|
||||
|
||||
ReturnPathLoweringResult lowerComputeResultReturnPath(spatial::SpatCompute computeOp,
|
||||
mlir::OpResult result,
|
||||
mlir::Value yieldValue,
|
||||
ReturnPathState& state,
|
||||
mlir::IRRewriter& rewriter);
|
||||
|
||||
void replaceReturnWithOutputBuffers(mlir::func::ReturnOp returnOp, mlir::IRRewriter& rewriter, ReturnPathState& state);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -16,8 +16,8 @@ def onnxToPimTranspose : Pat<
|
||||
>;
|
||||
|
||||
def spatToPimVMM : Pat<
|
||||
(SpatVMMOp:$srcOpRes $weightIndex, $vector),
|
||||
(PimVMMOp $weightIndex, $vector,
|
||||
(SpatVMMOp:$srcOpRes $weight, $vector),
|
||||
(PimVMMOp $weight, $vector,
|
||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
||||
>;
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/SCF/Utils/Utils.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinDialect.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
@@ -12,6 +13,7 @@
|
||||
#include "mlir/IR/SymbolTable.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/FoldUtils.h"
|
||||
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
|
||||
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
@@ -21,54 +23,28 @@
|
||||
#include <cassert>
|
||||
#include <utility>
|
||||
|
||||
#include "Common/PimCommon.hpp"
|
||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Cleanup.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/GlobalTensorMaterialization.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/PhaseVerification.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
|
||||
#include "Conversion/SpatialToPim/ChannelLoweringPatterns.hpp"
|
||||
#include "Conversion/SpatialToPim/Common.hpp"
|
||||
#include "Conversion/SpatialToPim/GlobalTensorMaterialization.hpp"
|
||||
#include "Conversion/SpatialToPim/PhaseVerification.hpp"
|
||||
#include "Conversion/SpatialToPim/TensorPackingPatterns.hpp"
|
||||
#include "Dialect/Pim/PimOps.hpp"
|
||||
#include "Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "Pass/PIMPasses.h"
|
||||
#include "SpatialToPimPass.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace onnx_mlir;
|
||||
using namespace pim;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
namespace {
|
||||
namespace raptor {
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/SpatialToPim.hpp.inc"
|
||||
|
||||
struct SpatialToPimPass : PassWrapper<SpatialToPimPass, OperationPass<ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialToPimPass)
|
||||
StringRef getArgument() const override { return "convert-spatial-to-pim"; }
|
||||
StringRef getDescription() const override { return "Lower Spatial ops to PIM-ready format"; }
|
||||
|
||||
SpatialToPimPass() = default;
|
||||
SpatialToPimPass(const SpatialToPimPass& pass) {}
|
||||
|
||||
void runOnOperation() final;
|
||||
|
||||
private:
|
||||
SmallVector<OutputTensorFactory> outputTensors;
|
||||
size_t coreId = 0;
|
||||
SmallVector<Operation*> operationsToRemove;
|
||||
|
||||
LogicalResult allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter);
|
||||
|
||||
void markOpToRemove(Operation* op);
|
||||
|
||||
void enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter);
|
||||
};
|
||||
|
||||
} // namespace
|
||||
} // namespace raptor
|
||||
|
||||
static memref::GlobalOp getOrCreateZeroGlobal(IRRewriter& rewriter, Location loc, RankedTensorType tensorType) {
|
||||
auto moduleOp = rewriter.getBlock()->getParentOp()->getParentOfType<ModuleOp>();
|
||||
@@ -104,23 +80,34 @@ static memref::GlobalOp getOrCreateZeroGlobal(IRRewriter& rewriter, Location loc
|
||||
IntegerAttr {});
|
||||
}
|
||||
|
||||
static Value createZeroedDeviceHVector(IRRewriter& rewriter, Location loc, RankedTensorType tensorType) {
|
||||
static Value createZeroedDeviceHVector(IRRewriter& rewriter,
|
||||
Location loc,
|
||||
RankedTensorType tensorType,
|
||||
OperationFolder& constantFolder) {
|
||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, tensorType);
|
||||
auto zeroGlobal = getOrCreateZeroGlobal(rewriter, loc, tensorType);
|
||||
auto zeroValue = memref::GetGlobalOp::create(rewriter, loc, zeroGlobal.getType(), zeroGlobal.getName());
|
||||
auto zeroAttr = rewriter.getI32IntegerAttr(0);
|
||||
auto zeroIndex = getOrCreateHostIndexConstant(outputBuffer.getOperation(), 0, constantFolder);
|
||||
auto sizeAttr = rewriter.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(tensorType)));
|
||||
|
||||
if (outputBuffer->getParentOfType<PimCoreBatchOp>())
|
||||
return PimMemCopyHostToDevBatchOp::create(
|
||||
rewriter, loc, tensorType, outputBuffer, zeroValue, zeroAttr, zeroAttr, sizeAttr)
|
||||
return PimMemCopyHostToDevBatchOp::create(rewriter,
|
||||
loc,
|
||||
tensorType,
|
||||
outputBuffer,
|
||||
zeroValue,
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
sizeAttr)
|
||||
.getOutput();
|
||||
|
||||
return PimMemCopyHostToDevOp::create(rewriter, loc, tensorType, outputBuffer, zeroValue, zeroAttr, zeroAttr, sizeAttr)
|
||||
return PimMemCopyHostToDevOp::create(
|
||||
rewriter, loc, tensorType, zeroIndex, zeroIndex, outputBuffer, zeroValue, sizeAttr)
|
||||
.getOutput();
|
||||
}
|
||||
|
||||
static Value padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector) {
|
||||
static Value
|
||||
padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, Value vector, OperationFolder& constantFolder) {
|
||||
auto vectorType = cast<RankedTensorType>(vector.getType());
|
||||
ArrayRef<int64_t> shape = vectorType.getShape();
|
||||
assert(isHVectorShape(shape) && "expected a horizontal vector");
|
||||
@@ -131,25 +118,29 @@ static Value padHVectorInputToCrossbarSize(IRRewriter& rewriter, Location loc, V
|
||||
|
||||
auto paddedType = RankedTensorType::get(
|
||||
{shape[0], static_cast<int64_t>(crossbarSize)}, vectorType.getElementType(), vectorType.getEncoding());
|
||||
Value zeroed = createZeroedDeviceHVector(rewriter, loc, paddedType);
|
||||
Value zeroed = createZeroedDeviceHVector(rewriter, loc, paddedType, constantFolder);
|
||||
auto zeroAttr = rewriter.getI32IntegerAttr(0);
|
||||
auto sizeAttr = rewriter.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(vectorType)));
|
||||
return PimMemCopyOp::create(rewriter, loc, paddedType, zeroed, vector, zeroAttr, zeroAttr, sizeAttr).getOutput();
|
||||
}
|
||||
|
||||
void SpatialToPimPass::runOnOperation() {
|
||||
coreId = 1;
|
||||
void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
|
||||
coreId = 0;
|
||||
outputTensors.clear();
|
||||
operationsToRemove.clear();
|
||||
ModuleOp moduleOp = getOperation();
|
||||
MLIRContext* ctx = moduleOp.getContext();
|
||||
|
||||
auto entryFunc = getPimEntryFunc(moduleOp);
|
||||
if (failed(entryFunc)) {
|
||||
moduleOp.emitError("failed to locate the PIM entry function during Spatial-to-PIM lowering");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
func::FuncOp funcOp = *entryFunc;
|
||||
|
||||
IRRewriter rewriter(&getContext());
|
||||
OperationFolder constantFolder(&getContext());
|
||||
|
||||
ConversionTarget target(*ctx);
|
||||
target.addLegalDialect<PimDialect,
|
||||
@@ -169,34 +160,30 @@ void SpatialToPimPass::runOnOperation() {
|
||||
spatial::SpatChannelSendTensorBatchOp,
|
||||
spatial::SpatExtractRowsOp>();
|
||||
|
||||
{
|
||||
RewritePatternSet patterns(ctx);
|
||||
populateWithGenerated(patterns);
|
||||
|
||||
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
RewritePatternSet patterns(ctx);
|
||||
populateGlobalTensorMaterializationPatterns(patterns);
|
||||
|
||||
walkAndApplyPatterns(moduleOp, std::move(patterns));
|
||||
}
|
||||
auto returnOp = cast<func::ReturnOp>(funcOp.front().getTerminator());
|
||||
|
||||
addReturnOutputBuffers(returnOp, rewriter, outputTensors);
|
||||
if (failed(allocateAndInitializeCoreLocalVariables(funcOp, rewriter))) {
|
||||
RewritePatternSet initialPatterns(ctx);
|
||||
populateWithGenerated(initialPatterns);
|
||||
if (failed(applyPartialConversion(moduleOp, target, std::move(initialPatterns)))) {
|
||||
moduleOp.emitError("failed to lower required Spatial ops to the initial PIM form");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
RewritePatternSet globalTensorPatterns(ctx);
|
||||
populateGlobalTensorMaterializationPatterns(globalTensorPatterns);
|
||||
walkAndApplyPatterns(moduleOp, std::move(globalTensorPatterns));
|
||||
|
||||
auto returnOp = cast<func::ReturnOp>(funcOp.front().getTerminator());
|
||||
addReturnOutputBuffers(returnOp, rewriter);
|
||||
if (failed(allocateAndInitializeCoreLocalVariables(funcOp, rewriter))) {
|
||||
funcOp.emitOpError("failed to allocate or initialize core-local tensors during Spatial-to-PIM lowering");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
CoreLoweringState coreLoweringState {coreId, outputTensors, operationsToRemove};
|
||||
for (auto computeOp : funcOp.getOps<spatial::SpatCompute>()) {
|
||||
markOpToRemove(computeOp);
|
||||
if (failed(lowerComputeOp(computeOp, coreLoweringState, rewriter))) {
|
||||
if (failed(lowerComputeOp(computeOp, rewriter, constantFolder))) {
|
||||
computeOp.emitOpError("failed to lower spat.compute to pim.core");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
@@ -204,18 +191,17 @@ void SpatialToPimPass::runOnOperation() {
|
||||
|
||||
for (auto computeBatchOp : funcOp.getOps<spatial::SpatComputeBatch>()) {
|
||||
markOpToRemove(computeBatchOp);
|
||||
if (failed(lowerComputeBatchOp(computeBatchOp, coreLoweringState, rewriter))) {
|
||||
if (failed(lowerComputeBatchOp(computeBatchOp, rewriter))) {
|
||||
computeBatchOp.emitOpError("failed to lower spat.compute_batch to pim.core_batch");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
RewritePatternSet patterns(ctx);
|
||||
populateTensorPackingPatterns(patterns);
|
||||
walkAndApplyPatterns(funcOp, std::move(patterns));
|
||||
eraseUnusedTensorPackingOps(funcOp, rewriter);
|
||||
}
|
||||
RewritePatternSet initialTensorPackingPatterns(ctx);
|
||||
populateTensorPackingPatterns(initialTensorPackingPatterns);
|
||||
walkAndApplyPatterns(funcOp, std::move(initialTensorPackingPatterns));
|
||||
eraseUnusedTensorPackingOps(funcOp, rewriter);
|
||||
|
||||
SmallVector<spatial::SpatChannelReceiveOp> receiveOps;
|
||||
for (auto op : funcOp.getOps<spatial::SpatChannelReceiveOp>())
|
||||
@@ -229,74 +215,66 @@ void SpatialToPimPass::runOnOperation() {
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
RewritePatternSet coreBodyPatterns(ctx);
|
||||
populateWithGenerated(coreBodyPatterns);
|
||||
FrozenRewritePatternSet frozenCoreBodyPatterns(std::move(coreBodyPatterns));
|
||||
RewritePatternSet coreBodyPatterns(ctx);
|
||||
populateWithGenerated(coreBodyPatterns);
|
||||
FrozenRewritePatternSet frozenCoreBodyPatterns(std::move(coreBodyPatterns));
|
||||
|
||||
SmallVector<pim::PimCoreOp> coreOps;
|
||||
funcOp.walk([&](pim::PimCoreOp coreOp) { coreOps.push_back(coreOp); });
|
||||
for (auto coreOp : coreOps) {
|
||||
if (failed(applyFullConversion(coreOp.getOperation(), target, frozenCoreBodyPatterns))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<pim::PimCoreBatchOp> coreBatchOps;
|
||||
funcOp.walk([&](pim::PimCoreBatchOp coreBatchOp) { coreBatchOps.push_back(coreBatchOp); });
|
||||
for (auto coreBatchOp : coreBatchOps) {
|
||||
if (failed(applyFullConversion(coreBatchOp.getOperation(), target, frozenCoreBodyPatterns))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
enlargeVMMOutTensorsToCrossbarSize(funcOp, rewriter);
|
||||
ReturnPathState returnPathState {outputTensors, operationsToRemove};
|
||||
replaceReturnWithOutputBuffers(returnOp, rewriter, returnPathState);
|
||||
|
||||
SmallVector<Operation*> pendingRemovals(operationsToRemove.begin(), operationsToRemove.end());
|
||||
if (failed(erasePendingOps(pendingRemovals, rewriter))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
{
|
||||
RewritePatternSet patterns(ctx);
|
||||
populateTensorPackingPatterns(patterns);
|
||||
walkAndApplyPatterns(funcOp, std::move(patterns));
|
||||
eraseUnusedTensorPackingOps(funcOp, rewriter);
|
||||
}
|
||||
|
||||
{
|
||||
ConversionTarget communicationTarget(*ctx);
|
||||
communicationTarget.addLegalDialect<PimDialect,
|
||||
tensor::TensorDialect,
|
||||
arith::ArithDialect,
|
||||
bufferization::BufferizationDialect,
|
||||
func::FuncDialect,
|
||||
memref::MemRefDialect,
|
||||
scf::SCFDialect,
|
||||
BuiltinDialect>();
|
||||
communicationTarget.addLegalOp<ModuleOp>();
|
||||
communicationTarget.addIllegalOp<spatial::SpatConcatOp,
|
||||
spatial::SpatChannelReceiveOp,
|
||||
spatial::SpatChannelReceiveTensorOp,
|
||||
spatial::SpatChannelSendOp,
|
||||
spatial::SpatChannelSendTensorOp,
|
||||
spatial::SpatExtractRowsOp>();
|
||||
|
||||
RewritePatternSet communicationPatterns(ctx);
|
||||
populateChannelLoweringPatterns(communicationPatterns);
|
||||
if (failed(applyFullConversion(funcOp, communicationTarget, std::move(communicationPatterns)))) {
|
||||
SmallVector<pim::PimCoreOp> coreOps;
|
||||
funcOp.walk([&](pim::PimCoreOp coreOp) { coreOps.push_back(coreOp); });
|
||||
for (auto coreOp : coreOps) {
|
||||
if (failed(applyFullConversion(coreOp.getOperation(), target, frozenCoreBodyPatterns))) {
|
||||
coreOp.emitOpError("failed to convert nested Spatial ops inside pim.core");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<pim::PimCoreBatchOp> coreBatchOps;
|
||||
funcOp.walk([&](pim::PimCoreBatchOp coreBatchOp) { coreBatchOps.push_back(coreBatchOp); });
|
||||
for (auto coreBatchOp : coreBatchOps) {
|
||||
if (failed(applyFullConversion(coreBatchOp.getOperation(), target, frozenCoreBodyPatterns))) {
|
||||
coreBatchOp.emitOpError("failed to convert nested Spatial ops inside pim.core_batch");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
enlargeVMMOutTensorsToCrossbarSize(funcOp, rewriter);
|
||||
replaceReturnWithOutputBuffers(returnOp, rewriter);
|
||||
eraseOpsToRemove();
|
||||
|
||||
RewritePatternSet finalTensorPackingPatterns(ctx);
|
||||
populateTensorPackingPatterns(finalTensorPackingPatterns);
|
||||
walkAndApplyPatterns(funcOp, std::move(finalTensorPackingPatterns));
|
||||
eraseUnusedTensorPackingOps(funcOp, rewriter);
|
||||
|
||||
ConversionTarget communicationTarget(*ctx);
|
||||
communicationTarget.addLegalDialect<PimDialect,
|
||||
tensor::TensorDialect,
|
||||
arith::ArithDialect,
|
||||
bufferization::BufferizationDialect,
|
||||
func::FuncDialect,
|
||||
memref::MemRefDialect,
|
||||
scf::SCFDialect,
|
||||
BuiltinDialect>();
|
||||
communicationTarget.addLegalOp<ModuleOp>();
|
||||
communicationTarget.addIllegalOp<spatial::SpatConcatOp,
|
||||
spatial::SpatChannelReceiveOp,
|
||||
spatial::SpatChannelReceiveTensorOp,
|
||||
spatial::SpatChannelSendOp,
|
||||
spatial::SpatChannelSendTensorOp,
|
||||
spatial::SpatExtractRowsOp>();
|
||||
|
||||
RewritePatternSet communicationPatterns(ctx);
|
||||
populateChannelLoweringPatterns(communicationPatterns);
|
||||
if (failed(applyFullConversion(funcOp, communicationTarget, std::move(communicationPatterns)))) {
|
||||
funcOp.emitOpError("failed to lower Spatial communication ops to PIM communication ops");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
if (failed(verifySpatialToPimBoundary(moduleOp))) {
|
||||
moduleOp.emitError("Spatial-to-PIM boundary verification failed");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
@@ -305,7 +283,8 @@ void SpatialToPimPass::runOnOperation() {
|
||||
dumpModule(moduleOp, "pim0");
|
||||
}
|
||||
|
||||
void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||
void raptor::SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||
OperationFolder constantFolder(funcOp.getContext());
|
||||
funcOp.walk([&](PimVMMOp vmmOp) {
|
||||
auto outputType = cast<RankedTensorType>(vmmOp.getOutput().getType());
|
||||
ArrayRef<int64_t> outputShape = outputType.getShape();
|
||||
@@ -313,7 +292,7 @@ void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, I
|
||||
assert(outputShape[1] <= static_cast<int64_t>(crossbarSize) && "output width must fit in one crossbar");
|
||||
|
||||
rewriter.setInsertionPoint(vmmOp);
|
||||
Value paddedInput = padHVectorInputToCrossbarSize(rewriter, vmmOp.getLoc(), vmmOp.getInput());
|
||||
Value paddedInput = padHVectorInputToCrossbarSize(rewriter, vmmOp.getLoc(), vmmOp.getInput(), constantFolder);
|
||||
auto paddedOutputType = RankedTensorType::get(
|
||||
{outputShape[0], static_cast<int64_t>(crossbarSize)}, outputType.getElementType(), outputType.getEncoding());
|
||||
Value paddedOutputBuffer = outputShape[1] == static_cast<int64_t>(crossbarSize)
|
||||
@@ -338,13 +317,17 @@ void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, I
|
||||
});
|
||||
}
|
||||
|
||||
LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||
LogicalResult raptor::SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp,
|
||||
IRRewriter& rewriter) {
|
||||
Location loc = funcOp.getLoc();
|
||||
OperationFolder constantFolder(funcOp.getContext());
|
||||
|
||||
auto insertMemCopyHostToDev = [&](Value inputTensor, int64_t elementsOffset) {
|
||||
auto tensorType = cast<ShapedType>(inputTensor.getType());
|
||||
Type elementType = tensorType.getElementType();
|
||||
size_t elementByteSize = elementType.getIntOrFloatBitWidth() / 8;
|
||||
if (!hasByteSizedElementType(elementType))
|
||||
return;
|
||||
size_t elementByteSize = getElementTypeSizeInBytes(elementType);
|
||||
rewriter.setInsertionPointAfter(inputTensor.getDefiningOp());
|
||||
|
||||
auto deviceTensor = tensor::EmptyOp::create(rewriter, loc, tensorType.getShape(), elementType);
|
||||
@@ -353,10 +336,11 @@ LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::Fu
|
||||
rewriter,
|
||||
loc,
|
||||
tensorType,
|
||||
getOrCreateHostIndexConstant(deviceTensor.getOperation(), 0, constantFolder),
|
||||
getOrCreateHostIndexConstant(
|
||||
deviceTensor.getOperation(), static_cast<int64_t>(elementsOffset * elementByteSize), constantFolder),
|
||||
deviceTensor,
|
||||
inputTensor,
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(elementsOffset * elementByteSize)),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(tensorType.getNumElements() * elementByteSize)));
|
||||
|
||||
rewriter.replaceAllUsesExcept(inputTensor, memCopyHostToDevOp.getResult(), {memCopyHostToDevOp});
|
||||
@@ -378,11 +362,18 @@ LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::Fu
|
||||
return success();
|
||||
}
|
||||
|
||||
void SpatialToPimPass::markOpToRemove(Operation* op) {
|
||||
void raptor::SpatialToPimPass::markOpToRemove(Operation* op) {
|
||||
if (!llvm::is_contained(operationsToRemove, op))
|
||||
operationsToRemove.push_back(op);
|
||||
}
|
||||
|
||||
std::unique_ptr<Pass> createSpatialToPimPass() { return std::make_unique<SpatialToPimPass>(); }
|
||||
void raptor::SpatialToPimPass::eraseOpsToRemove() {
|
||||
for (Operation* op : operationsToRemove) {
|
||||
op->dropAllUses();
|
||||
op->erase();
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<Pass> createSpatialToPimPass() { return std::make_unique<raptor::SpatialToPimPass>(); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -0,0 +1,72 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/SCF/Utils/Utils.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/FoldUtils.h"
|
||||
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
#include <functional>
|
||||
|
||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "Conversion/SpatialToPim/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace raptor {
|
||||
|
||||
struct SpatialToPimPass : mlir::PassWrapper<SpatialToPimPass, mlir::OperationPass<mlir::ModuleOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialToPimPass)
|
||||
llvm::StringRef getArgument() const override { return "convert-spatial-to-pim"; }
|
||||
llvm::StringRef getDescription() const override { return "Lower Spatial ops to PIM-ready format"; }
|
||||
|
||||
SpatialToPimPass() = default;
|
||||
SpatialToPimPass(const SpatialToPimPass& pass) {}
|
||||
|
||||
void runOnOperation() final;
|
||||
|
||||
private:
|
||||
using OutputTensorFactory = std::function<mlir::Value(mlir::IRRewriter& rewriter, mlir::Location loc)>;
|
||||
|
||||
llvm::SmallVector<OutputTensorFactory> outputTensors;
|
||||
size_t coreId = 0;
|
||||
llvm::SmallVector<mlir::Operation*> operationsToRemove;
|
||||
|
||||
mlir::LogicalResult allocateAndInitializeCoreLocalVariables(mlir::func::FuncOp funcOp, mlir::IRRewriter& rewriter);
|
||||
mlir::LogicalResult
|
||||
lowerComputeOp(spatial::SpatCompute computeOp, mlir::IRRewriter& rewriter, mlir::OperationFolder& constantFolder);
|
||||
mlir::LogicalResult lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, mlir::IRRewriter& rewriter);
|
||||
|
||||
enum class ReturnPathLoweringResult {
|
||||
Handled,
|
||||
NotReturnPath,
|
||||
Failure
|
||||
};
|
||||
|
||||
void addReturnOutputBuffers(mlir::func::ReturnOp returnOp, mlir::IRRewriter& rewriter);
|
||||
ReturnPathLoweringResult lowerComputeResultReturnPath(spatial::SpatCompute computeOp,
|
||||
mlir::OpResult result,
|
||||
mlir::Value yieldValue,
|
||||
mlir::IRRewriter& rewriter);
|
||||
ReturnPathLoweringResult lowerProducedValueReturnPath(mlir::Operation* producerOp,
|
||||
mlir::Value producedValue,
|
||||
mlir::Value storedValue,
|
||||
mlir::IRRewriter& rewriter);
|
||||
void replaceReturnWithOutputBuffers(mlir::func::ReturnOp returnOp, mlir::IRRewriter& rewriter);
|
||||
|
||||
void markOpToRemove(mlir::Operation* op);
|
||||
void eraseOpsToRemove();
|
||||
|
||||
void enlargeVMMOutTensorsToCrossbarSize(mlir::func::FuncOp funcOp, mlir::IRRewriter& rewriter);
|
||||
};
|
||||
|
||||
} // namespace raptor
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -1,5 +1,4 @@
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp"
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
@@ -75,16 +74,14 @@ struct PackSpatialConcatInputsPattern final : OpRewritePattern<spatial::SpatConc
|
||||
return failure();
|
||||
|
||||
auto outputType = cast<ShapedType>(concatOp.getOutput().getType());
|
||||
auto newConcat = pim::PimConcatOp::create(rewriter,
|
||||
concatOp.getLoc(),
|
||||
concatOp.getOutput().getType(),
|
||||
concatOp.getAxisAttr(),
|
||||
ValueRange(packedInputs),
|
||||
tensor::EmptyOp::create(rewriter,
|
||||
concatOp.getLoc(),
|
||||
outputType.getShape(),
|
||||
outputType.getElementType())
|
||||
.getResult());
|
||||
auto newConcat = pim::PimConcatOp::create(
|
||||
rewriter,
|
||||
concatOp.getLoc(),
|
||||
concatOp.getOutput().getType(),
|
||||
concatOp.getAxisAttr(),
|
||||
ValueRange(packedInputs),
|
||||
tensor::EmptyOp::create(rewriter, concatOp.getLoc(), outputType.getShape(), outputType.getElementType())
|
||||
.getResult());
|
||||
rewriter.replaceOp(concatOp, newConcat.getOutput());
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
+34
-16
@@ -2,6 +2,7 @@
|
||||
#define PIM_DIALECT_H
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/IR/OpAsmInterface.td"
|
||||
include "mlir/IR/AttrTypeBase.td"
|
||||
include "mlir/Dialect/MemRef/IR/MemRefBase.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
@@ -24,7 +25,8 @@ def PimTensor :
|
||||
// Execution
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def PimCoreOp : PimOp<"core", [SingleBlock, IsolatedFromAbove]> {
|
||||
def PimCoreOp : PimOp<"core", [SingleBlock,
|
||||
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>]> {
|
||||
let summary = "Execute a block on a PIM core";
|
||||
|
||||
let regions = (region SizedRegion<1>:$body);
|
||||
@@ -34,12 +36,16 @@ def PimCoreOp : PimOp<"core", [SingleBlock, IsolatedFromAbove]> {
|
||||
I32Attr:$coreId
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $weights `)` attr-dict regions `:` type($weights) `->` `(` `)`
|
||||
let extraClassDeclaration = [{
|
||||
::mlir::BlockArgument getWeightArgument(unsigned idx);
|
||||
}];
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def PimCoreBatchOp : PimOp<"core_batch", [SingleBlock, IsolatedFromAbove, AttrSizedOperandSegments]> {
|
||||
def PimCoreBatchOp : PimOp<"core_batch", [SingleBlock, AttrSizedOperandSegments,
|
||||
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>]> {
|
||||
let summary = "Execute equivalent batched core bodies";
|
||||
|
||||
let regions = (region SizedRegion<1>:$body);
|
||||
@@ -50,6 +56,13 @@ def PimCoreBatchOp : PimOp<"core_batch", [SingleBlock, IsolatedFromAbove, AttrSi
|
||||
Variadic<PimTensor>:$inputs
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
::mlir::BlockArgument getLaneArgument();
|
||||
::mlir::BlockArgument getWeightArgument(unsigned idx);
|
||||
::mlir::BlockArgument getInputArgument(unsigned idx);
|
||||
}];
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
@@ -81,11 +94,11 @@ def PimSendOp : PimOp<"send", []> {
|
||||
let arguments = (ins
|
||||
PimTensor:$input,
|
||||
I32Attr:$size,
|
||||
I32Attr:$targetCoreId
|
||||
Index:$targetCoreId
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $input `)` attr-dict `:` type($input) `->` `(` `)`
|
||||
`(` $input `,` $targetCoreId `)` attr-dict `:` type($input) `->` `(` `)`
|
||||
}];
|
||||
}
|
||||
|
||||
@@ -131,7 +144,7 @@ def PimReceiveOp : PimOp<"receive", [DestinationStyleOpInterface]> {
|
||||
let arguments = (ins
|
||||
PimTensor:$outputBuffer,
|
||||
I32Attr:$size,
|
||||
I32Attr:$sourceCoreId
|
||||
Index:$sourceCoreId
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
@@ -145,7 +158,7 @@ def PimReceiveOp : PimOp<"receive", [DestinationStyleOpInterface]> {
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $outputBuffer `)` attr-dict `:` type($outputBuffer) `->` type($output)
|
||||
`(` $outputBuffer `,` $sourceCoreId `)` attr-dict `:` type($outputBuffer) `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
@@ -219,10 +232,10 @@ def PimMemCopyHostToDevOp : PimOp<"memcp_hd", [DestinationStyleOpInterface]> {
|
||||
let summary = "Copy a memory region from host memory into device memory";
|
||||
|
||||
let arguments = (ins
|
||||
Index:$deviceTargetOffset,
|
||||
Index:$hostSourceOffset,
|
||||
PimTensor:$deviceTarget,
|
||||
PimTensor:$hostSource,
|
||||
I32Attr:$deviceTargetOffset,
|
||||
I32Attr:$hostSourceOffset,
|
||||
I32Attr:$size
|
||||
);
|
||||
|
||||
@@ -237,7 +250,9 @@ def PimMemCopyHostToDevOp : PimOp<"memcp_hd", [DestinationStyleOpInterface]> {
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $deviceTarget `,` $hostSource `)` attr-dict `:` `(` type($deviceTarget) `,` type($hostSource) `)` `->` type($output)
|
||||
`[` $deviceTargetOffset `,` $hostSourceOffset `]`
|
||||
`(` $deviceTarget `,` $hostSource `)` attr-dict
|
||||
`:` type($deviceTarget) `,` type($hostSource) `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
@@ -271,10 +286,10 @@ def PimMemCopyDevToHostOp : PimOp<"memcp_dh", [DestinationStyleOpInterface]> {
|
||||
let summary = "Copy a memory region from device memory into host memory";
|
||||
|
||||
let arguments = (ins
|
||||
Index:$hostTargetOffset,
|
||||
Index:$deviceSourceOffset,
|
||||
PimTensor:$hostTarget,
|
||||
PimTensor:$deviceSource,
|
||||
I32Attr:$hostTargetOffset,
|
||||
I32Attr:$deviceSourceOffset,
|
||||
I32Attr:$size
|
||||
);
|
||||
|
||||
@@ -289,7 +304,9 @@ def PimMemCopyDevToHostOp : PimOp<"memcp_dh", [DestinationStyleOpInterface]> {
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $hostTarget `,` $deviceSource `)` attr-dict `:` `(` type($hostTarget) `,` type($deviceSource) `)` `->` type($output)
|
||||
`[` $hostTargetOffset `,` $deviceSourceOffset `]`
|
||||
`(` $hostTarget `,` $deviceSource `)` attr-dict
|
||||
`:` type($hostTarget) `,` type($deviceSource) `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
@@ -374,7 +391,7 @@ def PimVMMOp : PimOp<"vmm", [DestinationStyleOpInterface]> {
|
||||
let summary = "Vector-matrix multiplication: c = a * b";
|
||||
|
||||
let arguments = (ins
|
||||
I32Attr:$weightIndex,
|
||||
PimTensor:$weight,
|
||||
PimTensor:$input,
|
||||
PimTensor:$outputBuffer
|
||||
);
|
||||
@@ -391,7 +408,8 @@ def PimVMMOp : PimOp<"vmm", [DestinationStyleOpInterface]> {
|
||||
|
||||
let hasVerifier = 1;
|
||||
let assemblyFormat = [{
|
||||
`(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output)
|
||||
`[` $weight `]` `(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($weight) `,` type($input) `,`
|
||||
type($outputBuffer) `)` `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
|
||||
@@ -1,8 +1,41 @@
|
||||
#include <string>
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace pim {
|
||||
|
||||
BlockArgument PimCoreOp::getWeightArgument(unsigned idx) { return getBody().front().getArgument(idx); }
|
||||
|
||||
void PimCoreOp::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
|
||||
if (region.empty())
|
||||
return;
|
||||
|
||||
for (unsigned index = 0; index < getWeights().size(); ++index)
|
||||
setNameFn(getWeightArgument(index), ("w" + std::to_string(index)).c_str());
|
||||
}
|
||||
|
||||
BlockArgument PimCoreBatchOp::getLaneArgument() { return getBody().front().getArgument(0); }
|
||||
|
||||
BlockArgument PimCoreBatchOp::getWeightArgument(unsigned idx) { return getBody().front().getArgument(1 + idx); }
|
||||
|
||||
BlockArgument PimCoreBatchOp::getInputArgument(unsigned idx) {
|
||||
return getBody().front().getArgument(1 + getWeights().size() + idx);
|
||||
}
|
||||
|
||||
void PimCoreBatchOp::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
|
||||
if (region.empty())
|
||||
return;
|
||||
|
||||
setNameFn(getLaneArgument(), "lane");
|
||||
for (unsigned index = 0; index < getWeights().size(); ++index)
|
||||
setNameFn(getWeightArgument(index), ("w" + std::to_string(index)).c_str());
|
||||
for (unsigned index = 0; index < getInputs().size(); ++index)
|
||||
setNameFn(getInputArgument(index), ("in" + std::to_string(index)).c_str());
|
||||
}
|
||||
|
||||
void PimDialect::initialize() {
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
|
||||
@@ -20,6 +20,79 @@ static DenseI32ArrayAttr getDenseI32ArrayAttr(OpAsmParser& parser, ArrayRef<int3
|
||||
return parser.getBuilder().getDenseI32ArrayAttr(values);
|
||||
}
|
||||
|
||||
static IntegerAttr getI32Attr(OpAsmParser& parser, int32_t value) {
|
||||
return parser.getBuilder().getI32IntegerAttr(value);
|
||||
}
|
||||
|
||||
static bool parseOptionalKeywordAlias(OpAsmParser& parser, StringRef preferred, StringRef legacy) {
|
||||
return succeeded(parser.parseOptionalKeyword(preferred)) || succeeded(parser.parseOptionalKeyword(legacy));
|
||||
}
|
||||
|
||||
static void printBlockArgumentList(OpAsmPrinter& printer, ArrayRef<BlockArgument> arguments) {
|
||||
printer << "(";
|
||||
for (auto [index, argument] : llvm::enumerate(arguments)) {
|
||||
if (index != 0)
|
||||
printer << ", ";
|
||||
printer.printOperand(argument);
|
||||
}
|
||||
printer << ")";
|
||||
}
|
||||
|
||||
static ParseResult parseBlockArgumentList(OpAsmParser& parser, SmallVectorImpl<OpAsmParser::Argument>& arguments) {
|
||||
if (parser.parseLParen())
|
||||
return failure();
|
||||
if (succeeded(parser.parseOptionalRParen()))
|
||||
return success();
|
||||
|
||||
OpAsmParser::Argument argument;
|
||||
if (parser.parseArgument(argument))
|
||||
return failure();
|
||||
arguments.push_back(argument);
|
||||
while (succeeded(parser.parseOptionalComma())) {
|
||||
if (parser.parseArgument(argument))
|
||||
return failure();
|
||||
arguments.push_back(argument);
|
||||
}
|
||||
return parser.parseRParen();
|
||||
}
|
||||
|
||||
static void
|
||||
printBoundValueList(OpAsmPrinter& printer, ValueRange arguments, ValueRange operands, ListDelimiter delimiter) {
|
||||
printCompressedValueList(printer, arguments, delimiter);
|
||||
printer << " = ";
|
||||
printCompressedValueList(printer, operands, delimiter);
|
||||
}
|
||||
|
||||
static ParseResult parseBoundValueList(OpAsmParser& parser,
|
||||
ListDelimiter delimiter,
|
||||
SmallVectorImpl<OpAsmParser::Argument>& arguments,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
|
||||
if (parseOpenDelimiter(parser, delimiter))
|
||||
return failure();
|
||||
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) {
|
||||
if (parser.parseEqual() || parseCompressedOperandList(parser, delimiter, operands))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
if (parseOneCompressedArgumentEntry(parser, arguments))
|
||||
return failure();
|
||||
while (succeeded(parser.parseOptionalComma()))
|
||||
if (parseOneCompressedArgumentEntry(parser, arguments))
|
||||
return failure();
|
||||
|
||||
auto parseCloseDelimiter = [&](ListDelimiter currentDelimiter) -> ParseResult {
|
||||
switch (currentDelimiter) {
|
||||
case ListDelimiter::Paren: return parser.parseRParen();
|
||||
case ListDelimiter::Square: return parser.parseRSquare();
|
||||
}
|
||||
llvm_unreachable("unsupported delimiter");
|
||||
};
|
||||
if (parseCloseDelimiter(delimiter) || parser.parseEqual() || parseCompressedOperandList(parser, delimiter, operands))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
static void printCoreIdList(OpAsmPrinter& printer, StringRef keyword, ArrayRef<int32_t> coreIds) {
|
||||
printer << " " << keyword << " ";
|
||||
printCompressedIntegerList(printer, coreIds);
|
||||
@@ -33,15 +106,76 @@ static ParseResult parseOptionalCoreIdList(OpAsmParser& parser, StringRef keywor
|
||||
|
||||
} // namespace
|
||||
|
||||
void PimCoreBatchOp::print(OpAsmPrinter& printer) {
|
||||
printer << " lanes " << getLaneCount() << " ";
|
||||
size_t weightsPerLane = getLaneCount() > 0 ? getWeights().size() / static_cast<size_t>(getLaneCount()) : 0;
|
||||
if (getLaneCount() > 1 && hasRepeatedTuple(getWeights(), weightsPerLane))
|
||||
printValueTupleRun(printer, getWeights(), weightsPerLane, ListDelimiter::Paren);
|
||||
else
|
||||
printCompressedValueList(printer, getWeights(), ListDelimiter::Paren);
|
||||
void PimCoreOp::print(OpAsmPrinter& printer) {
|
||||
SmallVector<Value> weightArgs;
|
||||
weightArgs.reserve(getWeights().size());
|
||||
for (unsigned index = 0; index < getWeights().size(); ++index)
|
||||
weightArgs.push_back(getWeightArgument(index));
|
||||
|
||||
printer << " ";
|
||||
printCompressedValueList(printer, getInputs(), ListDelimiter::Square);
|
||||
printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square);
|
||||
printer << " coreId " << getCoreId();
|
||||
printer.printOptionalAttrDict((*this)->getAttrs(), {getCoreIdAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square);
|
||||
printer << " -> () ";
|
||||
printer.printRegion(getBody(), /*printEntryBlockArgs=*/false);
|
||||
}
|
||||
|
||||
ParseResult PimCoreOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
SmallVector<OpAsmParser::Argument> weightArgs;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> weights;
|
||||
SmallVector<Type> weightTypes;
|
||||
int32_t coreId = 0;
|
||||
|
||||
if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights))
|
||||
return failure();
|
||||
|
||||
bool hasCoreId = parseOptionalKeywordAlias(parser, "coreId", "core_id");
|
||||
if (hasCoreId && parser.parseInteger(coreId))
|
||||
return failure();
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||
|| parseCompressedRepeatedList(
|
||||
parser, ListDelimiter::Square, weightTypes, [&](Type& type) { return parser.parseType(type); })
|
||||
|| parser.parseArrow() || parser.parseLParen() || parser.parseRParen())
|
||||
return failure();
|
||||
|
||||
if (weights.size() != weightTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
|
||||
if (weightArgs.size() != weights.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match");
|
||||
if (hasCoreId && result.attributes.get("coreId"))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"coreId cannot be specified both positionally and in attr-dict");
|
||||
|
||||
if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands))
|
||||
return failure();
|
||||
|
||||
if (hasCoreId)
|
||||
result.addAttribute("coreId", getI32Attr(parser, coreId));
|
||||
|
||||
Region* body = result.addRegion();
|
||||
applyArgumentTypes(weightTypes, weightArgs);
|
||||
return parser.parseRegion(*body, weightArgs);
|
||||
}
|
||||
|
||||
void PimCoreBatchOp::print(OpAsmPrinter& printer) {
|
||||
printer << " ";
|
||||
printer.printOperand(getLaneArgument());
|
||||
printer << " = 0 to " << getLaneCount() << " ";
|
||||
|
||||
SmallVector<Value> weightArgs;
|
||||
weightArgs.reserve(getWeights().size());
|
||||
for (unsigned index = 0; index < getWeights().size(); ++index)
|
||||
weightArgs.push_back(getWeightArgument(index));
|
||||
printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square);
|
||||
printer << " ";
|
||||
SmallVector<Value> inputArgs;
|
||||
inputArgs.reserve(getInputs().size());
|
||||
for (unsigned index = 0; index < getInputs().size(); ++index)
|
||||
inputArgs.push_back(getInputArgument(index));
|
||||
printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren);
|
||||
|
||||
if (auto coreIdsAttr = (*this)->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
|
||||
printCoreIdList(printer, "coreIds", coreIdsAttr.asArrayRef());
|
||||
@@ -49,51 +183,57 @@ void PimCoreBatchOp::print(OpAsmPrinter& printer) {
|
||||
printer.printOptionalAttrDict(
|
||||
(*this)->getAttrs(),
|
||||
{getLaneCountAttrName().getValue(), getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdsAttrName});
|
||||
printer << " ";
|
||||
printer.printRegion(getBody(), /*printEntryBlockArgs=*/false);
|
||||
printer << " : ";
|
||||
if (getLaneCount() > 1 && hasRepeatedTuple(TypeRange(getWeights()), weightsPerLane))
|
||||
printTypeTupleRun(printer, TypeRange(getWeights()), weightsPerLane, ListDelimiter::Paren);
|
||||
else
|
||||
printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Paren);
|
||||
printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square);
|
||||
printer << " ";
|
||||
printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Square);
|
||||
printer << " -> ()";
|
||||
printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren);
|
||||
printer << " -> () ";
|
||||
printer.printRegion(getBody(), /*printEntryBlockArgs=*/false);
|
||||
}
|
||||
|
||||
ParseResult PimCoreBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
int64_t lowerBound = 0;
|
||||
int32_t laneCount = 0;
|
||||
OpAsmParser::Argument laneArg;
|
||||
SmallVector<OpAsmParser::Argument> weightArgs;
|
||||
SmallVector<OpAsmParser::Argument> inputArgs;
|
||||
SmallVector<OpAsmParser::Argument> regionArgs;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> weights;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
|
||||
SmallVector<Type> weightTypes;
|
||||
SmallVector<Type> inputTypes;
|
||||
SmallVector<int32_t> coreIds;
|
||||
|
||||
if (parser.parseKeyword("lanes") || parser.parseInteger(laneCount)
|
||||
|| parseCompressedOrTupleOperandList(parser, ListDelimiter::Paren, weights)
|
||||
|| parseCompressedOperandList(parser, ListDelimiter::Square, inputs))
|
||||
if (parser.parseArgument(laneArg) || parser.parseEqual() || parser.parseInteger(lowerBound)
|
||||
|| parser.parseKeyword("to") || parser.parseInteger(laneCount))
|
||||
return failure();
|
||||
if (lowerBound != 0)
|
||||
return parser.emitError(parser.getCurrentLocation(), "core_batch currently requires a zero lower bound");
|
||||
|
||||
if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights)
|
||||
|| parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs))
|
||||
return failure();
|
||||
|
||||
bool hasCoreIds = succeeded(parser.parseOptionalKeyword("coreIds"));
|
||||
bool hasCoreIds = parseOptionalKeywordAlias(parser, "coreIds", "core_ids");
|
||||
if (hasCoreIds && parseCompressedIntegerList(parser, coreIds))
|
||||
return failure();
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes))
|
||||
return failure();
|
||||
|
||||
Region* body = result.addRegion();
|
||||
if (parser.parseRegion(*body))
|
||||
return failure();
|
||||
|
||||
if (parser.parseColon() || parseCompressedOrTupleTypeList(parser, ListDelimiter::Paren, weightTypes)
|
||||
|| parseCompressedTypeList(parser, ListDelimiter::Square, inputTypes) || parser.parseArrow()
|
||||
|| parser.parseLParen() || parser.parseRParen())
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon()
|
||||
|| parseCompressedRepeatedList(
|
||||
parser, ListDelimiter::Square, weightTypes, [&](Type& type) { return parser.parseType(type); })
|
||||
|| parseCompressedRepeatedList(
|
||||
parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); })
|
||||
|| parser.parseArrow() || parser.parseLParen() || parser.parseRParen())
|
||||
return failure();
|
||||
|
||||
if (weights.size() != weightTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
|
||||
if (weightArgs.size() != weights.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match");
|
||||
if (inputs.size() != inputTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
|
||||
if (inputArgs.size() != inputs.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of input bindings and input operands must match");
|
||||
if (hasCoreIds && result.attributes.get(onnx_mlir::kCoreIdsAttrName))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"coreIds cannot be specified both positionally and in attr-dict");
|
||||
@@ -110,7 +250,15 @@ ParseResult PimCoreBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
|| parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands)) {
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
|
||||
Region* body = result.addRegion();
|
||||
laneArg.type = builder.getIndexType();
|
||||
regionArgs.push_back(laneArg);
|
||||
applyArgumentTypes(weightTypes, weightArgs);
|
||||
llvm::append_range(regionArgs, weightArgs);
|
||||
applyArgumentTypes(inputTypes, inputArgs);
|
||||
llvm::append_range(regionArgs, inputArgs);
|
||||
return parser.parseRegion(*body, regionArgs);
|
||||
}
|
||||
|
||||
void PimYieldOp::print(OpAsmPrinter& printer) {
|
||||
|
||||
@@ -1,9 +1,14 @@
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/IR/Block.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
|
||||
#include "llvm/Support/LogicalResult.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
@@ -14,6 +19,63 @@ namespace pim {
|
||||
|
||||
namespace {
|
||||
|
||||
static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) {
|
||||
if (isa<PimMemCopyHostToDevOp>(op))
|
||||
return operandIndex == 3;
|
||||
if (isa<PimMemCopyHostToDevBatchOp>(op))
|
||||
return operandIndex == 1;
|
||||
if (isa<PimMemCopyDevToHostOp>(op))
|
||||
return operandIndex == 2;
|
||||
return false;
|
||||
}
|
||||
|
||||
static Region* getParentRegion(Value value) {
|
||||
if (auto blockArgument = dyn_cast<BlockArgument>(value))
|
||||
return blockArgument.getParentRegion();
|
||||
Operation* definingOp = value.getDefiningOp();
|
||||
return definingOp ? definingOp->getParentRegion() : nullptr;
|
||||
}
|
||||
|
||||
static bool isDefinedInsideRegion(Value value, Region& region) {
|
||||
Region* parentRegion = getParentRegion(value);
|
||||
return parentRegion && (®ion == parentRegion || region.isAncestor(parentRegion));
|
||||
}
|
||||
|
||||
static bool isConstantExternalValue(Value value) {
|
||||
Operation* definingOp = value.getDefiningOp();
|
||||
if (!definingOp)
|
||||
return false;
|
||||
if (definingOp->hasTrait<OpTrait::ConstantLike>())
|
||||
return true;
|
||||
|
||||
auto getGlobalOp = dyn_cast<memref::GetGlobalOp>(definingOp);
|
||||
if (!getGlobalOp)
|
||||
return false;
|
||||
|
||||
auto moduleOp = definingOp->getParentOfType<ModuleOp>();
|
||||
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||
return globalOp && globalOp.getConstant();
|
||||
}
|
||||
|
||||
static LogicalResult verifyOnlyConstantExternalValues(Operation* ownerOp, Region& region, StringRef kind) {
|
||||
bool hasFailure = false;
|
||||
region.walk([&](Operation* op) {
|
||||
for (OpOperand& operand : op->getOpOperands()) {
|
||||
Value value = operand.get();
|
||||
if (isDefinedInsideRegion(value, region) || isConstantExternalValue(value)
|
||||
|| isExplicitHostOperand(op, operand.getOperandNumber()))
|
||||
continue;
|
||||
|
||||
InFlightDiagnostic diagnostic = ownerOp->emitOpError()
|
||||
<< kind << " body may only directly reference external constants";
|
||||
diagnostic.attachNote(op->getLoc())
|
||||
<< "non-constant external operand #" << operand.getOperandNumber() << " is used by " << op->getName();
|
||||
hasFailure = true;
|
||||
}
|
||||
});
|
||||
return success(!hasFailure);
|
||||
}
|
||||
|
||||
static bool haveSameShapedContainerKind(Type lhs, Type rhs) {
|
||||
return (isa<RankedTensorType>(lhs) && isa<RankedTensorType>(rhs)) || (isa<MemRefType>(lhs) && isa<MemRefType>(rhs));
|
||||
}
|
||||
@@ -78,24 +140,43 @@ verifyTensorBatchCommunication(Operation* op, Type type, ArrayRef<int32_t> coreI
|
||||
return success();
|
||||
}
|
||||
|
||||
static FailureOr<ArrayRef<int64_t>> getWeightShapeForVMM(Operation* op, size_t weightIndex) {
|
||||
if (auto coreOp = op->getParentOfType<PimCoreOp>()) {
|
||||
if (weightIndex >= coreOp.getWeights().size())
|
||||
return failure();
|
||||
return cast<ShapedType>(coreOp.getWeights()[weightIndex].getType()).getShape();
|
||||
}
|
||||
|
||||
if (auto coreBatchOp = op->getParentOfType<PimCoreBatchOp>()) {
|
||||
if (weightIndex >= coreBatchOp.getWeights().size())
|
||||
return failure();
|
||||
return cast<ShapedType>(coreBatchOp.getWeights()[weightIndex].getType()).getShape();
|
||||
}
|
||||
|
||||
return failure();
|
||||
static FailureOr<ArrayRef<int64_t>> getWeightShapeForVMM(Value weight) {
|
||||
auto shapedType = dyn_cast<ShapedType>(weight.getType());
|
||||
if (!shapedType)
|
||||
return failure();
|
||||
return shapedType.getShape();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult PimCoreOp::verify() {
|
||||
Block& block = getBody().front();
|
||||
if (block.getNumArguments() != getWeights().size())
|
||||
return emitError("core body must have one block argument per weight");
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(getWeights()))
|
||||
if (getWeightArgument(weightIndex).getType() != weight.getType())
|
||||
return emitError("core weight block argument types must match weight operand types exactly");
|
||||
return verifyOnlyConstantExternalValues(getOperation(), getBody(), "pim.core");
|
||||
}
|
||||
|
||||
LogicalResult PimCoreBatchOp::verify() {
|
||||
if (getLaneCount() <= 0)
|
||||
return emitError("laneCount must be positive");
|
||||
Block& block = getBody().front();
|
||||
unsigned expectedArgCount = 1 + getWeights().size() + getInputs().size();
|
||||
if (block.getNumArguments() != expectedArgCount)
|
||||
return emitError("core_batch body must have lane, weight, and input block arguments");
|
||||
if (!getLaneArgument().getType().isIndex())
|
||||
return emitError("core_batch first block argument must have index type");
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(getWeights()))
|
||||
if (getWeightArgument(weightIndex).getType() != weight.getType())
|
||||
return emitError("core_batch weight block argument types must match weight operand types exactly");
|
||||
for (auto [inputIndex, input] : llvm::enumerate(getInputs()))
|
||||
if (getInputArgument(inputIndex).getType() != input.getType())
|
||||
return emitError("core_batch input block argument types must match input operand types exactly");
|
||||
return verifyOnlyConstantExternalValues(getOperation(), getBody(), "pim.core_batch");
|
||||
}
|
||||
|
||||
LogicalResult PimSendTensorOp::verify() {
|
||||
return verifyTensorCommunication(getOperation(), getInput().getType(), getTargetCoreIds(), "send_tensor");
|
||||
}
|
||||
@@ -126,9 +207,9 @@ LogicalResult PimVMMOp::verify() {
|
||||
getOperation(), getOutputBuffer().getType(), getOutput().getType(), "output buffer and output must match")))
|
||||
return failure();
|
||||
|
||||
auto matrixShapeOpt = getWeightShapeForVMM(getOperation(), getWeightIndex());
|
||||
auto matrixShapeOpt = getWeightShapeForVMM(getWeight());
|
||||
if (failed(matrixShapeOpt))
|
||||
return emitError("must be nested inside pim.core or pim.core_batch with a valid weightIndex");
|
||||
return emitError("weight must be a shaped value");
|
||||
ArrayRef<int64_t> matrixShape = *matrixShapeOpt;
|
||||
|
||||
auto vectorType = dyn_cast<ShapedType>(getInput().getType());
|
||||
|
||||
@@ -17,7 +17,7 @@ Value materializeContiguousMemRef(Value memrefValue, Location loc, RewriterBase&
|
||||
auto shapedType = cast<ShapedType>(memrefValue.getType());
|
||||
auto contiguousType = MemRefType::get(shapedType.getShape(), shapedType.getElementType());
|
||||
Value contiguousBuffer = memref::AllocOp::create(rewriter, loc, contiguousType);
|
||||
auto sizeInBytes = shapedType.getNumElements() * shapedType.getElementTypeBitWidth() / 8;
|
||||
auto sizeInBytes = getShapedTypeSizeInBytes(shapedType);
|
||||
|
||||
return PimMemCopyOp::create(rewriter,
|
||||
loc,
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
#include "Dialect/Pim/Transforms/Bufferization/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
IntegerAttr onnx_mlir::pim::getMemRefSizeInBytesAttr(OpBuilder& builder, Value memref) {
|
||||
auto type = mlir::cast<MemRefType>(memref.getType());
|
||||
int32_t sizeInBytes = static_cast<int32_t>(type.getNumElements() * type.getElementTypeBitWidth() / 8);
|
||||
int32_t sizeInBytes = static_cast<int32_t>(getShapedTypeSizeInBytes(type));
|
||||
return builder.getI32IntegerAttr(sizeInBytes);
|
||||
}
|
||||
|
||||
@@ -38,10 +38,10 @@ struct MemCopyHostToDevOpInterface
|
||||
replaceOpWithNewBufferizedOp<PimMemCopyHostToDevOp>(rewriter,
|
||||
memCopyHostToDevOp,
|
||||
deviceTargetMemRef.getType(),
|
||||
memCopyHostToDevOp.getDeviceTargetOffset(),
|
||||
memCopyHostToDevOp.getHostSourceOffset(),
|
||||
deviceTargetMemRef,
|
||||
hostSourceMemRef,
|
||||
memCopyHostToDevOp.getDeviceTargetOffsetAttr(),
|
||||
memCopyHostToDevOp.getHostSourceOffsetAttr(),
|
||||
memCopyHostToDevOp.getSizeAttr());
|
||||
return success();
|
||||
}
|
||||
@@ -96,10 +96,10 @@ struct MemCopyDevToHostOpInterface
|
||||
replaceOpWithNewBufferizedOp<PimMemCopyDevToHostOp>(rewriter,
|
||||
memCopyDevToHostOp,
|
||||
hostTargetMemRef.getType(),
|
||||
memCopyDevToHostOp.getHostTargetOffset(),
|
||||
memCopyDevToHostOp.getDeviceSourceOffset(),
|
||||
hostTargetMemRef,
|
||||
deviceSourceMemRef,
|
||||
memCopyDevToHostOp.getHostTargetOffsetAttr(),
|
||||
memCopyDevToHostOp.getDeviceSourceOffsetAttr(),
|
||||
memCopyDevToHostOp.getSizeAttr());
|
||||
return success();
|
||||
}
|
||||
@@ -151,12 +151,8 @@ struct ReceiveOpInterface : DstBufferizableOpInterfaceExternalModel<ReceiveOpInt
|
||||
if (failed(outputBufferOpt))
|
||||
return failure();
|
||||
|
||||
replaceOpWithNewBufferizedOp<PimReceiveOp>(rewriter,
|
||||
op,
|
||||
outputBufferOpt->getType(),
|
||||
*outputBufferOpt,
|
||||
receiveOp.getSizeAttr(),
|
||||
receiveOp.getSourceCoreIdAttr());
|
||||
replaceOpWithNewBufferizedOp<PimReceiveOp>(
|
||||
rewriter, op, outputBufferOpt->getType(), *outputBufferOpt, receiveOp.getSizeAttr(), receiveOp.getSourceCoreId());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -302,7 +298,7 @@ struct SendOpInterface : BufferizableOpInterface::ExternalModel<SendOpInterface,
|
||||
op,
|
||||
materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter),
|
||||
sendOp.getSizeAttr(),
|
||||
sendOp.getTargetCoreIdAttr());
|
||||
sendOp.getTargetCoreId());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -368,6 +364,37 @@ struct CoreOpInterface : BufferizableOpInterface::ExternalModel<CoreOpInterface,
|
||||
return {};
|
||||
}
|
||||
|
||||
AliasingOpOperandList getAliasingOpOperands(Operation* op, Value value, const AnalysisState& state) const {
|
||||
auto coreOp = cast<PimCoreOp>(op);
|
||||
auto bbArg = dyn_cast<BlockArgument>(value);
|
||||
if (!bbArg || bbArg.getOwner() != &coreOp.getBody().front())
|
||||
return {};
|
||||
|
||||
unsigned weightIndex = bbArg.getArgNumber();
|
||||
return {
|
||||
{&coreOp->getOpOperand(weightIndex), BufferRelation::Equivalent}
|
||||
};
|
||||
}
|
||||
|
||||
bool isWritable(Operation* op, Value value, const AnalysisState& state) const { return false; }
|
||||
|
||||
FailureOr<BufferLikeType> getBufferType(Operation* op,
|
||||
Value value,
|
||||
const BufferizationOptions& options,
|
||||
const BufferizationState& state,
|
||||
SmallVector<Value>& invocationStack) const {
|
||||
auto coreOp = cast<PimCoreOp>(op);
|
||||
auto bbArg = dyn_cast<BlockArgument>(value);
|
||||
if (!bbArg || bbArg.getOwner() != &coreOp.getBody().front())
|
||||
return failure();
|
||||
|
||||
Value tiedWeight = coreOp.getWeights()[bbArg.getArgNumber()];
|
||||
if (auto memRefType = dyn_cast<BufferLikeType>(tiedWeight.getType()))
|
||||
return memRefType;
|
||||
|
||||
return bufferization::getBufferType(tiedWeight, options, state, invocationStack);
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation* op,
|
||||
RewriterBase& rewriter,
|
||||
const BufferizationOptions& options,
|
||||
@@ -375,7 +402,10 @@ struct CoreOpInterface : BufferizableOpInterface::ExternalModel<CoreOpInterface,
|
||||
auto coreOp = cast<PimCoreOp>(op);
|
||||
|
||||
bool alreadyBufferized =
|
||||
llvm::all_of(coreOp.getWeights(), [](Value weight) { return isa<BufferLikeType>(weight.getType()); });
|
||||
llvm::all_of(coreOp.getWeights(), [](Value weight) { return isa<BufferLikeType>(weight.getType()); })
|
||||
&& llvm::all_of(coreOp.getBody().front().getArguments(), [](BlockArgument arg) {
|
||||
return !isa<TensorType>(arg.getType()) || isa<BufferLikeType>(arg.getType());
|
||||
});
|
||||
if (alreadyBufferized)
|
||||
return success();
|
||||
|
||||
@@ -420,9 +450,17 @@ struct CoreBatchOpInterface : BufferizableOpInterface::ExternalModel<CoreBatchOp
|
||||
if (!bbArg || bbArg.getOwner() != &coreBatchOp.getBody().front())
|
||||
return {};
|
||||
|
||||
unsigned inputOperandIndex = coreBatchOp.getWeights().size() + bbArg.getArgNumber();
|
||||
unsigned argNumber = bbArg.getArgNumber();
|
||||
if (argNumber == 0)
|
||||
return {};
|
||||
|
||||
unsigned weightCount = coreBatchOp.getWeights().size();
|
||||
unsigned operandIndex = argNumber - 1;
|
||||
if (argNumber > weightCount + 1)
|
||||
operandIndex = weightCount + (argNumber - 1 - weightCount);
|
||||
|
||||
return {
|
||||
{&coreBatchOp->getOpOperand(inputOperandIndex), BufferRelation::Equivalent}
|
||||
{&coreBatchOp->getOpOperand(operandIndex), BufferRelation::Equivalent}
|
||||
};
|
||||
}
|
||||
|
||||
@@ -438,11 +476,21 @@ struct CoreBatchOpInterface : BufferizableOpInterface::ExternalModel<CoreBatchOp
|
||||
if (!bbArg || bbArg.getOwner() != &coreBatchOp.getBody().front())
|
||||
return failure();
|
||||
|
||||
Value tiedInput = coreBatchOp.getInputs()[bbArg.getArgNumber()];
|
||||
if (auto memRefType = dyn_cast<BufferLikeType>(tiedInput.getType()))
|
||||
unsigned argNumber = bbArg.getArgNumber();
|
||||
if (argNumber == 0)
|
||||
return failure();
|
||||
|
||||
Value tiedOperand;
|
||||
unsigned weightCount = coreBatchOp.getWeights().size();
|
||||
if (argNumber <= weightCount)
|
||||
tiedOperand = coreBatchOp.getWeights()[argNumber - 1];
|
||||
else
|
||||
tiedOperand = coreBatchOp.getInputs()[argNumber - 1 - weightCount];
|
||||
|
||||
if (auto memRefType = dyn_cast<BufferLikeType>(tiedOperand.getType()))
|
||||
return memRefType;
|
||||
|
||||
return bufferization::getBufferType(tiedInput, options, state, invocationStack);
|
||||
return bufferization::getBufferType(tiedOperand, options, state, invocationStack);
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation* op,
|
||||
@@ -454,8 +502,9 @@ struct CoreBatchOpInterface : BufferizableOpInterface::ExternalModel<CoreBatchOp
|
||||
bool alreadyBufferized =
|
||||
llvm::all_of(coreBatchOp.getWeights(), [](Value weight) { return isa<BufferLikeType>(weight.getType()); })
|
||||
&& llvm::all_of(coreBatchOp.getInputs(), [](Value input) { return isa<BufferLikeType>(input.getType()); })
|
||||
&& llvm::all_of(coreBatchOp.getBody().front().getArguments(),
|
||||
[](BlockArgument arg) { return isa<BufferLikeType>(arg.getType()); });
|
||||
&& llvm::all_of(coreBatchOp.getBody().front().getArguments(), [](BlockArgument arg) {
|
||||
return !isa<TensorType>(arg.getType()) || isa<BufferLikeType>(arg.getType());
|
||||
});
|
||||
if (alreadyBufferized)
|
||||
return success();
|
||||
|
||||
@@ -553,6 +602,10 @@ struct VMMOpInterface : DstBufferizableOpInterfaceExternalModel<VMMOpInterface,
|
||||
BufferizationState& state) const {
|
||||
auto vmmOp = cast<PimVMMOp>(op);
|
||||
|
||||
auto weightOpt = getBufferOrValue(rewriter, vmmOp.getWeight(), options, state);
|
||||
if (failed(weightOpt))
|
||||
return failure();
|
||||
|
||||
auto inputOpt = getBufferOrValue(rewriter, vmmOp.getInput(), options, state);
|
||||
if (failed(inputOpt))
|
||||
return failure();
|
||||
@@ -564,7 +617,7 @@ struct VMMOpInterface : DstBufferizableOpInterfaceExternalModel<VMMOpInterface,
|
||||
Value contiguousInput = materializeContiguousMemRef(*inputOpt, op->getLoc(), rewriter);
|
||||
|
||||
replaceOpWithNewBufferizedOp<PimVMMOp>(
|
||||
rewriter, op, outputBufferOpt->getType(), vmmOp.getWeightIndexAttr(), contiguousInput, *outputBufferOpt);
|
||||
rewriter, op, outputBufferOpt->getType(), *weightOpt, contiguousInput, *outputBufferOpt);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -79,6 +79,7 @@ void PimBufferizationPass::runOnOperation() {
|
||||
return WalkResult::skip();
|
||||
});
|
||||
if (hasFailed) {
|
||||
moduleOp.emitError("failed to lower memref.copy-like ops inside PIM core bodies during bufferization");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescing.hpp"
|
||||
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
|
||||
#include <limits>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescing.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
@@ -22,16 +24,16 @@ static bool isSupportedAliasOp(Operation* op) {
|
||||
}
|
||||
|
||||
static bool isCandidateAllocType(MemRefType type) {
|
||||
return type && type.hasStaticShape() && type.getLayout().isIdentity() && type.getElementTypeBitWidth() > 0;
|
||||
return type && type.hasStaticShape() && type.getLayout().isIdentity()
|
||||
&& hasByteSizedElementType(type.getElementType());
|
||||
}
|
||||
|
||||
static uint64_t getTypeSizeBytes(MemRefType type) {
|
||||
return static_cast<uint64_t>(type.getNumElements() * type.getElementTypeBitWidth() / 8);
|
||||
return static_cast<uint64_t>(type.getNumElements() * getElementTypeSizeInBytes(type.getElementType()));
|
||||
}
|
||||
|
||||
static FailureOr<uint64_t> getLastUseInstruction(memref::AllocOp allocOp,
|
||||
Block& body,
|
||||
const DenseMap<Operation*, uint64_t>& opOrder) {
|
||||
static FailureOr<uint64_t>
|
||||
getLastUseInstruction(memref::AllocOp allocOp, Block& body, const DenseMap<Operation*, uint64_t>& opOrder) {
|
||||
uint64_t endInstruction = opOrder.lookup(allocOp);
|
||||
SmallPtrSet<Operation*, 16> visited;
|
||||
SmallVector<Value> pendingValues;
|
||||
@@ -45,9 +47,14 @@ static FailureOr<uint64_t> getLastUseInstruction(memref::AllocOp allocOp,
|
||||
if (!visited.insert(user).second)
|
||||
continue;
|
||||
|
||||
if (isSupportedAliasOp(user)) {
|
||||
if (isSupportedAliasOp(user))
|
||||
for (Value result : user->getResults())
|
||||
pendingValues.push_back(result);
|
||||
|
||||
if (auto forOp = dyn_cast<scf::ForOp>(user)) {
|
||||
for (auto [index, initArg] : llvm::enumerate(forOp.getInitArgs()))
|
||||
if (initArg == value)
|
||||
pendingValues.push_back(forOp.getResult(index));
|
||||
}
|
||||
|
||||
if (auto dpsOp = dyn_cast<DestinationStyleOpInterface>(user)) {
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
|
||||
+20
-19
@@ -45,9 +45,7 @@ struct CoalescingReportEntry {
|
||||
CoalescingReportRow row;
|
||||
};
|
||||
|
||||
static std::string formatMemory(uint64_t bytes) {
|
||||
return formatReportMemory(bytes);
|
||||
}
|
||||
static std::string formatMemory(uint64_t bytes) { return formatReportMemory(bytes); }
|
||||
|
||||
static SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) {
|
||||
auto coreIdsAttr = coreBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName);
|
||||
@@ -58,9 +56,10 @@ static SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) {
|
||||
static void printReportRow(raw_ostream& os, const CoalescingReportRow& row) {
|
||||
llvm::SmallVector<ReportField, 4> fields = {
|
||||
{"Number of candidates", std::to_string(row.numCandidates)},
|
||||
{"Skipped allocations", std::to_string(row.numSkipped)},
|
||||
{"Removed allocations", std::to_string(row.numRemoved)},
|
||||
{"Saved memory", formatMemory(row.savedBytes)}};
|
||||
{"Skipped allocations", std::to_string(row.numSkipped) },
|
||||
{"Removed allocations", std::to_string(row.numRemoved) },
|
||||
{"Saved memory", formatMemory(row.savedBytes) }
|
||||
};
|
||||
printReportFlatFields(os, fields);
|
||||
}
|
||||
|
||||
@@ -87,10 +86,12 @@ static void emitReport(ArrayRef<CoalescingReportEntry> entries) {
|
||||
totalRow.savedBytes += entryTotal.savedBytes;
|
||||
}
|
||||
|
||||
llvm::SmallVector<ReportField, 4> totalFields = {{"Number of candidates", std::to_string(totalRow.numCandidates)},
|
||||
{"Skipped allocations", std::to_string(totalRow.numSkipped)},
|
||||
{"Removed allocations", std::to_string(totalRow.numRemoved)},
|
||||
{"Saved memory", formatMemory(totalRow.savedBytes)}};
|
||||
llvm::SmallVector<ReportField, 4> totalFields = {
|
||||
{"Number of candidates", std::to_string(totalRow.numCandidates)},
|
||||
{"Skipped allocations", std::to_string(totalRow.numSkipped) },
|
||||
{"Removed allocations", std::to_string(totalRow.numRemoved) },
|
||||
{"Saved memory", formatMemory(totalRow.savedBytes) }
|
||||
};
|
||||
printReportTotalsBlock(os, totalFields);
|
||||
if (!entries.empty())
|
||||
os << "\n";
|
||||
@@ -127,15 +128,17 @@ static void emitReport(ArrayRef<CoalescingReportEntry> entries) {
|
||||
if (sortedEntries[index].kind == CoalescingReportEntry::Kind::Batch) {
|
||||
llvm::SmallVector<ReportField, 4> perCoreFields = {
|
||||
{"Number of candidates", std::to_string(sortedEntries[index].row.numCandidates)},
|
||||
{"Skipped allocations", std::to_string(sortedEntries[index].row.numSkipped)},
|
||||
{"Removed allocations", std::to_string(sortedEntries[index].row.numRemoved)},
|
||||
{"Saved memory", formatMemory(sortedEntries[index].row.savedBytes)}};
|
||||
{"Skipped allocations", std::to_string(sortedEntries[index].row.numSkipped) },
|
||||
{"Removed allocations", std::to_string(sortedEntries[index].row.numRemoved) },
|
||||
{"Saved memory", formatMemory(sortedEntries[index].row.savedBytes) }
|
||||
};
|
||||
CoalescingReportRow totalRow = getTotalRow(sortedEntries[index]);
|
||||
llvm::SmallVector<ReportField, 4> totalFields = {
|
||||
{"Number of candidates", std::to_string(totalRow.numCandidates)},
|
||||
{"Skipped allocations", std::to_string(totalRow.numSkipped)},
|
||||
{"Removed allocations", std::to_string(totalRow.numRemoved)},
|
||||
{"Saved memory", formatMemory(totalRow.savedBytes)}};
|
||||
{"Skipped allocations", std::to_string(totalRow.numSkipped) },
|
||||
{"Removed allocations", std::to_string(totalRow.numRemoved) },
|
||||
{"Saved memory", formatMemory(totalRow.savedBytes) }
|
||||
};
|
||||
printReportPerCoreAndTotalFields(os, perCoreFields, totalFields);
|
||||
}
|
||||
else {
|
||||
@@ -196,8 +199,6 @@ struct StaticMemoryCoalescingPass : PassWrapper<StaticMemoryCoalescingPass, Oper
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> createPimStaticMemoryCoalescingPass() {
|
||||
return std::make_unique<StaticMemoryCoalescingPass>();
|
||||
}
|
||||
std::unique_ptr<Pass> createPimStaticMemoryCoalescingPass() { return std::make_unique<StaticMemoryCoalescingPass>(); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -8,7 +8,14 @@ add_pim_library(SpatialOps
|
||||
SpatialOpsVerify.cpp
|
||||
SpatialOpsCanonicalization.cpp
|
||||
Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp
|
||||
Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp
|
||||
Transforms/MergeComputeNodes/PostMergeCompaction.cpp
|
||||
Transforms/MergeComputeNodes/RegularOpCompaction.cpp
|
||||
Transforms/MergeComputeNodes/Scheduling/ComputeGraph.cpp
|
||||
Transforms/MergeComputeNodes/Scheduling/ComputeInstanceUtils.cpp
|
||||
Transforms/MergeComputeNodes/Scheduling/DcpScheduler.cpp
|
||||
Transforms/MergeComputeNodes/Scheduling/MergeSchedulingAnalysis.cpp
|
||||
Transforms/MergeComputeNodes/Scheduling/PeftScheduler.cpp
|
||||
Transforms/MergeComputeNodes/DCPGraph/Graph.cpp
|
||||
Transforms/MergeComputeNodes/DCPGraph/GraphDebug.cpp
|
||||
Transforms/MergeComputeNodes/DCPGraph/GraphSupport.cpp
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/Channels.hpp"
|
||||
|
||||
@@ -9,19 +10,62 @@ namespace onnx_mlir::spatial {
|
||||
|
||||
namespace {
|
||||
|
||||
static Channels::ChannelId getChannelId(SpatChannelSendOp sendOp) { return sendOp.getChannelId(); }
|
||||
static FailureOr<int64_t> getConstantI64(Value value) {
|
||||
APInt constantValue;
|
||||
if (!matchPattern(value, m_ConstantInt(&constantValue)))
|
||||
return failure();
|
||||
return constantValue.getSExtValue();
|
||||
}
|
||||
|
||||
static Channels::ChannelId getChannelId(SpatChannelReceiveOp receiveOp) { return receiveOp.getChannelId(); }
|
||||
static FailureOr<int32_t> getConstantI32(Value value) {
|
||||
APInt constantValue;
|
||||
if (!matchPattern(value, m_ConstantInt(&constantValue)))
|
||||
return failure();
|
||||
return static_cast<int32_t>(constantValue.getSExtValue());
|
||||
}
|
||||
|
||||
static FailureOr<Channels::ChannelId> getChannelId(SpatChannelSendOp sendOp) {
|
||||
return getConstantI64(sendOp.getChannelId());
|
||||
}
|
||||
|
||||
static FailureOr<Channels::ChannelId> getChannelId(SpatChannelReceiveOp receiveOp) {
|
||||
return getConstantI64(receiveOp.getChannelId());
|
||||
}
|
||||
|
||||
static FailureOr<int32_t> getSourceCoreId(SpatChannelSendOp sendOp) { return getConstantI32(sendOp.getSourceCoreId()); }
|
||||
|
||||
static FailureOr<int32_t> getSourceCoreId(SpatChannelReceiveOp receiveOp) {
|
||||
return getConstantI32(receiveOp.getSourceCoreId());
|
||||
}
|
||||
|
||||
static FailureOr<int32_t> getTargetCoreId(SpatChannelSendOp sendOp) { return getConstantI32(sendOp.getTargetCoreId()); }
|
||||
|
||||
static FailureOr<int32_t> getTargetCoreId(SpatChannelReceiveOp receiveOp) {
|
||||
return getConstantI32(receiveOp.getTargetCoreId());
|
||||
}
|
||||
|
||||
static LogicalResult verifyEndpointPair(ChannelEndpoints endpoints) {
|
||||
if (!endpoints.send || !endpoints.receive)
|
||||
return failure();
|
||||
|
||||
if (endpoints.send.getSourceCoreId() != endpoints.receive.getSourceCoreId()) {
|
||||
FailureOr<int32_t> sendSourceCoreId = getSourceCoreId(endpoints.send);
|
||||
FailureOr<int32_t> receiveSourceCoreId = getSourceCoreId(endpoints.receive);
|
||||
if (failed(sendSourceCoreId) || failed(receiveSourceCoreId)) {
|
||||
endpoints.send.emitOpError("channel endpoints must use constant sourceCoreId operands");
|
||||
return failure();
|
||||
}
|
||||
if (*sendSourceCoreId != *receiveSourceCoreId) {
|
||||
endpoints.send.emitOpError("sourceCoreId does not match paired spat.channel_receive");
|
||||
return failure();
|
||||
}
|
||||
if (endpoints.send.getTargetCoreId() != endpoints.receive.getTargetCoreId()) {
|
||||
|
||||
FailureOr<int32_t> sendTargetCoreId = getTargetCoreId(endpoints.send);
|
||||
FailureOr<int32_t> receiveTargetCoreId = getTargetCoreId(endpoints.receive);
|
||||
if (failed(sendTargetCoreId) || failed(receiveTargetCoreId)) {
|
||||
endpoints.send.emitOpError("channel endpoints must use constant targetCoreId operands");
|
||||
return failure();
|
||||
}
|
||||
if (*sendTargetCoreId != *receiveTargetCoreId) {
|
||||
endpoints.send.emitOpError("targetCoreId does not match paired spat.channel_receive");
|
||||
return failure();
|
||||
}
|
||||
@@ -46,20 +90,26 @@ Channels::Channels(func::FuncOp funcOp) {
|
||||
Channels::ChannelId Channels::allocate() { return nextChannelId++; }
|
||||
|
||||
void Channels::insertSend(SpatChannelSendOp sendOp) {
|
||||
ChannelId channelId = getChannelId(sendOp);
|
||||
nextChannelId = std::max(nextChannelId, channelId + 1);
|
||||
endpoints[channelId].send = sendOp;
|
||||
FailureOr<ChannelId> channelId = getChannelId(sendOp);
|
||||
if (failed(channelId))
|
||||
return;
|
||||
nextChannelId = std::max(nextChannelId, *channelId + 1);
|
||||
endpoints[*channelId].send = sendOp;
|
||||
}
|
||||
|
||||
void Channels::insertReceive(SpatChannelReceiveOp receiveOp) {
|
||||
ChannelId channelId = getChannelId(receiveOp);
|
||||
nextChannelId = std::max(nextChannelId, channelId + 1);
|
||||
endpoints[channelId].receive = receiveOp;
|
||||
FailureOr<ChannelId> channelId = getChannelId(receiveOp);
|
||||
if (failed(channelId))
|
||||
return;
|
||||
nextChannelId = std::max(nextChannelId, *channelId + 1);
|
||||
endpoints[*channelId].receive = receiveOp;
|
||||
}
|
||||
|
||||
void Channels::eraseSend(SpatChannelSendOp sendOp) {
|
||||
ChannelId channelId = getChannelId(sendOp);
|
||||
auto it = endpoints.find(channelId);
|
||||
FailureOr<ChannelId> channelId = getChannelId(sendOp);
|
||||
if (failed(channelId))
|
||||
return;
|
||||
auto it = endpoints.find(*channelId);
|
||||
if (it == endpoints.end())
|
||||
return;
|
||||
it->second.send = {};
|
||||
@@ -68,8 +118,10 @@ void Channels::eraseSend(SpatChannelSendOp sendOp) {
|
||||
}
|
||||
|
||||
void Channels::eraseReceive(SpatChannelReceiveOp receiveOp) {
|
||||
ChannelId channelId = getChannelId(receiveOp);
|
||||
auto it = endpoints.find(channelId);
|
||||
FailureOr<ChannelId> channelId = getChannelId(receiveOp);
|
||||
if (failed(channelId))
|
||||
return;
|
||||
auto it = endpoints.find(*channelId);
|
||||
if (it == endpoints.end())
|
||||
return;
|
||||
it->second.receive = {};
|
||||
@@ -85,14 +137,20 @@ FailureOr<ChannelEndpoints> Channels::lookup(ChannelId id) const {
|
||||
}
|
||||
|
||||
FailureOr<SpatChannelReceiveOp> Channels::getReceiveFor(SpatChannelSendOp sendOp) const {
|
||||
auto endpointsOr = lookup(getChannelId(sendOp));
|
||||
FailureOr<ChannelId> channelId = getChannelId(sendOp);
|
||||
if (failed(channelId))
|
||||
return failure();
|
||||
auto endpointsOr = lookup(*channelId);
|
||||
if (failed(endpointsOr) || !endpointsOr->receive)
|
||||
return failure();
|
||||
return endpointsOr->receive;
|
||||
}
|
||||
|
||||
FailureOr<SpatChannelSendOp> Channels::getSendFor(SpatChannelReceiveOp receiveOp) const {
|
||||
auto endpointsOr = lookup(getChannelId(receiveOp));
|
||||
FailureOr<ChannelId> channelId = getChannelId(receiveOp);
|
||||
if (failed(channelId))
|
||||
return failure();
|
||||
auto endpointsOr = lookup(*channelId);
|
||||
if (failed(endpointsOr) || !endpointsOr->send)
|
||||
return failure();
|
||||
return endpointsOr->send;
|
||||
|
||||
@@ -2,8 +2,12 @@
|
||||
#define SPATIAL_DIALECT_H
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/IR/OpAsmInterface.td"
|
||||
include "mlir/IR/BuiltinTypes.td"
|
||||
include "mlir/IR/AttrTypeBase.td"
|
||||
include "mlir/IR/RegionKindInterface.td"
|
||||
include "mlir/Interfaces/ParallelCombiningOpInterface.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
|
||||
def SpatialDialect : Dialect {
|
||||
let name = "spat";
|
||||
@@ -22,7 +26,9 @@ def SpatTensor :
|
||||
// Execution
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def SpatCompute : SpatOp<"compute", [SingleBlock, AttrSizedOperandSegments]> {
|
||||
def SpatCompute : SpatOp<"compute",
|
||||
[SingleBlock, AttrSizedOperandSegments,
|
||||
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>]> {
|
||||
let summary = "Compute region with attached constant weights";
|
||||
|
||||
let arguments = (ins
|
||||
@@ -36,14 +42,26 @@ def SpatCompute : SpatOp<"compute", [SingleBlock, AttrSizedOperandSegments]> {
|
||||
|
||||
let regions = (region SizedRegion<1>:$body);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
std::optional<::mlir::BlockArgument> getWeightArgument(unsigned idx);
|
||||
std::optional<::mlir::BlockArgument> getInputArgument(unsigned idx);
|
||||
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
|
||||
insertWeight(unsigned idx, ::mlir::Value weight, ::mlir::Location loc);
|
||||
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
|
||||
insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc);
|
||||
::mlir::FailureOr<std::tuple<::mlir::OpResult, SpatCompute>>
|
||||
insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc);
|
||||
}];
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasFolder = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def SpatComputeBatch : SpatOp<"compute_batch",
|
||||
[SingleBlock, AttrSizedOperandSegments]> {
|
||||
let summary = "Compressed batch of independent equivalent compute lanes";
|
||||
[SingleBlock, AttrSizedOperandSegments,
|
||||
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>]> {
|
||||
let summary = "Tensor-native batch of equivalent compute lanes with shared weights and packed inputs";
|
||||
|
||||
let arguments = (ins
|
||||
I32Attr:$laneCount,
|
||||
@@ -57,10 +75,47 @@ def SpatComputeBatch : SpatOp<"compute_batch",
|
||||
|
||||
let regions = (region SizedRegion<1>:$body);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
std::optional<::mlir::BlockArgument> getLaneArgument();
|
||||
std::optional<::mlir::BlockArgument> getWeightArgument(unsigned idx);
|
||||
std::optional<::mlir::BlockArgument> getInputArgument(unsigned idx);
|
||||
std::optional<::mlir::BlockArgument> getOutputArgument(unsigned idx);
|
||||
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
|
||||
insertWeight(unsigned idx, ::mlir::Value weight, ::mlir::Location loc);
|
||||
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
|
||||
insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc);
|
||||
::mlir::FailureOr<std::tuple<::mlir::OpResult, ::mlir::BlockArgument, SpatComputeBatch>>
|
||||
insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc);
|
||||
}];
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def SpatInParallelOp : SpatOp<"in_parallel", [
|
||||
Pure,
|
||||
Terminator,
|
||||
DeclareOpInterfaceMethods<InParallelOpInterface>,
|
||||
HasParent<"SpatComputeBatch">,
|
||||
] # GraphRegionNoTerminator.traits> {
|
||||
let summary = "Parallel combining terminator for resultful spat.compute_batch";
|
||||
|
||||
let regions = (region SizedRegion<1>:$region);
|
||||
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let hasVerifier = 1;
|
||||
|
||||
let skipDefaultBuilders = 1;
|
||||
let builders = [
|
||||
OpBuilder<(ins)>,
|
||||
];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
::llvm::iterator_range<::mlir::Block::iterator> getYieldingOps();
|
||||
::mlir::OpResult getParentResult(int64_t idx);
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatYieldOp : SpatOp<"yield", [Terminator]> {
|
||||
let summary = "Yield results from a compute region";
|
||||
|
||||
@@ -110,14 +165,14 @@ def SpatChannelSendOp : SpatOp<"channel_send", []> {
|
||||
let summary = "Send a tensor through a logical channel";
|
||||
|
||||
let arguments = (ins
|
||||
I64Attr:$channelId,
|
||||
I32Attr:$sourceCoreId,
|
||||
I32Attr:$targetCoreId,
|
||||
Index:$channelId,
|
||||
Index:$sourceCoreId,
|
||||
Index:$targetCoreId,
|
||||
SpatTensor:$input
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$input attr-dict `:` type($input)
|
||||
$input `channel` $channelId `from` $sourceCoreId `to` $targetCoreId attr-dict `:` type($input)
|
||||
}];
|
||||
}
|
||||
|
||||
@@ -125,9 +180,9 @@ def SpatChannelReceiveOp : SpatOp<"channel_receive", []> {
|
||||
let summary = "Receive a tensor from a logical channel";
|
||||
|
||||
let arguments = (ins
|
||||
I64Attr:$channelId,
|
||||
I32Attr:$sourceCoreId,
|
||||
I32Attr:$targetCoreId
|
||||
Index:$channelId,
|
||||
Index:$sourceCoreId,
|
||||
Index:$targetCoreId
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
@@ -135,31 +190,33 @@ def SpatChannelReceiveOp : SpatOp<"channel_receive", []> {
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
attr-dict `:` type($output)
|
||||
`channel` $channelId `from` $sourceCoreId `to` $targetCoreId attr-dict `:` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatChannelSendTensorOp : SpatOp<"channel_send_tensor", []> {
|
||||
def SpatChannelSendTensorOp : SpatOp<"channel_send_tensor", [AttrSizedOperandSegments]> {
|
||||
let summary = "Send equal contiguous chunks of one tensor through logical channels";
|
||||
|
||||
let arguments = (ins
|
||||
DenseI64ArrayAttr:$channelIds,
|
||||
DenseI32ArrayAttr:$sourceCoreIds,
|
||||
DenseI32ArrayAttr:$targetCoreIds,
|
||||
Variadic<Index>:$channelIds,
|
||||
Variadic<Index>:$sourceCoreIds,
|
||||
Variadic<Index>:$targetCoreIds,
|
||||
SpatTensor:$input
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let assemblyFormat = [{
|
||||
$input `channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($input)
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatChannelReceiveTensorOp : SpatOp<"channel_receive_tensor", []> {
|
||||
def SpatChannelReceiveTensorOp : SpatOp<"channel_receive_tensor", [AttrSizedOperandSegments]> {
|
||||
let summary = "Receive equal contiguous chunks of one tensor from logical channels";
|
||||
|
||||
let arguments = (ins
|
||||
DenseI64ArrayAttr:$channelIds,
|
||||
DenseI32ArrayAttr:$sourceCoreIds,
|
||||
DenseI32ArrayAttr:$targetCoreIds
|
||||
Variadic<Index>:$channelIds,
|
||||
Variadic<Index>:$sourceCoreIds,
|
||||
Variadic<Index>:$targetCoreIds
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
@@ -167,44 +224,50 @@ def SpatChannelReceiveTensorOp : SpatOp<"channel_receive_tensor", []> {
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let assemblyFormat = [{
|
||||
`channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatChannelSendBatchOp : SpatOp<"channel_send_batch", []> {
|
||||
def SpatChannelSendBatchOp : SpatOp<"channel_send_batch", [AttrSizedOperandSegments]> {
|
||||
let summary = "Send per-lane tensors through logical channels in a batch body";
|
||||
|
||||
let arguments = (ins
|
||||
DenseI64ArrayAttr:$channelIds,
|
||||
DenseI32ArrayAttr:$sourceCoreIds,
|
||||
DenseI32ArrayAttr:$targetCoreIds,
|
||||
Variadic<Index>:$channelIds,
|
||||
Variadic<Index>:$sourceCoreIds,
|
||||
Variadic<Index>:$targetCoreIds,
|
||||
SpatTensor:$input
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let assemblyFormat = [{
|
||||
$input `channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($input)
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatChannelSendTensorBatchOp : SpatOp<"channel_send_tensor_batch", []> {
|
||||
def SpatChannelSendTensorBatchOp : SpatOp<"channel_send_tensor_batch", [AttrSizedOperandSegments]> {
|
||||
let summary = "Send equal contiguous chunks of one per-lane tensor through logical channels in a batch body";
|
||||
|
||||
let arguments = (ins
|
||||
DenseI64ArrayAttr:$channelIds,
|
||||
DenseI32ArrayAttr:$sourceCoreIds,
|
||||
DenseI32ArrayAttr:$targetCoreIds,
|
||||
Variadic<Index>:$channelIds,
|
||||
Variadic<Index>:$sourceCoreIds,
|
||||
Variadic<Index>:$targetCoreIds,
|
||||
SpatTensor:$input
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let assemblyFormat = [{
|
||||
$input `channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($input)
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatChannelReceiveBatchOp : SpatOp<"channel_receive_batch", []> {
|
||||
def SpatChannelReceiveBatchOp : SpatOp<"channel_receive_batch", [AttrSizedOperandSegments]> {
|
||||
let summary = "Receive a per-lane tensor through logical channels in a batch body";
|
||||
|
||||
let arguments = (ins
|
||||
DenseI64ArrayAttr:$channelIds,
|
||||
DenseI32ArrayAttr:$sourceCoreIds,
|
||||
DenseI32ArrayAttr:$targetCoreIds
|
||||
Variadic<Index>:$channelIds,
|
||||
Variadic<Index>:$sourceCoreIds,
|
||||
Variadic<Index>:$targetCoreIds
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
@@ -212,16 +275,18 @@ def SpatChannelReceiveBatchOp : SpatOp<"channel_receive_batch", []> {
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let assemblyFormat = [{
|
||||
`channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
def SpatChannelReceiveTensorBatchOp : SpatOp<"channel_receive_tensor_batch", []> {
|
||||
def SpatChannelReceiveTensorBatchOp : SpatOp<"channel_receive_tensor_batch", [AttrSizedOperandSegments]> {
|
||||
let summary = "Receive equal contiguous chunks of one per-lane tensor through logical channels in a batch body";
|
||||
|
||||
let arguments = (ins
|
||||
DenseI64ArrayAttr:$channelIds,
|
||||
DenseI32ArrayAttr:$sourceCoreIds,
|
||||
DenseI32ArrayAttr:$targetCoreIds
|
||||
Variadic<Index>:$channelIds,
|
||||
Variadic<Index>:$sourceCoreIds,
|
||||
Variadic<Index>:$targetCoreIds
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
@@ -229,7 +294,9 @@ def SpatChannelReceiveTensorBatchOp : SpatOp<"channel_receive_tensor_batch", []>
|
||||
);
|
||||
|
||||
let hasVerifier = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let assemblyFormat = [{
|
||||
`channels` `(` $channelIds `)` `from` `(` $sourceCoreIds `)` `to` `(` $targetCoreIds `)` attr-dict `:` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -240,7 +307,7 @@ def SpatVMMOp : SpatOp<"wvmm", []> {
|
||||
let summary = "Vector-matrix multiplication within a weighted compute operation";
|
||||
|
||||
let arguments = (ins
|
||||
I32Attr:$weightIndex,
|
||||
SpatTensor:$weight,
|
||||
SpatTensor:$input
|
||||
);
|
||||
|
||||
@@ -251,7 +318,7 @@ def SpatVMMOp : SpatOp<"wvmm", []> {
|
||||
let hasVerifier = 1;
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $input `)` attr-dict `:` type($input) `->` type($output)
|
||||
`[` $weight `]` `(` $input `)` attr-dict `:` `(` type($weight) `,` type($input) `)` `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
@@ -259,7 +326,7 @@ def SpatMVMOp : SpatOp<"Wmvm", []> {
|
||||
let summary = "Matrix-vector multiplication within a weighted compute operation";
|
||||
|
||||
let arguments = (ins
|
||||
I32Attr:$weightIndex,
|
||||
SpatTensor:$weight,
|
||||
SpatTensor:$input
|
||||
);
|
||||
|
||||
@@ -270,7 +337,7 @@ def SpatMVMOp : SpatOp<"Wmvm", []> {
|
||||
let hasVerifier = 1;
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $input `)` attr-dict `:` type($input) `->` type($output)
|
||||
`[` $weight `]` `(` $input `)` attr-dict `:` `(` type($weight) `,` type($input) `)` `->` type($output)
|
||||
}];
|
||||
}
|
||||
|
||||
|
||||
@@ -1,9 +1,208 @@
|
||||
#include <string>
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
namespace {
|
||||
|
||||
std::optional<BlockArgument> getBatchBodyArgument(Region& body, unsigned argIdx) {
|
||||
if (body.empty())
|
||||
return std::nullopt;
|
||||
|
||||
Block& block = body.front();
|
||||
if (argIdx >= block.getNumArguments())
|
||||
return std::nullopt;
|
||||
return block.getArgument(argIdx);
|
||||
}
|
||||
|
||||
std::optional<BlockArgument> insertBatchBodyArgument(Region& body, unsigned argIdx, Type type, Location loc) {
|
||||
if (body.empty())
|
||||
return std::nullopt;
|
||||
return body.insertArgument(argIdx, type, loc);
|
||||
}
|
||||
|
||||
void setComputeOperandSegmentSizes(Operation* op, int32_t weightCount, int32_t inputCount) {
|
||||
if (auto compute = dyn_cast<SpatCompute>(op)) {
|
||||
compute.getProperties().setOperandSegmentSizes({weightCount, inputCount});
|
||||
return;
|
||||
}
|
||||
cast<SpatComputeBatch>(op).getProperties().setOperandSegmentSizes({weightCount, inputCount});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::optional<BlockArgument> SpatCompute::getWeightArgument(unsigned idx) {
|
||||
return getBatchBodyArgument(getBody(), idx);
|
||||
}
|
||||
|
||||
std::optional<BlockArgument> SpatCompute::getInputArgument(unsigned idx) {
|
||||
return getBatchBodyArgument(getBody(), getWeights().size() + idx);
|
||||
}
|
||||
|
||||
std::optional<std::tuple<Value, BlockArgument>> SpatCompute::insertWeight(unsigned idx, Value weight, Location loc) {
|
||||
unsigned weightCount = getWeights().size();
|
||||
unsigned inputCount = getInputs().size();
|
||||
getOperation()->insertOperands(idx, ValueRange {weight});
|
||||
setComputeOperandSegmentSizes(
|
||||
getOperation(), static_cast<int32_t>(weightCount + 1), static_cast<int32_t>(inputCount));
|
||||
auto blockArg = insertBatchBodyArgument(getBody(), idx, weight.getType(), loc);
|
||||
if (!blockArg)
|
||||
return std::nullopt;
|
||||
return std::make_tuple(getOperation()->getOperand(idx), *blockArg);
|
||||
}
|
||||
|
||||
std::optional<std::tuple<Value, BlockArgument>> SpatCompute::insertInput(unsigned idx, Value input, Location loc) {
|
||||
unsigned weightCount = getWeights().size();
|
||||
unsigned inputCount = getInputs().size();
|
||||
getOperation()->insertOperands(weightCount + idx, ValueRange {input});
|
||||
setComputeOperandSegmentSizes(
|
||||
getOperation(), static_cast<int32_t>(weightCount), static_cast<int32_t>(inputCount + 1));
|
||||
auto blockArg = insertBatchBodyArgument(getBody(), weightCount + idx, input.getType(), loc);
|
||||
if (!blockArg)
|
||||
return std::nullopt;
|
||||
return std::make_tuple(getOperation()->getOperand(weightCount + idx), *blockArg);
|
||||
}
|
||||
|
||||
FailureOr<std::tuple<OpResult, SpatCompute>>
|
||||
SpatCompute::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Location loc) {
|
||||
if (idx > getNumResults())
|
||||
return failure();
|
||||
|
||||
rewriter.setInsertionPoint(getOperation());
|
||||
SmallVector<Type> resultTypes(getResultTypes().begin(), getResultTypes().end());
|
||||
resultTypes.insert(resultTypes.begin() + idx, type);
|
||||
auto newCompute = SpatCompute::create(rewriter, getLoc(), TypeRange(resultTypes), getWeights(), getInputs());
|
||||
newCompute->setAttrs((*this)->getAttrs());
|
||||
setComputeOperandSegmentSizes(newCompute.getOperation(),
|
||||
static_cast<int32_t>(newCompute.getWeights().size()),
|
||||
static_cast<int32_t>(newCompute.getInputs().size()));
|
||||
rewriter.inlineRegionBefore(getBody(), newCompute.getBody(), newCompute.getBody().end());
|
||||
for (unsigned oldResultIdx = 0; oldResultIdx < getNumResults(); ++oldResultIdx)
|
||||
getResult(oldResultIdx)
|
||||
.replaceAllUsesWith(newCompute.getResult(oldResultIdx < idx ? oldResultIdx : oldResultIdx + 1));
|
||||
rewriter.eraseOp(getOperation());
|
||||
return std::make_tuple(cast<OpResult>(newCompute.getResult(idx)), newCompute);
|
||||
}
|
||||
|
||||
void SpatCompute::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
|
||||
if (region.empty())
|
||||
return;
|
||||
|
||||
for (unsigned index = 0; index < getWeights().size(); ++index)
|
||||
if (auto weightArg = getWeightArgument(index))
|
||||
setNameFn(*weightArg, ("w" + std::to_string(index)).c_str());
|
||||
|
||||
for (unsigned index = 0; index < getInputs().size(); ++index)
|
||||
if (auto inputArg = getInputArgument(index))
|
||||
setNameFn(*inputArg, ("in" + std::to_string(index)).c_str());
|
||||
}
|
||||
|
||||
std::optional<BlockArgument> SpatComputeBatch::getLaneArgument() { return getBatchBodyArgument(getBody(), 0); }
|
||||
|
||||
std::optional<BlockArgument> SpatComputeBatch::getWeightArgument(unsigned idx) {
|
||||
return getBatchBodyArgument(getBody(), 1 + idx);
|
||||
}
|
||||
|
||||
std::optional<BlockArgument> SpatComputeBatch::getInputArgument(unsigned idx) {
|
||||
return getBatchBodyArgument(getBody(), 1 + getWeights().size() + idx);
|
||||
}
|
||||
|
||||
std::optional<BlockArgument> SpatComputeBatch::getOutputArgument(unsigned idx) {
|
||||
return getBatchBodyArgument(getBody(), 1 + getWeights().size() + getInputs().size() + idx);
|
||||
}
|
||||
|
||||
std::optional<std::tuple<Value, BlockArgument>>
|
||||
SpatComputeBatch::insertWeight(unsigned idx, Value weight, Location loc) {
|
||||
unsigned weightCount = getWeights().size();
|
||||
unsigned inputCount = getInputs().size();
|
||||
getOperation()->insertOperands(idx, ValueRange {weight});
|
||||
setComputeOperandSegmentSizes(
|
||||
getOperation(), static_cast<int32_t>(weightCount + 1), static_cast<int32_t>(inputCount));
|
||||
auto blockArg = insertBatchBodyArgument(getBody(), 1 + idx, weight.getType(), loc);
|
||||
if (!blockArg)
|
||||
return std::nullopt;
|
||||
return std::make_tuple(getOperation()->getOperand(idx), *blockArg);
|
||||
}
|
||||
|
||||
std::optional<std::tuple<Value, BlockArgument>> SpatComputeBatch::insertInput(unsigned idx, Value input, Location loc) {
|
||||
unsigned weightCount = getWeights().size();
|
||||
unsigned inputCount = getInputs().size();
|
||||
getOperation()->insertOperands(weightCount + idx, ValueRange {input});
|
||||
setComputeOperandSegmentSizes(
|
||||
getOperation(), static_cast<int32_t>(weightCount), static_cast<int32_t>(inputCount + 1));
|
||||
auto blockArg = insertBatchBodyArgument(getBody(), 1 + weightCount + idx, input.getType(), loc);
|
||||
if (!blockArg)
|
||||
return std::nullopt;
|
||||
return std::make_tuple(getOperation()->getOperand(weightCount + idx), *blockArg);
|
||||
}
|
||||
|
||||
FailureOr<std::tuple<OpResult, BlockArgument, SpatComputeBatch>>
|
||||
SpatComputeBatch::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Location loc) {
|
||||
if (idx > getNumResults())
|
||||
return failure();
|
||||
|
||||
rewriter.setInsertionPoint(getOperation());
|
||||
SmallVector<Type> resultTypes(getResultTypes().begin(), getResultTypes().end());
|
||||
resultTypes.insert(resultTypes.begin() + idx, type);
|
||||
auto newBatch =
|
||||
SpatComputeBatch::create(rewriter, getLoc(), TypeRange(resultTypes), getLaneCountAttr(), getWeights(), getInputs());
|
||||
newBatch->setAttrs((*this)->getAttrs());
|
||||
setComputeOperandSegmentSizes(newBatch.getOperation(),
|
||||
static_cast<int32_t>(newBatch.getWeights().size()),
|
||||
static_cast<int32_t>(newBatch.getInputs().size()));
|
||||
rewriter.inlineRegionBefore(getBody(), newBatch.getBody(), newBatch.getBody().end());
|
||||
if (newBatch.getBody().empty()) {
|
||||
rewriter.eraseOp(newBatch);
|
||||
return failure();
|
||||
}
|
||||
auto blockArg = newBatch.getBody().front().insertArgument(
|
||||
1 + newBatch.getWeights().size() + newBatch.getInputs().size() + idx, type, loc);
|
||||
for (unsigned oldResultIdx = 0; oldResultIdx < getNumResults(); ++oldResultIdx)
|
||||
getResult(oldResultIdx)
|
||||
.replaceAllUsesWith(newBatch.getResult(oldResultIdx < idx ? oldResultIdx : oldResultIdx + 1));
|
||||
rewriter.eraseOp(getOperation());
|
||||
return std::make_tuple(cast<OpResult>(newBatch.getResult(idx)), blockArg, newBatch);
|
||||
}
|
||||
|
||||
void SpatComputeBatch::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
|
||||
if (region.empty())
|
||||
return;
|
||||
|
||||
if (auto laneArg = getLaneArgument())
|
||||
setNameFn(*laneArg, "lane");
|
||||
|
||||
for (unsigned index = 0; index < getWeights().size(); ++index)
|
||||
if (auto weightArg = getWeightArgument(index))
|
||||
setNameFn(*weightArg, ("w" + std::to_string(index)).c_str());
|
||||
|
||||
for (unsigned index = 0; index < getInputs().size(); ++index)
|
||||
if (auto inputArg = getInputArgument(index))
|
||||
setNameFn(*inputArg, ("in" + std::to_string(index)).c_str());
|
||||
|
||||
for (unsigned index = 0; index < getNumResults(); ++index) {
|
||||
auto outputArg = getOutputArgument(index);
|
||||
if (!outputArg)
|
||||
continue;
|
||||
if (index == 0) {
|
||||
setNameFn(*outputArg, "out");
|
||||
continue;
|
||||
}
|
||||
setNameFn(*outputArg, ("out" + std::to_string(index)).c_str());
|
||||
}
|
||||
}
|
||||
|
||||
void SpatInParallelOp::build(OpBuilder& builder, OperationState& result) {
|
||||
OpBuilder::InsertionGuard guard(builder);
|
||||
Region* bodyRegion = result.addRegion();
|
||||
builder.createBlock(bodyRegion);
|
||||
}
|
||||
|
||||
OpResult SpatInParallelOp::getParentResult(int64_t idx) { return getOperation()->getParentOp()->getResult(idx); }
|
||||
|
||||
llvm::iterator_range<Block::iterator> SpatInParallelOp::getYieldingOps() { return getRegion().front().getOperations(); }
|
||||
|
||||
void SpatialDialect::initialize() {
|
||||
addTypes<
|
||||
|
||||
@@ -5,10 +5,15 @@
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/RegionKindInterface.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/Interfaces/ParallelCombiningOpInterface.h"
|
||||
|
||||
#include <map>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
/// Include the auto-generated header files containing the declarations
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialDialect.hpp.inc"
|
||||
|
||||
@@ -23,22 +23,6 @@ static bool parseOptionalKeywordAlias(OpAsmParser& parser, StringRef preferred,
|
||||
return succeeded(parser.parseOptionalKeyword(preferred)) || succeeded(parser.parseOptionalKeyword(legacy));
|
||||
}
|
||||
|
||||
static void printChannelMetadata(OpAsmPrinter& printer,
|
||||
ArrayRef<int64_t> channelIds,
|
||||
ArrayRef<int32_t> sourceCoreIds,
|
||||
ArrayRef<int32_t> targetCoreIds) {
|
||||
printer << " channels ";
|
||||
printCompressedIntegerList(printer, channelIds);
|
||||
printer << " from ";
|
||||
printCompressedIntegerList(printer, sourceCoreIds);
|
||||
printer << " to ";
|
||||
printCompressedIntegerList(printer, targetCoreIds);
|
||||
}
|
||||
|
||||
static DenseI64ArrayAttr getDenseI64ArrayAttr(OpAsmParser& parser, ArrayRef<int64_t> values) {
|
||||
return parser.getBuilder().getDenseI64ArrayAttr(values);
|
||||
}
|
||||
|
||||
static DenseI32ArrayAttr getDenseI32ArrayAttr(OpAsmParser& parser, ArrayRef<int32_t> values) {
|
||||
return parser.getBuilder().getDenseI32ArrayAttr(values);
|
||||
}
|
||||
@@ -47,94 +31,86 @@ static IntegerAttr getI32Attr(OpAsmParser& parser, int32_t value) {
|
||||
return parser.getBuilder().getI32IntegerAttr(value);
|
||||
}
|
||||
|
||||
template <typename TensorSendOpTy>
|
||||
static void printTensorSendOp(OpAsmPrinter& printer, TensorSendOpTy op) {
|
||||
printer << " ";
|
||||
printer.printOperand(op.getInput());
|
||||
printChannelMetadata(printer, op.getChannelIds(), op.getSourceCoreIds(), op.getTargetCoreIds());
|
||||
printer.printOptionalAttrDict(op->getAttrs(),
|
||||
{op.getChannelIdsAttrName().getValue(),
|
||||
op.getSourceCoreIdsAttrName().getValue(),
|
||||
op.getTargetCoreIdsAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printer.printType(op.getInput().getType());
|
||||
static void printBlockArgumentList(OpAsmPrinter& printer, ArrayRef<BlockArgument> arguments) {
|
||||
printer << "(";
|
||||
for (auto [index, argument] : llvm::enumerate(arguments)) {
|
||||
if (index != 0)
|
||||
printer << ", ";
|
||||
printer.printOperand(argument);
|
||||
}
|
||||
printer << ")";
|
||||
}
|
||||
|
||||
template <typename TensorReceiveOpTy>
|
||||
static void printTensorReceiveOp(OpAsmPrinter& printer, TensorReceiveOpTy op) {
|
||||
printChannelMetadata(printer, op.getChannelIds(), op.getSourceCoreIds(), op.getTargetCoreIds());
|
||||
printer.printOptionalAttrDict(op->getAttrs(),
|
||||
{op.getChannelIdsAttrName().getValue(),
|
||||
op.getSourceCoreIdsAttrName().getValue(),
|
||||
op.getTargetCoreIdsAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printer.printType(op.getOutput().getType());
|
||||
}
|
||||
|
||||
static ParseResult parseTensorSendOp(OpAsmParser& parser, OperationState& result) {
|
||||
OpAsmParser::UnresolvedOperand input;
|
||||
Type inputType;
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
|
||||
if (parser.parseOperand(input))
|
||||
static ParseResult parseBlockArgumentList(OpAsmParser& parser, SmallVectorImpl<OpAsmParser::Argument>& arguments) {
|
||||
if (parser.parseLParen())
|
||||
return failure();
|
||||
if (succeeded(parser.parseOptionalRParen()))
|
||||
return success();
|
||||
|
||||
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
|
||||
if (hasMetadata) {
|
||||
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|
||||
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|
||||
|| parseCompressedIntegerList(parser, targetCoreIds))
|
||||
OpAsmParser::Argument argument;
|
||||
if (parser.parseArgument(argument))
|
||||
return failure();
|
||||
arguments.push_back(argument);
|
||||
while (succeeded(parser.parseOptionalComma())) {
|
||||
if (parser.parseArgument(argument))
|
||||
return failure();
|
||||
arguments.push_back(argument);
|
||||
}
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType))
|
||||
return failure();
|
||||
|
||||
if (hasMetadata
|
||||
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|
||||
|| result.attributes.get("targetCoreIds")))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"channel metadata cannot be specified both positionally and in attr-dict");
|
||||
if (hasMetadata) {
|
||||
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
|
||||
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
|
||||
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
|
||||
}
|
||||
|
||||
return parser.resolveOperand(input, inputType, result.operands);
|
||||
return parser.parseRParen();
|
||||
}
|
||||
|
||||
static ParseResult parseTensorReceiveOp(OpAsmParser& parser, OperationState& result) {
|
||||
Type outputType;
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
static void applyBatchRegionArgumentTypes(ArrayRef<Type> inputTypes,
|
||||
ArrayRef<Type> weightTypes,
|
||||
ArrayRef<Type> outputTypes,
|
||||
OpAsmParser::Argument& laneArg,
|
||||
SmallVectorImpl<OpAsmParser::Argument>& weightArgs,
|
||||
SmallVectorImpl<OpAsmParser::Argument>& inputArgs,
|
||||
SmallVectorImpl<OpAsmParser::Argument>& outputArgs,
|
||||
SmallVectorImpl<OpAsmParser::Argument>& regionArgs,
|
||||
Builder& builder) {
|
||||
laneArg.type = builder.getIndexType();
|
||||
regionArgs.push_back(laneArg);
|
||||
applyArgumentTypes(weightTypes, weightArgs);
|
||||
llvm::append_range(regionArgs, weightArgs);
|
||||
applyArgumentTypes(inputTypes, inputArgs);
|
||||
applyArgumentTypes(outputTypes, outputArgs);
|
||||
llvm::append_range(regionArgs, inputArgs);
|
||||
llvm::append_range(regionArgs, outputArgs);
|
||||
}
|
||||
|
||||
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
|
||||
if (hasMetadata) {
|
||||
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|
||||
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|
||||
|| parseCompressedIntegerList(parser, targetCoreIds))
|
||||
return failure();
|
||||
}
|
||||
static void
|
||||
printBoundValueList(OpAsmPrinter& printer, ValueRange arguments, ValueRange operands, ListDelimiter delimiter) {
|
||||
printCompressedValueList(printer, arguments, delimiter);
|
||||
printer << " = ";
|
||||
printCompressedValueList(printer, operands, delimiter);
|
||||
}
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(outputType))
|
||||
static ParseResult parseBoundValueList(OpAsmParser& parser,
|
||||
ListDelimiter delimiter,
|
||||
SmallVectorImpl<OpAsmParser::Argument>& arguments,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
|
||||
if (parseOpenDelimiter(parser, delimiter))
|
||||
return failure();
|
||||
|
||||
if (hasMetadata
|
||||
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|
||||
|| result.attributes.get("targetCoreIds")))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"channel metadata cannot be specified both positionally and in attr-dict");
|
||||
if (hasMetadata) {
|
||||
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
|
||||
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
|
||||
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
|
||||
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) {
|
||||
if (parser.parseEqual() || parseCompressedOperandList(parser, delimiter, operands))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
result.addTypes(outputType);
|
||||
if (parseOneCompressedArgumentEntry(parser, arguments))
|
||||
return failure();
|
||||
while (succeeded(parser.parseOptionalComma()))
|
||||
if (parseOneCompressedArgumentEntry(parser, arguments))
|
||||
return failure();
|
||||
auto parseCloseDelimiter = [&](ListDelimiter currentDelimiter) -> ParseResult {
|
||||
switch (currentDelimiter) {
|
||||
case ListDelimiter::Paren: return parser.parseRParen();
|
||||
case ListDelimiter::Square: return parser.parseRSquare();
|
||||
}
|
||||
llvm_unreachable("unsupported delimiter");
|
||||
};
|
||||
if (parseCloseDelimiter(delimiter) || parser.parseEqual() || parseCompressedOperandList(parser, delimiter, operands))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -242,10 +218,27 @@ ParseResult SpatConcatOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
}
|
||||
|
||||
void SpatCompute::print(OpAsmPrinter& printer) {
|
||||
SmallVector<Value> weightArgs;
|
||||
weightArgs.reserve(getWeights().size());
|
||||
for (unsigned index = 0; index < getWeights().size(); ++index) {
|
||||
auto weightArg = getWeightArgument(index);
|
||||
if (!weightArg)
|
||||
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
|
||||
weightArgs.push_back(*weightArg);
|
||||
}
|
||||
SmallVector<Value> inputArgs;
|
||||
inputArgs.reserve(getInputs().size());
|
||||
for (unsigned index = 0; index < getInputs().size(); ++index) {
|
||||
auto inputArg = getInputArgument(index);
|
||||
if (!inputArg)
|
||||
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
|
||||
inputArgs.push_back(*inputArg);
|
||||
}
|
||||
|
||||
printer << " ";
|
||||
printCompressedValueList(printer, getWeights(), ListDelimiter::Square);
|
||||
printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square);
|
||||
printer << " ";
|
||||
printArgumentBindings(printer, getBody().front(), getInputs());
|
||||
printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren);
|
||||
|
||||
if (auto coreIdAttr = (*this)->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
|
||||
printer << " coreId " << coreIdAttr.getInt();
|
||||
@@ -264,6 +257,7 @@ void SpatCompute::print(OpAsmPrinter& printer) {
|
||||
}
|
||||
|
||||
ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) {
|
||||
SmallVector<OpAsmParser::Argument> weightArgs;
|
||||
SmallVector<OpAsmParser::Argument> regionArgs;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> weights;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
|
||||
@@ -272,10 +266,11 @@ ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) {
|
||||
SmallVector<Type> outputTypes;
|
||||
int32_t coreId = 0;
|
||||
|
||||
if (parseCompressedOperandList(parser, ListDelimiter::Square, weights))
|
||||
if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights))
|
||||
return failure();
|
||||
|
||||
if (parseArgumentBindings(parser, regionArgs, inputs))
|
||||
SmallVector<OpAsmParser::Argument> inputArgs;
|
||||
if (parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs))
|
||||
return failure();
|
||||
|
||||
bool hasCoreId = parseOptionalKeywordAlias(parser, "coreId", "core_id");
|
||||
@@ -292,9 +287,11 @@ ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) {
|
||||
|
||||
if (weights.size() != weightTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
|
||||
if (weightArgs.size() != weights.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match");
|
||||
if (inputs.size() != inputTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
|
||||
if (regionArgs.size() != inputs.size())
|
||||
if (inputArgs.size() != inputs.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match");
|
||||
if (hasCoreId && result.attributes.get(onnx_mlir::kCoreIdAttrName))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
@@ -313,19 +310,58 @@ ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) {
|
||||
result.addTypes(outputTypes);
|
||||
|
||||
Region* body = result.addRegion();
|
||||
applyArgumentTypes(inputTypes, regionArgs);
|
||||
applyArgumentTypes(weightTypes, weightArgs);
|
||||
applyArgumentTypes(inputTypes, inputArgs);
|
||||
llvm::append_range(regionArgs, weightArgs);
|
||||
llvm::append_range(regionArgs, inputArgs);
|
||||
return parser.parseRegion(*body, regionArgs);
|
||||
}
|
||||
|
||||
void SpatComputeBatch::print(OpAsmPrinter& printer) {
|
||||
printer << " lanes " << getLaneCount() << " ";
|
||||
size_t weightsPerLane = getLaneCount() > 0 ? getWeights().size() / static_cast<size_t>(getLaneCount()) : 0;
|
||||
if (getLaneCount() > 1 && hasRepeatedTuple(getWeights(), weightsPerLane))
|
||||
printValueTupleRun(printer, getWeights(), weightsPerLane, ListDelimiter::Square);
|
||||
else
|
||||
printCompressedValueList(printer, getWeights(), ListDelimiter::Square);
|
||||
auto laneArg = getLaneArgument();
|
||||
SmallVector<Value> weightArgs;
|
||||
weightArgs.reserve(getWeights().size());
|
||||
for (unsigned index = 0; index < getWeights().size(); ++index) {
|
||||
auto weightArg = getWeightArgument(index);
|
||||
if (!weightArg)
|
||||
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
|
||||
weightArgs.push_back(*weightArg);
|
||||
}
|
||||
SmallVector<Value> inputArgs;
|
||||
inputArgs.reserve(getInputs().size());
|
||||
for (unsigned index = 0; index < getInputs().size(); ++index) {
|
||||
auto inputArg = getInputArgument(index);
|
||||
if (!inputArg)
|
||||
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
|
||||
inputArgs.push_back(*inputArg);
|
||||
}
|
||||
|
||||
SmallVector<BlockArgument> outputArgs;
|
||||
if (!laneArg)
|
||||
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
|
||||
if (getNumResults() != 0) {
|
||||
outputArgs.reserve(getNumResults());
|
||||
for (unsigned index = 0; index < getNumResults(); ++index) {
|
||||
auto outputArg = getOutputArgument(index);
|
||||
if (!outputArg)
|
||||
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
|
||||
outputArgs.push_back(*outputArg);
|
||||
}
|
||||
}
|
||||
|
||||
printer << " ";
|
||||
printArgumentBindings(printer, getBody().front(), getInputs());
|
||||
printer.printOperand(*laneArg);
|
||||
printer << " = 0 to " << getLaneCount();
|
||||
|
||||
printer << " ";
|
||||
printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square);
|
||||
printer << " ";
|
||||
printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren);
|
||||
|
||||
if (getNumResults() != 0) {
|
||||
printer << " shared_outs";
|
||||
printBlockArgumentList(printer, outputArgs);
|
||||
}
|
||||
|
||||
if (auto coreIdsAttr = (*this)->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName)) {
|
||||
printer << " coreIds ";
|
||||
@@ -337,10 +373,7 @@ void SpatComputeBatch::print(OpAsmPrinter& printer) {
|
||||
{getLaneCountAttrName().getValue(), getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdsAttrName});
|
||||
|
||||
printer << " : ";
|
||||
if (getLaneCount() > 1 && hasRepeatedTuple(TypeRange(getWeights()), weightsPerLane))
|
||||
printTypeTupleRun(printer, TypeRange(getWeights()), weightsPerLane, ListDelimiter::Square);
|
||||
else
|
||||
printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square);
|
||||
printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square);
|
||||
printer << " ";
|
||||
printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren);
|
||||
printer << " -> ";
|
||||
@@ -350,7 +383,12 @@ void SpatComputeBatch::print(OpAsmPrinter& printer) {
|
||||
}
|
||||
|
||||
ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result) {
|
||||
int64_t lowerBound = 0;
|
||||
int32_t laneCount = 0;
|
||||
OpAsmParser::Argument laneArg;
|
||||
SmallVector<OpAsmParser::Argument> weightArgs;
|
||||
SmallVector<OpAsmParser::Argument> inputArgs;
|
||||
SmallVector<OpAsmParser::Argument> outputArgs;
|
||||
SmallVector<OpAsmParser::Argument> regionArgs;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> weights;
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> inputs;
|
||||
@@ -359,14 +397,21 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result)
|
||||
SmallVector<Type> outputTypes;
|
||||
SmallVector<int32_t> coreIds;
|
||||
|
||||
if (parser.parseKeyword("lanes") || parser.parseInteger(laneCount))
|
||||
if (parser.parseArgument(laneArg) || parser.parseEqual() || parser.parseInteger(lowerBound)
|
||||
|| parser.parseKeyword("to") || parser.parseInteger(laneCount))
|
||||
return failure();
|
||||
if (lowerBound != 0)
|
||||
return parser.emitError(parser.getCurrentLocation(), "compute_batch currently requires a zero lower bound");
|
||||
|
||||
if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights))
|
||||
return failure();
|
||||
|
||||
if (parseCompressedOrTupleOperandList(parser, ListDelimiter::Square, weights))
|
||||
if (parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs))
|
||||
return failure();
|
||||
|
||||
if (parseArgumentBindings(parser, regionArgs, inputs))
|
||||
return failure();
|
||||
if (succeeded(parser.parseOptionalKeyword("shared_outs")))
|
||||
if (parseBlockArgumentList(parser, outputArgs))
|
||||
return failure();
|
||||
|
||||
bool hasCoreIds = parseOptionalKeywordAlias(parser, "coreIds", "core_ids");
|
||||
if (hasCoreIds && parseCompressedIntegerList(parser, coreIds))
|
||||
@@ -381,10 +426,15 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result)
|
||||
|
||||
if (weights.size() != weightTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match");
|
||||
if (weightArgs.size() != weights.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of weight bindings and weight operands must match");
|
||||
if (inputs.size() != inputTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match");
|
||||
if (regionArgs.size() != inputs.size())
|
||||
if (inputArgs.size() != inputs.size())
|
||||
return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match");
|
||||
if (outputArgs.size() != outputTypes.size())
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"number of shared output bindings and result types must match");
|
||||
if (hasCoreIds && result.attributes.get(onnx_mlir::kCoreIdsAttrName))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"coreIds cannot be specified both positionally and in attr-dict");
|
||||
@@ -403,119 +453,28 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result)
|
||||
result.addTypes(outputTypes);
|
||||
|
||||
Region* body = result.addRegion();
|
||||
applyArgumentTypes(inputTypes, regionArgs);
|
||||
applyBatchRegionArgumentTypes(
|
||||
inputTypes, weightTypes, outputTypes, laneArg, weightArgs, inputArgs, outputArgs, regionArgs, parser.getBuilder());
|
||||
return parser.parseRegion(*body, regionArgs);
|
||||
}
|
||||
|
||||
void SpatChannelSendTensorOp::print(OpAsmPrinter& printer) { printTensorSendOp(printer, *this); }
|
||||
|
||||
ParseResult SpatChannelSendTensorOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
return parseTensorSendOp(parser, result);
|
||||
}
|
||||
|
||||
void SpatChannelSendBatchOp::print(OpAsmPrinter& printer) {
|
||||
void SpatInParallelOp::print(OpAsmPrinter& printer) {
|
||||
printer << " ";
|
||||
printer.printOperand(getInput());
|
||||
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
|
||||
printer.printOptionalAttrDict(
|
||||
(*this)->getAttrs(),
|
||||
{getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printer.printType(getInput().getType());
|
||||
printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/false);
|
||||
printer.printOptionalAttrDict((*this)->getAttrs());
|
||||
}
|
||||
|
||||
ParseResult SpatChannelSendBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
OpAsmParser::UnresolvedOperand input;
|
||||
Type inputType;
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
|
||||
if (parser.parseOperand(input))
|
||||
ParseResult SpatInParallelOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
auto& builder = parser.getBuilder();
|
||||
std::unique_ptr<Region> region = std::make_unique<Region>();
|
||||
SmallVector<OpAsmParser::Argument, 4> regionArgs;
|
||||
if (parser.parseRegion(*region, regionArgs))
|
||||
return failure();
|
||||
|
||||
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
|
||||
if (hasMetadata) {
|
||||
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|
||||
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|
||||
|| parseCompressedIntegerList(parser, targetCoreIds))
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType))
|
||||
return failure();
|
||||
|
||||
if (hasMetadata
|
||||
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|
||||
|| result.attributes.get("targetCoreIds")))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"channel metadata cannot be specified both positionally and in attr-dict");
|
||||
if (hasMetadata) {
|
||||
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
|
||||
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
|
||||
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
|
||||
}
|
||||
|
||||
return parser.resolveOperand(input, inputType, result.operands);
|
||||
}
|
||||
|
||||
void SpatChannelSendTensorBatchOp::print(OpAsmPrinter& printer) { printTensorSendOp(printer, *this); }
|
||||
|
||||
ParseResult SpatChannelSendTensorBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
return parseTensorSendOp(parser, result);
|
||||
}
|
||||
|
||||
void SpatChannelReceiveTensorOp::print(OpAsmPrinter& printer) { printTensorReceiveOp(printer, *this); }
|
||||
|
||||
ParseResult SpatChannelReceiveTensorOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
return parseTensorReceiveOp(parser, result);
|
||||
}
|
||||
|
||||
void SpatChannelReceiveBatchOp::print(OpAsmPrinter& printer) {
|
||||
printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
|
||||
printer.printOptionalAttrDict(
|
||||
(*this)->getAttrs(),
|
||||
{getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()});
|
||||
printer << " : ";
|
||||
printer.printType(getOutput().getType());
|
||||
}
|
||||
|
||||
ParseResult SpatChannelReceiveBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
Type outputType;
|
||||
SmallVector<int64_t> channelIds;
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
|
||||
bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels"));
|
||||
if (hasMetadata) {
|
||||
if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from")
|
||||
|| parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to")
|
||||
|| parseCompressedIntegerList(parser, targetCoreIds))
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(outputType))
|
||||
return failure();
|
||||
|
||||
if (hasMetadata
|
||||
&& (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds")
|
||||
|| result.attributes.get("targetCoreIds")))
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"channel metadata cannot be specified both positionally and in attr-dict");
|
||||
if (hasMetadata) {
|
||||
result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds));
|
||||
result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds));
|
||||
result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds));
|
||||
}
|
||||
|
||||
result.addTypes(outputType);
|
||||
return success();
|
||||
}
|
||||
|
||||
void SpatChannelReceiveTensorBatchOp::print(OpAsmPrinter& printer) { printTensorReceiveOp(printer, *this); }
|
||||
|
||||
ParseResult SpatChannelReceiveTensorBatchOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
return parseTensorReceiveOp(parser, result);
|
||||
if (region->empty())
|
||||
OpBuilder(builder.getContext()).createBlock(region.get());
|
||||
result.addRegion(std::move(region));
|
||||
return parser.parseOptionalAttrDict(result.attributes);
|
||||
}
|
||||
|
||||
} // namespace spatial
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/Block.h"
|
||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "llvm/Support/LogicalResult.h"
|
||||
@@ -81,20 +85,11 @@ inline LogicalResult mvmOpVerifySize4(SpatMVMOp* emitter,
|
||||
return success();
|
||||
}
|
||||
|
||||
static FailureOr<ArrayRef<int64_t>> getWeightShapeForWeightedOp(Operation* weightedOp, size_t weightIndex) {
|
||||
if (auto computeOp = weightedOp->getParentOfType<SpatCompute>())
|
||||
return cast<ShapedType>(computeOp.getWeights()[weightIndex].getType()).getShape();
|
||||
|
||||
if (auto coreOp = weightedOp->getParentOfType<pim::PimCoreOp>())
|
||||
return cast<ShapedType>(coreOp.getWeights()[weightIndex].getType()).getShape();
|
||||
|
||||
if (auto batchOp = weightedOp->getParentOfType<SpatComputeBatch>()) {
|
||||
if (batchOp.getWeights().empty() || weightIndex >= batchOp.getWeights().size())
|
||||
return failure();
|
||||
return cast<ShapedType>(batchOp.getWeights()[weightIndex].getType()).getShape();
|
||||
}
|
||||
|
||||
return failure();
|
||||
static FailureOr<ArrayRef<int64_t>> getWeightShapeForWeightedOp(Value weight) {
|
||||
auto shapedType = dyn_cast<ShapedType>(weight.getType());
|
||||
if (!shapedType)
|
||||
return failure();
|
||||
return shapedType.getShape();
|
||||
}
|
||||
|
||||
static FailureOr<int32_t> getParentBatchLaneCount(Operation* op) {
|
||||
@@ -104,15 +99,98 @@ static FailureOr<int32_t> getParentBatchLaneCount(Operation* op) {
|
||||
return batchOp.getLaneCount();
|
||||
}
|
||||
|
||||
static LogicalResult verifyTensorChannelSizes(Operation* op,
|
||||
Type type,
|
||||
ArrayRef<int64_t> channelIds,
|
||||
ArrayRef<int32_t> sourceCoreIds,
|
||||
ArrayRef<int32_t> targetCoreIds,
|
||||
StringRef kind) {
|
||||
if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size())
|
||||
static bool isBatchOutputArgument(SpatComputeBatch batchOp, Value value) {
|
||||
if (batchOp.getNumResults() == 0)
|
||||
return false;
|
||||
auto blockArg = dyn_cast<BlockArgument>(value);
|
||||
if (!blockArg || blockArg.getOwner() != &batchOp.getBody().front())
|
||||
return false;
|
||||
|
||||
unsigned argNumber = blockArg.getArgNumber();
|
||||
auto firstOutputArg = batchOp.getOutputArgument(0);
|
||||
if (!firstOutputArg)
|
||||
return false;
|
||||
unsigned firstOutputArgNumber = firstOutputArg->getArgNumber();
|
||||
return argNumber >= firstOutputArgNumber && argNumber < firstOutputArgNumber + batchOp.getNumResults();
|
||||
}
|
||||
|
||||
static bool isConstantIndexLike(Value value) {
|
||||
APInt constantValue;
|
||||
return matchPattern(value, m_ConstantInt(&constantValue));
|
||||
}
|
||||
|
||||
static bool isSupportedLaneOffsetExpr(Value value, BlockArgument laneArg) {
|
||||
if (value == laneArg || isConstantIndexLike(value))
|
||||
return true;
|
||||
|
||||
auto extractOp = value.getDefiningOp<tensor::ExtractOp>();
|
||||
if (extractOp) {
|
||||
auto constantTensor = extractOp.getTensor().getDefiningOp<arith::ConstantOp>();
|
||||
auto denseAttr = constantTensor ? dyn_cast<DenseIntElementsAttr>(constantTensor.getValue()) : nullptr;
|
||||
if (!denseAttr || denseAttr.getType().getRank() != 1 || extractOp.getIndices().size() != 1)
|
||||
return false;
|
||||
return isSupportedLaneOffsetExpr(extractOp.getIndices().front(), laneArg);
|
||||
}
|
||||
|
||||
auto addOp = value.getDefiningOp<arith::AddIOp>();
|
||||
if (!addOp)
|
||||
return false;
|
||||
return (addOp.getLhs() == laneArg && isConstantIndexLike(addOp.getRhs()))
|
||||
|| (addOp.getRhs() == laneArg && isConstantIndexLike(addOp.getLhs()));
|
||||
}
|
||||
|
||||
static LogicalResult
|
||||
verifyStaticUnitStrideExtractSliceOp(tensor::ExtractSliceOp sliceOp, BlockArgument laneArg, StringRef kind) {
|
||||
auto sourceType = dyn_cast<RankedTensorType>(sliceOp.getSource().getType());
|
||||
auto resultType = dyn_cast<RankedTensorType>(sliceOp.getResult().getType());
|
||||
if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape())
|
||||
return sliceOp.emitOpError() << kind << " requires static ranked tensor types";
|
||||
if (!sliceOp.hasUnitStride())
|
||||
return sliceOp.emitOpError() << kind << " requires unit strides";
|
||||
|
||||
for (int64_t size : sliceOp.getStaticSizes())
|
||||
if (ShapedType::isDynamic(size))
|
||||
return sliceOp.emitOpError() << kind << " requires static slice sizes";
|
||||
|
||||
auto offsets = sliceOp.getOffsets();
|
||||
for (auto [offsetIndex, offset] : llvm::enumerate(offsets)) {
|
||||
bool supported = offsetIndex == 0 ? isSupportedLaneOffsetExpr(offset, laneArg) : isConstantIndexLike(offset);
|
||||
if (!supported)
|
||||
return sliceOp.emitOpError() << kind << " requires simple lane-dependent offsets";
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verifyStaticUnitStrideParallelInsertSliceOp(tensor::ParallelInsertSliceOp sliceOp,
|
||||
BlockArgument laneArg,
|
||||
StringRef kind) {
|
||||
RankedTensorType sourceType = sliceOp.getSourceType();
|
||||
RankedTensorType destType = sliceOp.getDestType();
|
||||
if (!sourceType.hasStaticShape() || !destType.hasStaticShape())
|
||||
return sliceOp.emitOpError() << kind << " requires static ranked tensor types";
|
||||
if (!sliceOp.hasUnitStride())
|
||||
return sliceOp.emitOpError() << kind << " requires unit strides";
|
||||
|
||||
for (int64_t size : sliceOp.getStaticSizes())
|
||||
if (ShapedType::isDynamic(size))
|
||||
return sliceOp.emitOpError() << kind << " requires static slice sizes";
|
||||
|
||||
auto offsets = sliceOp.getOffsets();
|
||||
for (auto [offsetIndex, offset] : llvm::enumerate(offsets)) {
|
||||
bool supported = offsetIndex == 0 ? isSupportedLaneOffsetExpr(offset, laneArg) : isConstantIndexLike(offset);
|
||||
if (!supported)
|
||||
return sliceOp.emitOpError() << kind << " requires simple lane-dependent offsets";
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verifyTensorChannelSizes(
|
||||
Operation* op, Type type, size_t channelCount, size_t sourceCoreCount, size_t targetCoreCount, StringRef kind) {
|
||||
if (channelCount != sourceCoreCount || channelCount != targetCoreCount)
|
||||
return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length");
|
||||
if (channelIds.empty())
|
||||
if (channelCount == 0)
|
||||
return op->emitError() << kind << " must carry at least one chunk";
|
||||
|
||||
auto shapedType = dyn_cast<ShapedType>(type);
|
||||
@@ -124,40 +202,34 @@ static LogicalResult verifyTensorChannelSizes(Operation* op,
|
||||
return op->emitError() << kind << " requires byte-sized elements";
|
||||
|
||||
int64_t totalBytes = shapedType.getNumElements() * elementBits / 8;
|
||||
if (totalBytes % static_cast<int64_t>(channelIds.size()) != 0)
|
||||
if (totalBytes % static_cast<int64_t>(channelCount) != 0)
|
||||
return op->emitError() << kind << " tensor byte size must be divisible by the number of channel ids";
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verifyBatchChannelSizes(Operation* op,
|
||||
ArrayRef<int64_t> channelIds,
|
||||
ArrayRef<int32_t> sourceCoreIds,
|
||||
ArrayRef<int32_t> targetCoreIds) {
|
||||
if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size())
|
||||
static LogicalResult
|
||||
verifyBatchChannelSizes(Operation* op, size_t channelCount, size_t sourceCoreCount, size_t targetCoreCount) {
|
||||
if (channelCount != sourceCoreCount || channelCount != targetCoreCount)
|
||||
return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length");
|
||||
|
||||
auto laneCount = getParentBatchLaneCount(op);
|
||||
if (failed(laneCount))
|
||||
return op->emitError("must be nested inside spat.compute_batch");
|
||||
if (channelIds.size() != static_cast<size_t>(*laneCount))
|
||||
if (channelCount != static_cast<size_t>(*laneCount))
|
||||
return op->emitError("channel metadata length must match parent laneCount");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verifyTensorBatchChannelSizes(Operation* op,
|
||||
Type type,
|
||||
ArrayRef<int64_t> channelIds,
|
||||
ArrayRef<int32_t> sourceCoreIds,
|
||||
ArrayRef<int32_t> targetCoreIds,
|
||||
StringRef kind) {
|
||||
if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size())
|
||||
static LogicalResult verifyTensorBatchChannelSizes(
|
||||
Operation* op, Type type, size_t channelCount, size_t sourceCoreCount, size_t targetCoreCount, StringRef kind) {
|
||||
if (channelCount != sourceCoreCount || channelCount != targetCoreCount)
|
||||
return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length");
|
||||
|
||||
auto laneCount = getParentBatchLaneCount(op);
|
||||
if (failed(laneCount))
|
||||
return op->emitError("must be nested inside spat.compute_batch");
|
||||
if (channelIds.empty() || channelIds.size() % static_cast<size_t>(*laneCount) != 0)
|
||||
if (channelCount == 0 || channelCount % static_cast<size_t>(*laneCount) != 0)
|
||||
return op->emitError() << kind << " channel metadata length must be a positive multiple of parent laneCount";
|
||||
|
||||
auto shapedType = dyn_cast<ShapedType>(type);
|
||||
@@ -168,7 +240,7 @@ static LogicalResult verifyTensorBatchChannelSizes(Operation* op,
|
||||
if (elementBits <= 0 || elementBits % 8 != 0)
|
||||
return op->emitError() << kind << " requires byte-sized elements";
|
||||
|
||||
int64_t chunkCount = static_cast<int64_t>(channelIds.size()) / *laneCount;
|
||||
int64_t chunkCount = static_cast<int64_t>(channelCount) / *laneCount;
|
||||
int64_t totalBytes = shapedType.getNumElements() * elementBits / 8;
|
||||
if (totalBytes % chunkCount != 0)
|
||||
return op->emitError() << kind << " tensor byte size must be divisible by the chunk count per lane";
|
||||
@@ -176,28 +248,61 @@ static LogicalResult verifyTensorBatchChannelSizes(Operation* op,
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verifyBatchBody(Operation* op, Block& block, TypeRange outputTypes, size_t weightsPerLane) {
|
||||
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
|
||||
if (!yieldOp)
|
||||
return op->emitError("body must terminate with spat.yield");
|
||||
if (outputTypes.empty()) {
|
||||
static Region* getParentRegion(Value value) {
|
||||
if (auto blockArg = dyn_cast<BlockArgument>(value))
|
||||
return blockArg.getOwner()->getParent();
|
||||
if (Operation* definingOp = value.getDefiningOp())
|
||||
return definingOp->getParentRegion();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static bool isDefinedInsideRegion(Value value, Region& region) {
|
||||
Region* parentRegion = getParentRegion(value);
|
||||
return parentRegion && (®ion == parentRegion || region.isAncestor(parentRegion));
|
||||
}
|
||||
|
||||
static bool isConstantExternalValue(Value value) {
|
||||
Operation* definingOp = value.getDefiningOp();
|
||||
return definingOp && definingOp->hasTrait<OpTrait::ConstantLike>();
|
||||
}
|
||||
|
||||
static LogicalResult verifyOnlyConstantExternalValues(Operation* ownerOp, Region& region, StringRef kind) {
|
||||
bool hasFailure = false;
|
||||
region.walk([&](Operation* op) {
|
||||
for (OpOperand& operand : op->getOpOperands()) {
|
||||
Value value = operand.get();
|
||||
if (isDefinedInsideRegion(value, region) || isConstantExternalValue(value))
|
||||
continue;
|
||||
|
||||
InFlightDiagnostic diagnostic = ownerOp->emitOpError()
|
||||
<< kind << " body may only directly reference external constants";
|
||||
diagnostic.attachNote(op->getLoc())
|
||||
<< "non-constant external operand #" << operand.getOperandNumber() << " is used by " << op->getName();
|
||||
hasFailure = true;
|
||||
}
|
||||
});
|
||||
return success(!hasFailure);
|
||||
}
|
||||
|
||||
static LogicalResult verifyBatchBody(SpatComputeBatch batchOp, Block& block) {
|
||||
if (batchOp.getNumResults() == 0) {
|
||||
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
|
||||
if (!yieldOp)
|
||||
return batchOp.emitError("resultless compute_batch body must terminate with spat.yield");
|
||||
if (yieldOp.getNumOperands() != 0)
|
||||
return op->emitError("body yield must be empty when compute_batch has no results");
|
||||
return batchOp.emitError("resultless compute_batch body yield must be empty");
|
||||
}
|
||||
else {
|
||||
if (yieldOp.getNumOperands() != 1)
|
||||
return op->emitError("body yield must produce exactly one value");
|
||||
if (yieldOp.getOperand(0).getType() != outputTypes[0])
|
||||
return op->emitError("body yield type must match output type");
|
||||
else if (!isa_and_nonnull<SpatInParallelOp>(block.getTerminator())) {
|
||||
return batchOp.emitError("resultful compute_batch body must terminate with spat.in_parallel");
|
||||
}
|
||||
|
||||
auto laneArg = batchOp.getLaneArgument();
|
||||
if (!laneArg)
|
||||
return batchOp.emitError("compute_batch body must have a lane block argument");
|
||||
for (auto& bodyOp : block) {
|
||||
if (auto wvmm = dyn_cast<SpatVMMOp>(&bodyOp))
|
||||
if (wvmm.getWeightIndex() < 0 || static_cast<size_t>(wvmm.getWeightIndex()) >= weightsPerLane)
|
||||
return op->emitError("compute_batch body Wvmm weightIndex is out of range for one lane");
|
||||
if (auto wmvm = dyn_cast<SpatMVMOp>(&bodyOp))
|
||||
if (wmvm.getWeightIndex() < 0 || static_cast<size_t>(wmvm.getWeightIndex()) >= weightsPerLane)
|
||||
return op->emitError("compute_batch body Wmvm weightIndex is out of range for one lane");
|
||||
if (auto extractSlice = dyn_cast<tensor::ExtractSliceOp>(&bodyOp))
|
||||
if (failed(verifyStaticUnitStrideExtractSliceOp(extractSlice, *laneArg, "tensor.extract_slice")))
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
@@ -205,9 +310,9 @@ static LogicalResult verifyBatchBody(Operation* op, Block& block, TypeRange outp
|
||||
} // namespace
|
||||
|
||||
LogicalResult SpatMVMOp::verify() {
|
||||
auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex());
|
||||
auto matrixShapeOpt = getWeightShapeForWeightedOp(getWeight());
|
||||
if (failed(matrixShapeOpt))
|
||||
return emitError("SpatMVMOp was not within a SpatCompute or Core op");
|
||||
return emitError("weight must be a shaped value");
|
||||
auto matrixShape = *matrixShapeOpt;
|
||||
auto vectorShape = getInput().getType().getShape();
|
||||
auto outputShape = getOutput().getType().getShape();
|
||||
@@ -220,9 +325,9 @@ LogicalResult SpatMVMOp::verify() {
|
||||
}
|
||||
|
||||
LogicalResult SpatVMMOp::verify() {
|
||||
auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex());
|
||||
auto matrixShapeOpt = getWeightShapeForWeightedOp(getWeight());
|
||||
if (failed(matrixShapeOpt))
|
||||
return emitError("SpatVMMOp was not within a SpatCompute or Core op");
|
||||
return emitError("weight must be a shaped value");
|
||||
auto matrixShape = *matrixShapeOpt;
|
||||
auto vectorShape = getInput().getType().getShape();
|
||||
auto outputShape = getOutput().getType().getShape();
|
||||
@@ -338,8 +443,36 @@ LogicalResult SpatConcatOp::verify() {
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult verifyComputeResultsUses(Operation* op) {
|
||||
if (!isa<SpatCompute, SpatComputeBatch>(op))
|
||||
return op->emitError("verifyComputeResultUses: Op is not a SpatCompute/SpatComputeBatch operation");
|
||||
if (!llvm::all_of(op->getResults(), [](Value result) {
|
||||
return llvm::all_of(result.getUsers(), [](Operation* op) {
|
||||
return !(op->getParentOfType<SpatCompute>() || op->getParentOfType<SpatComputeBatch>());
|
||||
});
|
||||
})) {
|
||||
return op->emitError("ComputeResult used directly inside another Compute");
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult SpatCompute::verify() {
|
||||
auto& block = getBody().front();
|
||||
unsigned expectedArgCount = getWeights().size() + getInputs().size();
|
||||
if (block.getNumArguments() != expectedArgCount)
|
||||
return emitError("compute body must have weight and input block arguments");
|
||||
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) {
|
||||
auto blockArg = getWeightArgument(weightIndex);
|
||||
if (!blockArg || blockArg->getType() != weight.getType())
|
||||
return emitError("compute weight block argument types must match weight operand types exactly");
|
||||
}
|
||||
for (auto [inputIndex, input] : llvm::enumerate(getInputs())) {
|
||||
auto blockArg = getInputArgument(inputIndex);
|
||||
if (!blockArg || blockArg->getType() != input.getType())
|
||||
return emitError("compute input block argument types must match input operand types exactly");
|
||||
}
|
||||
|
||||
if (block.mightHaveTerminator()) {
|
||||
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
|
||||
if (!yieldOp)
|
||||
@@ -372,54 +505,59 @@ LogicalResult SpatCompute::verify() {
|
||||
}
|
||||
}
|
||||
|
||||
for (auto arg : block.getArguments())
|
||||
if (arg.use_empty())
|
||||
for (unsigned inputIndex = 0; inputIndex < getInputs().size(); ++inputIndex)
|
||||
if (auto inputArg = getInputArgument(inputIndex); !inputArg || inputArg->use_empty())
|
||||
return emitError("ComputeOp block argument is not used");
|
||||
|
||||
if (failed(verifyOnlyConstantExternalValues(this->getOperation(), getBody(), "spat.compute")))
|
||||
return failure();
|
||||
if (failed(verifyComputeResultsUses(this->getOperation())))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult SpatChannelSendTensorOp::verify() {
|
||||
return verifyTensorChannelSizes(getOperation(),
|
||||
getInput().getType(),
|
||||
getChannelIds(),
|
||||
getSourceCoreIds(),
|
||||
getTargetCoreIds(),
|
||||
getChannelIds().size(),
|
||||
getSourceCoreIds().size(),
|
||||
getTargetCoreIds().size(),
|
||||
"channel_send_tensor");
|
||||
}
|
||||
|
||||
LogicalResult SpatChannelReceiveTensorOp::verify() {
|
||||
return verifyTensorChannelSizes(getOperation(),
|
||||
getOutput().getType(),
|
||||
getChannelIds(),
|
||||
getSourceCoreIds(),
|
||||
getTargetCoreIds(),
|
||||
getChannelIds().size(),
|
||||
getSourceCoreIds().size(),
|
||||
getTargetCoreIds().size(),
|
||||
"channel_receive_tensor");
|
||||
}
|
||||
|
||||
LogicalResult SpatChannelSendBatchOp::verify() {
|
||||
return verifyBatchChannelSizes(getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
|
||||
return verifyBatchChannelSizes(
|
||||
getOperation(), getChannelIds().size(), getSourceCoreIds().size(), getTargetCoreIds().size());
|
||||
}
|
||||
|
||||
LogicalResult SpatChannelSendTensorBatchOp::verify() {
|
||||
return verifyTensorBatchChannelSizes(getOperation(),
|
||||
getInput().getType(),
|
||||
getChannelIds(),
|
||||
getSourceCoreIds(),
|
||||
getTargetCoreIds(),
|
||||
getChannelIds().size(),
|
||||
getSourceCoreIds().size(),
|
||||
getTargetCoreIds().size(),
|
||||
"channel_send_tensor_batch");
|
||||
}
|
||||
|
||||
LogicalResult SpatChannelReceiveBatchOp::verify() {
|
||||
return verifyBatchChannelSizes(getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds());
|
||||
return verifyBatchChannelSizes(
|
||||
getOperation(), getChannelIds().size(), getSourceCoreIds().size(), getTargetCoreIds().size());
|
||||
}
|
||||
|
||||
LogicalResult SpatChannelReceiveTensorBatchOp::verify() {
|
||||
return verifyTensorBatchChannelSizes(getOperation(),
|
||||
getOutput().getType(),
|
||||
getChannelIds(),
|
||||
getSourceCoreIds(),
|
||||
getTargetCoreIds(),
|
||||
getChannelIds().size(),
|
||||
getSourceCoreIds().size(),
|
||||
getTargetCoreIds().size(),
|
||||
"channel_receive_tensor_batch");
|
||||
}
|
||||
|
||||
@@ -429,35 +567,6 @@ LogicalResult SpatComputeBatch::verify() {
|
||||
return emitError("laneCount must be positive");
|
||||
|
||||
auto laneCountSz = static_cast<size_t>(count);
|
||||
if (getWeights().size() % laneCountSz != 0)
|
||||
return emitError("number of weights must be a multiple of laneCount");
|
||||
|
||||
if (!getInputs().empty() && getInputs().size() != laneCountSz)
|
||||
return emitError("number of inputs must be either 0 or laneCount");
|
||||
if (!getOutputs().empty() && getOutputs().size() != laneCountSz)
|
||||
return emitError("number of outputs must be either 0 or laneCount");
|
||||
|
||||
size_t weightsPerLane = getWeights().size() / laneCountSz;
|
||||
for (size_t weightIndex = 0; weightIndex < weightsPerLane; ++weightIndex) {
|
||||
Type weightType = getWeights()[weightIndex].getType();
|
||||
for (size_t lane = 1; lane < laneCountSz; ++lane)
|
||||
if (getWeights()[lane * weightsPerLane + weightIndex].getType() != weightType)
|
||||
return emitError("corresponding weights across lanes must have the same type");
|
||||
}
|
||||
|
||||
if (!getInputs().empty()) {
|
||||
Type inputType = getInputs()[0].getType();
|
||||
for (Value in : getInputs().drop_front())
|
||||
if (in.getType() != inputType)
|
||||
return emitError("all inputs must have the same type");
|
||||
}
|
||||
|
||||
if (!getOutputs().empty()) {
|
||||
Type outputType = getOutputs()[0].getType();
|
||||
for (Value out : getOutputs().drop_front())
|
||||
if (out.getType() != outputType)
|
||||
return emitError("all outputs must have the same type");
|
||||
}
|
||||
|
||||
if (auto coreIdAttr = (*this)->getAttr(kCoreIdsAttrName)) {
|
||||
auto coreIdsAttr = dyn_cast<DenseI32ArrayAttr>(coreIdAttr);
|
||||
@@ -465,27 +574,72 @@ LogicalResult SpatComputeBatch::verify() {
|
||||
return emitError("compute_batch coreIds attribute must be a dense i32 array");
|
||||
if (coreIdsAttr.size() != static_cast<int64_t>(laneCountSz))
|
||||
return emitError("compute_batch coreIds array length must match laneCount");
|
||||
if (llvm::any_of(coreIdsAttr.asArrayRef(), [](int32_t coreId) { return coreId <= 0; }))
|
||||
return emitError("compute_batch coreIds values must be positive");
|
||||
llvm::SmallDenseSet<int32_t, 8> seenCoreIds;
|
||||
if (llvm::any_of(coreIdsAttr.asArrayRef(), [](int32_t coreId) { return coreId < 0; }))
|
||||
return emitError("compute_batch coreIds values must be non-negative");
|
||||
DenseSet<int32_t> seenCoreIds;
|
||||
for (int32_t coreId : coreIdsAttr.asArrayRef())
|
||||
if (!seenCoreIds.insert(coreId).second)
|
||||
return emitError("compute_batch coreIds values must be distinct");
|
||||
return emitError("compute_batch coreIds values must be unique");
|
||||
}
|
||||
|
||||
Block& block = getBody().front();
|
||||
if (getInputs().empty()) {
|
||||
if (block.getNumArguments() != 0)
|
||||
return emitError("compute_batch body must have no block arguments when there are no inputs");
|
||||
if (block.getNumArguments() == 0)
|
||||
return emitError("compute_batch body must have exactly one lane block argument");
|
||||
unsigned expectedArgCount = 1 + getWeights().size() + getInputs().size() + getNumResults();
|
||||
if (block.getNumArguments() != expectedArgCount)
|
||||
return emitError("compute_batch body block arguments must match lane, weight, input, and output operands/results");
|
||||
auto laneArg = getLaneArgument();
|
||||
if (!laneArg || !laneArg->getType().isIndex())
|
||||
return emitError("compute_batch first block argument must have index type");
|
||||
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) {
|
||||
auto blockArg = getWeightArgument(weightIndex);
|
||||
if (!blockArg || blockArg->getType() != weight.getType())
|
||||
return emitError("compute_batch weight block argument types must match weight operand types exactly");
|
||||
}
|
||||
else {
|
||||
if (block.getNumArguments() != 1)
|
||||
return emitError("compute_batch body must have exactly one block argument");
|
||||
if (block.getArgument(0).getType() != getInputs()[0].getType())
|
||||
return emitError("body block argument type must match input type");
|
||||
for (auto [inputIndex, input] : llvm::enumerate(getInputs())) {
|
||||
auto blockArg = getInputArgument(inputIndex);
|
||||
if (!blockArg || blockArg->getType() != input.getType())
|
||||
return emitError("compute_batch input block argument types must match input operand types exactly");
|
||||
}
|
||||
for (auto [resultIndex, resultType] : llvm::enumerate(getResultTypes())) {
|
||||
auto blockArg = getOutputArgument(resultIndex);
|
||||
if (!blockArg || blockArg->getType() != resultType)
|
||||
return emitError("compute_batch output block argument types must match result types exactly");
|
||||
}
|
||||
|
||||
return verifyBatchBody(getOperation(), block, getResultTypes(), weightsPerLane);
|
||||
if (failed(verifyComputeResultsUses(this->getOperation())))
|
||||
return failure();
|
||||
if (failed(verifyOnlyConstantExternalValues(this->getOperation(), getBody(), "spat.compute_batch")))
|
||||
return failure();
|
||||
return verifyBatchBody(*this, block);
|
||||
}
|
||||
|
||||
LogicalResult SpatInParallelOp::verify() {
|
||||
auto batchOp = getOperation()->getParentOfType<SpatComputeBatch>();
|
||||
if (!batchOp)
|
||||
return emitOpError("expected spat.compute_batch parent");
|
||||
if (batchOp.getNumResults() == 0)
|
||||
return emitOpError("requires a resultful spat.compute_batch parent");
|
||||
|
||||
auto laneArg = batchOp.getLaneArgument();
|
||||
if (!laneArg)
|
||||
return emitOpError("expected compute_batch lane block argument");
|
||||
for (Operation& op : getRegion().front().getOperations()) {
|
||||
auto insertSliceOp = dyn_cast<tensor::ParallelInsertSliceOp>(&op);
|
||||
if (!insertSliceOp)
|
||||
return emitOpError("expected only tensor.parallel_insert_slice ops");
|
||||
|
||||
if (failed(verifyStaticUnitStrideParallelInsertSliceOp(insertSliceOp, *laneArg, "tensor.parallel_insert_slice")))
|
||||
return failure();
|
||||
|
||||
MutableOperandRange destinations = insertSliceOp.getUpdatedDestinations();
|
||||
for (OpOperand& destination : destinations)
|
||||
if (!isBatchOutputArgument(batchOp, destination.get()))
|
||||
return op.emitOpError("may only insert into a compute_batch output block argument");
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace spatial
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user