Compare commits

73 Commits

Author SHA1 Message Date
NiccoloN de0a2f4561 remove useless guard in gemm lowering
Validate Operations / validate-operations (push) Has been cancelled
2026-05-15 18:22:13 +02:00
NiccoloN 1c4a5bde76 compact softmax op lowering
Validate Operations / validate-operations (push) Has been cancelled
2026-05-15 18:14:59 +02:00
NiccoloN 78242e2887 compact resize op lowering
Validate Operations / validate-operations (push) Has been cancelled
2026-05-15 17:36:12 +02:00
NiccoloN fe244d5aa1 new ops tests for matmul, grouped conv, concat and reshape
Validate Operations / validate-operations (push) Has been cancelled
related fixes
2026-05-14 15:54:06 +02:00
NiccoloN d09e76c8f9 fix matmul rewriting/lowering
Validate Operations / validate-operations (push) Has been cancelled
fix reshape lowering
add support for grouped-convolution lowering
quieter verifier with capped error messages
2026-05-14 14:09:30 +02:00
NiccoloN c5e608fa5b replace greedy pattern rewrites with partial conversions
Validate Operations / validate-operations (push) Has been cancelled
better failure messages
2026-05-14 11:48:16 +02:00
ilgeco 43f3ccdd21 new yolo nodes with 100% more statics
Validate Operations / validate-operations (push) Has been cancelled
2026-05-14 10:47:31 +02:00
NiccoloN 8d95c604a6 automatic code formatting
Validate Operations / validate-operations (push) Has been cancelled
2026-05-13 21:51:19 +02:00
NiccoloN 55eda487dc use seed in validate.py for deterministic tests 2026-05-13 21:49:36 +02:00
NiccoloN 061139aefb fix wrong send/receive reordering in post dcp merge instructions compaction 2026-05-13 21:48:49 +02:00
NiccoloN ea61540e08 fix failing validations after last commit
Validate Operations / validate-operations (push) Has been cancelled
2026-05-13 17:46:19 +02:00
NiccoloN 324178cba8 fix instructions explosion in pim host constant folding pass
Validate Operations / validate-operations (push) Has been cancelled
2026-05-13 17:31:05 +02:00
NiccoloN e71ba07cd5 fix pim-simulator stale tests
Validate Operations / validate-operations (push) Has been cancelled
2026-05-13 16:59:53 +02:00
NiccoloN 64a3805619 fix pim-simulator stale tests
Validate Operations / validate-operations (push) Has been cancelled
2026-05-13 16:59:43 +02:00
NiccoloN 9f9e7c0892 Merge remote-tracking branch 'origin/main'
Validate Operations / validate-operations (push) Has been cancelled
2026-05-13 16:38:33 +02:00
NiccoloN 03eab42971 remove host core generation
strip config.json emitted by raptor
add actual pimsim-nn configs in validation pimsim-configs
2026-05-13 16:31:01 +02:00
ilgeco c15aba5d96 pim-simulator removed useless comment
Validate Operations / validate-operations (push) Has been cancelled
2026-05-13 15:05:17 +02:00
ilgeco 4821e8a55e Merge branch 'main' of chef.heaplab.deib.polimi.it:nnicolosi/Raptor 2026-05-13 15:03:20 +02:00
ilgeco 88bb223bb1 Fix multiple address bug 2026-05-13 14:15:41 +02:00
NiccoloN 623ee62a04 point pimsim-nn submodule to the HEAPLab fork
Validate Operations / validate-operations (push) Has been cancelled
2026-05-13 14:10:27 +02:00
ilgeco ad56888b0b Broken pim-sim commit 2026-05-13 13:32:09 +02:00
NiccoloN f993840641 update pimsim-nn submodule
Validate Operations / validate-operations (push) Has been cancelled
2026-05-13 11:39:47 +02:00
NiccoloN 0c7db55a24 binary pim code for reduced memory usage
Validate Operations / validate-operations (push) Has been cancelled
fast pim code emission
2026-05-13 11:15:54 +02:00
NiccoloN 41de3cb150 add memory coalescing pass
Validate Operations / validate-operations (push) Has been cancelled
better reports
refactor for more code-reuse and patter usage
fixes
2026-05-12 18:17:00 +02:00
NiccoloN 4f3570520c add pim.vmm verifier and fix vmm lowering
reuse code for subviews
2026-05-12 15:13:50 +02:00
NiccoloN 628dc630a4 compact syntax for spatial tensor ops
Validate Operations / validate-operations (push) Has been cancelled
better IR compaction after dcp merge
remove pim.mvm op
better memory report
2026-05-12 13:35:25 +02:00
NiccoloN 80a7298552 fix pool lowering
Validate Operations / validate-operations (push) Has been cancelled
better reports (dcp merge and memory)
2026-05-12 12:32:23 +02:00
ilgeco 8ad504fcdf Yolo splitted at conv boundary
Validate Operations / validate-operations (push) Has been cancelled
2026-05-12 11:33:15 +02:00
ilgeco e6f442c5d2 Merge branch 'main' of chef.heaplab.deib.polimi.it:nnicolosi/Raptor
Validate Operations / validate-operations (push) Has been cancelled
2026-05-12 10:43:01 +02:00
ilgeco f6b97b3813 Fix report memory 2026-05-12 10:42:38 +02:00
ilgeco 26317ea7d0 Shorter Memory Reporty 2026-05-12 10:38:35 +02:00
NiccoloN 909c4acfdd huge refactor for high RewritePatterns usage and less ad-hoc cpp code
Validate Operations / validate-operations (push) Has been cancelled
remove Spatial many ops in favor of tensor ops like in pim
2026-05-12 10:35:44 +02:00
ilgeco feaff820e1 pim-sim TraceTime + faer
Validate Operations / validate-operations (push) Has been cancelled
2026-05-11 18:19:30 +02:00
NiccoloN 1e279ae9bb minor fix
Validate Operations / validate-operations (push) Has been cancelled
2026-05-11 16:01:42 +02:00
NiccoloN 57f0cca8c0 remove duplicated code
Validate Operations / validate-operations (push) Has been cancelled
quieter validation scripts (with optional verbose flag)
2026-05-11 15:52:26 +02:00
NiccoloN 5ff364027b big cleanup: remove remaining pim many operations, simplify bufferization logic
Validate Operations / validate-operations (push) Has been cancelled
2026-05-11 14:38:13 +02:00
NiccoloN b1272d2283 fast pim bufferization using tensors
Validate Operations / validate-operations (push) Successful in 24m29s
2026-05-08 14:21:45 +02:00
NiccoloN 58e6587697 Merge remote-tracking branch 'origin/main' 2026-05-08 13:12:47 +02:00
NiccoloN f6c8cc4aa5 sightly better bufferization
minor fixes
2026-05-07 17:53:47 +02:00
ilgeco 566630b99a Removed SpatNopPattern
Validate Operations / validate-operations (push) Successful in 22m36s
2026-05-07 17:03:35 +02:00
ilgeco 74931ad75b Single Concat Fix 2026-05-07 16:47:01 +02:00
NiccoloN f2fe147961 compact pim IR
Validate Operations / validate-operations (push) Successful in 22m15s
2026-05-06 17:16:51 +02:00
ilgeco 7bb58e80de Merge branch 'main' of chef.heaplab.deib.polimi.it:nnicolosi/Raptor into main
Validate Operations / validate-operations (push) Successful in 24m25s
2026-05-06 12:25:29 +02:00
NiccoloN b2dc9c38b6 better spatial IR compaction with better custom syntax, scf.for and
Validate Operations / validate-operations (push) Has been cancelled
spat.map
2026-05-06 12:21:58 +02:00
ilgeco 3cb6a1abc5 Memory report 2026-05-06 10:47:04 +02:00
NiccoloN 285773fa55 rework actually broken dcp merge + compute re-batching (still to refine) 2026-05-04 19:30:40 +02:00
NiccoloN bdacb9871d fix dcp merge bug
Validate Operations / validate-operations (push) Failing after 15m54s
2026-05-04 15:58:14 +02:00
NiccoloN 5b9bb0c191 refactor spatial ops
Validate Operations / validate-operations (push) Successful in 24m55s
2026-05-04 14:19:30 +02:00
NiccoloN f789954ad7 Refactor ONNXToSpatial Common and diagnostics 2026-05-04 13:42:43 +02:00
ilgeco b6ba1e4fea Fix DCPTest using old constructor
Validate Operations / validate-operations (push) Successful in 24m15s
2026-05-04 10:58:51 +02:00
NiccoloN 717ad160cd Refactor PIM/Common (splitting in files, adding helpers, adding brief
Validate Operations / validate-operations (push) Failing after 18m36s
docs)
2026-05-04 09:20:43 +02:00
NiccoloN 905fa9f9a7 Merge remote changes
Validate Operations / validate-operations (push) Failing after 18m42s
2026-05-03 23:09:32 +02:00
NiccoloN 62b0a6e19d merge remote changes 2026-05-03 22:30:46 +02:00
NiccoloN b605585b1f compact spatial IR through different new operations and dedicated syntax
fast spatial node merging with batch operations
2026-05-03 14:14:14 +02:00
ilgeco 08b0fcd850 Parallel bufferization
Validate Operations / validate-operations (push) Successful in 21m49s
2026-04-30 11:48:17 +02:00
ilgeco 9dccc2c701 Translate global constant to symble 2026-04-28 12:42:01 +02:00
ilgeco 5c839e62c1 Func Input converted to symbol 2026-04-27 13:48:03 +02:00
NiccoloN 15e8edb9c4 better spat computes merging
Validate Operations / validate-operations (push) Successful in 21m14s
2026-04-25 19:24:09 +02:00
ilgeco 951baca106 Merge Node update fix comparison bug
Validate Operations / validate-operations (push) Successful in 20m21s
2026-04-23 19:52:16 +02:00
ilgeco fc5bccb487 Merge Node update status file
Validate Operations / validate-operations (push) Has started running
2026-04-23 19:42:56 +02:00
ilgeco 49dea15b95 DCP Merge status
Validate Operations / validate-operations (push) Successful in 22m29s
2026-04-23 18:40:33 +02:00
NiccoloN 5545b0f672 fix MatMul pattern non-contiguous extract_slices
Validate Operations / validate-operations (push) Successful in 22m31s
2026-04-23 14:44:30 +02:00
NiccoloN cff929a083 fix sigmoid implementation stability in pim-simulator
Validate Operations / validate-operations (push) Successful in 23m4s
2026-04-23 10:34:29 +02:00
NiccoloN 89b3501aa8 fix weightAlways attribute in spatial 2026-04-23 10:04:47 +02:00
NiccoloN 412ca957f6 multiple-output spat computes
Validate Operations / validate-operations (push) Successful in 22m38s
2026-04-23 09:28:57 +02:00
NiccoloN 0f13269040 faster DCPAnalysis on partial graph
Validate Operations / validate-operations (push) Successful in 27m37s
2026-04-21 18:36:16 +02:00
NiccoloN dafc1d15b7 faster pim-simulator 2026-04-21 18:35:51 +02:00
NiccoloN 3fa140be25 Merge branch 'main' of chef.heaplab.deib.polimi.it:nnicolosi/Raptor 2026-04-21 16:23:16 +02:00
ilgeco df703f0be9 pim-simulator add progress report
Validate Operations / validate-operations (push) Successful in 21m24s
2026-04-21 16:23:03 +02:00
NiccoloN 9fa850c140 Merge branch 'main' of chef.heaplab.deib.polimi.it:nnicolosi/Raptor 2026-04-21 15:59:08 +02:00
ilgeco 186c88d860 Merge branch 'main' of chef.heaplab.deib.polimi.it:nnicolosi/Raptor into main
Validate Operations / validate-operations (push) Has been cancelled
2026-04-21 15:44:40 +02:00
ilgeco 0368f96593 pims-simulator symlink memory opt 2026-04-21 15:43:10 +02:00
NiccoloN 25ade1bd63 fix memory allocation in pim codegen
fix crossbar allocation to only consider weights from vmm and mvm
2026-04-21 13:31:10 +02:00
281 changed files with 21180 additions and 4982 deletions
+12
View File
@@ -1,5 +1,17 @@
.zed
.idea
**/.vscode
.claude
.codex
AGENTS.md
CMakeUserPresets.json
build
build_release
cmake-build-debug
cmake-build-release
compile.sh
**/__*
+1 -1
View File
@@ -3,4 +3,4 @@
url = https://github.com/onnx/onnx-mlir.git
[submodule "backend-simulators/pim/pimsim-nn"]
path = backend-simulators/pim/pimsim-nn
url = https://github.com/wangxy-2000/pimsim-nn.git
url = https://github.com/HEAPLab/pimsim-nn.git
+158
View File
@@ -1,5 +1,163 @@
# Raptor
Raptor is a domain-specific MLIR compiler for neural networks (ONNX format)
targeting in-memory computing / processing-in-memory (PIM) architectures.
It progressively lowers ONNX-MLIR through a set of MLIR dialects down to
target-specific artifacts (currently JSON code for the `pimsim-nn` simulator).
## Overview
PIM architectures perform most of the computation directly in memory.
Raptor's first supported target is `pimsim-nn`, which simulates a chip with:
- a shared host memory,
- a number of cores that do most of the computation directly in their memory
(vector ops, vmm/mvm on ReRAM crossbars),
- no branching instructions (branchless architecture) and no hardware loop
support — any repeated work (e.g. convolutions) must be unrolled into
explicit per-iteration instructions.
Because of this, the amount of emitted instructions explodes quickly and the
compiler must optimize aggressively at every stage to keep compilation
tractable.
A second target, `PulPim`, is planned for an accelerator with RISC-V cores
each carrying its own in-memory computing unit and crossbars. It will live in
a dedicated dialect (future work).
### Targets and simulators
`pimsim-nn` (under `backend-simulators/pim/pimsim-nn`) is used for
**performance** estimates (latency, energy), but does not functionally execute
the JSON code it consumes. To validate the numerical correctness of the JSON
code produced by Raptor (or, for comparison, by the `pimcomp` compiler), we use
a Rust simulator we maintain in-tree at
`backend-simulators/pim/pim-simulator`.
## Compilation pipeline
The PIM-related sources live under `src/PIM` and the tests under `test/PIM`.
When working on this codebase, most changes should stay confined to those
trees (you only need to look outside, e.g. at `onnx-mlir` or `llvm`, for
framework-level details).
High-level lowering flow:
```
ONNX-MLIR ──► Spatial ──► Pim (tensor) ──► Pim (bufferized) ──► PIM JSON
```
1. **ONNX → Spatial** (`src/PIM/Conversion/ONNXToSpatial`).
Lowers ONNX ops into the `spat` dialect (`src/PIM/Dialect/Spatial`).
Spatial models a high-level spatial in-memory accelerator: vmm/mvm
operations are accelerated by storing a constant RHS matrix into a
crossbar. Crossbars cannot be re-programmed during execution, have a
limited fixed size, and there is a limited number of them per core.
Conversion patterns are split by op family under
`Conversion/ONNXToSpatial/Patterns/{Math,NN,Tensor}` (Conv, Gemm, MatMul,
Elementwise, ReduceMean, Pool, Relu, Sigmoid, Softmax, Concat, Gather,
Reshape, Resize, Split).
2. **Spatial → Pim** (`src/PIM/Conversion/SpatialToPim`).
Lowers Spatial to the `pim` dialect (`src/PIM/Dialect/Pim`), which
materializes PIM cores (`pim.core`), inter-core communication
(`pim.send` / `pim.receive`), halts, and crossbar-level operations.
3. **Merge compute nodes** (`src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes`).
A DCP-inspired heuristic (Dynamic Critical Path — see the original
scheduling paper by Kwok & Ahmad,
[DCP-eScience2007](https://clouds.cis.unimelb.edu.au/papers/DCP-eScience2007.pdf))
that coarsens the virtual node graph and decides how to group compute
nodes onto cores. Our implementation is only DCP-*inspired*: it is a
heuristic with different assumptions from the paper (different cost
model, constraints from crossbar capacity / core resources, and a
windowed coarsening loop instead of full-graph reprioritization). The
`dcp-critical-window-size` option controls how many lowest-slack virtual
nodes each coarsening iteration considers (0 = legacy full-graph
analysis). Related sources: `DCPGraph/DCPAnalysis.cpp`, `Graph.cpp/.hpp`,
`MergeComputeNodesPass.cpp`.
4. **Bufferization** (`src/PIM/Dialect/Pim/Transforms/Bufferization`).
Converts tensor-semantics PIM IR into memref-semantics PIM IR using the
standard MLIR `BufferizableOpInterface` machinery
(`OpBufferizationInterfaces.*`, `PimBufferization.td`).
5. **Static memory coalescing** (`src/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing`).
Conservatively reuses same-typed local memref allocations inside PIM cores
after bufferization and before code generation.
6. **PIM code generation** (`src/PIM/Pass/PimCodegen`):
- `HostConstantFolding` — folds host-side constants.
- `MaterializeHostConstantsPass` — materializes the remaining host
constants for emission.
- `VerificationPass` — checks invariants before emission.
- `EmitPimJsonPass` — emits the final PIM JSON consumed by `pimsim-nn`
and `pim-simulator`.
Supporting pieces:
- `src/PIM/Compiler` — PIM-specific compiler options (crossbar size/count,
core count, DCP window, experimental conv impl, concat error handling, …)
and `PimCodeGen` entry points.
- `src/PIM/Common` — shared utilities (`PimCommon`, `LabeledList`).
- `src/PIM/Pass` — auxiliary passes (`MessagePass`, `CountInstructionPass`)
and the `PIMPasses.h` registry used by `PimAccelerator`.
- `src/PIM/PimAccelerator.{cpp,hpp}` — accelerator entry point: registers
dialects, passes, and plugs Raptor into the ONNX-MLIR driver.
## Key compiler options
Pass these on the `onnx-mlir` command line when compiling for PIM:
- `--maccel=PIM` — select the PIM accelerator.
- `--EmitSpatial` / `--EmitPim` / `--EmitPimBufferized` / `--EmitPimCodegen`
— stop the pipeline at the requested stage (default: `EmitPimCodegen`).
- `--pim-only-codegen` — assume the input is already bufferized PIM IR and
run only the codegen tail.
- `--crossbar-size=<N>` / `--crossbar-count=<N>` — crossbar dimensions and
per-core count.
- `--core-count=<N>` — number of cores (`-1` picks the minimum).
- `--dcp-critical-window-size=<N>` — DCP coarsening window (0 = legacy).
- `--use-experimental-conv-impl` — alternative convolution lowering.
- `--ignore-concat-error` — soft-fail corner case in `ConcatOp`.
## Validation
Functional validation lives in `validation/` and drives the Rust
`pim-simulator` to compare Raptor's output against a reference.
Per-operation validation (from `validation/`):
```
validate.py \
--raptor-path ../cmake-build-release/Release/bin/onnx-mlir \
--onnx-include-dir ../onnx-mlir/include
```
End-to-end network validation (example: first 4 layers of YOLOv11n):
```
validate.py \
--raptor-path ../cmake-build-release/Release/bin/onnx-mlir \
--onnx-include-dir ../onnx-mlir/include \
--operations-dir ./networks/yolo11n/depth_04 \
--crossbar-size 2048 --crossbar-count 256 --core-count 1000
```
Available networks under `validation/networks/`: `vgg16`, `yolo11n`.
Available operations under `validation/operations/`: `add`, `conv`, `div`,
`gather`, `gemm`, `gemv`, `mul`, `pool`, `reduce_mean`, `relu`, `resize`,
`sigmoid`, `softmax`, `split`.
## Rebuilding
Release build (fast):
```
cmake --build /home/nico/raptor/raptor/cmake-build-release --target onnx-mlir -j 30
```
A slower debug build is also available — configure it the same way but with
`-DCMAKE_BUILD_TYPE=Debug` (see installation instructions below).
## Build
### Protobuf
File diff suppressed because it is too large Load Diff
@@ -1,4 +1,3 @@
[package]
name = "pim-simulator"
version = "0.1.0"
@@ -13,8 +12,9 @@ name = "pimcore"
path = "src/lib/pimcore.rs"
[features]
default = ["tracing"]
default = []
tracing = []
profile_time = ["dep:plotly", "dep:comfy-table", "dep:statrs"]
@@ -27,3 +27,10 @@ hex = "0"
paste = "1"
serde = { version = "1", features = ["derive"] }
serde_json = "1"
statrs = {version="0.16", optional=true}
comfy-table = {version="7.1", optional=true}
plotly = {version="0.8", optional=true}
rayon = "1.12.0"
faer = "0.24.0"
faer-traits = "0.24.0"
mimalloc = "0.1.50"
@@ -1,11 +1,20 @@
use anyhow::{bail, Context, Result};
use mimalloc::MiMalloc;
#[global_allocator]
static GLOBAL: MiMalloc = MiMalloc;
use anyhow::{Context, Result, bail};
use clap::Parser;
use glob::glob;
use pimcore::binary_to_instruction::binary_to_executor;
use pimcore::cpu::crossbar::Crossbar;
use pimcore::json_to_instruction::json_to_executor;
use pimcore::memory_manager::CoreMemory;
use pimcore::tracing::TRACER;
use serde_json::Value;
use std::fs;
use std::io::Write;
use std::collections::HashMap;
use std::fs::{self, File, read_link};
use std::io::{BufReader, Write};
use std::path::PathBuf;
/// Program to simulate core execution configuration
@@ -41,10 +50,18 @@ fn main() -> Result<()> {
let args = Args::parse();
let config_json = retrive_config(&args)?;
let core_jsons = retrive_cores(&args)?;
let mut core_inputs = retrive_cores(&args)?;
let memory = retrive_memory(&args)?;
let mut executor = json_to_executor::json_to_executor(config_json, core_jsons.iter());
populate_crossbar(&args, &mut executor);
let global_crossbars = get_crossbars(&config_json, &args).unwrap();
let crossbars = map_crossbars_to_cores(&config_json, &args, &global_crossbars);
let mut executor = match &mut core_inputs {
CoreInputs::Json(core_jsons) => {
json_to_executor::json_to_executor(config_json, core_jsons, crossbars)
}
CoreInputs::Binary(core_bins) => {
binary_to_executor(config_json, core_bins.iter(), crossbars)?
}
};
set_memory(&mut executor, memory);
TRACER
.lock()
@@ -55,21 +72,30 @@ fn main() -> Result<()> {
Ok(())
}
fn populate_crossbar(args: &Args, executor: &mut pimcore::Executable) {
let num_cores = executor.cpu_mut().num_core();
fn map_crossbars_to_cores<'c>(
config: &Value,
args: &Args,
global_crossbars: &'c HashMap<String, Crossbar>,
) -> Vec<Vec<&'c Crossbar>> {
let mut res = vec![Vec::new()];
let num_cores = config.get("core_cnt").unwrap().as_i64().unwrap() as i32;
if let Some(folder) = args.folder.as_ref() {
for core_idx in 0..num_cores {
let core_folder = folder.join(format!("core_{}", core_idx));
res.push(Vec::new());
if !core_folder.is_dir() {
continue;
}
let mut bin_files: Vec<(u32, std::path::PathBuf)> = std::fs::read_dir(&core_folder)
let mut sym_link_files: Vec<(u32, std::path::PathBuf)> =
std::fs::read_dir(&core_folder)
.expect("Failed to read core directory")
.filter_map(|entry| {
let path = entry.ok()?.path();
let entry = entry.ok()?;
assert!(entry.metadata().unwrap().is_symlink());
let path = entry.path();
let file_name = path.file_name()?.to_str()?;
if file_name.starts_with("crossbar_") && file_name.ends_with(".bin") {
@@ -81,20 +107,64 @@ fn populate_crossbar(args: &Args, executor: &mut pimcore::Executable) {
}
})
.collect();
bin_files.sort_by_key(|&(num, _)| num);
let core = executor.cpu_mut().core(core_idx);
let (_memory, crossbars) = core.get_memory_crossbar();
sym_link_files.sort_by_key(|&(num, _)| num);
for (i, path) in bin_files {
let bytes = std::fs::read(path).expect("Failed to read binary file");
crossbars
.get_mut(i as usize)
for (_, symlink) in sym_link_files {
let real_path = read_link(symlink).unwrap();
let path_as_str = real_path.to_str().unwrap();
assert!(
global_crossbars.contains_key(path_as_str),
"symlink point to {:?}\n a not stored crossbar",
real_path
);
res.iter_mut()
.next_back()
.unwrap()
.execute_store(&bytes)
.unwrap();
.push(global_crossbars.get(path_as_str).unwrap());
}
}
}
res
}
fn get_crossbars(config: &Value, args: &Args) -> anyhow::Result<HashMap<String, Crossbar>> {
let xbar_size = config.get("xbar_size").unwrap().as_array().unwrap();
let rows_crossbar = xbar_size[0].as_i64().unwrap() as usize;
let column_corssbar = xbar_size[1].as_i64().unwrap() as usize;
let mut res = HashMap::new();
if let Some(folder) = args.folder.as_ref() {
let weight_folder = folder.join("weights");
if !weight_folder.is_dir() {
bail!("Not a directory");
}
for weight_file in
std::fs::read_dir(&weight_folder).context("Weight folder not iterable")?
{
let weight_file = weight_file.context("File not iterable")?;
if weight_file
.metadata()
.context("Doesn't contain metadata")?
.is_dir()
{
continue;
}
let bytes = std::fs::read(weight_file.path()).expect("Failed to read binary file");
let mut crossbar = Crossbar::new(column_corssbar * 4, rows_crossbar, CoreMemory::new());
crossbar.execute_store(&bytes).unwrap();
res.insert(
weight_file
.path()
.to_str()
.context("file name not utf-8")?
.to_string(),
crossbar,
);
}
}
Ok(res)
}
fn dump_memory(mut executor: pimcore::Executable, args: &Args) -> Result<()> {
@@ -155,41 +225,82 @@ fn retrive_memory(args: &Args) -> Result<Vec<u8>> {
Ok(memory_vector)
}
fn retrive_cores(args: &Args) -> Result<Vec<Value>, anyhow::Error> {
let mut core_jsons: Vec<Value> = Vec::new();
if let Some(cores_override) = &args.cores {
for core in cores_override {
let content = fs::read_to_string(core)
.with_context(|| format!("Failed to read core file: {:?}", cores_override))?;
let json: Value =
serde_json::from_str(&content).context("Failed to parse core json override")?;
core_jsons.push(json);
}
} else if let Some(folder) = args.folder.as_ref() {
let pattern = folder.join("core*.json");
let pattern_str = pattern.to_str().context("Invalid path encoding")?;
let mut paths: Vec<_> = glob(pattern_str)?.map(|x| x.unwrap()).collect();
paths.sort_by_cached_key(|x| {
let mut x = x.file_stem().expect("Extracting the stem").to_str().expect("File not utf-8");
x = &x[5..];
x.parse::<i32>().unwrap()
});
enum CoreInputs {
Json(Vec<BufReader<File>>),
Binary(Vec<Vec<u8>>),
}
if paths.is_empty() {
bail!("No core*.json files found in {:?}", folder);
fn retrive_cores(args: &Args) -> Result<CoreInputs, anyhow::Error> {
if let Some(cores_override) = &args.cores {
let first_extension = cores_override
.first()
.and_then(|path| path.extension())
.and_then(|ext| ext.to_str())
.unwrap_or_default();
if first_extension == "pim" {
let mut core_bins = Vec::with_capacity(cores_override.len());
for core in cores_override {
core_bins.push(
fs::read(core)
.with_context(|| format!("Failed to read binary core file: {:?}", core))?,
);
}
for entry in paths {
let path = entry;
let content = fs::read_to_string(&path)
.with_context(|| format!("Failed to read core file: {:?}", path))?;
let json: Value = serde_json::from_str(&content)
.with_context(|| format!("Failed to parse JSON in {:?}", path))?;
core_jsons.push(json);
return Ok(CoreInputs::Binary(core_bins));
}
} else {
let mut core_jsons_reader: Vec<BufReader<File>> = Vec::with_capacity(cores_override.len());
for core in cores_override {
let file = File::open(core)?;
let reader = BufReader::new(file);
core_jsons_reader.push(reader);
}
return Ok(CoreInputs::Json(core_jsons_reader));
}
if let Some(folder) = args.folder.as_ref() {
let binary_pattern = folder.join("core*.pim");
let binary_pattern_str = binary_pattern.to_str().context("Invalid path encoding")?;
let mut binary_paths: Vec<_> = glob(binary_pattern_str)?.map(|x| x.unwrap()).collect();
binary_paths.sort_by_cached_key(core_sort_key);
if !binary_paths.is_empty() {
let mut core_bins = Vec::with_capacity(binary_paths.len());
for path in binary_paths {
core_bins.push(
fs::read(&path)
.with_context(|| format!("Failed to read core file: {:?}", path))?,
);
}
return Ok(CoreInputs::Binary(core_bins));
}
let json_pattern = folder.join("core*.json");
let json_pattern_str = json_pattern.to_str().context("Invalid path encoding")?;
let mut json_paths: Vec<_> = glob(json_pattern_str)?.map(|x| x.unwrap()).collect();
json_paths.sort_by_cached_key(core_sort_key);
if json_paths.is_empty() {
bail!("No core*.pim or core*.json files found in {:?}", folder);
}
let mut core_json_reader: Vec<BufReader<File>> = Vec::with_capacity(json_paths.len());
for path in json_paths {
let file = File::open(path)?;
let reader = BufReader::new(file);
core_json_reader.push(reader);
}
return Ok(CoreInputs::Json(core_json_reader));
}
bail!("Either --core or --folder must be provided to find core definitions.");
}
Ok(core_jsons)
}
fn core_sort_key(path: &PathBuf) -> i32 {
let mut stem = path
.file_stem()
.expect("Extracting the stem")
.to_str()
.expect("File not utf-8");
stem = &stem[5..];
stem.parse::<i32>().unwrap()
}
fn retrive_config(args: &Args) -> Result<Value, anyhow::Error> {
@@ -0,0 +1,497 @@
use crate::{
CoreInstructionsBuilder, Executable,
cpu::{CPU, crossbar::Crossbar},
instruction_set::{InstructionsBuilder, instruction_data::InstructionDataBuilder, isa::*},
};
use anyhow::{Context, Result, bail, ensure};
use serde_json::Value;
use std::collections::HashMap;
use std::convert::TryFrom;
use std::sync::LazyLock;
const MAGIC: &[u8; 4] = b"PIMB";
const VERSION: u32 = 1;
const HEADER_SIZE: usize = 12;
const RECORD_SIZE: usize = 20;
macro_rules! add_name {
($storage:ident, $opcode:literal, $name:literal) => {
$storage.insert($opcode, $name);
};
}
static INSTRUCTIONS: LazyLock<HashMap<usize, &'static str>> = LazyLock::new(|| {
let mut hash = HashMap::new();
add_name!(hash, 0, "nop");
add_name!(hash, 1, "sldi");
add_name!(hash, 2, "sld");
add_name!(hash, 3, "sadd");
add_name!(hash, 4, "ssub");
add_name!(hash, 5, "smul");
add_name!(hash, 6, "saddi");
add_name!(hash, 7, "smuli");
add_name!(hash, 8, "setbw");
add_name!(hash, 9, "mvmul");
add_name!(hash, 10, "vvadd");
add_name!(hash, 11, "vvsub");
add_name!(hash, 12, "vvmul");
add_name!(hash, 13, "vvdmul");
add_name!(hash, 14, "vvmax");
add_name!(hash, 15, "vvsll");
add_name!(hash, 16, "vvsra");
add_name!(hash, 17, "vavg");
add_name!(hash, 18, "vrelu");
add_name!(hash, 19, "vtanh");
add_name!(hash, 20, "vsigm");
add_name!(hash, 21, "vsoftmax");
add_name!(hash, 22, "vmv");
add_name!(hash, 23, "vrsu");
add_name!(hash, 24, "vrsl");
add_name!(hash, 25, "ld");
add_name!(hash, 26, "st");
add_name!(hash, 27, "lldi");
add_name!(hash, 28, "lmv");
add_name!(hash, 29, "send");
add_name!(hash, 30, "recv");
add_name!(hash, 31, "wait");
add_name!(hash, 32, "sync");
hash
});
#[derive(Clone, Copy, Debug, Default)]
struct InstructionRecord {
opcode: u8,
rd: u8,
r1: u8,
r2_or_imm: i32,
generic1: i32,
generic2: i32,
generic3: i32,
flags: u8,
}
fn read_u32_le(bytes: &[u8], offset: usize) -> u32 {
u32::from_le_bytes(bytes[offset..offset + 4].try_into().unwrap())
}
fn read_i32_le(bytes: &[u8], offset: usize) -> i32 {
i32::from_le_bytes(bytes[offset..offset + 4].try_into().unwrap())
}
fn parse_binary_records(bytes: &[u8]) -> Result<Vec<InstructionRecord>> {
ensure!(bytes.len() >= HEADER_SIZE, "binary core file too small");
ensure!(&bytes[0..4] == MAGIC, "invalid PIM binary magic");
let version = read_u32_le(bytes, 4);
ensure!(
version == VERSION,
"unsupported PIM binary version {version}"
);
let instruction_count = read_u32_le(bytes, 8) as usize;
let expected_len = HEADER_SIZE + instruction_count * RECORD_SIZE;
ensure!(
bytes.len() == expected_len,
"PIM binary size mismatch: expected {expected_len} bytes, got {}",
bytes.len()
);
let mut records = Vec::with_capacity(instruction_count);
for index in 0..instruction_count {
let base = HEADER_SIZE + index * RECORD_SIZE;
records.push(InstructionRecord {
opcode: bytes[base],
rd: bytes[base + 1],
r1: bytes[base + 2],
flags: bytes[base + 3],
r2_or_imm: read_i32_le(bytes, base + 4),
generic1: read_i32_le(bytes, base + 8),
generic2: read_i32_le(bytes, base + 12),
generic3: read_i32_le(bytes, base + 16),
});
}
Ok(records)
}
fn append_record(
inst_builder: &mut InstructionsBuilder,
inst_data_builder: &mut InstructionDataBuilder,
record: InstructionRecord,
) -> Result<()> {
let InstructionRecord {
opcode,
rd,
r1,
r2_or_imm,
generic1,
generic2,
generic3,
flags: _,
} = record;
match opcode {
0 => {}
1 => {
inst_data_builder.set_rd_u8(rd).set_imm(r2_or_imm);
inst_builder.make_inst(sldi, inst_data_builder.build());
}
2 => {
inst_data_builder
.set_rd_u8(rd)
.set_r1_u8(r1)
.set_offset_select(generic1)
.set_offset_value(generic2);
inst_builder.make_inst(sld, inst_data_builder.build());
}
3 => {
inst_data_builder.set_rdr1r2_u8(rd, r1, r2_or_imm);
inst_builder.make_inst(sadd, inst_data_builder.build());
}
4 => {
inst_data_builder.set_rdr1r2_u8(rd, r1, r2_or_imm);
inst_builder.make_inst(ssub, inst_data_builder.build());
}
5 => {
inst_data_builder.set_rdr1r2_u8(rd, r1, r2_or_imm);
inst_builder.make_inst(smul, inst_data_builder.build());
}
6 => {
inst_data_builder.set_rdr1imm_u8(rd, r1, r2_or_imm);
inst_builder.make_inst(saddi, inst_data_builder.build());
}
7 => {
inst_data_builder.set_rdr1imm_u8(rd, r1, r2_or_imm);
inst_builder.make_inst(smuli, inst_data_builder.build());
}
8 => {
inst_data_builder.set_ibiw_obiw(generic1, generic2);
inst_builder.make_inst(setbw, inst_data_builder.build());
}
9 => {
inst_data_builder
.set_rd_u8(rd)
.set_r1_u8(r1)
.set_mbiw_immrelu_immgroup(r2_or_imm, generic1, generic2);
inst_builder.make_inst(mvmul, inst_data_builder.build());
}
10 => {
inst_data_builder
.set_rdr1r2_u8(rd, r1, r2_or_imm)
.set_imm_len(generic3)
.set_offset_select_value(generic1, generic2);
inst_builder.make_inst(vvadd, inst_data_builder.build());
}
11 => {
inst_data_builder
.set_rdr1r2_u8(rd, r1, r2_or_imm)
.set_imm_len(generic3)
.set_offset_select_value(generic1, generic2);
inst_builder.make_inst(vvsub, inst_data_builder.build());
}
12 => {
inst_data_builder
.set_rdr1r2_u8(rd, r1, r2_or_imm)
.set_imm_len(generic3)
.set_offset_select_value(generic1, generic2);
inst_builder.make_inst(vvmul, inst_data_builder.build());
}
13 => {
inst_data_builder
.set_rdr1r2_u8(rd, r1, r2_or_imm)
.set_imm_len(generic3)
.set_offset_select_value(generic1, generic2);
inst_builder.make_inst(vvdmul, inst_data_builder.build());
}
14 => {
inst_data_builder
.set_rdr1r2_u8(rd, r1, r2_or_imm)
.set_imm_len(generic3)
.set_offset_select_value(generic1, generic2);
inst_builder.make_inst(vvmax, inst_data_builder.build());
}
15 => {
inst_data_builder
.set_rdr1r2_u8(rd, r1, r2_or_imm)
.set_imm_len(generic3)
.set_offset_select_value(generic1, generic2);
inst_builder.make_inst(vvsll, inst_data_builder.build());
}
16 => {
inst_data_builder
.set_rdr1r2_u8(rd, r1, r2_or_imm)
.set_imm_len(generic3)
.set_offset_select_value(generic1, generic2);
inst_builder.make_inst(vvsra, inst_data_builder.build());
}
17 => {
inst_data_builder
.set_rdr1r2_u8(rd, r1, r2_or_imm)
.set_imm_len(generic3)
.set_offset_select_value(generic1, generic2);
inst_builder.make_inst(vavg, inst_data_builder.build());
}
18 => {
inst_data_builder
.set_rdr1_u8(rd, r1)
.set_imm_len(generic3)
.set_offset_select_value(generic1, generic2);
inst_builder.make_inst(vrelu, inst_data_builder.build());
}
19 => {
inst_data_builder
.set_rdr1_u8(rd, r1)
.set_imm_len(generic3)
.set_offset_select_value(generic1, generic2);
inst_builder.make_inst(vtanh, inst_data_builder.build());
}
20 => {
inst_data_builder
.set_rdr1_u8(rd, r1)
.set_imm_len(generic3)
.set_offset_select_value(generic1, generic2);
inst_builder.make_inst(vsigm, inst_data_builder.build());
}
21 => {
inst_data_builder
.set_rdr1_u8(rd, r1)
.set_imm_len(generic3)
.set_offset_select_value(generic1, generic2);
inst_builder.make_inst(vsoftmax, inst_data_builder.build());
}
22 => {
inst_data_builder
.set_rdr1r2_u8(rd, r1, r2_or_imm)
.set_imm_len(generic3)
.set_offset_select_value(generic1, generic2);
inst_builder.make_inst(vmv, inst_data_builder.build());
}
23 => {
inst_data_builder
.set_rdr1r2_u8(rd, r1, r2_or_imm)
.set_imm_len(generic3)
.set_offset_select_value(generic1, generic2);
inst_builder.make_inst(vrsu, inst_data_builder.build());
}
24 => {
inst_data_builder
.set_rdr1r2_u8(rd, r1, r2_or_imm)
.set_imm_len(generic3)
.set_offset_select_value(generic1, generic2);
inst_builder.make_inst(vrsl, inst_data_builder.build());
}
25 => {
inst_data_builder
.set_rdr1_u8(rd, r1)
.set_imm_len(generic3)
.set_offset_select_value(generic1, generic2);
inst_builder.make_inst(ld, inst_data_builder.build());
}
26 => {
inst_data_builder
.set_rdr1_u8(rd, r1)
.set_imm_len(generic3)
.set_offset_select_value(generic1, generic2);
inst_builder.make_inst(st, inst_data_builder.build());
}
27 => {
inst_data_builder
.set_rd_u8(rd)
.set_imm(r2_or_imm)
.set_imm_len(generic3)
.set_offset_select_value(generic1, generic2);
inst_builder.make_inst(lldi, inst_data_builder.build());
}
28 => {
inst_data_builder
.set_rdr1_u8(rd, r1)
.set_imm_len(generic3)
.set_offset_select_value(generic1, generic2);
inst_builder.make_inst(lmv, inst_data_builder.build());
}
29 => {
inst_data_builder
.set_rd_u8(rd)
.set_imm_core(r2_or_imm + 1)
.set_imm_len(generic3)
.set_offset_select_value(generic1, generic2);
inst_builder.make_inst(send, inst_data_builder.build());
}
30 => {
inst_data_builder
.set_rd_u8(rd)
.set_imm_core(r2_or_imm + 1)
.set_imm_len(generic3)
.set_offset_select_value(generic1, generic2);
inst_builder.make_inst(recv, inst_data_builder.build());
}
31 => {
inst_builder.make_inst(wait, inst_data_builder.build());
}
32 => {
inst_builder.make_inst(sync, inst_data_builder.build());
}
_ => bail!("unsupported PIM binary opcode {opcode}"),
}
Ok(())
}
fn binary_to_instructions(
core_bytes: &[u8],
core_index: i32,
) -> Result<Vec<crate::instruction_set::Instruction>> {
let records = parse_binary_records(core_bytes)?;
let mut insts_builder = InstructionsBuilder::new();
let mut inst_data_builder = InstructionDataBuilder::new();
inst_data_builder
.set_core_indx_u16(u16::try_from(core_index).expect("core index does not fit in u16"))
.fix_core_indx();
for record in records {
let opcode = record.opcode;
let name = INSTRUCTIONS
.get(&(opcode as usize))
.copied()
.unwrap_or("<unknown>");
append_record(&mut insts_builder, &mut inst_data_builder, record).with_context(|| {
format!(
"while decoding binary instruction for core {core_index}: opcode {opcode} ({name})"
)
})?;
}
Ok(insts_builder.build())
}
pub fn binary_to_executor<'a, 'b>(
config: Value,
cores: impl Iterator<Item = &'b Vec<u8>>,
crossbars: Vec<Vec<&'a Crossbar>>,
) -> Result<Executable<'a>> {
let core_cnt = config
.get("core_cnt")
.context("missing core_cnt in config")?
.as_i64()
.context("core_cnt is not an integer")? as i32;
let cpu = CPU::new(core_cnt, crossbars);
let mut core_insts_builder = CoreInstructionsBuilder::new(core_cnt as usize);
for (external_core_indx, core_bytes) in cores.enumerate() {
let core_indx = external_core_indx as i32 + 1;
let instructions = binary_to_instructions(core_bytes, core_indx)?;
core_insts_builder.set_core(core_indx, instructions);
}
Ok(Executable::new(cpu, core_insts_builder.build()))
}
#[cfg(test)]
mod tests {
use super::{
HEADER_SIZE, InstructionRecord, MAGIC, RECORD_SIZE, VERSION, binary_to_instructions,
};
use crate::{
functor_to_name,
instruction_set::{InstructionsBuilder, instruction_data::InstructionDataBuilder},
json_to_instruction::json_isa::json_to_instruction,
};
fn encode_record(record: InstructionRecord, dst: &mut Vec<u8>) {
dst.push(record.opcode);
dst.push(record.rd);
dst.push(record.r1);
dst.push(record.flags);
dst.extend_from_slice(&record.r2_or_imm.to_le_bytes());
dst.extend_from_slice(&record.generic1.to_le_bytes());
dst.extend_from_slice(&record.generic2.to_le_bytes());
dst.extend_from_slice(&record.generic3.to_le_bytes());
}
fn binary_blob(records: &[InstructionRecord]) -> Vec<u8> {
let mut blob = Vec::with_capacity(HEADER_SIZE + records.len() * RECORD_SIZE);
blob.extend_from_slice(MAGIC);
blob.extend_from_slice(&VERSION.to_le_bytes());
blob.extend_from_slice(&(records.len() as u32).to_le_bytes());
for &record in records {
encode_record(record, &mut blob);
}
blob
}
#[test]
fn json_and_binary_decoders_match_for_representative_ops() {
let json_program = [
r#"{"imm":64,"op":"sldi","rd":0}"#,
r#"{"imm":128,"op":"sldi","rd":1}"#,
r#"{"len":16,"offset":{"offset_select":0,"offset_value":0},"op":"lmv","rd":0,"rs1":1}"#,
r#"{"group":3,"mbiw":8,"op":"mvmul","rd":0,"relu":0,"rs1":1}"#,
r#"{"len":16,"offset":{"offset_select":0,"offset_value":0},"op":"vvadd","rd":0,"rs1":1,"rs2":2}"#,
r#"{"core":2,"offset":{"offset_select":0,"offset_value":0},"op":"send","rd":0,"size":16}"#,
];
let binary_program = binary_blob(&[
InstructionRecord {
opcode: 1,
rd: 0,
r2_or_imm: 64,
..Default::default()
},
InstructionRecord {
opcode: 1,
rd: 1,
r2_or_imm: 128,
..Default::default()
},
InstructionRecord {
opcode: 28,
rd: 0,
r1: 1,
generic3: 16,
..Default::default()
},
InstructionRecord {
opcode: 9,
rd: 0,
r1: 1,
r2_or_imm: 8,
generic2: 3,
..Default::default()
},
InstructionRecord {
opcode: 10,
rd: 0,
r1: 1,
r2_or_imm: 2,
generic3: 16,
..Default::default()
},
InstructionRecord {
opcode: 29,
rd: 0,
r2_or_imm: 2,
generic3: 16,
..Default::default()
},
]);
let mut json_builder = InstructionsBuilder::new();
let mut json_data_builder = InstructionDataBuilder::new();
json_data_builder.set_core_indx(1).fix_core_indx();
for inst in json_program {
let value = serde_json::from_str(inst).unwrap();
json_to_instruction(&mut json_builder, &mut json_data_builder, &value);
}
let json_instructions = json_builder.build();
let binary_instructions = binary_to_instructions(&binary_program, 1).unwrap();
assert_eq!(json_instructions.len(), binary_instructions.len());
for (json_inst, binary_inst) in json_instructions.iter().zip(binary_instructions.iter()) {
assert_eq!(
functor_to_name(json_inst.functor as usize),
functor_to_name(binary_inst.functor as usize)
);
assert_eq!(json_inst.data, binary_inst.data);
}
}
}
@@ -38,14 +38,14 @@ impl Crossbar {
self.memory.execute_store(0, element)
}
pub fn load<T>(&mut self, size: usize) -> Result<Vec<&[T]>> where
pub fn load<T>(&self, size: usize) -> Result<Vec<&[T]>> where
T: MemoryStorable, {
if self.memory.get_len() < size
//|| self.stored_bytes < size
{
bail!("Loading outside crossbar boundary [{} {}] < {}", self.stored_bytes, self.memory.get_len() , size);
}
self.memory.load(0, size)
self.memory.load_const(0, size)
}
@@ -1,5 +1,5 @@
use std::{collections::HashMap, fmt::Debug};
use anyhow::{Context, Result};
use anyhow::{Context, Result, ensure};
use crate::{
cpu::crossbar::Crossbar,
@@ -10,53 +10,44 @@ use crate::{
pub mod crossbar;
#[derive(Debug, Clone)]
pub struct CPU {
cores: Box<[Core]>,
pub struct CPU<'a> {
cores: Box<[Core<'a>]>,
}
impl CPU {
pub fn new(num_cores: impl TryToUsize) -> Self {
impl<'a> CPU<'a> {
pub fn new(num_cores: impl TryToUsize, crossbars: Vec<Vec<&'a Crossbar>> ) -> Self {
let num_cores = num_cores.try_into().expect("num_cores can not be negative");
let mut cores: Vec<Core> = std::iter::repeat_with(Core::new)
.take(num_cores + 1)
.collect();
assert!(crossbars.len() == num_cores + 1);
let mut cores = Vec::new();
for crossbar in crossbars {
cores.push(Core::new(crossbar));
}
Self {
cores: cores.into(),
}
}
pub fn reserve_crossbar(
&mut self,
num_crossbar: impl TryToUsize,
byte_width: impl TryToUsize,
height: impl TryToUsize,
) {
let num_crossbar = num_crossbar
.try_into()
.expect("num_crossbar can not be negative");
let byte_width = byte_width
.try_into()
.expect("byte_width can not be negative");
let height = height.try_into().expect("height can not be negative");
for core in &mut self.cores {
core.reserve_crossbar(num_crossbar, byte_width, height);
}
pub fn host<'b>(&'b mut self) -> &'b mut Core<'a>
where 'a : 'b
{
& mut self.cores[0]
}
pub fn host(&mut self) -> &mut Core {
&mut self.cores[0]
}
pub fn core(&mut self, index: impl TryToUsize) -> &mut Core {
pub fn core<'b >(&'b mut self, index: impl TryToUsize) -> &'b mut Core<'a>
where 'a : 'b
{
let index = index.try_into().expect("can not be negative");
&mut self.cores[index]
& mut self.cores[index]
}
pub fn num_core(&self) -> usize {
self.cores.len()
}
pub(crate) fn host_and_cores(&mut self, core: impl TryToUsize) -> (&mut Core, &mut Core) {
pub(crate) fn host_and_cores<'b, 'c >(&'b mut self, core: impl TryToUsize) -> (&'c mut Core<'a>, &'c mut Core<'a>)
where 'a: 'b,
'b: 'c
{
let core = core.try_into().expect("core can not be negative");
assert_ne!(
core, 0,
@@ -70,45 +61,29 @@ impl CPU {
(host, core)
}
pub fn get_multiple_cores<const N: usize>(&mut self, indices: [usize; N]) -> [&mut Core; N] {
pub fn get_multiple_cores<'b, const N: usize>(&'b mut self, indices: [usize; N]) -> [&'b mut Core<'a>; N]
where 'a : 'b
{
self.cores.get_disjoint_mut(indices).unwrap()
}
}
#[derive(Debug, Clone)]
pub struct Core {
crossbars: Vec<Crossbar>,
pub struct Core<'a> {
crossbars: Vec<&'a Crossbar>,
memory: CoreMemory,
registers: [i32; 32],
}
impl Core {
fn new() -> Self {
impl<'a> Core<'a> {
fn new(crossbars : Vec<&'a Crossbar>) -> Self {
Self {
crossbars: Vec::new(),
crossbars,
memory: CoreMemory::new(),
registers: [0; 32],
}
}
pub fn reserve_crossbar(
&mut self,
num_crossbar: impl TryToUsize,
width: impl TryToUsize,
height: impl TryToUsize,
) {
let num_crossbar = num_crossbar
.try_into()
.expect("num_crossbar can not be negative");
let width = width.try_into().expect("width can not be negative");
let height = height.try_into().expect("height can not be negative");
for _ in 0..num_crossbar {
let mut crossbar = CoreMemory::new();
crossbar.set_capacity(width * height);
self.crossbars.push(Crossbar::new(width, height, crossbar));
}
}
pub fn execute_load<T>(&mut self) -> Result<Vec<&[T]>>
where
T: MemoryStorable,
@@ -157,7 +132,7 @@ impl Core {
self.memory.load(address, size)
}
pub fn get_memory_crossbar(&mut self) -> (&mut CoreMemory, &mut Vec<Crossbar>) {
pub fn get_memory_crossbar(&mut self) -> (&mut CoreMemory, &mut Vec<&'a Crossbar>) {
let Self {
crossbars,
memory,
@@ -1,10 +1,11 @@
use paste::paste;
use std::convert::TryFrom;
#[derive(Clone, Copy, Debug, Default)]
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub struct InstructionData {
core_indx: i32,
rd: i32,
r1: i32,
core_indx: u16,
rd: u8,
r1: u8,
//r2 imm mbiw imm_core
r2_or_imm: i32,
//offset_select imm_relu ibiw
@@ -16,18 +17,30 @@ pub struct InstructionData {
}
impl InstructionData {
pub fn core_indx(&self) -> i32 {
pub fn core_indx_u16(&self) -> u16 {
self.core_indx
}
pub fn rd(&self) -> i32 {
pub fn core_indx(&self) -> i32 {
i32::from(self.core_indx)
}
pub fn rd_u8(&self) -> u8 {
self.rd
}
pub fn r1(&self) -> i32 {
pub fn rd(&self) -> i32 {
i32::from(self.rd)
}
pub fn r1_u8(&self) -> u8 {
self.r1
}
pub fn r1(&self) -> i32 {
i32::from(self.r1)
}
pub fn r2(&self) -> i32 {
self.r2_or_imm
}
@@ -49,26 +62,26 @@ impl InstructionData {
}
pub fn get_core_rd_r1(&self) -> (i32, i32, i32) {
(self.core_indx, self.rd, self.r1)
(self.core_indx(), self.rd(), self.r1())
}
pub fn get_core_rd_r1_r2(&self) -> (i32, i32, i32, i32) {
(self.core_indx, self.rd, self.r1, self.r2_or_imm)
(self.core_indx(), self.rd(), self.r1(), self.r2_or_imm)
}
pub fn get_core_rd_imm(&self) -> (i32, i32, i32) {
(self.core_indx, self.rd, self.r2_or_imm)
(self.core_indx(), self.rd(), self.r2_or_imm)
}
pub fn get_core_rd_r1_imm(&self) -> (i32, i32, i32, i32) {
(self.core_indx, self.rd, self.r1, self.r2_or_imm)
(self.core_indx(), self.rd(), self.r1(), self.r2_or_imm)
}
pub fn get_core_rd_r1_r2_immlen_offset(&self) -> (i32, i32, i32, i32, i32, i32, i32) {
(
self.core_indx,
self.rd,
self.r1,
self.core_indx(),
self.rd(),
self.r1(),
self.r2_or_imm,
self.generic3,
self.generic1,
@@ -78,9 +91,9 @@ impl InstructionData {
pub fn get_core_rd_r1_mbiw_immrelu_immgroup(&self) -> (i32, i32, i32, i32, i32, i32) {
(
self.core_indx,
self.rd,
self.r1,
self.core_indx(),
self.rd(),
self.r1(),
self.r2_or_imm,
self.generic1,
self.generic2,
@@ -100,7 +113,7 @@ impl InstructionData {
}
pub(crate) fn get_core_immcore(&self) -> (i32, i32) {
(self.core_indx, self.r2_or_imm)
(self.core_indx(), self.r2_or_imm)
}
}
@@ -216,6 +229,18 @@ impl InstructionDataBuilder {
common_getter_setter![imm_group];
common_getter_setter![imm_core];
pub fn set_core_indx_u16(&mut self, val: u16) -> &mut Self {
self.set_core_indx(i32::from(val))
}
pub fn set_rd_u8(&mut self, val: u8) -> &mut Self {
self.set_rd(i32::from(val))
}
pub fn set_r1_u8(&mut self, val: u8) -> &mut Self {
self.set_r1(i32::from(val))
}
pub fn new() -> Self {
Self {
core_indx: Fixer::Edit(0),
@@ -254,20 +279,16 @@ impl InstructionDataBuilder {
fn check_sanity(&self) {
assert!(!(self.get_r2() != 0 && self.get_imm() != 0 && self.get_mbiw() != 0 && self.get_imm_core() != 0));
assert!(
!(self.get_ibiw() != 0 && self.get_offset_select() != 0 && self.get_imm_relu() != 0)
);
assert!(
!(self.get_obiw() != 0 && self.get_offset_value() != 0 && self.get_imm_group() != 0)
);
assert!(!(self.get_ibiw() != 0 && self.get_offset_select() != 0 && self.get_imm_relu() != 0));
assert!(!(self.get_obiw() != 0 && self.get_offset_value() != 0 && self.get_imm_group() != 0));
}
pub fn build(&mut self) -> InstructionData {
self.check_sanity();
let inst_data = InstructionData {
core_indx: self.get_core_indx(),
rd: self.get_rd(),
r1: self.get_r1(),
core_indx: u16::try_from(self.get_core_indx()).expect("core index does not fit in u16"),
rd: u8::try_from(self.get_rd()).expect("rd does not fit in u8"),
r1: u8::try_from(self.get_r1()).expect("r1 does not fit in u8"),
r2_or_imm: self.get_r2() + self.get_imm() + self.get_mbiw() + self.get_imm_core(),
generic1: self.get_offset_select() + self.get_ibiw() + self.get_imm_relu(),
generic2: self.get_offset_value() + self.get_obiw() + self.get_imm_group(),
@@ -281,6 +302,10 @@ impl InstructionDataBuilder {
self.set_rd(rd).set_r1(r1).set_r2(r2)
}
pub fn set_rdr1r2_u8(&mut self, rd: u8, r1: u8, r2: i32) -> &mut Self {
self.set_rd_u8(rd).set_r1_u8(r1).set_r2(r2)
}
pub fn set_offset_select_value(&mut self, offset_select: i32, offset_value: i32) -> &mut Self {
self.set_offset_select(offset_select)
.set_offset_value(offset_value)
@@ -290,14 +315,26 @@ impl InstructionDataBuilder {
self.set_rd(rd).set_r1(r1).set_imm(imm)
}
pub fn set_rdr1imm_u8(&mut self, rd: u8, r1: u8, imm: i32) -> &mut Self {
self.set_rd_u8(rd).set_r1_u8(r1).set_imm(imm)
}
pub fn set_rdr1(&mut self, rd: i32, r1: i32) -> &mut Self {
self.set_rd(rd).set_r1(r1)
}
pub fn set_rdr1_u8(&mut self, rd: u8, r1: u8) -> &mut Self {
self.set_rd_u8(rd).set_r1_u8(r1)
}
pub fn set_rdimm(&mut self, rd: i32, imm: i32) -> &mut Self {
self.set_rd(rd).set_imm(imm)
}
pub fn set_rdimm_u8(&mut self, rd: u8, imm: i32) -> &mut Self {
self.set_rd_u8(rd).set_imm(imm)
}
pub fn set_ibiw_obiw(&mut self, ibiw: i32, obiw: i32) -> &mut Self {
self.set_ibiw(ibiw).set_obiw(obiw)
}
@@ -1,17 +1,22 @@
use crate::{
cpu::{CPU, crossbar}, instruction_set::{
cpu::{CPU, crossbar},
instruction_set::{
Instruction, InstructionData, InstructionStatus, InstructionType, VectorBitWith,
helper::add_all,
}, memory_manager::{
},
memory_manager::{
MemoryStorable,
type_traits::{FromFloat, UpcastDestTraits, UpcastSlice},
}, tracing::TRACER, utility::{add_offset_r1, add_offset_r2, add_offset_rd}
},
tracing::TRACER,
utility::{add_offset_r1, add_offset_r2, add_offset_rd},
};
use aligned_vec::{AVec, ConstAlign};
use anyhow::{Context, Result, ensure};
use rayon::prelude::*;
use paste::paste;
use std::{borrow::Cow, cell::OnceCell, collections::HashMap};
use std::{borrow::Cow, cell::OnceCell, collections::HashMap };
use std::{collections::HashSet, sync::LazyLock};
macro_rules! add_name {
@@ -30,7 +35,7 @@ macro_rules! add_name_simd {
};
}
static NAMES: LazyLock<HashMap<usize, &'static str>> = LazyLock::new(|| {
pub static NAMES: LazyLock<HashMap<usize, &'static str>> = LazyLock::new(|| {
let mut hash = HashMap::new();
add_name!(hash, sldi);
add_name!(hash, sld);
@@ -76,6 +81,7 @@ pub fn functor_to_name(functor: usize) -> &'static str {
///////////////////////////////////////////////////////////////
/////////////////Scalar/register Instructions//////////////////
///////////////////////////////////////////////////////////////
#[inline(never)]
pub fn sldi(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
TRACER.lock().unwrap().pre_sldi(cores, data);
let (core_indx, rd, imm) = data.get_core_rd_imm();
@@ -85,6 +91,7 @@ pub fn sldi(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
Ok(InstructionStatus::Completed)
}
#[inline(never)]
pub fn sld(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
TRACER.lock().unwrap().pre_sld(cores, data);
let (core_indx, rd, r1) = data.get_core_rd_r1();
@@ -99,6 +106,7 @@ pub fn sld(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
Ok(InstructionStatus::Completed)
}
#[inline(never)]
pub fn sadd(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
TRACER.lock().unwrap().pre_sadd(cores, data);
let (core_indx, rd, r1, r2) = data.get_core_rd_r1_r2();
@@ -109,6 +117,7 @@ pub fn sadd(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
Ok(InstructionStatus::Completed)
}
#[inline(never)]
pub fn ssub(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
TRACER.lock().unwrap().pre_ssub(cores, data);
let (core_indx, rd, r1, r2) = data.get_core_rd_r1_r2();
@@ -119,6 +128,7 @@ pub fn ssub(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
Ok(InstructionStatus::Completed)
}
#[inline(never)]
pub fn smul(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
TRACER.lock().unwrap().pre_smul(cores, data);
let (core_indx, rd, r1, r2) = data.get_core_rd_r1_r2();
@@ -129,6 +139,7 @@ pub fn smul(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
Ok(InstructionStatus::Completed)
}
#[inline(never)]
pub fn saddi(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
TRACER.lock().unwrap().pre_saddi(cores, data);
let (core_indx, rd, r1, imm) = data.get_core_rd_r1_imm();
@@ -138,6 +149,7 @@ pub fn saddi(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus
Ok(InstructionStatus::Completed)
}
#[inline(never)]
pub fn smuli(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
TRACER.lock().unwrap().pre_smuli(cores, data);
let (core_indx, rd, r1, imm) = data.get_core_rd_r1_imm();
@@ -212,14 +224,17 @@ pub fn is_setbw(functor: InstructionType) -> bool {
functor as usize == setbw as *const () as usize
}
#[inline(never)]
pub fn setbw(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
panic!("You are calling a placeholder, this instruction is resolved in the construction phase");
}
#[inline(never)]
pub fn mvmul(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
panic!("You are calling a placeholder, the real call is the generic version");
}
#[inline(never)]
pub(super) fn mvm_impl_internal<F, M, T>(
cores: &mut CPU,
data: InstructionData,
@@ -228,25 +243,30 @@ where
[F]: UpcastSlice<T> + UpcastSlice<M>,
[M]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
M: UpcastDestTraits<M> + MemoryStorable + FromFloat,
// Add faer::ComplexField HERE, directly bounding M for this function only
M: UpcastDestTraits<M> + MemoryStorable + FromFloat + faer_traits::ComplexField,
F: UpcastDestTraits<F> + MemoryStorable,
{
TRACER.lock().unwrap().pre_mvm::<F,M,T>(cores, data);
TRACER.lock().unwrap().pre_mvm::<F, M, T>(cores, data);
let (core_indx, rd, r1, mbiw, relu, group) = data.get_core_rd_r1_mbiw_immrelu_immgroup();
let group: usize = group.try_into().context("group can not be negative")?;
let core = cores.core(core_indx);
let r1_val = core.register(r1);
let rd_val = core.register(rd);
let (memory, crossbars) = core.get_memory_crossbar();
let crossbar = crossbars.get_mut(group).unwrap();
let crossbar_stored_bytes = crossbar.stored_bytes();
let crossbar_byte_width = crossbar.width();
//Fix this
let crossbar_elem_width = crossbar_byte_width / size_of::<M>();
ensure!(
crossbar_byte_width & size_of::<M>() == 0,
crossbar_byte_width % size_of::<M>() == 0,
"M not divisor of the crosbbar size"
);
let crossbar_height = crossbar.height();
let crossbar_byte_size = crossbar_byte_width * crossbar_height;
@@ -256,19 +276,29 @@ where
let load = loads[0];
let vec: Cow<[M]> = load.up();
let matrix = crossbar.load::<M>(crossbar_byte_size)?[0];
let mut res = Vec::with_capacity(crossbar_elem_width);
let mut partial :AVec<M, _> = AVec::<M, ConstAlign<64>>::with_capacity(64, vec.len());
partial.resize(vec.len(), M::from_f32(0.0));
for x in 0..crossbar_elem_width {
partial[0] = vec[0] * matrix[x];
for y in 1..crossbar_height {
partial[y] = vec[y] * matrix[y * crossbar_elem_width + x];
}
// --- FAER IMPLEMENTATION ---
// 1. Explicitly create a Matrix Reference (MatRef)
let matrix_view = faer::mat::MatRef::from_row_major_slice(
matrix.as_ref(),
crossbar_height,
crossbar_elem_width,
);
// 2. Explicitly create a Column Vector Reference (ColRef)
// Using `ColRef` here guarantees we don't accidentally get a RowRef (Fixes E0277)
let vec_view = faer::col::ColRef::from_slice(vec.as_ref());
let res_col: faer::col::Col<M> = matrix_view.transpose() * vec_view;
// 4. Convert back to standard Rust Vec
// try_as_slice() returns an Option<&[M]>.
// We can safely unwrap() because a freshly allocated, owned Col is ALWAYS contiguous!
let mut res: Vec<M> = (0..crossbar_elem_width).map(|i| res_col[i]).collect();
// --- END FAER ---
let mut acc = add_all(partial.as_slice());
res.push(acc);
}
if relu != 0 {
res.iter_mut().for_each(|x| {
if *x < M::from_f32(0.0) {
@@ -276,16 +306,20 @@ where
}
});
}
ensure!(
res.len() == crossbar_elem_width,
"mvm generate a vector bigger thant it's requested elements"
);
let res_up: Cow<[T]> = res.as_slice().up();
core.execute_store(rd_val, res_up.as_ref());
TRACER.lock().unwrap().post_mvm::<F,M,T>(cores, data);
TRACER.lock().unwrap().post_mvm::<F, M, T>(cores, data);
Ok(InstructionStatus::Completed)
}
#[inline(never)]
pub(super) fn mvmul_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
where
[F]: UpcastSlice<T> + UpcastSlice<f32> + UpcastSlice<f64>,
@@ -306,17 +340,19 @@ where
}
}
#[inline(never)]
pub fn vvadd(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
panic!("You are calling a placeholder, the real call is the generic version");
}
#[inline(never)]
pub(super) fn vvadd_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
TRACER.lock().unwrap().pre_vvadd::<F,T>(cores, data);
TRACER.lock().unwrap().pre_vvadd::<F, T>(cores, data);
let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) =
data.get_core_rd_r1_r2_immlen_offset();
let core = cores.core(core_indx);
@@ -344,21 +380,23 @@ where
);
let res_up: Cow<[T]> = res.as_slice().up();
core.execute_store(rd_val, res_up.as_ref());
TRACER.lock().unwrap().post_vvadd::<F,T>(cores, data);
TRACER.lock().unwrap().post_vvadd::<F, T>(cores, data);
Ok(InstructionStatus::Completed)
}
#[inline(never)]
pub fn vvsub(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
panic!("You are calling a placeholder, the real call is the generic version");
}
#[inline(never)]
pub(super) fn vvsub_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
TRACER.lock().unwrap().pre_vvsub::<F,T>(cores, data);
TRACER.lock().unwrap().pre_vvsub::<F, T>(cores, data);
let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) =
data.get_core_rd_r1_r2_immlen_offset();
let core = cores.core(core_indx);
@@ -393,13 +431,14 @@ pub fn vvmul(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus
panic!("You are calling a placeholder, the real call is the generic version");
}
#[inline(never)]
pub(super) fn vvmul_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
TRACER.lock().unwrap().pre_vvmul::<F,T>(cores, data);
TRACER.lock().unwrap().pre_vvmul::<F, T>(cores, data);
let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) =
data.get_core_rd_r1_r2_immlen_offset();
let core = cores.core(core_indx);
@@ -429,17 +468,19 @@ where
Ok(InstructionStatus::Completed)
}
#[inline(never)]
pub fn vvdmul(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
panic!("You are calling a placeholder, the real call is the generic version");
}
#[inline(never)]
pub(super) fn vvdmul_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
TRACER.lock().unwrap().pre_vvdmul::<F,T>(cores, data);
TRACER.lock().unwrap().pre_vvdmul::<F, T>(cores, data);
let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) =
data.get_core_rd_r1_r2_immlen_offset();
let core = cores.core(core_indx);
@@ -465,17 +506,19 @@ where
Ok(InstructionStatus::Completed)
}
#[inline(never)]
pub fn vvmax(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
panic!("You are calling a placeholder, the real call is the generic version");
}
#[inline(never)]
pub(super) fn vvmax_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
TRACER.lock().unwrap().pre_vvmax::<F,T>(cores, data);
TRACER.lock().unwrap().pre_vvmax::<F, T>(cores, data);
let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) =
data.get_core_rd_r1_r2_immlen_offset();
let core = cores.core(core_indx);
@@ -502,29 +545,33 @@ where
Ok(InstructionStatus::Completed)
}
#[inline(never)]
pub fn vvsll(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
panic!(
"Shift left on floating point what does it means? who has generated this instruction???"
);
}
#[inline(never)]
pub fn vvsra(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
panic!(
"Shift right on floating point what does it means? who has generated this instruction???"
);
}
#[inline(never)]
pub fn vavg(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
panic!("You are calling a placeholder, the real call is the generic version");
}
#[inline(never)]
pub(super) fn vavg_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
TRACER.lock().unwrap().pre_vavg::<F,T>(cores, data);
TRACER.lock().unwrap().pre_vavg::<F, T>(cores, data);
let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) =
data.get_core_rd_r1_r2_immlen_offset();
let core = cores.core(core_indx);
@@ -532,7 +579,10 @@ where
let r2_val = r2;
ensure!(r2_val == 1, "Stride different than 1 not supported");
let rd_val = core.register(rd);
ensure!(offset_select == 1, "Offset select cannot be different from 1");
ensure!(
offset_select == 1,
"Offset select cannot be different from 1"
);
let r1_val = add_offset_r1(r1_val, offset_select, offset_value);
let loads = core.reserve_load(r1_val, imm_len)?.execute_load::<F>()?;
let load1 = loads[0];
@@ -544,17 +594,19 @@ where
Ok(InstructionStatus::Completed)
}
#[inline(never)]
pub fn vrelu(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
panic!("You are calling a placeholder, the real call is the generic version");
}
#[inline(never)]
pub(super) fn vrelu_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
{
TRACER.lock().unwrap().pre_vrelu::<F,T>(cores, data);
TRACER.lock().unwrap().pre_vrelu::<F, T>(cores, data);
let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) =
data.get_core_rd_r1_r2_immlen_offset();
let core = cores.core(core_indx);
@@ -574,17 +626,19 @@ where
Ok(InstructionStatus::Completed)
}
#[inline(never)]
pub fn vtanh(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
panic!("You are calling a placeholder, the real call is the generic version");
}
#[inline(never)]
pub(super) fn vtanh_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
{
TRACER.lock().unwrap().pre_vtanh::<F,T>(cores, data);
TRACER.lock().unwrap().pre_vtanh::<F, T>(cores, data);
let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) =
data.get_core_rd_r1_r2_immlen_offset();
let core = cores.core(core_indx);
@@ -602,17 +656,19 @@ where
Ok(InstructionStatus::Completed)
}
#[inline(never)]
pub fn vsigm(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
panic!("You are calling a placeholder, the real call is the generic version");
}
#[inline(never)]
pub(super) fn vsigm_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
{
TRACER.lock().unwrap().pre_vsigm::<F,T>(cores, data);
TRACER.lock().unwrap().pre_vsigm::<F, T>(cores, data);
let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) =
data.get_core_rd_r1_r2_immlen_offset();
let core = cores.core(core_indx);
@@ -628,17 +684,22 @@ where
Ok(InstructionStatus::Completed)
}
#[inline(never)]
pub fn vsoftmax(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
panic!("You are calling a placeholder, the real call is the generic version");
}
pub(super) fn vsoftmax_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
#[inline(never)]
pub(super) fn vsoftmax_impl<F, T>(
cores: &mut CPU,
data: InstructionData,
) -> Result<InstructionStatus>
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
{
TRACER.lock().unwrap().pre_vsoftmax::<F,T>(cores, data);
TRACER.lock().unwrap().pre_vsoftmax::<F, T>(cores, data);
let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) =
data.get_core_rd_r1_r2_immlen_offset();
let core = cores.core(core_indx);
@@ -655,27 +716,29 @@ where
.reduce(|a, b| if a > b { a } else { b })
.unwrap();
let exp_values: Vec<F> = load1.iter().map(|&a| (a - max_val).exp()).collect();
let sum = exp_values
.iter()
.copied()
.reduce(|a, b| a + b)
.unwrap();
ensure!(sum > 0.0.into(), "vsoftmax normalization sum must be positive");
let sum = exp_values.iter().copied().reduce(|a, b| a + b).unwrap();
ensure!(
sum > 0.0.into(),
"vsoftmax normalization sum must be positive"
);
let res: Vec<F> = exp_values.iter().map(|&a| a / sum).collect();
let res_up: Cow<[T]> = res.as_slice().up();
core.execute_store(rd_val, res_up.as_ref());
TRACER.lock().unwrap().post_vsoftmax::<F,T>(cores, data);
TRACER.lock().unwrap().post_vsoftmax::<F, T>(cores, data);
Ok(InstructionStatus::Completed)
}
#[inline(never)]
pub fn vmv(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
todo!()
}
#[inline(never)]
pub fn vrsu(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
todo!()
}
#[inline(never)]
pub fn vrsl(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
todo!()
}
@@ -683,6 +746,7 @@ pub fn vrsl(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
///////////////////////////////////////////////////////////////
///Communication/synchronization Instructions/////////////////
///////////////////////////////////////////////////////////////
#[inline(never)]
pub fn ld(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
TRACER.lock().unwrap().pre_ld(cores, data);
let (core, rd, r1, _, imm_len, offset_select, offset_value) =
@@ -699,6 +763,7 @@ pub fn ld(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
Ok(InstructionStatus::Completed)
}
#[inline(never)]
pub fn st(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
TRACER.lock().unwrap().pre_st(cores, data);
let (core, rd, r1, _, imm_len, offset_select, offset_value) =
@@ -715,6 +780,7 @@ pub fn st(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
Ok(InstructionStatus::Completed)
}
#[inline(never)]
pub fn lldi(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
TRACER.lock().unwrap().pre_lldi(cores, data);
let (core, rd, imm) = data.get_core_rd_imm();
@@ -731,6 +797,7 @@ pub fn lldi(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
Ok(InstructionStatus::Completed)
}
#[inline(never)]
pub fn lmv(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
TRACER.lock().unwrap().pre_lmv(cores, data);
let (core, rd, r1, _, imm_len, offset_select, offset_value) =
@@ -747,20 +814,32 @@ pub fn lmv(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
Ok(InstructionStatus::Completed)
}
#[inline(never)]
pub fn isa_send(functor : usize) -> bool{
(send as *const () as usize) == functor
}
#[inline(never)]
pub fn send(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
TRACER.lock().unwrap().pre_send(cores, data);
Ok(InstructionStatus::Sending(data))
}
#[inline(never)]
pub fn isa_recv(functor : usize) -> bool{
(recv as *const () as usize) == functor
}
#[inline(never)]
pub fn recv(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
TRACER.lock().unwrap().pre_recv(cores, data);
Ok(InstructionStatus::Reciving(data))
}
#[inline(never)]
pub fn wait(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
Ok(InstructionStatus::Waiting(data))
}
#[inline(never)]
pub fn sync(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
Ok(InstructionStatus::Sync(data))
}
@@ -14,7 +14,7 @@ pub mod helper;
#[derive(Clone, Copy, Debug)]
pub struct Instruction {
pub data: InstructionData,
functor: InstructionType,
pub functor: InstructionType,
}
#[derive(Debug, Clone, Copy, Default)]
@@ -40,7 +40,9 @@ impl Instruction {
Self { data, functor }
}
pub fn execute(&self, cpu: &mut CPU) -> InstructionStatus {
pub fn execute<'a, 'b>(&'b self, cpu: &mut CPU<'a>) -> InstructionStatus
where 'a : 'b
{
(self.functor)(cpu, self.data)
.with_context(|| format!("Instruction: {}", functor_to_name(self.functor as usize)))
.with_context(|| format!("Error in core: {}", self.data.core_indx() - 1))
@@ -567,7 +567,7 @@ fn json_to_send(
let (offset_select, offset_value) = json_to_offset(json.get("offset").unwrap());
inst_data_builder
.set_rd(rd)
.set_imm_core(core)
.set_imm_core(core + 1)
.set_imm_len(size)
.set_offset_select(offset_select)
.set_offset_value(offset_value);
@@ -588,7 +588,7 @@ fn json_to_recv(
let (offset_select, offset_value) = json_to_offset(json.get("offset").unwrap());
inst_data_builder
.set_rd(rd)
.set_imm_core(core)
.set_imm_core(core + 1)
.set_imm_len(size)
.set_offset_select(offset_select)
.set_offset_value(offset_value);
@@ -1,42 +1,32 @@
use core::panic;
use serde_json::{Map, Value};
use serde_json::Value;
use std::{fs::File, io::BufReader};
use crate::{
CoreInstructionsBuilder, Executable,
cpu::{CPU, crossbar},
instruction_set::{
InstructionsBuilder,
instruction_data::{self, InstructionData, InstructionDataBuilder},
},
json_to_instruction::{self, json_isa},
memory_manager::type_traits::TryToUsize,
cpu::{CPU, crossbar::Crossbar},
instruction_set::{InstructionsBuilder, instruction_data::InstructionDataBuilder},
json_to_instruction::json_isa,
};
pub fn json_to_executor<'a>(
pub fn json_to_executor<'a, 'b>(
config: Value,
mut cores: impl Iterator<Item = &'a Value>,
) -> Executable {
let cell_precision = config.get("cell_precision").unwrap().as_i64().unwrap() as i32;
let core_cnt = config.get("core_cnt").unwrap().as_i64().unwrap() as i32 - 1;
let xbar_count = config.get("xbar_array_count").unwrap().as_i64().unwrap() as i32;
let xbar_size = config.get("xbar_size").unwrap().as_array().unwrap();
let rows_crossbar = xbar_size[0].as_i64().unwrap() as i32;
let column_corssbar = xbar_size[1].as_i64().unwrap() as i32;
let mut cpu = CPU::new(core_cnt);
cpu.reserve_crossbar(xbar_count, column_corssbar * 4, rows_crossbar);
cores: &'b mut Vec<BufReader<File>>,
crossbars: Vec<Vec<&'a Crossbar>>,
) -> Executable<'a> {
let core_cnt = config.get("core_cnt").unwrap().as_i64().unwrap() as i32;
let cpu = CPU::new(core_cnt, crossbars);
let mut core_insts_builder = CoreInstructionsBuilder::new(core_cnt as usize);
cores.next();
for core_indx in 1..=core_cnt {
for (external_core_indx, json_core_reader) in cores.iter_mut().enumerate() {
let core_indx = external_core_indx as i32 + 1;
let mut insts_builder = InstructionsBuilder::new();
let mut inst_data_builder = InstructionDataBuilder::new();
inst_data_builder.set_core_indx(core_indx).fix_core_indx();
let json_core = cores
.next()
.unwrap_or_else(|| panic!("cores files less than {}", core_indx ));
let json_core: Value = serde_json::from_reader(json_core_reader)
.unwrap_or_else(|err| panic!("failed to parse core{}: {}", external_core_indx, err));
let json_core_insts = json_core
.as_array()
.unwrap_or_else(|| panic!("core{} has not a list of instruction", core_indx));
.unwrap_or_else(|| panic!("core{} has not a list of instruction", external_core_indx));
for json_inst in json_core_insts {
json_isa::json_to_instruction(&mut insts_builder, &mut inst_data_builder, json_inst);
}
@@ -1,2 +1,2 @@
mod json_isa;
pub(crate) mod json_isa;
pub mod json_to_executor;
@@ -111,6 +111,24 @@ where {
Ok(res)
}
pub fn load_const<T>(&self, address: impl TryToUsize, size: impl TryToUsize) -> Result<Vec<&[T]>>
where
T: MemoryStorable,
{
let address = address.try_into().expect("address can not be negative");
let size = size.try_into().expect("size can not be negative");
let Self {
memory,
load_requests,
} = self;
let mut res = Vec::new();
let memory_slice = &memory[address..address + size];
let memory_slice = unsafe { slice_from_u8(memory_slice) }
.with_context(|| format!("Accessing from {} to {}", address, address + size))?;
res.push(memory_slice);
Ok(res)
}
pub fn load<T>(&mut self, address: impl TryToUsize, size: impl TryToUsize) -> Result<Vec<&[T]>>
where
T: MemoryStorable,
@@ -55,16 +55,24 @@ pub trait HasSigm {
impl HasSigm for f32 {
fn sigm(self) -> Self {
if self >= 0.0 {
1.0 / (1.0 + (-self).exp())
} else {
let ex = self.exp();
ex / (1.0 + ex)
}
}
}
impl HasSigm for f64 {
fn sigm(self) -> Self {
if self >= 0.0 {
1.0 / (1.0 + (-self).exp())
} else {
let ex = self.exp();
ex / (1.0 + ex)
}
}
}
pub trait HasExp {
@@ -1,50 +1,61 @@
#![allow(unused)]
use crate::{
cpu::CPU, instruction_set::{Instruction, InstructionStatus, Instructions, isa::functor_to_name}, memory_manager::type_traits::TryToUsize, send_recv::{SendRecv, handle_send_recv}, tracing::TRACER
use std::{
collections::{HashMap, HashSet},
time::{Duration, SystemTime},
};
use crate::{
cpu::CPU,
instruction_set::{
Instruction, InstructionStatus, Instructions,
isa::{NAMES, functor_to_name, isa_recv, isa_send},
},
memory_manager::type_traits::TryToUsize,
send_recv::{SendRecv, handle_send_recv},
tracing::TRACER,
};
pub mod binary_to_instruction;
pub mod cpu;
pub mod instruction_set;
pub mod json_to_instruction;
pub mod memory_manager;
pub mod send_recv;
pub mod utility;
pub mod json_to_instruction;
pub mod tracing;
pub mod utility;
#[derive(Debug, Clone)]
pub struct CoreInstructionsBuilder {
core_instructions : Vec<CoreInstruction>
core_instructions: Vec<CoreInstructions>,
}
impl CoreInstructionsBuilder {
pub fn new(size:usize) -> Self {
pub fn new(size: usize) -> Self {
let mut core_instructions = Vec::with_capacity(size);
for _ in 0..=size {
core_instructions.push(CoreInstruction::empty());
core_instructions.push(CoreInstructions::empty());
}
Self { core_instructions }
}
pub fn build(self) -> Vec<CoreInstruction> {
pub fn build(self) -> Vec<CoreInstructions> {
self.core_instructions
}
pub fn set_core(&mut self, core : impl TryToUsize, core_instruction : Instructions) -> &mut Self{
self.core_instructions[core.try_into().expect("Set core with not valid size")] = core_instruction.into();
pub fn set_core(&mut self, core: impl TryToUsize, core_instruction: Instructions) -> &mut Self {
self.core_instructions[core.try_into().expect("Set core with not valid size")] =
core_instruction.into();
self
}
}
#[derive(Debug, Clone)]
pub struct CoreInstruction {
pub struct CoreInstructions {
instructions: Instructions,
program_counter: usize,
}
impl CoreInstruction {
impl CoreInstructions {
fn new(instructions: Instructions, program_counter: usize) -> Self {
Self {
instructions,
@@ -53,13 +64,16 @@ impl CoreInstruction {
}
fn empty() -> Self {
Self { instructions: Vec::new(), program_counter: 0 }
Self {
instructions: Vec::new(),
program_counter: 0,
}
}
}
impl From<Instructions> for CoreInstruction {
impl From<Instructions> for CoreInstructions {
fn from(value: Instructions) -> Self {
CoreInstruction {
CoreInstructions {
instructions: value,
program_counter: 0,
}
@@ -67,39 +81,64 @@ impl From<Instructions> for CoreInstruction {
}
#[derive(Debug, Clone)]
pub struct Executable {
cpu: CPU,
core_instructions: Vec<CoreInstruction>,
send_recv : SendRecv,
pub struct Executable<'a> {
cpu: CPU<'a>,
core_instructions: Vec<CoreInstructions>,
send_recv: SendRecv,
}
impl Executable {
pub fn new(cpu: CPU, core_instructions: Vec<CoreInstruction>) -> Self {
fn print_status(core_instructions: &[CoreInstructions]) {
let mut tot_instructions = 0;
let mut progress = 0;
for core_instruction in core_instructions.iter() {
tot_instructions += core_instruction.instructions.len();
progress += core_instruction.program_counter;
}
println!(
"Progress: {}% ({}/{}) ",
progress as f32 / tot_instructions as f32 * 100.0,
progress,
tot_instructions
);
}
impl<'a> Executable<'a> {
pub fn new(cpu: CPU<'a>, core_instructions: Vec<CoreInstructions>) -> Executable<'a> {
let num_core = cpu.num_core();
let send_recv = SendRecv::new(num_core);
assert_eq!(num_core, core_instructions.len(), "Some core doesn't have is list of istruction (required even if empty)");
assert_eq!(
num_core,
core_instructions.len(),
"Some core doesn't have is list of istruction (required even if empty)"
);
Self {
cpu,
core_instructions,
send_recv
send_recv,
}
}
pub fn execute(&mut self) {
pub fn execute<'b>(&'b mut self)
where
'a: 'b,
{
let Self {
cpu,
core_instructions,
send_recv
core_instructions: cores_instructions,
send_recv,
} = self;
let mut cpu_progressed = 0;
let max_core = cpu.num_core();
let mut index_unit = 0;
let mut cpu_index = 0;
let mut now = SystemTime::now();
while (cpu_progressed > -2) {
let mut core_result = InstructionStatus::Completed;
while core_result.is_completed() && let Some(core_instruction) = core_instructions.get_mut(index_unit){
while core_result.is_completed()
&& let Some(core_instruction) = cores_instructions.get_mut(cpu_index)
{
core_result = InstructionStatus::NotExecuted;
let CoreInstruction {
let CoreInstructions {
instructions,
program_counter,
} = core_instruction;
@@ -112,23 +151,42 @@ impl Executable {
cpu_progressed = 0;
*program_counter += 1;
}
if (now.elapsed().unwrap() > Duration::from_secs(5)) {
print_status(cores_instructions);
check_cycle(cpu, cores_instructions, send_recv);
now = SystemTime::now();
}
if handle_send_recv(cpu, core_instructions, send_recv, core_result) { cpu_progressed = 0; }
handle_wait_sync(cpu, core_instructions, core_result);
index_unit = if index_unit + 1 >= max_core {
cpu_progressed-=1;
}
handle_wait_sync(cpu, cores_instructions, core_result);
match handle_send_recv(cpu, cores_instructions, send_recv, core_result) {
(true, other_cpu_index) => {
cpu_progressed = 0;
cpu_index = other_cpu_index;
}
(false, 0) => {
cpu_index = if cpu_index + 1 >= cores_instructions.len() {
cpu_progressed -= 1;
0
} else {
index_unit + 1
cpu_index + 1
};
}
(false, other_cpu_index) => {
cpu_index = other_cpu_index;
}
}
}
print_status(cores_instructions);
#[cfg(feature = "profile_time")]
TRACER.lock().unwrap().report();
}
pub fn cpu(&self) -> &CPU {
pub fn cpu(&self) -> &CPU<'a> {
&self.cpu
}
pub fn cpu_mut(&mut self) -> &mut CPU {
pub fn cpu_mut(&mut self) -> &mut CPU<'a> {
&mut self.cpu
}
@@ -143,64 +201,107 @@ impl Executable {
}
}
fn handle_wait_sync(cpu: &mut CPU, core_instructions: &mut [CoreInstruction], core_result: InstructionStatus) {
fn check_cycle(cpu: &mut CPU, cores_instructions: &[CoreInstructions], send_recv: &mut SendRecv) {
#[derive(Debug, PartialEq, Eq)]
enum CoreState {
SendingTo(i32),
ReceivingFrom(i32),
Working,
Halted,
}
let mut states = HashMap::new();
for core_inst in cores_instructions.iter() {
if core_inst.program_counter >= core_inst.instructions.len() {
continue;
}
let Instruction { data, functor } = core_inst.instructions[core_inst.program_counter];
let functor_address = functor as usize;
let (this_core, target_core) = data.get_core_immcore();
if isa_recv(functor_address) {
states.insert(this_core, CoreState::ReceivingFrom(target_core));
} else if isa_send(functor_address) {
states.insert(this_core, CoreState::SendingTo(target_core));
} else {
states.insert(this_core, CoreState::Working);
}
}
let mut wait_for = HashMap::new();
for (&core_id, state) in states.iter() {
match state {
CoreState::SendingTo(target_core) => {
let target_state = states.get(target_core).unwrap_or(&CoreState::Halted);
if target_state != &CoreState::ReceivingFrom(core_id) {
wait_for.insert(core_id, *target_core);
}
}
CoreState::ReceivingFrom(target_core) => {
let target_state = states.get(target_core).unwrap_or(&CoreState::Halted);
if target_state != &CoreState::SendingTo(core_id) {
wait_for.insert(core_id, *target_core);
}
}
CoreState::Working | CoreState::Halted => {
}
}
}
let mut visited = HashSet::new();
for &start_core in wait_for.keys() {
if visited.contains(&start_core) {
continue;
}
let mut path = Vec::new();
let mut current_core = start_core;
let mut in_path = HashSet::new();
while let Some(&waiting_for) = wait_for.get(&current_core) {
path.push(current_core);
in_path.insert(current_core);
visited.insert(current_core);
// Found a closed loop!
if in_path.contains(&waiting_for) {
let cycle_start = path.iter().position(|&c| c == waiting_for).unwrap();
let cycle = &path[cycle_start..];
let cycle_str = cycle
.iter()
.map(|c| c.to_string())
.collect::<Vec<_>>()
.join(" -> ");
let cycle_msg = format!("{} -> {}", cycle_str, waiting_for);
println!("Fatal: Deadlock cycle detected: {}", cycle_msg);
// bail!("Deadlock detected: {}", cycle_msg);
break; // Stop tracing
}
// Hit a known branch that didn't result in a cycle
if visited.contains(&waiting_for) {
break;
}
current_core = waiting_for;
}
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::instruction_set::instruction_data::InstructionDataBuilder;
use crate::instruction_set::{InstructionsBuilder, isa::*};
#[test]
fn test_only_host() {
let mut cpu = CPU::new(0);
cpu.host()
.execute_store(0, &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]);
let mut inst_builder = InstructionsBuilder::new();
let mut idata_build = InstructionDataBuilder::new();
idata_build.set_core_indx(0).fix_core_indx();
inst_builder.make_inst(sldi, idata_build.set_rdimm(1, 0).build());
inst_builder.make_inst(sld, idata_build.set_rdr1(1, 1).build());
inst_builder.make_inst(sldi, idata_build.set_rdimm(2, 8).build());
inst_builder.make_inst(sld, idata_build.set_rdr1(2, 2).build());
inst_builder.make_inst(sadd, idata_build.set_rdr1r2(2, 1, 2).build());
let mut core_instruction = vec![inst_builder.build().into()];
let mut executable = Executable::new(cpu, core_instruction);
executable.execute();
assert_eq!(executable.cpu_mut().host().register(2), 4, "Not sum to 4");
}
#[test]
fn test_10_core_same_code() {
let setup_core = |index: usize, cpu: &mut CPU| -> Instructions {
cpu.core(index)
.execute_store(0, &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]);
let mut inst_builder = InstructionsBuilder::new();
let mut idata_build = InstructionDataBuilder::new();
idata_build.set_core_indx(index as i32).fix_core_indx();
inst_builder.make_inst(sldi, idata_build.set_rdimm(1, 0).build());
inst_builder.make_inst(sld, idata_build.set_rdr1(1, 1).build());
inst_builder.make_inst(sldi, idata_build.set_rdimm(2, 8).build());
inst_builder.make_inst(sld, idata_build.set_rdr1(2, 2).build());
inst_builder.make_inst(sadd, idata_build.set_rdr1r2(2, 1, 2).build());
inst_builder.build()
};
let mut cpu = CPU::new(10);
let mut core_instruction = Vec::new();
for i in 0..cpu.num_core() {
core_instruction.push(setup_core(i, &mut cpu).into())
}
let mut executable = Executable::new(cpu, core_instruction);
executable.execute();
for i in 0.. executable.cpu.num_core() {
assert_eq!(executable.cpu_mut().core(i).register(2), 4, "Core {} not sum to 4", i);
}
}
fn handle_wait_sync<'a, 'b, 'c>(
cpu: &'b mut CPU<'a>,
core_instructions: &'c mut [CoreInstructions],
core_result: InstructionStatus,
) where
'a: 'b,
'a: 'c,
{
}
@@ -1,7 +1,7 @@
use anyhow::Context;
use crate::{
CoreInstruction, cpu::CPU, instruction_set::InstructionStatus, tracing::TRACER,
CoreInstructions, cpu::CPU, instruction_set::InstructionStatus, tracing::TRACER,
utility::add_offset_rd,
};
@@ -41,14 +41,16 @@ impl SendRecv {
}
}
pub fn handle_send_recv(
cpu: &mut CPU,
core_instructions: &mut [CoreInstruction],
send_recv: &mut SendRecv,
pub fn handle_send_recv<'a, 'b >(
cpu: &'b mut CPU<'a>,
core_instructions: & mut [CoreInstructions],
send_recv: & mut SendRecv,
core_result: InstructionStatus,
) -> bool {
let transfer_memory = |cpu: &mut CPU,
core_instructions: &mut [CoreInstruction],
) -> (bool, usize)
where 'a : 'b
{
let transfer_memory = |cpu: &'b mut CPU<'a>,
core_instructions: & mut [CoreInstructions],
sender: Option<SendRecvInfo>,
receiver: Option<SendRecvInfo>| {
if let Some(sender) = sender
@@ -56,6 +58,20 @@ pub fn handle_send_recv(
&& sender.internal_core == receiver.external_core
&& receiver.internal_core == sender.external_core
{
{
let sender = &mut core_instructions[sender.internal_core];
let pc = sender.program_counter;
let inst = sender.instructions.get(pc).unwrap();
let data = inst.data;
TRACER.lock().unwrap().pre_send(cpu, data);
}
{
let recv = &mut core_instructions[receiver.internal_core];
let pc = recv.program_counter;
let inst = recv.instructions.get(pc).unwrap();
let data = inst.data;
TRACER.lock().unwrap().pre_recv(cpu, data);
}
let [sender_core, reciver_core] =
cpu.get_multiple_cores([sender.internal_core, receiver.internal_core]);
let memory = sender_core
@@ -117,7 +133,7 @@ pub fn handle_send_recv(
send_recv.sending[sender] = None;
send_recv.receiving[receiver] = None;
}
return transfered;
(transfered, receiver)
}
InstructionStatus::Reciving(instruction_data) => {
let (core_idx, imm_core) = instruction_data.get_core_immcore();
@@ -146,8 +162,8 @@ pub fn handle_send_recv(
send_recv.sending[sender] = None;
send_recv.receiving[receiver] = None;
}
return transfered;
(transfered, sender)
}
_ => false,
_ => (false, 0),
}
}
@@ -13,7 +13,7 @@ use crate::{
};
use std::io::Write;
#[cfg(not(feature = "tracing"))]
#[cfg(not(any(feature = "tracing", feature = "profile_time")))]
impl Trace {
///////////////////////////////////////////////////////////////
/////////////////Scalar/register Instructions//////////////////
@@ -1,53 +1,32 @@
mod tracing_isa;
mod disable;
mod pretty_print;
#[cfg(feature = "profile_time")]
mod profile;
#[cfg(feature = "profile_time")]
use profile::Trace;
#[cfg(feature = "tracing")]
use std::{fs::File, path::{ PathBuf}};
use std::sync::{LazyLock, Mutex};
mod trace;
#[cfg(feature = "tracing")]
use trace::Trace;
use crate::Executable;
#[cfg(feature = "tracing")]
pub struct Trace {
out_files : Vec<File>
}
#[cfg(not(any(feature = "tracing", feature = "profile_time")))]
use std::path::PathBuf;
use std::sync::{LazyLock, Mutex};
#[cfg(feature = "tracing")]
#[cfg(not(any(feature = "tracing", feature = "profile_time")))]
pub struct Trace {}
#[cfg(not(any(feature = "tracing", feature = "profile_time")))]
impl Trace {
fn new() -> Self {
Self { out_files : Vec::new()}
Self {}
}
pub fn init(&mut self, num_core : usize , mut path : PathBuf) {
path.pop();
for i in 0..num_core {
path.push(format!("TraceCore{}", i));
let file = File::create(&path).expect("Can not create file");
self.out_files.push(file);
path.pop();
}
}
}
#[cfg(not(feature = "tracing"))]
pub struct Trace {
pub fn init(&mut self, num_core: usize, path: PathBuf) {}
}
#[cfg(not(feature = "tracing"))]
impl Trace {
fn new() -> Self {
Self { }
}
pub fn init(&mut self, num_core : usize, path : PathBuf ) {
}
}
pub static TRACER: LazyLock<Mutex<Trace>> = LazyLock::new(|| { Trace::new().into()});
pub static TRACER: LazyLock<Mutex<Trace>> = LazyLock::new(|| Trace::new().into());
@@ -0,0 +1,73 @@
use std::{collections::HashMap, path::PathBuf, time::Instant};
use crate::tracing::profile::profile_analysis::{
analyze_timings, generate_interactive_report, print_textual_report,
};
pub mod profile_analysis;
pub mod profile_isa;
pub struct Trace {
instruction_times: HashMap<String, Vec<(u128,u128)>>,
core_start_time: HashMap<usize, Option<Instant>>,
start_time: Instant,
}
impl Trace {
pub fn new() -> Self {
let mut instruction_times = HashMap::new();
instruction_times.insert("sldi".to_string(), Vec::with_capacity(20000));
instruction_times.insert("sld".to_string(), Vec::with_capacity(20000));
instruction_times.insert("sadd".to_string(), Vec::with_capacity(20000));
instruction_times.insert("ssub".to_string(), Vec::with_capacity(20000));
instruction_times.insert("smul".to_string(), Vec::with_capacity(20000));
instruction_times.insert("saddi".to_string(), Vec::with_capacity(20000));
instruction_times.insert("smuli".to_string(), Vec::with_capacity(20000));
instruction_times.insert("setbw".to_string(), Vec::with_capacity(20000));
instruction_times.insert("mvmul".to_string(), Vec::with_capacity(20000));
instruction_times.insert("vvadd".to_string(), Vec::with_capacity(20000));
instruction_times.insert("vvsub".to_string(), Vec::with_capacity(20000));
instruction_times.insert("vvmul".to_string(), Vec::with_capacity(20000));
instruction_times.insert("vvdmul".to_string(), Vec::with_capacity(20000));
instruction_times.insert("vvmax".to_string(), Vec::with_capacity(20000));
instruction_times.insert("vvsll".to_string(), Vec::with_capacity(20000));
instruction_times.insert("vvsra".to_string(), Vec::with_capacity(20000));
instruction_times.insert("vavg".to_string(), Vec::with_capacity(20000));
instruction_times.insert("vrelu".to_string(), Vec::with_capacity(20000));
instruction_times.insert("vtanh".to_string(), Vec::with_capacity(20000));
instruction_times.insert("vsigm".to_string(), Vec::with_capacity(20000));
instruction_times.insert("vsoftmax".to_string(), Vec::with_capacity(20000));
instruction_times.insert("vmv".to_string(), Vec::with_capacity(20000));
instruction_times.insert("vrsu".to_string(), Vec::with_capacity(20000));
instruction_times.insert("vrsl".to_string(), Vec::with_capacity(20000));
instruction_times.insert("ld".to_string(), Vec::with_capacity(20000));
instruction_times.insert("st".to_string(), Vec::with_capacity(20000));
instruction_times.insert("lldi".to_string(), Vec::with_capacity(20000));
instruction_times.insert("lmv".to_string(), Vec::with_capacity(20000));
instruction_times.insert("send".to_string(), Vec::with_capacity(20000));
instruction_times.insert("recv".to_string(), Vec::with_capacity(20000));
instruction_times.insert("wait".to_string(), Vec::with_capacity(20000));
instruction_times.insert("sync".to_string(), Vec::with_capacity(20000));
Self {
instruction_times,
core_start_time: HashMap::new(),
start_time: Instant::now()
}
}
pub fn init(&mut self, num_core: usize, path: PathBuf) {
for i in 0..num_core {
self.core_start_time.insert(i, None);
}
}
pub fn report(&self) {
let res = analyze_timings(&self.instruction_times);
print_textual_report(&res);
generate_interactive_report(
&self.instruction_times,
&["mvmul", "recv"],
"/tmp/report.html",
);
}
}
@@ -0,0 +1,192 @@
use comfy_table::{Cell, Table, modifiers::UTF8_ROUND_CORNERS, presets::UTF8_FULL};
use statrs::statistics::{Data, Distribution, Max, Min, OrderStatistics};
use std::collections::HashMap;
#[derive(Debug)]
pub struct InstructionStats {
pub name: String,
pub count: usize,
pub total_time: u128,
pub min: f64,
pub max: f64,
pub mean: f64,
pub median: f64,
pub std_dev: f64,
pub cv: f64,
pub p95: f64,
pub p99: f64,
pub skewness: f64,
pub kurtosis: f64,
}
fn format_time(ns: f64) -> String {
if ns.is_nan() {
return "NaN".to_string();
}
if ns >= 1_000_000_000.0 {
format!("{:.2} s", ns / 1_000_000_000.0)
} else if ns >= 1_000_000.0 {
format!("{:.2} ms", ns / 1_000_000.0)
} else if ns >= 1_000.0 {
format!("{:.2} µs", ns / 1_000.0)
} else {
format!("{:.2} ns", ns)
}
}
fn calculate_skewness_kurtosis(times: &[f64], mean: f64, std_dev: f64) -> (f64, f64) {
let n = times.len() as f64;
if n < 4.0 || std_dev == 0.0 {
return (f64::NAN, f64::NAN);
}
let mut sum_m3 = 0.0;
let mut sum_m4 = 0.0;
for &x in times {
let deviation = x - mean;
sum_m3 += deviation.powi(3);
sum_m4 += deviation.powi(4);
}
let m3 = sum_m3 / n;
let m4 = sum_m4 / n;
let skewness = m3 / std_dev.powi(3);
let kurtosis = (m4 / std_dev.powi(4)) - 3.0;
(skewness, kurtosis)
}
pub fn analyze_timings(timings: &HashMap<String, Vec<(u128, u128)>>) -> Vec<InstructionStats> {
let mut results = Vec::new();
for (instruction, times) in timings {
let count = times.len();
if count == 0 {
continue;
}
// Extract ONLY the duration (the second element of the tuple) for stats
let durations: Vec<u128> = times.iter().map(|&(_, duration)| duration).collect();
let total_time: u128 = durations.iter().sum();
let f64_times: Vec<f64> = durations.iter().map(|&t| t as f64).collect();
let mut data = Data::new(f64_times.clone());
let mean = data.mean().unwrap_or(0.0);
let std_dev = data.std_dev().unwrap_or(0.0);
let cv = if mean > 0.0 { std_dev / mean } else { 0.0 };
let (skewness, kurtosis) = calculate_skewness_kurtosis(&f64_times, mean, std_dev);
results.push(InstructionStats {
name: instruction.clone(),
count,
total_time,
min: data.min(),
max: data.max(),
mean,
median: data.median(),
std_dev,
cv,
p95: data.percentile(95),
p99: data.percentile(99),
skewness,
kurtosis,
});
}
results.sort_by(|a, b| b.mean.partial_cmp(&a.mean).unwrap());
results
}
pub fn print_textual_report(stats: &[InstructionStats]) {
let mut table = Table::new();
table
.load_preset(UTF8_FULL)
.apply_modifier(UTF8_ROUND_CORNERS)
.set_header(vec![
"Instruction",
"Count",
"Total Time",
"Mean",
"Median",
"Min",
"Max",
"P95",
"P99",
"StdDev",
"CV",
"Skewness",
"Kurtosis",
]);
for stat in stats {
table.add_row(vec![
Cell::new(&stat.name),
Cell::new(stat.count.to_string()),
Cell::new(format_time(stat.total_time as f64)), // Cast u128 to f64 for formatting
Cell::new(format_time(stat.mean)),
Cell::new(format_time(stat.median)),
Cell::new(format_time(stat.min)),
Cell::new(format_time(stat.max)),
Cell::new(format_time(stat.p95)),
Cell::new(format_time(stat.p99)),
Cell::new(format_time(stat.std_dev)),
Cell::new(format!("{:.3}", stat.cv)),
Cell::new(format!("{:.2}", stat.skewness)),
Cell::new(format!("{:.2}", stat.kurtosis)),
]);
}
println!("{table}");
}
pub fn generate_interactive_report(
timings: &HashMap<String, Vec<(u128, u128)>>,
instructions_to_plot: &[&str], // <-- NEW: Only plot these
file_path: &str,
) {
use plotly::common::{Mode, Marker, Line};
use plotly::layout::{Axis, Layout};
use plotly::{Plot, Scatter};
use std::collections::HashMap;
let mut plot = Plot::new();
for &instruction_name in instructions_to_plot {
// Only proceed if the instruction exists in our timings map
if let Some(times) = timings.get(instruction_name) {
let x_axis: Vec<f64> = times.iter().map(|&(ts, _)| ts as f64).collect();
let y_axis: Vec<f64> = times.iter().map(|&(_, dur)| dur as f64).collect();
let text_array: Vec<String> = times.iter()
.map(|&(_, dur)| format_time(dur as f64))
.collect();
let trace = Scatter::new(x_axis, y_axis)
.name(instruction_name)
.mode(Mode::LinesMarkers)
.marker(Marker::new().size(4).opacity(0.6))
.line(Line::new().width(1.0))
.text_array(text_array)
.hover_info(plotly::common::HoverInfo::All);
plot.add_trace(trace);
}
}
let layout = Layout::new()
.title(plotly::common::Title::new("Simulator Timeline: Top Offenders"))
.x_axis(Axis::new().title(plotly::common::Title::new("Absolute Time (ns)")))
.y_axis(Axis::new().title(plotly::common::Title::new("Execution Duration")));
plot.set_layout(layout);
plot.write_html(file_path);
println!("🌐 Interactive timeline saved to {}", file_path);
}
@@ -0,0 +1,364 @@
use crate::{
cpu::CPU,
instruction_set::instruction_data::InstructionData,
memory_manager::{
MemoryStorable,
type_traits::{FromFloat, UpcastDestTraits, UpcastSlice},
},
tracing::Trace,
utility::{add_offset_r1, add_offset_rd},
};
use std::io::Write;
use std::time::Instant;
#[cfg(feature = "profile_time")]
impl Trace {
///////////////////////////////////////////////////////////////
/////////////////Scalar/register Instructions//////////////////
///////////////////////////////////////////////////////////////
fn pre_impl(&mut self, cores: &mut CPU, data: InstructionData) {
let (core_indx, rd, imm) = data.get_core_rd_imm();
let core_indx = core_indx as usize;
if self.core_start_time.get(&core_indx).unwrap().is_none() {
self.core_start_time.insert(core_indx, Some(Instant::now()));
}
}
fn post_impl(&mut self, cores: &mut CPU, data: InstructionData, name: &'static str) {
let (core_indx, rd, imm) = data.get_core_rd_imm();
let core_indx = core_indx as usize;
let Self {
instruction_times,
core_start_time,
start_time,
} = self;
let now = Instant::now();
instruction_times
.get_mut(name)
.unwrap()
.push((now.duration_since(*start_time).as_nanos(), now.duration_since(core_start_time[&core_indx].unwrap()).as_nanos()));
self.core_start_time.insert(core_indx, None);
}
pub fn pre_sldi(&mut self, cores: &mut CPU, data: InstructionData) {
self.pre_impl(cores, data);
}
pub fn post_sldi(&mut self, cores: &mut CPU, data: InstructionData) {
self.post_impl(cores, data, "sldi");
}
pub fn pre_sld(&mut self, cores: &mut CPU, data: InstructionData) {
self.pre_impl(cores, data);
}
pub fn post_sld(&mut self, cores: &mut CPU, data: InstructionData) {
self.post_impl(cores, data, "sld");
}
pub fn pre_sadd(&mut self, cores: &mut CPU, data: InstructionData) {
self.pre_impl(cores, data);
}
pub fn post_sadd(&mut self, cores: &mut CPU, data: InstructionData) {
self.post_impl(cores, data, "sadd");
}
pub fn pre_ssub(&mut self, cores: &mut CPU, data: InstructionData) {
self.pre_impl(cores, data);
}
pub fn post_ssub(&mut self, cores: &mut CPU, data: InstructionData) {
self.post_impl(cores, data, "ssub");
}
pub fn pre_smul(&mut self, cores: &mut CPU, data: InstructionData) {
self.pre_impl(cores, data);
}
pub fn post_smul(&mut self, cores: &mut CPU, data: InstructionData) {
self.post_impl(cores, data, "smul");
}
pub fn pre_saddi(&mut self, cores: &mut CPU, data: InstructionData) {
self.pre_impl(cores, data);
}
pub fn post_saddi(&mut self, cores: &mut CPU, data: InstructionData) {
self.post_impl(cores, data, "saddi");
}
pub fn pre_smuli(&mut self, cores: &mut CPU, data: InstructionData) {
self.pre_impl(cores, data);
}
pub fn post_smuli(&mut self, cores: &mut CPU, data: InstructionData) {
self.post_impl(cores, data, "smuli");
}
/////////////////////////////////////////////////////////////////
///////////////////Matrix/vector Instructions////////////////////
/////////////////////////////////////////////////////////////////
pub fn pre_setbw(&mut self, cores: &mut CPU, data: InstructionData) {
self.pre_impl(cores, data);
}
pub fn post_setbw(&mut self, cores: &mut CPU, data: InstructionData) {
self.post_impl(cores, data, "setbw");
}
pub fn pre_mvm<F, M, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T> + UpcastSlice<M>,
[M]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
M: UpcastDestTraits<M> + MemoryStorable + FromFloat,
F: UpcastDestTraits<F> + MemoryStorable,
{
self.pre_impl(cores, data);
}
pub fn post_mvm<F, M, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T> + UpcastSlice<M>,
[M]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
M: UpcastDestTraits<M> + MemoryStorable + FromFloat,
F: UpcastDestTraits<F> + MemoryStorable,
{
self.post_impl(cores, data, "mvmul");
}
pub fn pre_vvadd<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
self.pre_impl(cores, data);
}
pub fn post_vvadd<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
self.post_impl(cores, data, "vvadd");
}
pub fn pre_vvsub<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
self.pre_impl(cores, data);
}
pub fn post_vvsub<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
self.post_impl(cores, data, "vvsub");
}
pub fn pre_vvmul<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
self.pre_impl(cores, data);
}
pub fn post_vvmul<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
self.post_impl(cores, data, "vvmul");
}
pub fn pre_vvdmul<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
self.pre_impl(cores, data);
}
pub fn post_vvdmul<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
self.post_impl(cores, data, "vvdmul");
}
pub fn pre_vvmax<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
self.pre_impl(cores, data);
}
pub fn post_vvmax<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
self.post_impl(cores, data, "vvmax");
}
pub fn pre_vavg<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
self.pre_impl(cores, data);
}
pub fn post_vavg<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
self.post_impl(cores, data, "vavg");
}
pub fn pre_vrelu<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
{
self.pre_impl(cores, data);
}
pub fn post_vrelu<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
{
self.post_impl(cores, data, "vrelu");
}
pub fn pre_vtanh<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
{
self.pre_impl(cores, data);
}
pub fn post_vtanh<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
{
self.post_impl(cores, data, "vtanh");
}
pub fn pre_vsigm<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
{
self.pre_impl(cores, data);
}
pub fn post_vsigm<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
{
self.post_impl(cores, data, "vsigm");
}
pub fn pre_vsoftmax<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
{
self.pre_impl(cores, data);
}
pub fn post_vsoftmax<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
{
self.post_impl(cores, data, "vsoftmax");
}
/////////////////////////////////////////////////////////////////
/////Communication/synchronization Instructions/////////////////
/////////////////////////////////////////////////////////////////
pub fn pre_ld(&mut self, cores: &mut CPU, data: InstructionData) {
self.pre_impl(cores, data);
}
pub fn post_ld(&mut self, cores: &mut CPU, data: InstructionData) {
self.post_impl(cores, data, "ld");
}
pub fn pre_st(&mut self, cores: &mut CPU, data: InstructionData) {
self.pre_impl(cores, data);
}
pub fn post_st(&mut self, cores: &mut CPU, data: InstructionData) {
self.post_impl(cores, data, "st");
}
pub fn pre_lldi(&mut self, cores: &mut CPU, data: InstructionData) {
self.pre_impl(cores, data);
}
pub fn post_lldi(&mut self, cores: &mut CPU, data: InstructionData) {
self.post_impl(cores, data, "lldi");
}
pub fn pre_lmv(&mut self, cores: &mut CPU, data: InstructionData) {
self.pre_impl(cores, data);
}
pub fn post_lmv(&mut self, cores: &mut CPU, data: InstructionData) {
self.post_impl(cores, data, "lmv");
}
pub fn pre_send(&mut self, cores: &mut CPU, data: InstructionData) {
self.pre_impl(cores, data);
}
pub fn post_send(&mut self, cores: &mut CPU, data: InstructionData) {
self.post_impl(cores, data, "send");
}
pub fn pre_recv(&mut self, cores: &mut CPU, data: InstructionData) {
self.pre_impl(cores, data);
}
pub fn post_recv(&mut self, cores: &mut CPU, data: InstructionData) {
self.post_impl(cores, data, "recv");
}
}
@@ -0,0 +1,28 @@
use std::{fs::File, path::PathBuf};
pub mod pretty_print;
pub mod tracing_isa;
pub struct Trace {
out_files: Vec<File>,
}
impl Trace {
pub fn new() -> Self {
Self {
out_files: Vec::new(),
}
}
pub fn init(&mut self, num_core: usize, mut path: PathBuf) {
path.pop();
for i in 0..num_core {
path.push(format!("TraceCore{}", i));
let file = File::create(&path).expect("Can not create file");
self.out_files.push(file);
path.pop();
}
}
}
@@ -1,4 +1,4 @@
use crate::tracing::pretty_print;
use crate::{tracing::trace::pretty_print, utility::add_offset_r2};
use std::fs::File;
use crate::{
@@ -13,7 +13,6 @@ use crate::{
};
use std::io::Write;
#[cfg(feature = "tracing")]
impl Trace {
///////////////////////////////////////////////////////////////
/////////////////Scalar/register Instructions//////////////////
@@ -284,7 +283,6 @@ impl Trace {
M: UpcastDestTraits<M> + MemoryStorable + FromFloat,
F: UpcastDestTraits<F> + MemoryStorable,
{
use crate::tracing::pretty_print;
let (core_indx, rd, r1, mbiw, relu, group) = data.get_core_rd_r1_mbiw_immrelu_immgroup();
let file: &mut File = self
@@ -358,8 +356,6 @@ impl Trace {
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
use crate::{tracing::pretty_print, utility::add_offset_r2};
let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) =
data.get_core_rd_r1_r2_immlen_offset();
let file: &mut File = self
@@ -990,8 +986,6 @@ impl Trace {
/////////////////////////////////////////////////////////////////
pub fn ld_impl(&mut self, cores: &mut CPU, data: InstructionData, prefix: &'static str) {
use crate::tracing::pretty_print;
let (core, rd, r1, _, imm_len, offset_select, offset_value) =
data.get_core_rd_r1_r2_immlen_offset();
let file: &mut File = self
@@ -1044,8 +1038,6 @@ impl Trace {
}
pub fn st_impl(&mut self, cores: &mut CPU, data: InstructionData, prefix: &'static str) {
use crate::tracing::pretty_print;
let (core, rd, r1, _, imm_len, offset_select, offset_value) =
data.get_core_rd_r1_r2_immlen_offset();
let file: &mut File = self
@@ -1138,7 +1130,6 @@ impl Trace {
}
fn lmv_impl (&mut self, cores: &mut CPU, data: InstructionData, prefix: &'static str) {
use crate::tracing::pretty_print;
let (core, rd, r1, _, imm_len, offset_select, offset_value) =
data.get_core_rd_r1_r2_immlen_offset();
@@ -1,6 +1,11 @@
use std::path::Path;
use pimcore::{Executable, cpu::CPU, instruction_set::{InstructionsBuilder, instruction_data::InstructionDataBuilder, isa::*}};
use pimcore::{
Executable,
cpu::crossbar::Crossbar,
instruction_set::{InstructionsBuilder, instruction_data::InstructionDataBuilder, isa::*},
memory_manager::CoreMemory,
};
fn simple_read(path: &Path) -> Vec<f32> {
if !path.exists() {
@@ -17,14 +22,12 @@ fn simple_read(path: &Path) -> Vec<f32> {
fn mvmul_f32(err: &str)
where
{
let mut cpu = CPU::new(0);
cpu.reserve_crossbar(1, 1024 * size_of::<f32>(), 1024);
let (memory, crossbars) = cpu.host().get_memory_crossbar();
let matrix = simple_read(Path::new("B.txt")) ;
crossbars.get_mut(0).unwrap().execute_store( &matrix).unwrap();
let vector = simple_read(Path::new("A.txt"));
let matrix = simple_read(Path::new("tests/B.txt"));
let mut crossbar = Crossbar::new(1024 * size_of::<f32>(), 1024, CoreMemory::new());
crossbar.execute_store(&matrix).unwrap();
let mut cpu = pimcore::cpu::CPU::new(0, vec![vec![&crossbar]]);
let (memory, _) = cpu.host().get_memory_crossbar();
let vector = simple_read(Path::new("tests/A.txt"));
memory.execute_store(0, &vector).unwrap();
let mut inst_builder = InstructionsBuilder::new();
@@ -57,7 +60,7 @@ where
.cpu_mut()
.host()
.load::<f32>(1024 * size_of::<f32>(), 1024*size_of::<f32>()).unwrap()[0].iter().zip(
simple_read(Path::new("X.txt")) ).all(|(&a,b) : (&f32, f32)| {a-b < 0.001}),
simple_read(Path::new("tests/X.txt")) ).all(|(&a,b) : (&f32, f32)| {a-b < 0.001}),
"Wrong result for {}",
err
);
@@ -69,5 +72,3 @@ fn mvmul_big_test() {
}
@@ -0,0 +1,5 @@
use pimcore::cpu::CPU;
pub fn empty_cpu(num_cores: usize) -> CPU<'static> {
CPU::new(num_cores, vec![Vec::new(); num_cores + 1])
}
@@ -1,51 +1,103 @@
use std::{fs, io::BufReader, path::Path};
use std::{
fs::{self, File},
io::BufReader,
path::{Path, PathBuf},
};
use anyhow::{Context, Result};
use pimcore::json_to_instruction::json_to_executor;
use pimcore::{
cpu::crossbar::Crossbar,
json_to_instruction::json_to_executor,
memory_manager::CoreMemory,
};
use serde_json::Value;
fn collect_json_from_subfolders<P: AsRef<Path>>(root: P) -> Result<Vec<(Value, Vec<Value>)>> {
fn collect_examples<P: AsRef<Path>>(root: P) -> Result<Vec<PathBuf>> {
let mut result = Vec::new();
for entry in fs::read_dir(root)? {
let entry = entry.context("Root not found")?;
let path = entry.path();
if path.is_dir() {
let mut cores = Vec::new();
let mut config: Option<Value> = None;
for sub_entry in fs::read_dir(&path)
.with_context(|| format!("File {} not readable", path.display()))?
{
let sub_entry =
sub_entry.with_context(|| format!("File {} not readable", path.display()))?;
let sub_path = sub_entry.path();
if sub_path.is_file()
&& sub_path.extension().and_then(|s| s.to_str()) == Some("json")
{
let file = fs::File::open(&sub_path)
.with_context(|| format!("Subpath {} not opened", sub_path.display()))?;
let reader = BufReader::new(file);
let val: Value = serde_json::from_reader(reader).with_context(|| format!(
"Serde reader fail for subpath {}",
sub_path.display()
))?;
if sub_path.file_name().unwrap() == "config.json" {
config = Some(val);
} else {
cores.push(val);
}
}
}
result.push((config.unwrap(), cores));
result.push(path);
}
}
Ok(result)
}
fn core_sort_key(path: &Path) -> i32 {
let stem = path.file_stem().unwrap().to_str().unwrap();
stem[5..].parse::<i32>().unwrap()
}
fn crossbar_sort_key(path: &Path) -> i32 {
let stem = path.file_stem().unwrap().to_str().unwrap();
stem[9..].parse::<i32>().unwrap()
}
fn load_crossbars(folder: &Path, config: &Value) -> Result<Vec<Vec<Crossbar>>> {
let xbar_size = config.get("xbar_size").unwrap().as_array().unwrap();
let rows = xbar_size[0].as_i64().unwrap() as usize;
let cols = xbar_size[1].as_i64().unwrap() as usize;
let core_cnt = config.get("core_cnt").unwrap().as_i64().unwrap() as usize;
let mut owned_crossbars = Vec::with_capacity(core_cnt + 1);
owned_crossbars.push(Vec::new());
for core_idx in 0..core_cnt {
let core_folder = folder.join(format!("core_{core_idx}"));
let mut core_crossbars = Vec::new();
if core_folder.is_dir() {
let mut paths: Vec<_> = fs::read_dir(&core_folder)?
.map(|entry| entry.map(|entry| entry.path()))
.collect::<std::io::Result<Vec<_>>>()?;
paths.sort_by_cached_key(|path| crossbar_sort_key(path));
for path in paths {
if path.extension().and_then(|ext| ext.to_str()) != Some("bin") {
continue;
}
let bytes = fs::read(&path)
.with_context(|| format!("failed to read crossbar {}", path.display()))?;
let mut crossbar = Crossbar::new(cols * 4, rows, CoreMemory::new());
crossbar.execute_store(&bytes)?;
core_crossbars.push(crossbar);
}
}
owned_crossbars.push(core_crossbars);
}
Ok(owned_crossbars)
}
#[test]
fn json_folder_tester() {
let examples = collect_json_from_subfolders("data").unwrap();
for example in examples {
let (config, cores) = example;
json_to_executor::json_to_executor(config, cores.iter()).execute();
let examples = collect_examples("tests/data").unwrap();
for folder in examples {
let config_path = folder.join("config.json");
let config_file = File::open(&config_path).unwrap();
let config: Value = serde_json::from_reader(BufReader::new(config_file)).unwrap();
let core_cnt = config.get("core_cnt").unwrap().as_i64().unwrap() as usize;
let mut core_paths: Vec<_> = fs::read_dir(&folder)
.unwrap()
.map(|entry| entry.unwrap().path())
.filter(|path| path.extension().and_then(|ext| ext.to_str()) == Some("json"))
.filter(|path| path.file_name().unwrap() != "config.json")
.collect();
core_paths.sort_by_cached_key(|path| core_sort_key(path));
assert_eq!(core_paths.len(), core_cnt);
let mut core_readers: Vec<_> = core_paths
.into_iter()
.map(|path| BufReader::new(File::open(path).unwrap()))
.collect();
let owned_crossbars = load_crossbars(&folder, &config).unwrap();
let crossbars = owned_crossbars
.iter()
.map(|core_crossbars| core_crossbars.iter().collect())
.collect();
let mut executable = json_to_executor::json_to_executor(config, &mut core_readers, crossbars);
let memory = fs::read(folder.join("memory.bin")).unwrap();
executable.cpu_mut().host().execute_store(0, &memory).unwrap();
executable.execute();
}
}
@@ -1,11 +1,17 @@
mod common;
use pimcore::{Executable, cpu::CPU, instruction_set::{InstructionType, InstructionsBuilder, instruction_data::InstructionDataBuilder, isa::*}};
use pimcore::{
Executable,
instruction_set::{
InstructionType, InstructionsBuilder, instruction_data::InstructionDataBuilder, isa::*,
},
};
#[test]
#[should_panic(expected = "Function not found for the requested size") ]
fn wrong_size_place_holder() {
let cpu = CPU::new(0);
let cpu = common::empty_cpu(0);
let mut inst_builder = InstructionsBuilder::new();
let mut idata_build = InstructionDataBuilder::new();
idata_build.set_core_indx(0).fix_core_indx();
@@ -30,7 +36,7 @@ fn wrong_size_place_holder() {
fn place_holder(inst : InstructionType) {
let mut cpu = CPU::new(0);
let mut cpu = common::empty_cpu(0);
let mut idata_build = InstructionDataBuilder::new();
idata_build.set_core_indx(0).fix_core_indx();
inst(&mut cpu, idata_build.build()).unwrap();
@@ -1,8 +1,10 @@
mod common;
use pimcore::{
Executable,
cpu::CPU,
cpu::crossbar::Crossbar,
instruction_set::{InstructionsBuilder, instruction_data::InstructionDataBuilder, isa::*},
memory_manager::{MemoryStorable, type_traits::UpcastDestTraits},
memory_manager::{CoreMemory, MemoryStorable, type_traits::UpcastDestTraits},
};
/// VVADD Test
@@ -11,7 +13,7 @@ where
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
{
let mut cpu = CPU::new(0);
let mut cpu = common::empty_cpu(0);
let buff: [F; _] = [
1.0.into(),
2.0.into(),
@@ -115,7 +117,7 @@ where
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
{
let mut cpu = CPU::new(0);
let mut cpu = common::empty_cpu(0);
let buff: [F; _] = [
1.0.into(),
2.0.into(),
@@ -219,7 +221,7 @@ where
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
{
let mut cpu = CPU::new(0);
let mut cpu = common::empty_cpu(0);
let buff: [F; _] = [
1.0.into(),
2.0.into(),
@@ -323,7 +325,7 @@ where
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
{
let mut cpu = CPU::new(0);
let mut cpu = common::empty_cpu(0);
let buff: [F; _] = [
1.0.into(),
2.0.into(),
@@ -420,7 +422,7 @@ where
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
{
let mut cpu = CPU::new(0);
let mut cpu = common::empty_cpu(0);
let buff: [F; _] = [
9.0.into(),
2.0.into(),
@@ -524,7 +526,7 @@ where
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
{
let mut cpu = CPU::new(0);
let mut cpu = common::empty_cpu(0);
let buff: [F; _] = [
9.0.into(),
2.0.into(),
@@ -562,6 +564,7 @@ where
vavg,
idata_build
.set_rdr1r2(3, 1, 1)
.set_offset_select(1)
.set_imm_len(8 * size_of::<F>() as i32)
.build(),
);
@@ -617,7 +620,7 @@ where
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
{
let mut cpu = CPU::new(0);
let mut cpu = common::empty_cpu(0);
let buff: [F; _] = [
(-9.0).into(),
2.0.into(),
@@ -717,7 +720,7 @@ where
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable + UpcastDestTraits<T>,
{
let mut cpu = CPU::new(0);
let mut cpu = common::empty_cpu(0);
let buff: [F; _] = [
0.1.into(),
0.2.into(),
@@ -819,7 +822,7 @@ where
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable + UpcastDestTraits<T>,
{
let mut cpu = CPU::new(0);
let mut cpu = common::empty_cpu(0);
let buff: [F; _] = [
0.1.into(),
0.2.into(),
@@ -923,9 +926,6 @@ where
M: From<f32> + std::fmt::Debug + PartialEq<M> + MemoryStorable,
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable + UpcastDestTraits<T>,
{
let mut cpu = CPU::new(0);
cpu.reserve_crossbar(1, 4 * size_of::<M>(), 4);
let (memory, crossbars) = cpu.host().get_memory_crossbar();
let matrix: [M; _] = [
1.0.into(),
2.0.into(),
@@ -944,7 +944,10 @@ where
15.0.into(),
16.0.into(),
];
crossbars.get_mut(0).unwrap().execute_store( &matrix).unwrap();
let mut crossbar = Crossbar::new(4 * size_of::<M>(), 4, CoreMemory::new());
crossbar.execute_store(&matrix).unwrap();
let mut cpu = pimcore::cpu::CPU::new(0, vec![vec![&crossbar]]);
let (memory, _) = cpu.host().get_memory_crossbar();
let vector: [F; _] = [
1.0.into(),
2.0.into(),
@@ -1054,5 +1057,3 @@ fn mvmul_test() {
mvmul_test_generic::<f64,f64,f64>("mvmul<f64,f64,f64>",1);
}
@@ -1,12 +1,13 @@
mod common;
use pimcore::{
Executable, CoreInstructionsBuilder,
cpu::CPU,
instruction_set::{InstructionsBuilder, instruction_data::InstructionDataBuilder, isa::*},
};
#[test]
fn ld_test() {
let mut cpu = CPU::new(1);
let mut cpu = common::empty_cpu(1);
let mut core_instruction_builder = CoreInstructionsBuilder::new(1);
let buff: [f32; _] = [
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
@@ -41,7 +42,7 @@ fn ld_test() {
#[test]
fn st_test() {
let mut cpu = CPU::new(1);
let mut cpu = common::empty_cpu(1);
let mut core_instruction_builder = CoreInstructionsBuilder::new(1);
let buff: [f32; _] = [
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
@@ -76,7 +77,7 @@ fn st_test() {
#[test]
fn lldi_test() {
let cpu = CPU::new(1);
let cpu = common::empty_cpu(1);
let mut core_instruction_builder = CoreInstructionsBuilder::new(1);
let mut inst_builder = InstructionsBuilder::new();
let mut idata_build = InstructionDataBuilder::new();
@@ -106,7 +107,7 @@ fn lldi_test() {
#[test]
fn lmv_test() {
let mut cpu = CPU::new(1);
let mut cpu = common::empty_cpu(1);
let mut core_instruction_builder = CoreInstructionsBuilder::new(1);
let buff: [f32; _] = [
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
@@ -148,7 +149,7 @@ fn lmv_test() {
#[test]
fn simple_send_recv_test() {
let mut cpu = CPU::new(2);
let mut cpu = common::empty_cpu(2);
let mut core_instruction_builder = CoreInstructionsBuilder::new(2);
let buff: [f32; _] = [
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
@@ -207,7 +208,7 @@ fn simple_send_recv_test() {
#[test]
fn multiple_send_recv_test() {
let mut cpu = CPU::new(4);
let mut cpu = common::empty_cpu(4);
let mut core_instruction_builder = CoreInstructionsBuilder::new(4);
let buff: [f32; _] = [
1.0, 1.0, 1.0, 1.0, 1.0
@@ -226,7 +227,7 @@ fn multiple_send_recv_test() {
];
cpu.core(4).execute_store(0, &buff).unwrap();
let send_inst = |cpu :&mut CPU, core_instruction_builder: &mut CoreInstructionsBuilder, inst_builder: &mut InstructionsBuilder, from : i32, to : i32| {
let send_inst = |inst_builder: &mut InstructionsBuilder, from: i32, to: i32| {
let mut idata_build = InstructionDataBuilder::new();
idata_build.set_core_indx(from).fix_core_indx();
inst_builder.make_inst(sldi, idata_build.set_rdimm(1, from*size_of::<f32>() as i32).build());
@@ -240,7 +241,7 @@ fn multiple_send_recv_test() {
);
};
let recv_inst = |cpu :&mut CPU, core_instruction_builder: &mut CoreInstructionsBuilder, mut inst_builder: &mut InstructionsBuilder, to : i32, from : i32| {
let recv_inst = |inst_builder: &mut InstructionsBuilder, to: i32, from: i32| {
let mut idata_build = InstructionDataBuilder::new();
idata_build.set_core_indx(to).fix_core_indx();
inst_builder.make_inst(sldi, idata_build.set_rdimm(1, from*size_of::<f32>() as i32).build());
@@ -257,26 +258,26 @@ fn multiple_send_recv_test() {
// 1 -> 3
send_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,1, 3);
send_inst(&mut inst_builder, 1, 3);
core_instruction_builder.set_core(1, inst_builder.build());
// 2 -> 3
// 2 <- 4
send_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,2, 3);
recv_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,2, 4);
send_inst(&mut inst_builder, 2, 3);
recv_inst(&mut inst_builder, 2, 4);
core_instruction_builder.set_core(2, inst_builder.build());
// 3 <- 2
// 3 <- 4
// 3 <- 1
recv_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,3, 2);
recv_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,3, 4);
recv_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,3, 1);
recv_inst(&mut inst_builder, 3, 2);
recv_inst(&mut inst_builder, 3, 4);
recv_inst(&mut inst_builder, 3, 1);
core_instruction_builder.set_core(3, inst_builder.build());
// 4 -> 2
// 4 -> 3
send_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,4, 2);
send_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,4, 3);
send_inst(&mut inst_builder, 4, 2);
send_inst(&mut inst_builder, 4, 3);
core_instruction_builder.set_core(4, inst_builder.build());
let mut executable = Executable::new(cpu, core_instruction_builder.build());
+1
View File
@@ -68,5 +68,6 @@ add_pim_library(OMPIMAccel
OMSpatialToPim
OMPimCommon
OMPimBufferization
OMPimStaticMemoryCoalescing
MLIRTensorInferTypeOpInterfaceImpl
)
+10 -1
View File
@@ -1,5 +1,14 @@
add_pim_library(OMPimCommon
PimCommon.cpp
IR/AddressAnalysis.cpp
IR/CoreBlockUtils.cpp
IR/EntryPointUtils.cpp
IR/ShapeUtils.cpp
IR/SubviewUtils.cpp
IR/WeightUtils.cpp
Support/DebugDump.cpp
Support/Diagnostics.cpp
Support/FileSystemUtils.cpp
Support/ReportUtils.cpp
EXCLUDE_FROM_OM_LIBS
+267
View File
@@ -0,0 +1,267 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
namespace onnx_mlir {
mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp) {
if (!moduleOp || !getGlobalOp)
return {};
return moduleOp.lookupSymbol<mlir::memref::GlobalOp>(getGlobalOp.getName());
}
namespace {
mlir::Value resolveAlias(mlir::Value value, const StaticValueKnowledge* knowledge) {
if (!knowledge)
return value;
auto iter = knowledge->aliases.find(value);
while (iter != knowledge->aliases.end()) {
value = iter->second;
iter = knowledge->aliases.find(value);
}
return value;
}
mlir::Value resolveLoopCarriedAliasImpl(mlir::Value value, const StaticValueKnowledge* knowledge) {
value = resolveAlias(value, knowledge);
if (mlir::isa<mlir::BlockArgument>(value))
return value;
mlir::Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return value;
if (auto dpsDefiningOp = mlir::dyn_cast<mlir::DestinationStyleOpInterface>(definingOp)) {
if (auto result = mlir::dyn_cast<mlir::OpResult>(value))
if (mlir::OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(result))
return resolveLoopCarriedAliasImpl(tiedOperand->get(), knowledge);
}
if (auto castOp = mlir::dyn_cast<mlir::memref::CastOp>(definingOp))
return resolveLoopCarriedAliasImpl(castOp.getSource(), knowledge);
if (auto collapseOp = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(definingOp))
return resolveLoopCarriedAliasImpl(collapseOp.getSrc(), knowledge);
if (auto expandOp = mlir::dyn_cast<mlir::memref::ExpandShapeOp>(definingOp))
return resolveLoopCarriedAliasImpl(expandOp.getSrc(), knowledge);
return value;
}
llvm::FailureOr<int64_t> resolveOpFoldResult(mlir::OpFoldResult ofr, const StaticValueKnowledge* knowledge);
llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticValueKnowledge* knowledge) {
value = resolveAlias(value, knowledge);
if (knowledge) {
auto iter = knowledge->indexValues.find(value);
if (iter != knowledge->indexValues.end())
return iter->second;
}
auto constantOp = value.getDefiningOp<mlir::arith::ConstantOp>();
if (constantOp) {
if (auto integerAttr = mlir::dyn_cast<mlir::IntegerAttr>(constantOp.getValue()))
return integerAttr.getInt();
}
mlir::Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return mlir::failure();
if (auto indexCastOp = mlir::dyn_cast<mlir::arith::IndexCastOp>(definingOp))
return resolveIndexValueImpl(indexCastOp.getIn(), knowledge);
if (auto addOp = mlir::dyn_cast<mlir::arith::AddIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(addOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(addOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs))
return mlir::failure();
return *lhs + *rhs;
}
if (auto subOp = mlir::dyn_cast<mlir::arith::SubIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(subOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(subOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs))
return mlir::failure();
return *lhs - *rhs;
}
if (auto mulOp = mlir::dyn_cast<mlir::arith::MulIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(mulOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(mulOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs))
return mlir::failure();
return *lhs * *rhs;
}
if (auto divOp = mlir::dyn_cast<mlir::arith::DivUIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(divOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(divOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs) || *rhs == 0)
return mlir::failure();
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) / static_cast<uint64_t>(*rhs));
}
if (auto minOp = mlir::dyn_cast<mlir::arith::MinUIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(minOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(minOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs))
return mlir::failure();
return static_cast<int64_t>(std::min(static_cast<uint64_t>(*lhs), static_cast<uint64_t>(*rhs)));
}
if (auto remOp = mlir::dyn_cast<mlir::arith::RemUIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(remOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(remOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs) || *rhs == 0)
return mlir::failure();
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) % static_cast<uint64_t>(*rhs));
}
return mlir::failure();
}
llvm::FailureOr<int64_t> resolveOpFoldResult(mlir::OpFoldResult ofr, const StaticValueKnowledge* knowledge) {
if (auto attr = mlir::dyn_cast<mlir::Attribute>(ofr)) {
auto integerAttr = mlir::dyn_cast<mlir::IntegerAttr>(attr);
if (!integerAttr)
return mlir::failure();
return integerAttr.getInt();
}
return resolveIndexValueImpl(mlir::cast<mlir::Value>(ofr), knowledge);
}
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddressImpl(mlir::Value value,
const StaticValueKnowledge* knowledge) {
int64_t byteOffset = 0;
value = resolveAlias(value, knowledge);
while (true) {
if (mlir::isa<mlir::BlockArgument>(value))
return ResolvedContiguousAddress {value, byteOffset};
mlir::Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return mlir::failure();
if (auto dpsDefiningOp = mlir::dyn_cast<mlir::DestinationStyleOpInterface>(definingOp)) {
mlir::OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(mlir::dyn_cast<mlir::OpResult>(value));
if (!tiedOperand)
return mlir::failure();
value = resolveAlias(tiedOperand->get(), knowledge);
continue;
}
if (auto forOp = mlir::dyn_cast<mlir::scf::ForOp>(definingOp)) {
auto result = mlir::dyn_cast<mlir::OpResult>(value);
if (!result)
return mlir::failure();
auto yieldOp = mlir::cast<mlir::scf::YieldOp>(forOp.getBody()->getTerminator());
mlir::Value yieldedValue = resolveLoopCarriedAliasImpl(yieldOp.getOperand(result.getResultNumber()), knowledge);
if (auto blockArgument = mlir::dyn_cast<mlir::BlockArgument>(yieldedValue)) {
if (blockArgument.getOwner() == forOp.getBody() && blockArgument.getArgNumber() > 0
&& static_cast<unsigned>(blockArgument.getArgNumber() - 1) < forOp.getInitArgs().size()) {
value = resolveAlias(forOp.getInitArgs()[blockArgument.getArgNumber() - 1], knowledge);
continue;
}
}
value = yieldedValue;
continue;
}
if (auto subviewOp = mlir::dyn_cast<mlir::memref::SubViewOp>(definingOp)) {
auto sourceType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getSource().getType());
auto subviewType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getType());
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
return mlir::failure();
llvm::SmallVector<int64_t> offsets;
llvm::SmallVector<int64_t> sizes;
llvm::SmallVector<int64_t> strides;
offsets.reserve(subviewOp.getMixedOffsets().size());
sizes.reserve(subviewOp.getMixedSizes().size());
strides.reserve(subviewOp.getMixedStrides().size());
for (mlir::OpFoldResult offset : subviewOp.getMixedOffsets()) {
auto resolvedOffset = resolveOpFoldResult(offset, knowledge);
if (failed(resolvedOffset))
return mlir::failure();
offsets.push_back(*resolvedOffset);
}
for (mlir::OpFoldResult size : subviewOp.getMixedSizes()) {
auto resolvedSize = resolveOpFoldResult(size, knowledge);
if (failed(resolvedSize))
return mlir::failure();
sizes.push_back(*resolvedSize);
}
for (mlir::OpFoldResult stride : subviewOp.getMixedStrides()) {
auto resolvedStride = resolveOpFoldResult(stride, knowledge);
if (failed(resolvedStride))
return mlir::failure();
strides.push_back(*resolvedStride);
}
if (!isMemoryContiguous(sourceType.getShape(), offsets, sizes, strides))
return mlir::failure();
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
byteOffset += linearizeIndex(offsets, sourceStrides) * subviewType.getElementTypeBitWidth() / 8;
value = resolveAlias(subviewOp.getSource(), knowledge);
continue;
}
if (auto castOp = mlir::dyn_cast<mlir::memref::CastOp>(definingOp)) {
value = resolveAlias(castOp.getSource(), knowledge);
continue;
}
if (auto collapseOp = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(definingOp)) {
value = resolveAlias(collapseOp.getSrc(), knowledge);
continue;
}
if (auto expandOp = mlir::dyn_cast<mlir::memref::ExpandShapeOp>(definingOp)) {
value = resolveAlias(expandOp.getSrc(), knowledge);
continue;
}
if (mlir::isa<mlir::memref::AllocOp, mlir::memref::GetGlobalOp>(definingOp))
return ResolvedContiguousAddress {value, byteOffset};
return mlir::failure();
}
}
} // namespace
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value) { return resolveIndexValueImpl(value, nullptr); }
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge) {
return resolveIndexValueImpl(value, &knowledge);
}
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value) {
return resolveContiguousAddressImpl(value, nullptr);
}
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value,
const StaticValueKnowledge& knowledge) {
return resolveContiguousAddressImpl(value, &knowledge);
}
mlir::Value resolveLoopCarriedAlias(mlir::Value value, const StaticValueKnowledge& knowledge) {
return resolveLoopCarriedAliasImpl(value, &knowledge);
}
} // namespace onnx_mlir
+43
View File
@@ -0,0 +1,43 @@
#pragma once
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/DenseMap.h"
namespace onnx_mlir {
/// Describes a value as a base addressable object plus a statically known
/// byte offset after peeling aliases, casts, and contiguous subviews.
struct ResolvedContiguousAddress {
mlir::Value base;
int64_t byteOffset = 0;
};
/// Records compile-time facts used when interpreting address arithmetic and
/// loop-carried aliases inside PIM regions.
struct StaticValueKnowledge {
llvm::DenseMap<mlir::Value, int64_t> indexValues;
llvm::DenseMap<mlir::Value, mlir::Value> aliases;
StaticValueKnowledge() {}
};
mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp);
/// Resolves a value to contiguous backing storage when that storage can be
/// proven statically from aliases, DPS ties, casts, and subviews.
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value);
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value,
const StaticValueKnowledge& knowledge);
/// Statically evaluates index-like SSA values, including simple integer
/// arithmetic and loop facts recorded in `knowledge`.
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value);
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge);
/// Follows alias, view, and DPS chains to recover the backing value of a
/// loop-carried memref/result.
mlir::Value resolveLoopCarriedAlias(mlir::Value value, const StaticValueKnowledge& knowledge);
} // namespace onnx_mlir
+745
View File
@@ -0,0 +1,745 @@
#ifndef ONNX_MLIR_PIM_COMPACT_ASM_UTILS_HPP
#define ONNX_MLIR_PIM_COMPACT_ASM_UTILS_HPP
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/LogicalResult.h"
namespace onnx_mlir {
namespace compact_asm {
using namespace mlir;
enum class ListDelimiter {
Square,
Paren
};
inline ParseResult parseOpenDelimiter(OpAsmParser& parser, ListDelimiter delimiter) {
if (delimiter == ListDelimiter::Square)
return parser.parseLSquare();
return parser.parseLParen();
}
inline ParseResult parseOptionalCloseDelimiter(OpAsmParser& parser, ListDelimiter delimiter) {
if (delimiter == ListDelimiter::Square)
return parser.parseOptionalRSquare();
return parser.parseOptionalRParen();
}
template <typename StreamT>
inline void printOpenDelimiter(StreamT& stream, ListDelimiter delimiter) {
stream << (delimiter == ListDelimiter::Square ? "[" : "(");
}
template <typename StreamT>
inline void printCloseDelimiter(StreamT& stream, ListDelimiter delimiter) {
stream << (delimiter == ListDelimiter::Square ? "]" : ")");
}
template <typename EntryT, typename ParseEntryFn>
inline ParseResult parseCompressedRepeatedList(OpAsmParser& parser,
ListDelimiter delimiter,
SmallVectorImpl<EntryT>& entries,
ParseEntryFn parseEntry) {
if (parseOpenDelimiter(parser, delimiter))
return failure();
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
return success();
while (true) {
EntryT entry;
if (parseEntry(entry))
return failure();
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t index = 0; index < repeatCount; ++index)
entries.push_back(entry);
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
return success();
if (parser.parseComma())
return failure();
}
}
template <typename IntT>
inline ParseResult
parseCompressedIntegerEntries(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl<IntT>& values) {
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
return success();
while (true) {
if (succeeded(parser.parseOptionalLParen())) {
SmallVector<IntT> subgroup;
if (parseCompressedIntegerEntries(parser, ListDelimiter::Paren, subgroup))
return failure();
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
llvm::append_range(values, subgroup);
}
else {
int64_t first = 0;
if (parser.parseInteger(first))
return failure();
if (succeeded(parser.parseOptionalKeyword("to"))) {
int64_t last = 0;
if (parser.parseInteger(last) || last < first)
return parser.emitError(parser.getCurrentLocation(), "invalid ascending range");
int64_t step = 1;
if (succeeded(parser.parseOptionalKeyword("by"))) {
if (parser.parseInteger(step) || step <= 0)
return parser.emitError(parser.getCurrentLocation(), "step after 'by' must be positive");
}
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
if ((last - first) % step != 0) {
return parser.emitError(parser.getCurrentLocation(),
"range end must be reachable from start using the given step");
}
for (int64_t value = first; value <= last; value += step)
for (int64_t index = 0; index < repeatCount; ++index)
values.push_back(static_cast<IntT>(value));
}
else {
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t index = 0; index < repeatCount; ++index)
values.push_back(static_cast<IntT>(first));
}
}
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
return success();
if (parser.parseComma())
return failure();
}
}
template <typename IntT>
inline ParseResult
parseCompressedIntegerSequence(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl<IntT>& values) {
if (parseOpenDelimiter(parser, delimiter))
return failure();
return parseCompressedIntegerEntries(parser, delimiter, values);
}
template <typename RangeT, typename PrintEntryFn>
inline void printCompressedEqualRuns(OpAsmPrinter& printer, RangeT entries, PrintEntryFn printEntry) {
for (size_t index = 0; index < entries.size();) {
size_t runEnd = index + 1;
while (runEnd < entries.size() && entries[runEnd] == entries[index])
++runEnd;
if (index != 0)
printer << ", ";
printEntry(entries[index]);
size_t runLength = runEnd - index;
if (runLength > 1)
printer << " x" << runLength;
index = runEnd;
}
}
template <typename StreamT, typename IntT>
inline void printCompressedIntegerEntries(StreamT& stream, ArrayRef<IntT> values) {
struct FlatCompression {
enum class Kind {
Single,
EqualRun,
Progression
};
Kind kind = Kind::Single;
size_t covered = 1;
size_t repeatCount = 1;
size_t progressionValueCount = 1;
int64_t step = 1;
IntT firstValue {};
IntT lastValue {};
};
auto computeFlatCompression = [&](size_t start) {
FlatCompression compression;
compression.firstValue = values[start];
compression.lastValue = values[start];
auto findEqualRunEnd = [&](size_t runStart) {
size_t runEnd = runStart + 1;
while (runEnd < values.size() && values[runEnd] == values[runStart])
++runEnd;
return runEnd;
};
size_t firstRunEnd = findEqualRunEnd(start);
compression.repeatCount = firstRunEnd - start;
size_t progressionEnd = firstRunEnd;
int64_t step = 0;
IntT lastValue = values[start];
if (firstRunEnd < values.size()) {
size_t secondRunEnd = findEqualRunEnd(firstRunEnd);
step = static_cast<int64_t>(values[firstRunEnd]) - static_cast<int64_t>(values[start]);
if (step > 0 && secondRunEnd - firstRunEnd == compression.repeatCount) {
progressionEnd = secondRunEnd;
lastValue = values[firstRunEnd];
size_t currentRunStart = secondRunEnd;
while (currentRunStart < values.size()) {
size_t currentRunEnd = findEqualRunEnd(currentRunStart);
if (currentRunEnd - currentRunStart != compression.repeatCount)
break;
if (static_cast<int64_t>(values[currentRunStart]) != static_cast<int64_t>(lastValue) + step)
break;
lastValue = values[currentRunStart];
progressionEnd = currentRunEnd;
currentRunStart = currentRunEnd;
}
}
else {
step = 0;
}
}
compression.covered = 1;
if (progressionEnd > firstRunEnd) {
size_t progressionValueCount = (progressionEnd - start) / compression.repeatCount;
if (progressionValueCount >= 3) {
compression.kind = FlatCompression::Kind::Progression;
compression.covered = progressionEnd - start;
compression.progressionValueCount = progressionValueCount;
compression.step = step;
compression.lastValue = lastValue;
return compression;
}
}
if (compression.repeatCount > 1) {
compression.kind = FlatCompression::Kind::EqualRun;
compression.covered = compression.repeatCount;
return compression;
}
return compression;
};
auto findRepeatedSublist = [&](size_t start) {
size_t bestLength = 0;
size_t bestRepeatCount = 1;
size_t remaining = values.size() - start;
for (size_t length = 2; length * 2 <= remaining; ++length) {
size_t repeatCount = 1;
ArrayRef<IntT> candidate = values.slice(start, length);
while (start + (repeatCount + 1) * length <= values.size()
&& llvm::equal(candidate, values.slice(start + repeatCount * length, length))) {
++repeatCount;
}
if (repeatCount <= 1)
continue;
size_t covered = length * repeatCount;
size_t bestCovered = bestLength * bestRepeatCount;
if (covered > bestCovered || (covered == bestCovered && length < bestLength)) {
bestLength = length;
bestRepeatCount = repeatCount;
}
}
return std::pair(bestLength, bestRepeatCount);
};
for (size_t index = 0; index < values.size();) {
if (index != 0)
stream << ", ";
FlatCompression flat = computeFlatCompression(index);
auto [sublistLength, sublistRepeatCount] = findRepeatedSublist(index);
size_t repeatedSublistCoverage = sublistLength * sublistRepeatCount;
if (sublistRepeatCount > 1 && sublistLength > 1 && repeatedSublistCoverage > flat.covered) {
printOpenDelimiter(stream, ListDelimiter::Paren);
printCompressedIntegerEntries(stream, values.slice(index, sublistLength));
printCloseDelimiter(stream, ListDelimiter::Paren);
stream << " x" << sublistRepeatCount;
index += repeatedSublistCoverage;
continue;
}
switch (flat.kind) {
case FlatCompression::Kind::Progression:
stream << flat.firstValue << " to " << flat.lastValue;
if (flat.step != 1)
stream << " by " << flat.step;
if (flat.repeatCount > 1)
stream << " x" << flat.repeatCount;
index += flat.covered;
break;
case FlatCompression::Kind::EqualRun:
stream << flat.firstValue << " x" << flat.repeatCount;
index += flat.covered;
break;
case FlatCompression::Kind::Single:
stream << flat.firstValue;
index += flat.covered;
break;
}
}
}
template <typename StreamT, typename IntT>
inline void printCompressedIntegerSequence(StreamT& stream, ArrayRef<IntT> values, ListDelimiter delimiter) {
printOpenDelimiter(stream, delimiter);
printCompressedIntegerEntries(stream, values);
printCloseDelimiter(stream, delimiter);
}
template <typename IntT>
inline ParseResult parseCompressedIntegerList(OpAsmParser& parser, SmallVectorImpl<IntT>& values) {
return parseCompressedIntegerSequence(parser, ListDelimiter::Square, values);
}
template <typename IntT>
inline void printCompressedIntegerList(OpAsmPrinter& printer, ArrayRef<IntT> values) {
printCompressedIntegerSequence(printer, values, ListDelimiter::Square);
}
inline void printCompressedValueSequence(OpAsmPrinter& printer, ValueRange values) {
for (size_t index = 0; index < values.size();) {
size_t equalRunEnd = index + 1;
while (equalRunEnd < values.size() && values[equalRunEnd] == values[index])
++equalRunEnd;
if (index != 0)
printer << ", ";
if (equalRunEnd - index > 1) {
printer.printOperand(values[index]);
printer << " x" << (equalRunEnd - index);
index = equalRunEnd;
continue;
}
size_t rangeEnd = index + 1;
if (auto firstResult = dyn_cast<OpResult>(values[index])) {
while (rangeEnd < values.size()) {
auto nextResult = dyn_cast<OpResult>(values[rangeEnd]);
if (!nextResult || nextResult.getOwner() != firstResult.getOwner()
|| nextResult.getResultNumber() != firstResult.getResultNumber() + (rangeEnd - index)) {
break;
}
++rangeEnd;
}
}
else if (auto firstArg = dyn_cast<BlockArgument>(values[index])) {
while (rangeEnd < values.size()) {
auto nextArg = dyn_cast<BlockArgument>(values[rangeEnd]);
if (!nextArg || nextArg.getOwner() != firstArg.getOwner()
|| nextArg.getArgNumber() != firstArg.getArgNumber() + (rangeEnd - index)) {
break;
}
++rangeEnd;
}
}
printer.printOperand(values[index]);
if (rangeEnd - index >= 3) {
printer << " to ";
printer.printOperand(values[rangeEnd - 1]);
}
else if (rangeEnd - index == 2) {
printer << ", ";
printer.printOperand(values[index + 1]);
}
index = rangeEnd;
}
}
inline void printCompressedTypeSequence(OpAsmPrinter& printer, TypeRange types) {
printCompressedEqualRuns(printer, types, [&](Type type) { printer.printType(type); });
}
inline void printCompressedValueList(OpAsmPrinter& printer, ValueRange values, ListDelimiter delimiter) {
printOpenDelimiter(printer, delimiter);
printCompressedValueSequence(printer, values);
printCloseDelimiter(printer, delimiter);
}
inline void printCompressedTypeList(OpAsmPrinter& printer, TypeRange types, ListDelimiter delimiter) {
printOpenDelimiter(printer, delimiter);
printCompressedTypeSequence(printer, types);
printCloseDelimiter(printer, delimiter);
}
inline ParseResult parseCompressedTypeSequence(OpAsmParser& parser, SmallVectorImpl<Type>& types, bool allowEmpty) {
Type firstType;
OptionalParseResult firstTypeResult = parser.parseOptionalType(firstType);
if (!firstTypeResult.has_value()) {
if (allowEmpty)
return success();
return parser.emitError(parser.getCurrentLocation(), "expected type");
}
if (failed(*firstTypeResult))
return failure();
auto appendType = [&](Type type) -> ParseResult {
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t index = 0; index < repeatCount; ++index)
types.push_back(type);
return success();
};
if (appendType(firstType))
return failure();
while (succeeded(parser.parseOptionalComma())) {
Type nextType;
if (parser.parseType(nextType) || appendType(nextType))
return failure();
}
return success();
}
inline ParseResult parseCompressedOperandEntryWithFirst(OpAsmParser& parser,
OpAsmParser::UnresolvedOperand firstOperand,
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
if (succeeded(parser.parseOptionalKeyword("to"))) {
OpAsmParser::UnresolvedOperand lastOperand;
if (parser.parseOperand(lastOperand))
return failure();
if (firstOperand.name != lastOperand.name || firstOperand.number > lastOperand.number)
return parser.emitError(parser.getCurrentLocation(), "invalid operand range");
for (unsigned number = firstOperand.number; number <= lastOperand.number; ++number)
operands.push_back({firstOperand.location, firstOperand.name, number});
return success();
}
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t index = 0; index < repeatCount; ++index)
operands.push_back(firstOperand);
return success();
}
inline ParseResult parseOneCompressedOperandEntry(OpAsmParser& parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
OpAsmParser::UnresolvedOperand firstOperand;
if (parser.parseOperand(firstOperand))
return failure();
return parseCompressedOperandEntryWithFirst(parser, firstOperand, operands);
}
inline ParseResult parseCompressedOperandList(OpAsmParser& parser,
ListDelimiter delimiter,
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
if (parseOpenDelimiter(parser, delimiter))
return failure();
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
return success();
while (true) {
if (parseOneCompressedOperandEntry(parser, operands))
return failure();
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
return success();
if (parser.parseComma())
return failure();
}
}
inline ParseResult parseCompressedOperandSequence(OpAsmParser& parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
if (parseOneCompressedOperandEntry(parser, operands))
return failure();
while (succeeded(parser.parseOptionalComma()))
if (parseOneCompressedOperandEntry(parser, operands))
return failure();
return success();
}
inline ParseResult parseCompressedTypeList(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl<Type>& types) {
if (parseOpenDelimiter(parser, delimiter))
return failure();
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
return success();
if (parseCompressedTypeSequence(parser, types, /*allowEmpty=*/false))
return failure();
return parseOptionalCloseDelimiter(parser, delimiter);
}
inline bool hasRepeatedTuple(ValueRange values, size_t tupleSize) {
if (tupleSize == 0 || values.empty() || values.size() % tupleSize != 0)
return false;
SmallVector<Value> valueVec(values.begin(), values.end());
ArrayRef<Value> tuple(valueVec.data(), tupleSize);
for (size_t index = tupleSize; index < values.size(); index += tupleSize)
if (!llvm::equal(tuple, ArrayRef<Value>(valueVec).slice(index, tupleSize)))
return false;
return true;
}
inline bool hasRepeatedTuple(TypeRange types, size_t tupleSize) {
if (tupleSize == 0 || types.empty() || types.size() % tupleSize != 0)
return false;
SmallVector<Type> typeVec(types.begin(), types.end());
ArrayRef<Type> tuple(typeVec.data(), tupleSize);
for (size_t index = tupleSize; index < types.size(); index += tupleSize)
if (!llvm::equal(tuple, ArrayRef<Type>(typeVec).slice(index, tupleSize)))
return false;
return true;
}
inline void printValueTupleRun(OpAsmPrinter& printer, ValueRange values, size_t tupleSize, ListDelimiter delimiter) {
printOpenDelimiter(printer, delimiter);
printOpenDelimiter(printer, ListDelimiter::Paren);
for (size_t index = 0; index < tupleSize; ++index) {
if (index != 0)
printer << ", ";
printer.printOperand(values[index]);
}
printCloseDelimiter(printer, ListDelimiter::Paren);
printer << " x" << (values.size() / tupleSize);
printCloseDelimiter(printer, delimiter);
}
inline void printTypeTupleRun(OpAsmPrinter& printer, TypeRange types, size_t tupleSize, ListDelimiter delimiter) {
printOpenDelimiter(printer, delimiter);
printOpenDelimiter(printer, ListDelimiter::Paren);
for (size_t index = 0; index < tupleSize; ++index) {
if (index != 0)
printer << ", ";
printer.printType(types[index]);
}
printCloseDelimiter(printer, ListDelimiter::Paren);
printer << " x" << (types.size() / tupleSize);
printCloseDelimiter(printer, delimiter);
}
inline ParseResult parseCompressedOrTupleOperandList(OpAsmParser& parser,
ListDelimiter delimiter,
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
if (parseOpenDelimiter(parser, delimiter))
return failure();
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
return success();
if (succeeded(parser.parseOptionalLParen())) {
SmallVector<OpAsmParser::UnresolvedOperand> tupleOperands;
if (parseCompressedOperandSequence(parser, tupleOperands) || parser.parseRParen())
return failure();
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
llvm::append_range(operands, tupleOperands);
while (succeeded(parser.parseOptionalComma())) {
if (parser.parseLParen())
return failure();
tupleOperands.clear();
if (parseCompressedOperandSequence(parser, tupleOperands) || parser.parseRParen())
return failure();
repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
llvm::append_range(operands, tupleOperands);
}
return parseOptionalCloseDelimiter(parser, delimiter);
}
while (true) {
if (parseOneCompressedOperandEntry(parser, operands))
return failure();
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
return success();
if (parser.parseComma())
return failure();
}
}
inline ParseResult
parseCompressedOrTupleTypeList(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl<Type>& types) {
if (parseOpenDelimiter(parser, delimiter))
return failure();
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
return success();
if (succeeded(parser.parseOptionalLParen())) {
SmallVector<Type> tupleTypes;
if (parseCompressedTypeSequence(parser, tupleTypes, /*allowEmpty=*/false) || parser.parseRParen())
return failure();
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
llvm::append_range(types, tupleTypes);
while (succeeded(parser.parseOptionalComma())) {
if (parser.parseLParen())
return failure();
tupleTypes.clear();
if (parseCompressedTypeSequence(parser, tupleTypes, /*allowEmpty=*/false) || parser.parseRParen())
return failure();
repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
llvm::append_range(types, tupleTypes);
}
return parseOptionalCloseDelimiter(parser, delimiter);
}
while (true) {
Type type;
if (parser.parseType(type))
return failure();
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
types.push_back(type);
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
return success();
if (parser.parseComma())
return failure();
}
}
inline void printArgumentBindings(OpAsmPrinter& printer, Block& block, ValueRange operands) {
if (block.getNumArguments() == 0) {
printer << "() = ()";
return;
}
if (block.getNumArguments() == 1) {
printer.printOperand(block.getArgument(0));
printer << " = ";
printCompressedValueList(printer, operands, ListDelimiter::Paren);
return;
}
printCompressedValueList(printer, ValueRange(block.getArguments()), ListDelimiter::Paren);
printer << " = ";
printCompressedValueList(printer, operands, ListDelimiter::Paren);
}
inline ParseResult parseCompressedArgumentEntryWithFirst(OpAsmParser& parser,
OpAsmParser::Argument firstArgument,
SmallVectorImpl<OpAsmParser::Argument>& arguments) {
if (succeeded(parser.parseOptionalKeyword("to"))) {
OpAsmParser::Argument lastArgument;
if (parser.parseArgument(lastArgument))
return failure();
if (firstArgument.ssaName.name != lastArgument.ssaName.name
|| firstArgument.ssaName.number > lastArgument.ssaName.number) {
return parser.emitError(parser.getCurrentLocation(), "invalid argument range");
}
for (unsigned number = firstArgument.ssaName.number; number <= lastArgument.ssaName.number; ++number) {
OpAsmParser::Argument argument;
argument.ssaName = {firstArgument.ssaName.location, firstArgument.ssaName.name, number};
arguments.push_back(argument);
}
return success();
}
arguments.push_back(firstArgument);
return success();
}
inline ParseResult parseOneCompressedArgumentEntry(OpAsmParser& parser,
SmallVectorImpl<OpAsmParser::Argument>& arguments) {
OpAsmParser::Argument firstArgument;
if (parser.parseArgument(firstArgument))
return failure();
return parseCompressedArgumentEntryWithFirst(parser, firstArgument, arguments);
}
inline void applyArgumentTypes(ArrayRef<Type> inputTypes, SmallVectorImpl<OpAsmParser::Argument>& arguments) {
for (auto [argument, inputType] : llvm::zip_equal(arguments, inputTypes))
argument.type = inputType;
}
inline ParseResult parseArgumentBindings(OpAsmParser& parser,
SmallVectorImpl<OpAsmParser::Argument>& arguments,
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
if (succeeded(parser.parseOptionalLParen())) {
if (succeeded(parser.parseOptionalRParen())) {
if (parser.parseEqual() || parseCompressedOperandList(parser, ListDelimiter::Paren, operands))
return failure();
return success();
}
OpAsmParser::Argument firstArgument;
if (parser.parseArgument(firstArgument) || parseCompressedArgumentEntryWithFirst(parser, firstArgument, arguments))
return failure();
while (succeeded(parser.parseOptionalComma()))
if (parseOneCompressedArgumentEntry(parser, arguments))
return failure();
if (parser.parseRParen() || parser.parseEqual()
|| parseCompressedOperandList(parser, ListDelimiter::Paren, operands)) {
return failure();
}
return success();
}
OpAsmParser::Argument argument;
if (parser.parseArgument(argument) || parser.parseEqual()
|| parseCompressedOperandList(parser, ListDelimiter::Paren, operands)) {
return failure();
}
arguments.push_back(argument);
return success();
}
} // namespace compact_asm
} // namespace onnx_mlir
#endif
+68
View File
@@ -0,0 +1,68 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
namespace onnx_mlir {
bool isCoreStaticAddressOp(mlir::Operation* op) {
return mlir::isa<mlir::arith::ConstantOp,
mlir::arith::AddIOp,
mlir::arith::SubIOp,
mlir::arith::MulIOp,
mlir::arith::DivUIOp,
mlir::arith::MinUIOp,
mlir::arith::RemUIOp,
mlir::arith::IndexCastOp,
mlir::memref::AllocOp,
mlir::memref::SubViewOp,
mlir::memref::CastOp,
mlir::memref::CollapseShapeOp,
mlir::memref::ExpandShapeOp>(op);
}
mlir::LogicalResult
walkPimCoreBlock(mlir::Block& block,
const StaticValueKnowledge& knowledge,
llvm::function_ref<mlir::LogicalResult(mlir::Operation&, const StaticValueKnowledge&)> callback) {
bool hasFailure = false;
for (mlir::Operation& op : block) {
if (mlir::isa<pim::PimHaltOp, mlir::scf::YieldOp>(op) || isCoreStaticAddressOp(&op))
continue;
if (auto forOp = mlir::dyn_cast<mlir::scf::ForOp>(op)) {
mlir::Block& loopBody = forOp.getRegion().front();
auto lowerBound = resolveIndexValue(forOp.getLowerBound(), knowledge);
auto upperBound = resolveIndexValue(forOp.getUpperBound(), knowledge);
auto step = resolveIndexValue(forOp.getStep(), knowledge);
if (failed(lowerBound) || failed(upperBound) || failed(step) || *step <= 0) {
forOp.emitOpError("requires statically evaluable scf.for bounds for PIM codegen");
hasFailure = true;
continue;
}
llvm::SmallVector<mlir::Value> iterValues(forOp.getInitArgs().begin(), forOp.getInitArgs().end());
for (int64_t inductionValue = *lowerBound; inductionValue < *upperBound; inductionValue += *step) {
StaticValueKnowledge loopKnowledge = knowledge;
loopKnowledge.indexValues[forOp.getInductionVar()] = inductionValue;
for (auto [iterArg, iterValue] : llvm::zip_equal(forOp.getRegionIterArgs(), iterValues))
loopKnowledge.aliases[iterArg] = iterValue;
if (failed(walkPimCoreBlock(loopBody, loopKnowledge, callback)))
hasFailure = true;
auto yieldOp = mlir::cast<mlir::scf::YieldOp>(loopBody.getTerminator());
for (auto [index, yieldedValue] : llvm::enumerate(yieldOp.getOperands()))
iterValues[index] = resolveLoopCarriedAlias(yieldedValue, loopKnowledge);
}
continue;
}
if (failed(callback(op, knowledge)))
hasFailure = true;
}
return mlir::success(!hasFailure);
}
} // namespace onnx_mlir
+24
View File
@@ -0,0 +1,24 @@
#pragma once
#include "mlir/IR/Block.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/STLFunctionalExtras.h"
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
namespace onnx_mlir {
/// Returns true for ops in a `pim.core` body that only participate in static
/// address or index computation and therefore do not emit PIM instructions.
bool isCoreStaticAddressOp(mlir::Operation* op);
/// Walks a `pim.core` body, statically unrolling nested `scf.for` loops when
/// their bounds are known and invoking `callback` only on instruction-emitting
/// operations.
mlir::LogicalResult
walkPimCoreBlock(mlir::Block& block,
const StaticValueKnowledge& knowledge,
llvm::function_ref<mlir::LogicalResult(mlir::Operation&, const StaticValueKnowledge&)> callback);
} // namespace onnx_mlir
+45
View File
@@ -0,0 +1,45 @@
#include "src/Accelerators/PIM/Common/IR/EntryPointUtils.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
namespace onnx_mlir {
llvm::FailureOr<mlir::func::FuncOp> getPimEntryFunc(mlir::ModuleOp moduleOp) {
if (!moduleOp)
return mlir::failure();
llvm::SmallVector<mlir::ONNXEntryPointOp> entryPoints(moduleOp.getOps<mlir::ONNXEntryPointOp>());
if (entryPoints.size() > 1) {
moduleOp.emitError("PIM pipeline requires a single ONNX entry point, but found ") << entryPoints.size();
return mlir::failure();
}
if (!entryPoints.empty()) {
auto entryPointAttr =
entryPoints.front()->getAttrOfType<mlir::SymbolRefAttr>(mlir::ONNXEntryPointOp::getEntryPointFuncAttrName());
if (!entryPointAttr) {
entryPoints.front().emitOpError("is missing the entry point function attribute");
return mlir::failure();
}
auto entryFunc = moduleOp.lookupSymbol<mlir::func::FuncOp>(entryPointAttr.getLeafReference().getValue());
if (!entryFunc) {
entryPoints.front().emitOpError("references an unknown entry function ")
<< entryPointAttr.getLeafReference().getValue();
return mlir::failure();
}
return entryFunc;
}
if (auto mainGraphFunc = moduleOp.lookupSymbol<mlir::func::FuncOp>("main_graph"))
return mainGraphFunc;
llvm::SmallVector<mlir::func::FuncOp> nonExternalFuncs;
for (auto funcOp : moduleOp.getOps<mlir::func::FuncOp>())
if (!funcOp.isExternal())
nonExternalFuncs.push_back(funcOp);
if (nonExternalFuncs.size() == 1)
return nonExternalFuncs.front();
moduleOp.emitError("could not resolve a unique PIM entry function");
return mlir::failure();
}
} // namespace onnx_mlir
+13
View File
@@ -0,0 +1,13 @@
#pragma once
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
namespace onnx_mlir {
/// Resolves the function the PIM pipeline should treat as its entry point.
/// Prefers ONNX entry-point metadata, then `main_graph`, then the only
/// non-external function if the module is otherwise unambiguous.
llvm::FailureOr<mlir::func::FuncOp> getPimEntryFunc(mlir::ModuleOp moduleOp);
} // namespace onnx_mlir
+89
View File
@@ -0,0 +1,89 @@
#include "llvm/ADT/STLExtras.h"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
namespace onnx_mlir {
llvm::SmallVector<int64_t> computeRowMajorStrides(llvm::ArrayRef<int64_t> shape) {
llvm::SmallVector<int64_t> strides(shape.size(), 1);
for (int64_t dim = static_cast<int64_t>(shape.size()) - 2; dim >= 0; --dim)
strides[dim] = strides[dim + 1] * shape[dim + 1];
return strides;
}
llvm::SmallVector<int64_t>
delinearizeIndex(int64_t linearIndex, llvm::ArrayRef<int64_t> shape, llvm::ArrayRef<int64_t> strides) {
llvm::SmallVector<int64_t> indices(shape.size(), 0);
for (auto [dim, stride] : llvm::enumerate(strides)) {
indices[dim] = linearIndex / stride;
linearIndex %= stride;
}
return indices;
}
int64_t linearizeIndex(llvm::ArrayRef<int64_t> indices, llvm::ArrayRef<int64_t> strides) {
int64_t linearIndex = 0;
for (auto [index, stride] : llvm::zip_equal(indices, strides))
linearIndex += index * stride;
return linearIndex;
}
int64_t getNumElements(llvm::ArrayRef<int64_t> shape) {
int64_t numElements = 1;
for (int64_t dim : shape)
numElements *= dim;
return numElements;
}
bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
llvm::ArrayRef<int64_t> offsets,
llvm::ArrayRef<int64_t> sizes,
llvm::ArrayRef<int64_t> strides) {
if (std::any_of(strides.begin(), strides.end(), [](int64_t stride) -> bool { return stride != 1; }))
return false;
auto offsetsAndSizesAndShape = llvm::zip_equal(llvm::make_range(offsets.rbegin(), offsets.rend()),
llvm::make_range(sizes.rbegin(), sizes.rend()),
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
auto firstNonZeroOffset = std::find_if(
offsetsAndSizesAndShape.begin(), offsetsAndSizesAndShape.end(), [&](auto offsetAndSizeAndShape) -> bool {
auto [offset, _size, _dimension] = offsetAndSizeAndShape;
return offset != 0;
});
if (firstNonZeroOffset != offsetsAndSizesAndShape.end()) {
auto [offset, size, dimension] = *firstNonZeroOffset;
if (size > dimension - offset)
return false;
++firstNonZeroOffset;
if (std::any_of(firstNonZeroOffset, offsetsAndSizesAndShape.end(), [](auto offsetAndSizeAndShape) -> bool {
auto [_offset, size, _dimension] = offsetAndSizeAndShape;
return size != 1;
}))
return false;
}
auto sizesAndShape = llvm::zip_equal(llvm::make_range(sizes.rbegin(), sizes.rend()),
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
auto firstDifferentSize = std::find_if(sizesAndShape.begin(), sizesAndShape.end(), [&](auto sizeAndShape) -> bool {
auto [size, dimension] = sizeAndShape;
return size != dimension;
});
if (firstDifferentSize != sizesAndShape.end()) {
++firstDifferentSize;
if (std::any_of(firstDifferentSize, sizesAndShape.end(), [](auto sizeAndShape) -> bool {
auto [size, _dimension] = sizeAndShape;
return size != 1;
}))
return false;
}
return true;
}
} // namespace onnx_mlir
+22
View File
@@ -0,0 +1,22 @@
#pragma once
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
namespace onnx_mlir {
llvm::SmallVector<int64_t> computeRowMajorStrides(llvm::ArrayRef<int64_t> shape);
llvm::SmallVector<int64_t>
delinearizeIndex(int64_t linearIndex, llvm::ArrayRef<int64_t> shape, llvm::ArrayRef<int64_t> strides);
int64_t linearizeIndex(llvm::ArrayRef<int64_t> indices, llvm::ArrayRef<int64_t> strides);
int64_t getNumElements(llvm::ArrayRef<int64_t> shape);
bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
llvm::ArrayRef<int64_t> offsets,
llvm::ArrayRef<int64_t> sizes,
llvm::ArrayRef<int64_t> strides);
} // namespace onnx_mlir
+84
View File
@@ -0,0 +1,84 @@
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
using namespace mlir;
namespace onnx_mlir {
Value stripMemRefCasts(Value value) {
while (auto castOp = value.getDefiningOp<memref::CastOp>())
value = castOp.getSource();
return value;
}
Value stripMemRefViewOps(Value value) {
while (true) {
if (auto castOp = value.getDefiningOp<memref::CastOp>()) {
value = castOp.getSource();
continue;
}
if (auto collapseOp = value.getDefiningOp<memref::CollapseShapeOp>()) {
value = collapseOp.getSrc();
continue;
}
if (auto expandOp = value.getDefiningOp<memref::ExpandShapeOp>()) {
value = expandOp.getSrc();
continue;
}
return value;
}
}
bool hasAllStaticSubviewParts(memref::SubViewOp subview) {
return llvm::all_of(subview.getStaticOffsets(), [](int64_t value) { return !ShapedType::isDynamic(value); })
&& llvm::all_of(subview.getStaticSizes(), [](int64_t value) { return !ShapedType::isDynamic(value); })
&& llvm::all_of(subview.getStaticStrides(), [](int64_t value) { return !ShapedType::isDynamic(value); });
}
FailureOr<StaticSubviewInfo> getStaticSubviewInfo(Value value) {
value = stripMemRefViewOps(value);
auto subviewOp = value.getDefiningOp<memref::SubViewOp>();
if (!subviewOp)
return failure();
auto source = stripMemRefCasts(subviewOp.getSource());
auto sourceType = dyn_cast<MemRefType>(source.getType());
auto subviewType = dyn_cast<MemRefType>(subviewOp.getType());
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
return failure();
StaticSubviewInfo info;
info.source = source;
info.sourceShape.assign(sourceType.getShape().begin(), sourceType.getShape().end());
SmallVector<OpFoldResult> mixedOffsets = subviewOp.getMixedOffsets();
info.offsets.assign(mixedOffsets.begin(), mixedOffsets.end());
for (OpFoldResult size : subviewOp.getMixedSizes()) {
auto staticSize = getConstantIntValue(size);
if (!staticSize)
return failure();
info.sizes.push_back(*staticSize);
}
for (OpFoldResult stride : subviewOp.getMixedStrides()) {
auto staticStride = getConstantIntValue(stride);
if (!staticStride)
return failure();
info.strides.push_back(*staticStride);
}
return info;
}
FailureOr<SmallVector<int64_t>> getStaticSubviewOffsets(const StaticSubviewInfo& info) {
SmallVector<int64_t> staticOffsets;
staticOffsets.reserve(info.offsets.size());
for (OpFoldResult offset : info.offsets) {
auto staticOffset = getConstantIntValue(offset);
if (!staticOffset)
return failure();
staticOffsets.push_back(*staticOffset);
}
return staticOffsets;
}
} // namespace onnx_mlir
+30
View File
@@ -0,0 +1,30 @@
#pragma once
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/LogicalResult.h"
namespace onnx_mlir {
struct StaticSubviewInfo {
mlir::Value source;
llvm::SmallVector<int64_t> sourceShape;
llvm::SmallVector<mlir::OpFoldResult> offsets;
llvm::SmallVector<int64_t> sizes;
llvm::SmallVector<int64_t> strides;
};
mlir::Value stripMemRefCasts(mlir::Value value);
mlir::Value stripMemRefViewOps(mlir::Value value);
bool hasAllStaticSubviewParts(mlir::memref::SubViewOp subview);
llvm::FailureOr<StaticSubviewInfo> getStaticSubviewInfo(mlir::Value value);
/// Returns the offsets in `info` as int64_t, failing if any offset is dynamic.
llvm::FailureOr<llvm::SmallVector<int64_t>> getStaticSubviewOffsets(const StaticSubviewInfo& info);
} // namespace onnx_mlir
+108
View File
@@ -0,0 +1,108 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallSet.h"
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
namespace onnx_mlir {
bool hasWeightAlways(mlir::Operation* op) { return op && op->getAttr(PimWeightAlwaysAttrName) != nullptr; }
void markWeightAlways(mlir::Operation* op) {
assert(op && "expected valid op");
op->setAttr(PimWeightAlwaysAttrName, mlir::UnitAttr::get(op->getContext()));
}
namespace {
template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy>
bool hasMvmVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) {
bool found = false;
parentOp.walk([&](mlir::Operation* op) {
if (auto mvmOp = mlir::dyn_cast<MVMOpTy>(op))
found |= mvmOp.getWeightIndex() == weightIndex;
else if (auto vmmOp = mlir::dyn_cast<VMMOpTy>(op))
found |= vmmOp.getWeightIndex() == weightIndex;
});
return found;
}
template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy>
void walkMvmVmmWeightUses(ParentOpTy parentOp, llvm::function_ref<void(mlir::OpOperand&)> callback) {
auto weights = parentOp.getWeights();
llvm::SmallSet<unsigned, 8> visited;
auto walkWeightIndex = [&](unsigned weightIndex) {
if (weightIndex < weights.size() && visited.insert(weightIndex).second)
callback(parentOp->getOpOperand(weightIndex));
};
parentOp.walk([&](MVMOpTy op) { walkWeightIndex(op.getWeightIndex()); });
parentOp.walk([&](VMMOpTy op) { walkWeightIndex(op.getWeightIndex()); });
}
} // namespace
bool isSpatialMvmVmmWeightUse(mlir::OpOperand& use) {
mlir::Operation* user = use.getOwner();
unsigned operandIndex = use.getOperandNumber();
auto computeOp = mlir::dyn_cast<spatial::SpatCompute>(user);
if (!computeOp || operandIndex >= computeOp.getWeights().size())
return false;
return hasMvmVmmWeightUse<spatial::SpatMVMOp, spatial::SpatVMMOp>(computeOp, operandIndex);
}
bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value) {
llvm::SmallPtrSet<mlir::Value, 8> visited;
auto walkUses = [&](mlir::Value currentValue, auto& self) -> bool {
if (!visited.insert(currentValue).second)
return true;
if (currentValue.use_empty())
return false;
return llvm::all_of(currentValue.getUses(), [&](mlir::OpOperand& use) {
if (isSpatialMvmVmmWeightUse(use))
return true;
mlir::Operation* user = use.getOwner();
if (auto extractSliceOp = mlir::dyn_cast<mlir::tensor::ExtractSliceOp>(user))
return extractSliceOp.getSource() == currentValue && self(extractSliceOp.getResult(), self);
if (auto expandShapeOp = mlir::dyn_cast<mlir::tensor::ExpandShapeOp>(user))
return expandShapeOp.getSrc() == currentValue && self(expandShapeOp.getResult(), self);
if (auto collapseShapeOp = mlir::dyn_cast<mlir::tensor::CollapseShapeOp>(user))
return collapseShapeOp.getSrc() == currentValue && self(collapseShapeOp.getResult(), self);
if (auto transposeOp = mlir::dyn_cast<mlir::ONNXTransposeOp>(user))
return transposeOp.getData() == currentValue && self(transposeOp.getResult(), self);
return false;
});
};
return walkUses(value, walkUses);
}
void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir::OpOperand&)> callback) {
assert(root && "expected valid root op");
root->walk([&](pim::PimCoreOp coreOp) {
coreOp.walk([&](pim::PimVMMOp vmmOp) {
auto weights = coreOp.getWeights();
unsigned weightIndex = vmmOp.getWeightIndex();
if (weightIndex < weights.size())
callback(coreOp->getOpOperand(weightIndex));
});
});
root->walk([&](pim::PimCoreBatchOp coreBatchOp) {
auto weights = coreBatchOp.getWeights();
for (auto weight : weights)
for (mlir::OpOperand& use : weight.getUses())
if (use.getOwner() == coreBatchOp.getOperation())
callback(use);
});
}
} // namespace onnx_mlir
+29
View File
@@ -0,0 +1,29 @@
#pragma once
#include "mlir/IR/Operation.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/STLFunctionalExtras.h"
#include "llvm/ADT/StringRef.h"
inline constexpr llvm::StringRef PimWeightAlwaysAttrName = "weightAlways";
namespace onnx_mlir {
bool hasWeightAlways(mlir::Operation* op);
/// Tags an op as producing a value that should stay materialized as a reusable
/// weight across later PIM lowering/codegen stages.
void markWeightAlways(mlir::Operation* op);
bool isSpatialMvmVmmWeightUse(mlir::OpOperand& use);
/// Returns true when a value flows only into Spatial weighted MVM/VMM operands,
/// allowing later passes to preserve it as a dedicated weight-like object.
bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value);
/// Visits weight operands consumed by Pim core ops/core batches so downstream
/// passes can identify globals that must remain weight-backed.
void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir::OpOperand&)> callback);
} // namespace onnx_mlir
-546
View File
@@ -1,546 +0,0 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "llvm/Support/raw_os_ostream.h"
#include <filesystem>
#include <fstream>
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Compiler/CompilerOptions.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
std::string getOutputDir() {
if (outputBaseName.empty() || outputBaseName == "-")
return {};
size_t lastSlash = outputBaseName.find_last_of('/');
if (lastSlash == std::string::npos)
return ".";
return outputBaseName.substr(0, lastSlash);
}
void createDirectory(const std::string& directory) {
std::error_code errorCode;
std::filesystem::create_directories(directory, errorCode);
assert(!errorCode && ("Failed to create directory: " + errorCode.message()).data());
}
void dumpModule(ModuleOp moduleOp, const std::string& name) {
std::string outputDir = getOutputDir();
if (outputDir.empty())
return;
std::string dialectsDir = outputDir + "/dialects";
createDirectory(dialectsDir);
std::fstream file(dialectsDir + "/" + name + ".mlir", std::ios::out);
llvm::raw_os_ostream os(file);
os << *moduleOp;
os.flush();
file.close();
}
FailureOr<func::FuncOp> getPimEntryFunc(ModuleOp moduleOp) {
if (!moduleOp)
return failure();
SmallVector<ONNXEntryPointOp> entryPoints(moduleOp.getOps<ONNXEntryPointOp>());
if (entryPoints.size() > 1) {
moduleOp.emitError("PIM pipeline requires a single ONNX entry point, but found ") << entryPoints.size();
return failure();
}
if (!entryPoints.empty()) {
auto entryPointAttr =
entryPoints.front()->getAttrOfType<SymbolRefAttr>(ONNXEntryPointOp::getEntryPointFuncAttrName());
if (!entryPointAttr) {
entryPoints.front().emitOpError("is missing the entry point function attribute");
return failure();
}
auto entryFunc = moduleOp.lookupSymbol<func::FuncOp>(entryPointAttr.getLeafReference().getValue());
if (!entryFunc) {
entryPoints.front().emitOpError("references an unknown entry function ")
<< entryPointAttr.getLeafReference().getValue();
return failure();
}
return entryFunc;
}
if (auto mainGraphFunc = moduleOp.lookupSymbol<func::FuncOp>("main_graph"))
return mainGraphFunc;
SmallVector<func::FuncOp> nonExternalFuncs;
for (auto funcOp : moduleOp.getOps<func::FuncOp>())
if (!funcOp.isExternal())
nonExternalFuncs.push_back(funcOp);
if (nonExternalFuncs.size() == 1)
return nonExternalFuncs.front();
moduleOp.emitError("could not resolve a unique PIM entry function");
return failure();
}
bool hasWeightAlways(Operation* op) { return op && op->getAttr(PimWeightAlwaysAttrName) != nullptr; }
void markWeightAlways(Operation* op) {
assert(op && "expected valid op");
op->setAttr(PimWeightAlwaysAttrName, UnitAttr::get(op->getContext()));
}
memref::GlobalOp lookupGlobalForGetGlobal(ModuleOp moduleOp, memref::GetGlobalOp getGlobalOp) {
if (!moduleOp || !getGlobalOp)
return {};
return moduleOp.lookupSymbol<memref::GlobalOp>(getGlobalOp.getName());
}
FailureOr<Operation*> getOtherEndOfChannel(Operation* op, bool opIsReceive, RewriterBase& rewriter) {
auto channelNewOp = op->getOperand(0).getDefiningOp<spatial::SpatChannelNewOp>();
if (!channelNewOp) {
op->emitError("User of Channel must have the first operand created by ChannelNewOp.");
return failure();
}
// channelNewOp should have two users: `op` and a
// `ChannelSendOp`/`ChannelReceiveOp`
auto channelUsers = channelNewOp->getUsers();
auto usersIterator = channelUsers.begin();
auto firstUser = *usersIterator;
usersIterator++;
if (usersIterator == channelUsers.end()) {
op->emitError("Operand generated by ChannelNewOp must have two users, "
"only one found.");
channelNewOp->dump();
op->dump();
channelNewOp->getParentOp()->dump();
return failure();
}
auto secondUser = *usersIterator;
usersIterator++;
if (usersIterator != channelUsers.end()) {
op->emitError("Operand generated by ChannelNewOp must have two users, "
"more than two found.");
return failure();
}
Operation* notOpUser;
if (firstUser == op) {
notOpUser = secondUser;
}
else if (secondUser == op) {
notOpUser = firstUser;
}
else {
op->emitError("Operand generated by ChannelNewOp must have two users, "
"and one of them must be me, but"
"none of them is actually me.");
return failure();
}
if (opIsReceive) {
if (!isa<spatial::SpatChannelSendOp>(notOpUser)) {
op->emitError("Operand generated by ChannelNewOp has two user, one is "
"me, the other is not a ChannelSendOp.");
return failure();
}
return notOpUser;
}
else {
if (!isa<spatial::SpatChannelReceiveOp>(notOpUser)) {
op->emitError("Operand generated by ChannelNewOp has two user, one is "
"me, the other is not a ChannelReceiveOp.");
return failure();
}
return notOpUser;
}
}
SmallVector<int64_t> computeRowMajorStrides(ArrayRef<int64_t> shape) {
SmallVector<int64_t> strides(shape.size(), 1);
for (int64_t dim = static_cast<int64_t>(shape.size()) - 2; dim >= 0; --dim)
strides[dim] = strides[dim + 1] * shape[dim + 1];
return strides;
}
SmallVector<int64_t> delinearizeIndex(int64_t linearIndex, ArrayRef<int64_t> shape, ArrayRef<int64_t> strides) {
SmallVector<int64_t> indices(shape.size(), 0);
for (auto [dim, stride] : llvm::enumerate(strides)) {
indices[dim] = linearIndex / stride;
linearIndex %= stride;
}
return indices;
}
int64_t linearizeIndex(ArrayRef<int64_t> indices, ArrayRef<int64_t> strides) {
int64_t linearIndex = 0;
for (auto [index, stride] : llvm::zip_equal(indices, strides))
linearIndex += index * stride;
return linearIndex;
}
int64_t getNumElements(ArrayRef<int64_t> shape) {
int64_t numElements = 1;
for (int64_t dim : shape)
numElements *= dim;
return numElements;
}
bool isMemoryContiguous(ArrayRef<int64_t> srcShape,
ArrayRef<int64_t> offsets,
ArrayRef<int64_t> sizes,
ArrayRef<int64_t> strides) {
if (std::any_of(strides.begin(), strides.end(), [](int64_t stride) -> bool { return stride != 1; }))
return false;
auto offsetsAndSizesAndShape = llvm::zip_equal(llvm::make_range(offsets.rbegin(), offsets.rend()),
llvm::make_range(sizes.rbegin(), sizes.rend()),
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
auto firstNonZeroOffset = std::find_if(
offsetsAndSizesAndShape.begin(), offsetsAndSizesAndShape.end(), [&](auto offsetAndSizeAndShape) -> bool {
auto [offset, _size, _dimension] = offsetAndSizeAndShape;
return offset != 0;
});
if (firstNonZeroOffset != offsetsAndSizesAndShape.end()) {
auto [offset, size, dimension] = *firstNonZeroOffset;
if (size > dimension - offset)
return false;
++firstNonZeroOffset;
if (std::any_of(firstNonZeroOffset, offsetsAndSizesAndShape.end(), [](auto offsetAndSizeAndShape) -> bool {
auto [_offset, size, _dimension] = offsetAndSizeAndShape;
return size != 1;
}))
return false;
}
auto sizesAndShape = llvm::zip_equal(llvm::make_range(sizes.rbegin(), sizes.rend()),
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
auto firstDifferentSize = std::find_if(sizesAndShape.begin(), sizesAndShape.end(), [&](auto sizeAndShape) -> bool {
auto [size, dimension] = sizeAndShape;
return size != dimension;
});
if (firstDifferentSize != sizesAndShape.end()) {
++firstDifferentSize;
if (std::any_of(firstDifferentSize, sizesAndShape.end(), [](auto sizeAndShape) -> bool {
auto [size, _dimension] = sizeAndShape;
return size != 1;
}))
return false;
}
return true;
}
static Value resolveAlias(Value value, const StaticValueKnowledge* knowledge) {
if (!knowledge)
return value;
auto iter = knowledge->aliases.find(value);
while (iter != knowledge->aliases.end()) {
value = iter->second;
iter = knowledge->aliases.find(value);
}
return value;
}
// Walks through view-like ops and DPS tied operands to find the "underlying" memref value
// behind an scf.for iter-arg. Used both when resolving a contiguous address inside a loop
// and when propagating yielded values across iterations during static unrolling.
static Value resolveLoopCarriedAliasImpl(Value value, const StaticValueKnowledge* knowledge) {
value = resolveAlias(value, knowledge);
if (auto blockArgument = dyn_cast<BlockArgument>(value))
return value;
Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return value;
if (auto dpsDefiningOp = dyn_cast<DestinationStyleOpInterface>(definingOp)) {
if (auto result = dyn_cast<OpResult>(value))
if (OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(result))
return resolveLoopCarriedAliasImpl(tiedOperand->get(), knowledge);
}
if (auto castOp = dyn_cast<memref::CastOp>(definingOp))
return resolveLoopCarriedAliasImpl(castOp.getSource(), knowledge);
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp))
return resolveLoopCarriedAliasImpl(collapseOp.getSrc(), knowledge);
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp))
return resolveLoopCarriedAliasImpl(expandOp.getSrc(), knowledge);
return value;
}
static FailureOr<int64_t> resolveOpFoldResult(OpFoldResult ofr, const StaticValueKnowledge* knowledge);
static FailureOr<int64_t> resolveIndexValueImpl(Value value, const StaticValueKnowledge* knowledge) {
value = resolveAlias(value, knowledge);
if (knowledge) {
auto iter = knowledge->indexValues.find(value);
if (iter != knowledge->indexValues.end())
return iter->second;
}
auto constantOp = value.getDefiningOp<arith::ConstantOp>();
if (constantOp) {
if (auto integerAttr = dyn_cast<IntegerAttr>(constantOp.getValue()))
return integerAttr.getInt();
}
Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return failure();
if (auto indexCastOp = dyn_cast<arith::IndexCastOp>(definingOp))
return resolveIndexValueImpl(indexCastOp.getIn(), knowledge);
if (auto addOp = dyn_cast<arith::AddIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(addOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(addOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs))
return failure();
return *lhs + *rhs;
}
if (auto subOp = dyn_cast<arith::SubIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(subOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(subOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs))
return failure();
return *lhs - *rhs;
}
if (auto mulOp = dyn_cast<arith::MulIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(mulOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(mulOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs))
return failure();
return *lhs * *rhs;
}
if (auto divOp = dyn_cast<arith::DivUIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(divOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(divOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs) || *rhs == 0)
return failure();
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) / static_cast<uint64_t>(*rhs));
}
if (auto remOp = dyn_cast<arith::RemUIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(remOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(remOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs) || *rhs == 0)
return failure();
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) % static_cast<uint64_t>(*rhs));
}
return failure();
}
static FailureOr<int64_t> resolveOpFoldResult(OpFoldResult ofr, const StaticValueKnowledge* knowledge) {
if (auto attr = dyn_cast<Attribute>(ofr)) {
auto integerAttr = dyn_cast<IntegerAttr>(attr);
if (!integerAttr)
return failure();
return integerAttr.getInt();
}
return resolveIndexValueImpl(cast<Value>(ofr), knowledge);
}
static FailureOr<ResolvedContiguousAddress> resolveContiguousAddressImpl(Value value,
const StaticValueKnowledge* knowledge) {
int64_t byteOffset = 0;
value = resolveAlias(value, knowledge);
while (true) {
if (isa<BlockArgument>(value))
return ResolvedContiguousAddress {value, byteOffset};
Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return failure();
if (auto dpsDefiningOp = dyn_cast<DestinationStyleOpInterface>(definingOp)) {
OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(dyn_cast<OpResult>(value));
if (!tiedOperand)
return failure();
value = resolveAlias(tiedOperand->get(), knowledge);
continue;
}
if (auto forOp = dyn_cast<scf::ForOp>(definingOp)) {
auto result = dyn_cast<OpResult>(value);
if (!result)
return failure();
// Trace the loop carry back to its underlying memref, then if that memref is the
// loop's own iter-arg we know the base comes from the corresponding init arg
// (every iteration yields the same backing memory in the DPS sense).
auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
Value yieldedValue = resolveLoopCarriedAliasImpl(yieldOp.getOperand(result.getResultNumber()), knowledge);
if (auto blockArgument = dyn_cast<BlockArgument>(yieldedValue)) {
if (blockArgument.getOwner() == forOp.getBody() && blockArgument.getArgNumber() > 0
&& static_cast<unsigned>(blockArgument.getArgNumber() - 1) < forOp.getInitArgs().size()) {
value = resolveAlias(forOp.getInitArgs()[blockArgument.getArgNumber() - 1], knowledge);
continue;
}
}
value = yieldedValue;
continue;
}
if (auto subviewOp = dyn_cast<memref::SubViewOp>(definingOp)) {
auto sourceType = dyn_cast<MemRefType>(subviewOp.getSource().getType());
auto subviewType = dyn_cast<MemRefType>(subviewOp.getType());
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
return failure();
SmallVector<int64_t> offsets;
SmallVector<int64_t> sizes;
SmallVector<int64_t> strides;
offsets.reserve(subviewOp.getMixedOffsets().size());
sizes.reserve(subviewOp.getMixedSizes().size());
strides.reserve(subviewOp.getMixedStrides().size());
for (OpFoldResult offset : subviewOp.getMixedOffsets()) {
auto resolvedOffset = resolveOpFoldResult(offset, knowledge);
if (failed(resolvedOffset))
return failure();
offsets.push_back(*resolvedOffset);
}
for (OpFoldResult size : subviewOp.getMixedSizes()) {
auto resolvedSize = resolveOpFoldResult(size, knowledge);
if (failed(resolvedSize))
return failure();
sizes.push_back(*resolvedSize);
}
for (OpFoldResult stride : subviewOp.getMixedStrides()) {
auto resolvedStride = resolveOpFoldResult(stride, knowledge);
if (failed(resolvedStride))
return failure();
strides.push_back(*resolvedStride);
}
if (!isMemoryContiguous(sourceType.getShape(), offsets, sizes, strides))
return failure();
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
byteOffset += linearizeIndex(offsets, sourceStrides) * subviewType.getElementTypeBitWidth() / 8;
value = resolveAlias(subviewOp.getSource(), knowledge);
continue;
}
if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
value = resolveAlias(castOp.getSource(), knowledge);
continue;
}
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp)) {
value = resolveAlias(collapseOp.getSrc(), knowledge);
continue;
}
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp)) {
value = resolveAlias(expandOp.getSrc(), knowledge);
continue;
}
if (isa<memref::AllocOp, memref::GetGlobalOp>(definingOp))
return ResolvedContiguousAddress {value, byteOffset};
return failure();
}
}
FailureOr<int64_t> resolveIndexValue(Value value) { return resolveIndexValueImpl(value, nullptr); }
FailureOr<int64_t> resolveIndexValue(Value value, const StaticValueKnowledge& knowledge) {
return resolveIndexValueImpl(value, &knowledge);
}
FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(Value value) {
return resolveContiguousAddressImpl(value, nullptr);
}
FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(Value value, const StaticValueKnowledge& knowledge) {
return resolveContiguousAddressImpl(value, &knowledge);
}
Value resolveLoopCarriedAlias(Value value, const StaticValueKnowledge& knowledge) {
return resolveLoopCarriedAliasImpl(value, &knowledge);
}
bool isCoreStaticAddressOp(Operation* op) {
return isa<arith::ConstantOp,
arith::AddIOp,
arith::SubIOp,
arith::MulIOp,
arith::DivUIOp,
arith::RemUIOp,
arith::IndexCastOp,
memref::AllocOp,
memref::SubViewOp,
memref::CastOp,
memref::CollapseShapeOp,
memref::ExpandShapeOp>(op);
}
LogicalResult walkPimCoreBlock(Block& block,
const StaticValueKnowledge& knowledge,
llvm::function_ref<LogicalResult(Operation&, const StaticValueKnowledge&)> callback) {
bool hasFailure = false;
for (Operation& op : block) {
if (isa<pim::PimHaltOp, scf::YieldOp>(op) || isCoreStaticAddressOp(&op))
continue;
if (auto forOp = dyn_cast<scf::ForOp>(op)) {
Block& loopBody = forOp.getRegion().front();
auto lowerBound = resolveIndexValue(forOp.getLowerBound(), knowledge);
auto upperBound = resolveIndexValue(forOp.getUpperBound(), knowledge);
auto step = resolveIndexValue(forOp.getStep(), knowledge);
if (failed(lowerBound) || failed(upperBound) || failed(step) || *step <= 0) {
forOp.emitOpError("requires statically evaluable scf.for bounds for PIM codegen");
hasFailure = true;
continue;
}
SmallVector<Value> iterValues(forOp.getInitArgs().begin(), forOp.getInitArgs().end());
for (int64_t inductionValue = *lowerBound; inductionValue < *upperBound; inductionValue += *step) {
StaticValueKnowledge loopKnowledge = knowledge;
loopKnowledge.indexValues[forOp.getInductionVar()] = inductionValue;
for (auto [iterArg, iterValue] : llvm::zip_equal(forOp.getRegionIterArgs(), iterValues))
loopKnowledge.aliases[iterArg] = iterValue;
if (failed(walkPimCoreBlock(loopBody, loopKnowledge, callback)))
hasFailure = true;
auto yieldOp = cast<scf::YieldOp>(loopBody.getTerminator());
for (auto [index, yieldedValue] : llvm::enumerate(yieldOp.getOperands()))
iterValues[index] = resolveLoopCarriedAlias(yieldedValue, loopKnowledge);
}
continue;
}
if (failed(callback(op, knowledge)))
hasFailure = true;
}
return success(!hasFailure);
}
} // namespace onnx_mlir
+10 -70
View File
@@ -7,82 +7,22 @@
#include "mlir/IR/Value.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLFunctionalExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/EntryPointUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
#include "src/Accelerators/PIM/Common/Support/DebugDump.hpp"
#include "src/Accelerators/PIM/Common/Support/FileSystemUtils.hpp"
#include "src/Compiler/CompilerOptions.hpp"
inline constexpr llvm::StringRef PimWeightAlwaysAttrName = "weightAlways";
namespace onnx_mlir {
struct ResolvedContiguousAddress {
mlir::Value base;
int64_t byteOffset = 0;
};
struct StaticValueKnowledge {
llvm::DenseMap<mlir::Value, int64_t> indexValues;
llvm::DenseMap<mlir::Value, mlir::Value> aliases;
StaticValueKnowledge() {}
};
std::string getOutputDir();
void createDirectory(const std::string& directory);
void dumpModule(mlir::ModuleOp moduleOp, const std::string& name);
llvm::FailureOr<mlir::func::FuncOp> getPimEntryFunc(mlir::ModuleOp moduleOp);
bool hasWeightAlways(mlir::Operation* op);
void markWeightAlways(mlir::Operation* op);
mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp);
llvm::FailureOr<mlir::Operation*>
getOtherEndOfChannel(mlir::Operation* op, bool opIsReceive, mlir::RewriterBase& rewriter);
llvm::SmallVector<int64_t> computeRowMajorStrides(llvm::ArrayRef<int64_t> shape);
llvm::SmallVector<int64_t>
delinearizeIndex(int64_t linearIndex, llvm::ArrayRef<int64_t> shape, llvm::ArrayRef<int64_t> strides);
int64_t linearizeIndex(llvm::ArrayRef<int64_t> indices, llvm::ArrayRef<int64_t> strides);
int64_t getNumElements(llvm::ArrayRef<int64_t> shape);
bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
llvm::ArrayRef<int64_t> offsets,
llvm::ArrayRef<int64_t> sizes,
llvm::ArrayRef<int64_t> strides);
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value);
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value,
const StaticValueKnowledge& knowledge);
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value);
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge);
/// Follows alias and view/DPS chains using `knowledge` to find the value an scf.for
/// iter-arg is ultimately backed by. Used when interpreting scf.for loop carries.
mlir::Value resolveLoopCarriedAlias(mlir::Value value, const StaticValueKnowledge& knowledge);
/// Returns true for ops inside a pim.core body that do not emit any PIM instruction and
/// only contribute to static addressing or index computations (arith integer math,
/// memref view ops, memref.alloc, arith.constant).
bool isCoreStaticAddressOp(mlir::Operation* op);
/// Walks `block` (the body of a pim.core region or an scf.for nested in it), statically
/// unrolling any scf.for with resolvable bounds using `knowledge`. For each remaining op
/// that is not skipped (pim.halt, scf.yield, or isCoreStaticAddressOp), `callback` is
/// invoked with the op and the in-scope knowledge. The walker keeps going after a callback
/// failure so callers can collect multiple diagnostics, but propagates the overall result.
mlir::LogicalResult
walkPimCoreBlock(mlir::Block& block,
const StaticValueKnowledge& knowledge,
llvm::function_ref<mlir::LogicalResult(mlir::Operation&, const StaticValueKnowledge&)> callback);
inline constexpr llvm::StringLiteral kCoreIdAttrName = "coreId";
inline constexpr llvm::StringLiteral kCoreIdsAttrName = "coreIds";
} // namespace onnx_mlir
+27
View File
@@ -0,0 +1,27 @@
#include "llvm/Support/raw_os_ostream.h"
#include <fstream>
#include "src/Accelerators/PIM/Common/Support/DebugDump.hpp"
#include "src/Accelerators/PIM/Common/Support/FileSystemUtils.hpp"
namespace onnx_mlir {
void dumpModule(mlir::ModuleOp moduleOp, const std::string& name) {
std::string outputDir = getOutputDir();
if (outputDir.empty())
return;
std::string dialectsDir = outputDir + "/dialects";
createDirectory(dialectsDir);
std::fstream file(dialectsDir + "/" + name + ".mlir", std::ios::out);
llvm::raw_os_ostream os(file);
mlir::OpPrintingFlags flags;
flags.elideLargeElementsAttrs();
moduleOp.print(os, flags);
os.flush();
file.close();
}
} // namespace onnx_mlir
+13
View File
@@ -0,0 +1,13 @@
#pragma once
#include "mlir/IR/BuiltinOps.h"
#include <string>
namespace onnx_mlir {
/// Emits a MLIR snapshot under the current compiler output
/// directory for pass-level debugging.
void dumpModule(mlir::ModuleOp moduleOp, const std::string& name);
} // namespace onnx_mlir
+41
View File
@@ -0,0 +1,41 @@
#include "llvm/ADT/STLExtras.h"
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
namespace onnx_mlir::pim {
mlir::InFlightDiagnostic emitUnsupportedStaticShapeDiagnostic(mlir::Operation* op, llvm::StringRef valueDescription) {
return op->emitOpError() << "requires statically shaped " << valueDescription;
}
mlir::InFlightDiagnostic emitUnsupportedRankDiagnostic(mlir::Operation* op,
llvm::StringRef valueDescription,
int64_t actualRank,
llvm::ArrayRef<int64_t> supportedRanks) {
auto diag = op->emitOpError() << "has unsupported rank " << actualRank << " for " << valueDescription;
if (supportedRanks.empty())
return diag;
diag << "; supported rank";
if (supportedRanks.size() != 1)
diag << 's';
diag << ' ';
llvm::interleaveComma(supportedRanks, diag, [&](int64_t rank) { diag << rank; });
return diag;
}
mlir::InFlightDiagnostic
emitMissingSymbolDiagnostic(mlir::Operation* op, llvm::StringRef symbolKind, llvm::StringRef symbolName) {
return op->emitOpError() << "references missing " << symbolKind << " `" << symbolName << "`";
}
mlir::LogicalResult emitFileSystemError(mlir::Location loc,
llvm::StringRef action,
llvm::StringRef path,
const std::error_code& errorCode) {
mlir::emitError(loc) << "failed to " << action << " `" << path << "`: " << errorCode.message();
return mlir::failure();
}
} // namespace onnx_mlir::pim
+62
View File
@@ -0,0 +1,62 @@
#pragma once
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Operation.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/StringRef.h"
#include <cstdint>
#include <system_error>
namespace onnx_mlir::pim {
struct CappedDiagnosticReporter {
explicit CappedDiagnosticReporter(int64_t maxReportedFailures = 8) : maxReportedFailures(maxReportedFailures) {}
template <typename EmitFn>
void report(mlir::Operation* op, EmitFn&& emit) {
numFailures++;
if (numFailures <= maxReportedFailures)
emit(op);
}
void emitSuppressedSummary(mlir::Operation* op, llvm::StringRef failureDescription) const {
if (numFailures > maxReportedFailures)
op->emitError() << "suppressed " << (numFailures - maxReportedFailures) << " additional "
<< failureDescription;
}
bool hasFailure() const { return numFailures != 0; }
private:
int64_t maxReportedFailures;
int64_t numFailures = 0;
};
/// Emits a consistent diagnostic for target paths that require static shapes.
mlir::InFlightDiagnostic emitUnsupportedStaticShapeDiagnostic(mlir::Operation* op, llvm::StringRef valueDescription);
/// Emits a consistent diagnostic for unsupported ranks while listing the ranks
/// accepted by the current lowering/codegen path.
mlir::InFlightDiagnostic emitUnsupportedRankDiagnostic(mlir::Operation* op,
llvm::StringRef valueDescription,
int64_t actualRank,
llvm::ArrayRef<int64_t> supportedRanks);
/// Emits a consistent diagnostic for missing symbol/global references.
mlir::InFlightDiagnostic
emitMissingSymbolDiagnostic(mlir::Operation* op, llvm::StringRef symbolKind, llvm::StringRef symbolName);
/// Converts a filesystem error into an MLIR failure diagnostic anchored at
/// the relevant IR location.
mlir::LogicalResult
emitFileSystemError(mlir::Location loc, llvm::StringRef action, llvm::StringRef path, const std::error_code& errorCode);
template <typename T>
mlir::LogicalResult failureOrToLogicalResult(const llvm::FailureOr<T>& value) {
return mlir::success(succeeded(value));
}
} // namespace onnx_mlir::pim
@@ -0,0 +1,24 @@
#include <filesystem>
#include "src/Accelerators/PIM/Common/Support/FileSystemUtils.hpp"
#include "src/Compiler/CompilerOptions.hpp"
namespace onnx_mlir {
std::string getOutputDir() {
if (outputBaseName.empty() || outputBaseName == "-")
return {};
size_t lastSlash = outputBaseName.find_last_of('/');
if (lastSlash == std::string::npos)
return ".";
return outputBaseName.substr(0, lastSlash);
}
void createDirectory(const std::string& directory) {
std::error_code errorCode;
std::filesystem::create_directories(directory, errorCode);
assert(!errorCode && ("Failed to create directory: " + errorCode.message()).data());
}
} // namespace onnx_mlir
@@ -0,0 +1,13 @@
#pragma once
#include <string>
namespace onnx_mlir {
/// Returns the directory that should hold PIM artifacts/debug dumps for the
/// current compiler invocation.
std::string getOutputDir();
void createDirectory(const std::string& directory);
} // namespace onnx_mlir
+62
View File
@@ -0,0 +1,62 @@
#include "llvm/Support/Format.h"
#include "src/Accelerators/PIM/Common/Support/FileSystemUtils.hpp"
#include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp"
namespace onnx_mlir {
std::fstream openReportFile(const std::string& name) {
std::string outputDir = getOutputDir();
if (outputDir.empty())
return {};
std::string reportsDir = outputDir + "/reports";
createDirectory(reportsDir);
return std::fstream(reportsDir + "/" + name + ".txt", std::ios::out);
}
std::string formatReportMemory(uint64_t bytes) {
const char* units[] = {"B", "KB", "MB", "GB", "TB", "PB", "EB"};
int i = 0;
double size = static_cast<double>(bytes);
while (size >= 1024 && i < 6) {
size /= 1024;
i++;
}
std::string out;
llvm::raw_string_ostream rss(out);
rss << llvm::format("%.2f ", size) << units[i];
return rss.str();
}
void printReportFlatFields(llvm::raw_ostream& os, llvm::ArrayRef<ReportField> fields) {
for (const ReportField& field : fields)
os << "\t" << field.label << ": " << field.value << "\n";
}
void printReportFieldBlock(llvm::raw_ostream& os, llvm::StringRef title, llvm::ArrayRef<ReportField> fields) {
os << "\t" << title << ":\n";
for (const ReportField& field : fields)
os << "\t " << field.label << ": " << field.value << "\n";
}
void printReportTotalsBlock(llvm::raw_ostream& os, llvm::ArrayRef<ReportField> fields) {
os << "Totals:\n";
for (const ReportField& field : fields)
os << "\t" << field.label << ": " << field.value << "\n";
}
void printReportPerCoreAndTotalFields(llvm::raw_ostream& os,
llvm::ArrayRef<ReportField> perCoreFields,
llvm::ArrayRef<ReportField> totalFields) {
printReportFieldBlock(os, "Per core", perCoreFields);
printReportFieldBlock(os, "Total", totalFields);
}
void printReportEntrySeparator(llvm::raw_ostream& os, bool hasNextEntry) {
if (hasNextEntry)
os << "\n";
}
} // namespace onnx_mlir
+47
View File
@@ -0,0 +1,47 @@
#pragma once
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/raw_ostream.h"
#include <fstream>
#include <limits>
#include <string>
namespace onnx_mlir {
std::fstream openReportFile(const std::string& name);
std::string formatReportMemory(uint64_t bytes);
struct ReportField {
std::string label;
std::string value;
};
void printReportFlatFields(llvm::raw_ostream& os, llvm::ArrayRef<ReportField> fields);
void printReportFieldBlock(llvm::raw_ostream& os, llvm::StringRef title, llvm::ArrayRef<ReportField> fields);
void printReportTotalsBlock(llvm::raw_ostream& os, llvm::ArrayRef<ReportField> fields);
void printReportPerCoreAndTotalFields(llvm::raw_ostream& os,
llvm::ArrayRef<ReportField> perCoreFields,
llvm::ArrayRef<ReportField> totalFields);
void printReportEntrySeparator(llvm::raw_ostream& os, bool hasNextEntry);
template <typename EntryTy>
int32_t getFirstReportCoreId(const EntryTy& entry) {
if (entry.coreIds.empty())
return std::numeric_limits<int32_t>::max();
return entry.coreIds.front();
}
template <typename EntryRange>
void sortReportEntriesByFirstCore(EntryRange& entries) {
llvm::stable_sort(entries, [](const auto& lhs, const auto& rhs) {
int32_t lhsFirstCore = getFirstReportCoreId(lhs);
int32_t rhsFirstCore = getFirstReportCoreId(rhs);
if (lhsFirstCore != rhsFirstCore)
return lhsFirstCore < rhsFirstCore;
return lhs.id < rhs.id;
});
}
} // namespace onnx_mlir
+4
View File
@@ -15,7 +15,10 @@ add_pim_library(OMPimCompilerOptions
add_pim_library(OMPimCompilerUtils
PimCompilerUtils.cpp
PimArtifactWriter.cpp
PimBatchEmission.cpp
PimCodeGen.cpp
PimWeightEmitter.cpp
EXCLUDE_FROM_OM_LIBS
@@ -26,6 +29,7 @@ add_pim_library(OMPimCompilerUtils
OMPimCompilerOptions
OMPimCommon
OMPimBufferization
OMPimStaticMemoryCoalescing
OMPimPasses
OMONNXToSpatial
OMSpatialToPim
+106
View File
@@ -0,0 +1,106 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/raw_ostream.h"
#include <algorithm>
#include <cassert>
#include <cstring>
#include <vector>
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
#include "src/Accelerators/PIM/Compiler/PimArtifactWriter.hpp"
#include "src/Accelerators/PIM/Compiler/PimBinaryFormat.hpp"
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
using namespace llvm;
using namespace mlir;
namespace onnx_mlir {
OnnxMlirCompilerErrorCodes
writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory& memory, StringRef outputDirPath) {
auto memoryFilePath = (outputDirPath + "/memory.bin").str();
std::error_code errorCode;
raw_fd_ostream memoryFileStream(memoryFilePath, errorCode, sys::fs::OF_None);
if (errorCode) {
errs() << "Error while opening memory file " << memoryFilePath << ": " << errorCode.message() << '\n';
return InvalidOutputFileAccess;
}
std::vector<char> memoryBuffer(memory.hostMem.getFirstAvailableAddress(), 0);
SmallPtrSet<Operation*, 16> writtenGlobals;
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
if (hasWeightAlways(getGlobalOp))
return;
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp)
return;
if (!writtenGlobals.insert(globalOp.getOperation()).second)
return;
auto initialValue = globalOp.getInitialValue();
if (!initialValue)
return;
auto denseAttr = dyn_cast<DenseElementsAttr>(*initialValue);
if (!denseAttr)
return;
MemEntry memEntry = memory.hostMem.getMemEntry(getGlobalOp.getResult());
ArrayRef<char> rawData = denseAttr.getRawData();
char* dst = memoryBuffer.data() + memEntry.address;
if (denseAttr.isSplat()) {
size_t elementSize = rawData.size();
assert(elementSize * getGlobalOp.getType().getNumElements() == memEntry.size && "Data size mismatch");
for (size_t offset = 0; offset < memEntry.size; offset += elementSize)
std::memcpy(dst + offset, rawData.data(), std::min(elementSize, memEntry.size - offset));
}
else {
assert(rawData.size() == memEntry.size && "Data size mismatch");
std::memcpy(dst, rawData.data(), rawData.size());
}
});
memoryFileStream.write(memoryBuffer.data(), memoryBuffer.size());
memoryFileStream.close();
return CompilerSuccess;
}
OnnxMlirCompilerErrorCodes writeConfigJson(func::FuncOp funcOp,
PimAcceleratorMemory& memory,
size_t maxCoreId,
json::Object xbarsPerArrayGroup,
StringRef outputDirPath) {
json::Object configJson;
configJson["core_cnt"] = maxCoreId + 1;
configJson["xbar_size"] = {crossbarSize.getValue(), crossbarSize.getValue()};
configJson["array_group_map"] = std::move(xbarsPerArrayGroup);
json::Array inputsAddresses;
for (BlockArgument input : funcOp.getArguments())
inputsAddresses.push_back(memory.getValueAddress(input));
configJson["inputs_addresses"] = std::move(inputsAddresses);
json::Array outputsAddresses;
for (func::ReturnOp returnOp : funcOp.getOps<func::ReturnOp>())
for (mlir::Value output : returnOp.getOperands())
outputsAddresses.push_back(memory.getValueAddress(output));
configJson["outputs_addresses"] = std::move(outputsAddresses);
auto configPath = (outputDirPath + "/config.json").str();
std::error_code errorCode;
raw_fd_ostream jsonOS(configPath, errorCode);
if (errorCode) {
errs() << "Error while opening config file: " << errorCode.message() << '\n';
return InvalidOutputFileAccess;
}
jsonOS << json::Value(std::move(configJson)) << '\n';
jsonOS.close();
return CompilerSuccess;
}
} // namespace onnx_mlir
+25
View File
@@ -0,0 +1,25 @@
#pragma once
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/JSON.h"
#include "onnx-mlir/Compiler/OMCompilerTypes.h"
namespace onnx_mlir {
class PimAcceleratorMemory;
OnnxMlirCompilerErrorCodes writeMemoryBinary(mlir::ModuleOp moduleOp,
mlir::func::FuncOp funcOp,
PimAcceleratorMemory& memory,
llvm::StringRef outputDirPath);
OnnxMlirCompilerErrorCodes writeConfigJson(mlir::func::FuncOp funcOp,
PimAcceleratorMemory& memory,
size_t maxCoreId,
llvm::json::Object xbarsPerArrayGroup,
llvm::StringRef outputDirPath);
} // namespace onnx_mlir
+136
View File
@@ -0,0 +1,136 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) {
auto coreIdsAttr = coreBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName);
assert(coreIdsAttr && "pim.core_batch requires coreIds array attribute");
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
}
static SmallVector<int32_t> getLaneChunkCoreIds(ArrayRef<int32_t> coreIds, size_t laneCount, unsigned lane) {
SmallVector<int32_t> laneCoreIds;
laneCoreIds.reserve(coreIds.size() / laneCount);
for (size_t chunkIndex = 0; chunkIndex < coreIds.size() / laneCount; ++chunkIndex)
laneCoreIds.push_back(coreIds[chunkIndex * laneCount + lane]);
return laneCoreIds;
}
static void scalarizeBatchOpsInCore(pim::PimCoreOp scalarCore, size_t laneCount, unsigned lane) {
IRRewriter rewriter(scalarCore.getContext());
SmallVector<Operation*> batchOps;
scalarCore.walk([&](Operation* op) {
if (isa<pim::PimSendBatchOp,
pim::PimSendTensorBatchOp,
pim::PimReceiveBatchOp,
pim::PimReceiveTensorBatchOp,
pim::PimMemCopyHostToDevBatchOp>(op)) {
batchOps.push_back(op);
}
});
for (Operation* op : batchOps) {
rewriter.setInsertionPoint(op);
if (auto sendBatchOp = dyn_cast<pim::PimSendBatchOp>(op)) {
pim::PimSendOp::create(rewriter,
sendBatchOp.getLoc(),
sendBatchOp.getInput(),
sendBatchOp.getSizeAttr(),
rewriter.getI32IntegerAttr(sendBatchOp.getTargetCoreIds()[lane]));
rewriter.eraseOp(op);
continue;
}
if (auto sendTensorBatchOp = dyn_cast<pim::PimSendTensorBatchOp>(op)) {
pim::PimSendTensorOp::create(
rewriter,
sendTensorBatchOp.getLoc(),
sendTensorBatchOp.getInput(),
rewriter.getDenseI32ArrayAttr(getLaneChunkCoreIds(sendTensorBatchOp.getTargetCoreIds(), laneCount, lane)));
rewriter.eraseOp(op);
continue;
}
if (auto receiveBatchOp = dyn_cast<pim::PimReceiveBatchOp>(op)) {
auto scalarReceive =
pim::PimReceiveOp::create(rewriter,
receiveBatchOp.getLoc(),
receiveBatchOp.getOutput().getType(),
receiveBatchOp.getOutputBuffer(),
receiveBatchOp.getSizeAttr(),
rewriter.getI32IntegerAttr(receiveBatchOp.getSourceCoreIds()[lane]));
rewriter.replaceOp(op, scalarReceive->getResults());
continue;
}
if (auto receiveTensorBatchOp = dyn_cast<pim::PimReceiveTensorBatchOp>(op)) {
auto scalarReceive = pim::PimReceiveTensorOp::create(
rewriter,
receiveTensorBatchOp.getLoc(),
receiveTensorBatchOp.getOutput().getType(),
receiveTensorBatchOp.getOutputBuffer(),
rewriter.getDenseI32ArrayAttr(getLaneChunkCoreIds(receiveTensorBatchOp.getSourceCoreIds(), laneCount, lane)));
rewriter.replaceOp(op, scalarReceive->getResults());
continue;
}
auto memcpBatchOp = cast<pim::PimMemCopyHostToDevBatchOp>(op);
auto scalarCopy = pim::PimMemCopyHostToDevOp::create(rewriter,
memcpBatchOp.getLoc(),
memcpBatchOp.getOutput().getType(),
memcpBatchOp.getDeviceTarget(),
memcpBatchOp.getHostSource(),
memcpBatchOp.getDeviceTargetOffsetAttr(),
memcpBatchOp.getHostSourceOffsetAttr(),
memcpBatchOp.getSizeAttr());
rewriter.replaceOp(op, scalarCopy->getResults());
}
}
} // namespace
LogicalResult withScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp,
unsigned lane,
llvm::function_ref<LogicalResult(pim::PimCoreOp)> callback) {
OwningOpRef<ModuleOp> scratchModule = ModuleOp::create(coreBatchOp.getLoc());
OpBuilder builder(scratchModule->getContext());
builder.setInsertionPointToStart(scratchModule->getBody());
size_t laneCount = static_cast<size_t>(coreBatchOp.getLaneCount());
size_t weightsPerLane = coreBatchOp.getWeights().size() / laneCount;
SmallVector<Value> laneWeights;
laneWeights.reserve(weightsPerLane);
for (size_t weightIndex = 0; weightIndex < weightsPerLane; ++weightIndex)
laneWeights.push_back(coreBatchOp.getWeights()[lane * weightsPerLane + weightIndex]);
auto coreIds = getBatchCoreIds(coreBatchOp);
auto scalarCore = pim::PimCoreOp::create(
builder, coreBatchOp.getLoc(), ValueRange(laneWeights), builder.getI32IntegerAttr(coreIds[lane]));
Block* block = builder.createBlock(&scalarCore.getBody(), scalarCore.getBody().end());
IRMapping mapper;
if (coreBatchOp.getBody().front().getNumArguments() == 1)
mapper.map(coreBatchOp.getBody().front().getArgument(0), coreBatchOp.getInputs()[lane]);
builder.setInsertionPointToEnd(block);
for (Operation& op : coreBatchOp.getBody().front()) {
Operation* cloned = builder.clone(op, mapper);
for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults()))
mapper.map(originalResult, clonedResult);
}
if (block->empty() || !isa<pim::PimHaltOp>(block->back()))
pim::PimHaltOp::create(builder, coreBatchOp.getLoc());
scalarizeBatchOpsInCore(scalarCore, laneCount, lane);
return callback(scalarCore);
}
} // namespace onnx_mlir
+13
View File
@@ -0,0 +1,13 @@
#pragma once
#include "llvm/ADT/STLFunctionalExtras.h"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
namespace onnx_mlir {
mlir::LogicalResult withScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp,
unsigned lane,
llvm::function_ref<mlir::LogicalResult(pim::PimCoreOp)> callback);
} // namespace onnx_mlir
+374
View File
@@ -0,0 +1,374 @@
#pragma once
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Endian.h"
#include "llvm/Support/JSON.h"
#include "llvm/Support/raw_ostream.h"
#include <array>
#include <cassert>
#include <limits>
namespace onnx_mlir::pim_binary {
inline constexpr char kMagic[4] = {'P', 'I', 'M', 'B'};
inline constexpr uint32_t kVersion = 1;
inline constexpr uint64_t kCountOffset = 8;
inline constexpr size_t kHeaderSize = 12;
inline constexpr size_t kRecordSize = 20;
enum class Opcode : uint32_t {
nop = 0,
sldi = 1,
sld = 2,
sadd = 3,
ssub = 4,
smul = 5,
saddi = 6,
smuli = 7,
setbw = 8,
mvmul = 9,
vvadd = 10,
vvsub = 11,
vvmul = 12,
vvdmul = 13,
vvmax = 14,
vvsll = 15,
vvsra = 16,
vavg = 17,
vrelu = 18,
vtanh = 19,
vsigm = 20,
vsoftmax = 21,
vmv = 22,
vrsu = 23,
vrsl = 24,
ld = 25,
st = 26,
lldi = 27,
lmv = 28,
send = 29,
recv = 30,
wait = 31,
sync = 32,
};
struct InstructionRecord {
Opcode opcode = Opcode::nop;
uint8_t rd = 0;
uint8_t r1 = 0;
int32_t r2OrImm = 0;
int32_t generic1 = 0;
int32_t generic2 = 0;
int32_t generic3 = 0;
uint8_t flags = 0;
};
inline void writeUint32LE(llvm::raw_ostream& os, uint32_t value) {
std::array<char, sizeof(uint32_t)> bytes;
llvm::support::endian::write32le(bytes.data(), value);
os.write(bytes.data(), bytes.size());
}
inline void writeInt32LE(llvm::raw_ostream& os, int32_t value) { writeUint32LE(os, static_cast<uint32_t>(value)); }
inline void writeHeader(llvm::raw_ostream& os) {
os.write(kMagic, sizeof(kMagic));
writeUint32LE(os, kVersion);
writeUint32LE(os, 0);
}
inline void patchInstructionCount(llvm::raw_pwrite_stream& os, uint32_t instructionCount) {
std::array<char, sizeof(uint32_t)> bytes;
llvm::support::endian::write32le(bytes.data(), instructionCount);
os.pwrite(bytes.data(), bytes.size(), kCountOffset);
}
inline void writeInstructionRecord(llvm::raw_ostream& os, const InstructionRecord& record) {
os << static_cast<char>(static_cast<uint8_t>(record.opcode));
os << static_cast<char>(record.rd);
os << static_cast<char>(record.r1);
os << static_cast<char>(record.flags);
writeInt32LE(os, record.r2OrImm);
writeInt32LE(os, record.generic1);
writeInt32LE(os, record.generic2);
writeInt32LE(os, record.generic3);
}
inline int32_t toI32(int64_t value) {
assert(value >= std::numeric_limits<int32_t>::min() && value <= std::numeric_limits<int32_t>::max()
&& "PIM binary field out of int32 range");
return static_cast<int32_t>(value);
}
inline uint8_t toU8(int64_t value) {
assert(value >= 0 && value <= std::numeric_limits<uint8_t>::max() && "PIM binary field out of uint8 range");
return static_cast<uint8_t>(value);
}
inline int32_t getOptionalInt(const llvm::json::Object& object, llvm::StringRef key, int32_t defaultValue = 0) {
if (std::optional<int64_t> value = object.getInteger(key))
return toI32(*value);
return defaultValue;
}
inline Opcode opcodeFromString(llvm::StringRef opName) {
if (opName == "nop")
return Opcode::nop;
if (opName == "sldi")
return Opcode::sldi;
if (opName == "sld")
return Opcode::sld;
if (opName == "sadd")
return Opcode::sadd;
if (opName == "ssub")
return Opcode::ssub;
if (opName == "smul")
return Opcode::smul;
if (opName == "saddi")
return Opcode::saddi;
if (opName == "smuli")
return Opcode::smuli;
if (opName == "setbw")
return Opcode::setbw;
if (opName == "mvmul")
return Opcode::mvmul;
if (opName == "vvadd")
return Opcode::vvadd;
if (opName == "vvsub")
return Opcode::vvsub;
if (opName == "vvmul")
return Opcode::vvmul;
if (opName == "vvdmul")
return Opcode::vvdmul;
if (opName == "vvmax")
return Opcode::vvmax;
if (opName == "vvsll")
return Opcode::vvsll;
if (opName == "vvsra")
return Opcode::vvsra;
if (opName == "vavg")
return Opcode::vavg;
if (opName == "vrelu")
return Opcode::vrelu;
if (opName == "vtanh")
return Opcode::vtanh;
if (opName == "vsigm")
return Opcode::vsigm;
if (opName == "vsoftmax")
return Opcode::vsoftmax;
if (opName == "vmv")
return Opcode::vmv;
if (opName == "vrsu")
return Opcode::vrsu;
if (opName == "vrsl")
return Opcode::vrsl;
if (opName == "ld")
return Opcode::ld;
if (opName == "st")
return Opcode::st;
if (opName == "lldi")
return Opcode::lldi;
if (opName == "lmv")
return Opcode::lmv;
if (opName == "send")
return Opcode::send;
if (opName == "recv")
return Opcode::recv;
if (opName == "wait")
return Opcode::wait;
if (opName == "sync")
return Opcode::sync;
llvm_unreachable("Unsupported PIM binary opcode");
}
inline llvm::StringRef opcodeToString(Opcode opcode) {
switch (opcode) {
case Opcode::nop: return "nop";
case Opcode::sldi: return "sldi";
case Opcode::sld: return "sld";
case Opcode::sadd: return "sadd";
case Opcode::ssub: return "ssub";
case Opcode::smul: return "smul";
case Opcode::saddi: return "saddi";
case Opcode::smuli: return "smuli";
case Opcode::setbw: return "setbw";
case Opcode::mvmul: return "mvmul";
case Opcode::vvadd: return "vvadd";
case Opcode::vvsub: return "vvsub";
case Opcode::vvmul: return "vvmul";
case Opcode::vvdmul: return "vvdmul";
case Opcode::vvmax: return "vvmax";
case Opcode::vvsll: return "vvsll";
case Opcode::vvsra: return "vvsra";
case Opcode::vavg: return "vavg";
case Opcode::vrelu: return "vrelu";
case Opcode::vtanh: return "vtanh";
case Opcode::vsigm: return "vsigm";
case Opcode::vsoftmax: return "vsoftmax";
case Opcode::vmv: return "vmv";
case Opcode::vrsu: return "vrsu";
case Opcode::vrsl: return "vrsl";
case Opcode::ld: return "ld";
case Opcode::st: return "st";
case Opcode::lldi: return "lldi";
case Opcode::lmv: return "lmv";
case Opcode::send: return "send";
case Opcode::recv: return "recv";
case Opcode::wait: return "wait";
case Opcode::sync: return "sync";
}
llvm_unreachable("Unsupported PIM binary opcode");
}
inline InstructionRecord makeInstructionRecord(const llvm::json::Object& instruction) {
InstructionRecord record;
std::optional<llvm::StringRef> opName = instruction.getString("op");
assert(opName && "Missing op field in PIM instruction");
record.opcode = opcodeFromString(*opName);
record.rd = toU8(getOptionalInt(instruction, "rd"));
record.r1 = toU8(getOptionalInt(instruction, "rs1"));
switch (record.opcode) {
case Opcode::sldi:
case Opcode::saddi:
case Opcode::smuli:
case Opcode::lldi: record.r2OrImm = getOptionalInt(instruction, "imm"); break;
case Opcode::mvmul:
record.r2OrImm = getOptionalInt(instruction, "mbiw");
record.generic1 = getOptionalInt(instruction, "relu");
record.generic2 = getOptionalInt(instruction, "group");
break;
case Opcode::setbw:
record.generic1 = getOptionalInt(instruction, "ibiw");
record.generic2 = getOptionalInt(instruction, "obiw");
break;
case Opcode::send:
case Opcode::recv:
record.r2OrImm = getOptionalInt(instruction, "core");
record.generic3 = getOptionalInt(instruction, "size");
break;
default: record.r2OrImm = getOptionalInt(instruction, "rs2"); break;
}
if (record.opcode != Opcode::mvmul && record.opcode != Opcode::setbw) {
if (auto* offsetValue = instruction.getObject("offset")) {
record.generic1 = getOptionalInt(*offsetValue, "offset_select");
record.generic2 = getOptionalInt(*offsetValue, "offset_value");
}
}
if (instruction.get("len"))
record.generic3 = getOptionalInt(instruction, "len");
else if (instruction.get("size") && record.opcode != Opcode::send && record.opcode != Opcode::recv)
record.generic3 = getOptionalInt(instruction, "size");
return record;
}
inline llvm::json::Object makeInstructionJson(const InstructionRecord& record) {
llvm::json::Object instruction;
instruction["op"] = opcodeToString(record.opcode).str();
auto addOffset = [&](int32_t offsetSelect, int32_t offsetValue) {
llvm::json::Object offset;
offset["offset_select"] = offsetSelect;
offset["offset_value"] = offsetValue;
instruction["offset"] = std::move(offset);
};
switch (record.opcode) {
case Opcode::sldi:
instruction["rd"] = static_cast<int64_t>(record.rd);
instruction["imm"] = record.r2OrImm;
break;
case Opcode::sld:
instruction["rd"] = static_cast<int64_t>(record.rd);
instruction["rs1"] = static_cast<int64_t>(record.r1);
addOffset(record.generic1, record.generic2);
break;
case Opcode::sadd:
case Opcode::ssub:
case Opcode::smul:
instruction["rd"] = static_cast<int64_t>(record.rd);
instruction["rs1"] = static_cast<int64_t>(record.r1);
instruction["rs2"] = record.r2OrImm;
break;
case Opcode::saddi:
case Opcode::smuli:
instruction["rd"] = static_cast<int64_t>(record.rd);
instruction["rs1"] = static_cast<int64_t>(record.r1);
instruction["imm"] = record.r2OrImm;
break;
case Opcode::setbw:
instruction["ibiw"] = record.generic1;
instruction["obiw"] = record.generic2;
break;
case Opcode::mvmul:
instruction["rd"] = static_cast<int64_t>(record.rd);
instruction["rs1"] = static_cast<int64_t>(record.r1);
instruction["mbiw"] = record.r2OrImm;
instruction["relu"] = record.generic1;
instruction["group"] = record.generic2;
break;
case Opcode::vvadd:
case Opcode::vvsub:
case Opcode::vvmul:
case Opcode::vvdmul:
case Opcode::vvmax:
case Opcode::vvsll:
case Opcode::vvsra:
case Opcode::vavg:
case Opcode::vmv:
case Opcode::vrsu:
case Opcode::vrsl:
instruction["rd"] = static_cast<int64_t>(record.rd);
instruction["rs1"] = static_cast<int64_t>(record.r1);
instruction["rs2"] = record.r2OrImm;
addOffset(record.generic1, record.generic2);
instruction["len"] = record.generic3;
break;
case Opcode::vrelu:
case Opcode::vtanh:
case Opcode::vsigm:
case Opcode::vsoftmax:
instruction["rd"] = static_cast<int64_t>(record.rd);
instruction["rs1"] = static_cast<int64_t>(record.r1);
addOffset(record.generic1, record.generic2);
instruction["len"] = record.generic3;
break;
case Opcode::ld:
case Opcode::st:
instruction["rd"] = static_cast<int64_t>(record.rd);
instruction["rs1"] = static_cast<int64_t>(record.r1);
addOffset(record.generic1, record.generic2);
instruction["size"] = record.generic3;
break;
case Opcode::lldi:
instruction["rd"] = static_cast<int64_t>(record.rd);
instruction["imm"] = record.r2OrImm;
addOffset(record.generic1, record.generic2);
instruction["len"] = record.generic3;
break;
case Opcode::lmv:
instruction["rd"] = static_cast<int64_t>(record.rd);
instruction["rs1"] = static_cast<int64_t>(record.r1);
addOffset(record.generic1, record.generic2);
instruction["len"] = record.generic3;
break;
case Opcode::send:
case Opcode::recv:
instruction["rd"] = static_cast<int64_t>(record.rd);
instruction["core"] = record.r2OrImm;
addOffset(record.generic1, record.generic2);
instruction["size"] = record.generic3;
break;
case Opcode::wait:
case Opcode::sync:
case Opcode::nop: break;
}
return instruction;
}
} // namespace onnx_mlir::pim_binary
File diff suppressed because it is too large Load Diff
+70 -9
View File
@@ -1,10 +1,19 @@
#pragma once
#include "mlir/IR/Operation.h"
#include "llvm-project/clang/include/clang/Basic/LLVM.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/JSON.h"
#include "llvm/Support/raw_os_ostream.h"
#include <fstream>
#include <optional>
#include "onnx-mlir/Compiler/OMCompilerTypes.h"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp"
#include "src/Accelerators/PIM/Compiler/PimBinaryFormat.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
namespace onnx_mlir {
@@ -14,16 +23,42 @@ struct MemEntry {
size_t size;
};
struct MemoryReportRow {
uint64_t numAlloca = 0;
uint64_t sizeAlloca = 0;
uint64_t numGlobal = 0;
uint64_t sizeGlobal = 0;
bool operator==(const MemoryReportRow& other) const {
return numAlloca == other.numAlloca && sizeAlloca == other.sizeAlloca && numGlobal == other.numGlobal
&& sizeGlobal == other.sizeGlobal;
}
};
struct MemoryReportEntry {
enum class Kind {
Core,
Batch
};
Kind kind = Kind::Core;
uint64_t id = 0;
llvm::SmallVector<int32_t, 8> coreIds;
MemoryReportRow row;
uint64_t totalAllocaCount = 0;
uint64_t totalAllocaBytes = 0;
};
class PimMemory {
llvm::SmallVector<std::pair<MemEntry, mlir::Value>, 32> memEntries;
llvm::SmallDenseMap<mlir::Value, MemEntry, 32>& globalMemEntriesMap;
llvm::SmallDenseMap<mlir::Value, MemEntry, 32> ownedMemEntriesMap;
size_t maxSize = 0; // 0 for unbounded memory
size_t startAddress = 0;
size_t minAlignment = 4;
size_t firstAvailableAddress = 0;
MemEntry* gatherMemEntry(mlir::Value value);
void allocateGatheredMemory();
void allocateMemoryForValue(mlir::Value value, MemEntry& memEntry);
public:
@@ -32,6 +67,8 @@ public:
void allocateHost(mlir::ModuleOp moduleOp, mlir::func::FuncOp funcOp);
void allocateCore(mlir::Operation* op);
MemoryReportRow getReportRow() const;
void remove(mlir::Value val);
size_t getFirstAvailableAddress() const { return firstAvailableAddress; }
MemEntry getMemEntry(mlir::Value value) const;
@@ -44,26 +81,41 @@ public:
private:
llvm::SmallDenseMap<size_t, PimMemory> deviceMem;
std::fstream fileReport;
std::optional<MemoryReportRow> hostReportRow;
llvm::SmallVector<MemoryReportEntry, 32> reportEntries;
public:
PimAcceleratorMemory()
: hostMem(memEntriesMap) {}
: hostMem(memEntriesMap), fileReport(openReportFile("memory_report")) {}
PimMemory& getOrCreateDeviceMem(size_t id);
size_t getValueAddress(mlir::Value value, const StaticValueKnowledge& knowledge = {}) const;
void reportHost();
void recordCoreReport(size_t coreId, const MemoryReportRow& row);
void recordBatchReport(uint64_t batchId,
llvm::ArrayRef<int32_t> coreIds,
const MemoryReportRow& perCoreRow,
uint64_t totalAllocaCount,
uint64_t totalAllocaBytes);
void flushReport();
void clean(mlir::Operation* op);
};
class PimCodeGen {
PimAcceleratorMemory& memory;
llvm::raw_fd_ostream& coreFileStream;
llvm::raw_fd_ostream& coreBinaryStream;
llvm::raw_fd_ostream* coreJsonStream;
const llvm::DenseMap<size_t, size_t>& emittedCoreIds;
mutable uint32_t emittedInstructionCount = 0;
size_t addressOf(mlir::Value value, const StaticValueKnowledge& knowledge) const {
return memory.getValueAddress(value, knowledge);
}
size_t remapCoreId(size_t coreId) const;
static llvm::json::Object createEmptyOffset();
void emitInstruction(llvm::json::Object instruction) const;
void emitInstruction(const pim_binary::InstructionRecord& instruction) const;
void genSetRegisterImmediateUnsigned(size_t registerNumber, size_t immediate) const;
void setupRd(size_t rdAddress, size_t rdOffset) const;
@@ -82,15 +134,23 @@ class PimCodeGen {
void emitMvmOp(size_t groupId, size_t rdAddr, size_t rdOffset, size_t rs1Addr, size_t rs1Offset) const;
public:
PimCodeGen(PimAcceleratorMemory& memory, llvm::raw_fd_ostream& coreJson)
: memory(memory), coreFileStream(coreJson) {}
PimCodeGen(PimAcceleratorMemory& memory,
llvm::raw_fd_ostream& coreBinary,
llvm::raw_fd_ostream* coreJson,
const llvm::DenseMap<size_t, size_t>& emittedCoreIds)
: memory(memory), coreBinaryStream(coreBinary), coreJsonStream(coreJson), emittedCoreIds(emittedCoreIds) {}
uint32_t getEmittedInstructionCount() const { return emittedInstructionCount; }
void codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp, const StaticValueKnowledge& knowledge) const;
void codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp, const StaticValueKnowledge& knowledge) const;
void codeGenLmvOp(pim::PimMemCopyOp lmvOp, const StaticValueKnowledge& knowledge) const;
void codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValueKnowledge& knowledge) const;
void codeGenReceiveTensorOp(pim::PimReceiveTensorOp receiveTensorOp, const StaticValueKnowledge& knowledge) const;
void codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const;
void codeGenSendTensorOp(pim::PimSendTensorOp sendTensorOp, const StaticValueKnowledge& knowledge) const;
void codeGenConcatOp(pim::PimConcatOp concatOp, const StaticValueKnowledge& knowledge) const;
template <typename MVMTy>
void codeGenMVMLikeOp(size_t mvmId, MVMTy mvmLikeOp, bool transposeMatrix, const StaticValueKnowledge& knowledge);
@@ -105,9 +165,10 @@ public:
void codeGenVTanhOp(pim::PimVTanhOp vtanhOp, const StaticValueKnowledge& knowledge) const;
void codeGenVSigmOp(pim::PimVSigmOp vsigmOp, const StaticValueKnowledge& knowledge) const;
void codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp, const StaticValueKnowledge& knowledge) const;
void codeGetGlobalOp(mlir::memref::GetGlobalOp getGlobalOp, const StaticValueKnowledge& knowledge) const;
void codeGenTransposeOp(pim::PimTransposeOp transposeOp, const StaticValueKnowledge& knowledge) const;
};
OnnxMlirCompilerErrorCodes compileToPimJson(mlir::ModuleOp& moduleOpRef, std::string& outputDirName);
OnnxMlirCompilerErrorCodes compileToPimCode(mlir::ModuleOp& moduleOpRef, std::string& outputDirName);
} // namespace onnx_mlir
+12 -14
View File
@@ -1,16 +1,3 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/
//===------------------------- PimCompilerOptions.cpp --------------------===//
//
// Copyright 2022 The IBM Research Authors.
//
// =============================================================================
//
// Compiler Options for PIM
//
//===----------------------------------------------------------------------===//
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#define DEBUG_TYPE "PimCompilerOptions"
@@ -37,16 +24,27 @@ llvm::cl::opt<bool> useExperimentalConvImpl("use-experimental-conv-impl",
llvm::cl::init(false),
llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<bool> pimEmitJson("pim-emit-json",
llvm::cl::desc("Also emit per-core JSON instruction files alongside binary .pim files"),
llvm::cl::init(false),
llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<size_t>
crossbarSize("crossbar-size", llvm::cl::desc("Width and heigth of a single crossbar"), llvm::cl::init(2));
llvm::cl::opt<size_t>
crossbarCountInCore("crossbar-count", llvm::cl::desc("Number of crossbars in each core"), llvm::cl::init(2));
crossbarCountInCore("crossbar-count", llvm::cl::desc("Number of crossbars in each core"), llvm::cl::init(256));
llvm::cl::opt<long> coresCount("core-count",
llvm::cl::desc("Number of cores in the chip. `-1` to use the minimum amount of cores."),
llvm::cl::init(-1));
llvm::cl::opt<size_t> dcpCriticalWindowSize(
"dcp-critical-window-size",
llvm::cl::desc("Number of lowest-slack virtual nodes considered by each DCP coarsening iteration. "
"Use 0 to run the legacy full-graph DCP analysis."),
llvm::cl::init(4000));
llvm::cl::opt<bool>
ignoreConcatError("ignore-concat-error",
llvm::cl::desc("Ignore ConcatOp corner case: do not assert and do a simplification"),
+2
View File
@@ -25,10 +25,12 @@ extern llvm::cl::opt<PimEmissionTargetType> pimEmissionTarget;
extern llvm::cl::opt<bool> pimOnlyCodegen;
extern llvm::cl::opt<bool> useExperimentalConvImpl;
extern llvm::cl::opt<bool> pimEmitJson;
extern llvm::cl::opt<size_t> crossbarSize;
extern llvm::cl::opt<size_t> crossbarCountInCore;
extern llvm::cl::opt<long> coresCount;
extern llvm::cl::opt<size_t> dcpCriticalWindowSize;
// This option, by default set to false, will ignore an error when resolving a
// specific tiles of the operands of a concat. This specific case is when the
+3 -2
View File
@@ -41,6 +41,7 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
if (pimEmissionTarget >= EmitPimBufferized) {
pm.addPass(createPimBufferizationPass());
pm.addPass(createPimStaticMemoryCoalescingPass());
// pm.addPass(createCountInstructionPass());
pm.addPass(createMessagePass("Pim bufferized"));
}
@@ -51,9 +52,9 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
pm.addPass(createPimMaterializeHostConstantsPass());
pm.addPass(createPimVerificationPass());
pm.addPass(createMessagePass("Pim verified"));
pm.addPass(createEmitPimJsonPass());
pm.addPass(createEmitPimCodePass());
// pm.addPass(createCountInstructionPass());
pm.addPass(createMessagePass("Pim json code emitted"));
pm.addPass(createMessagePass("Pim code emitted"));
}
}
+250
View File
@@ -0,0 +1,250 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/raw_ostream.h"
#include <cassert>
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Compiler/PimWeightEmitter.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace llvm;
using namespace mlir;
namespace onnx_mlir {
namespace {
struct DenseWeightView {
DenseElementsAttr denseAttr;
SmallVector<int64_t> shape;
SmallVector<int64_t> strides;
int64_t offset = 0;
};
FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value weight) {
SmallVector<Operation*> viewOps;
mlir::Value current = weight;
memref::GetGlobalOp getGlobalOp;
while (true) {
Operation* defOp = current.getDefiningOp();
if (!defOp)
return failure();
if ((getGlobalOp = dyn_cast<memref::GetGlobalOp>(defOp)))
break;
if (auto subview = dyn_cast<memref::SubViewOp>(defOp)) {
if (!hasAllStaticSubviewParts(subview))
return failure();
viewOps.push_back(subview);
current = subview.getSource();
continue;
}
if (auto cast = dyn_cast<memref::CastOp>(defOp)) {
current = cast.getSource();
continue;
}
if (auto collapse = dyn_cast<memref::CollapseShapeOp>(defOp)) {
auto srcType = dyn_cast<MemRefType>(collapse.getSrc().getType());
auto resultType = dyn_cast<MemRefType>(collapse.getResult().getType());
if (!srcType || !resultType || !srcType.hasStaticShape() || !resultType.hasStaticShape())
return failure();
viewOps.push_back(collapse);
current = collapse.getSrc();
continue;
}
if (auto expand = dyn_cast<memref::ExpandShapeOp>(defOp)) {
auto srcType = dyn_cast<MemRefType>(expand.getSrc().getType());
auto resultType = dyn_cast<MemRefType>(expand.getResult().getType());
if (!srcType || !resultType || !srcType.hasStaticShape() || !resultType.hasStaticShape())
return failure();
viewOps.push_back(expand);
current = expand.getSrc();
continue;
}
return failure();
}
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp || !globalOp.getInitialValue())
return failure();
auto denseAttr = dyn_cast<DenseElementsAttr>(*globalOp.getInitialValue());
if (!denseAttr)
return failure();
DenseWeightView view;
view.denseAttr = denseAttr;
view.shape.assign(denseAttr.getType().getShape().begin(), denseAttr.getType().getShape().end());
view.strides = computeRowMajorStrides(view.shape);
for (Operation* viewOp : llvm::reverse(viewOps)) {
if (auto subview = dyn_cast<memref::SubViewOp>(viewOp)) {
SmallVector<int64_t> nextStrides;
nextStrides.reserve(subview.getStaticStrides().size());
for (auto [offset, stride, sourceStride] :
llvm::zip_equal(subview.getStaticOffsets(), subview.getStaticStrides(), view.strides)) {
view.offset += offset * sourceStride;
nextStrides.push_back(stride * sourceStride);
}
view.shape.assign(subview.getStaticSizes().begin(), subview.getStaticSizes().end());
view.strides = std::move(nextStrides);
continue;
}
// Collapse/expand are accepted only as contiguous static reshapes of a
// dense global view, so a row-major stride recomputation preserves layout.
if (auto collapse = dyn_cast<memref::CollapseShapeOp>(viewOp)) {
if (view.strides != computeRowMajorStrides(view.shape))
return failure();
auto resultType = cast<MemRefType>(collapse.getResult().getType());
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
view.strides = computeRowMajorStrides(view.shape);
continue;
}
if (auto expand = dyn_cast<memref::ExpandShapeOp>(viewOp)) {
if (view.strides != computeRowMajorStrides(view.shape))
return failure();
auto resultType = cast<MemRefType>(expand.getResult().getType());
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
view.strides = computeRowMajorStrides(view.shape);
continue;
}
}
return view;
}
SmallVector<unsigned, 8> getUsedWeightIndices(Block& block) {
SmallVector<unsigned, 8> indices;
auto addIndex = [&](unsigned weightIndex) {
if (!llvm::is_contained(indices, weightIndex))
indices.push_back(weightIndex);
};
block.walk([&](pim::PimVMMOp vmmOp) { addIndex(vmmOp.getWeightIndex()); });
llvm::sort(indices);
return indices;
}
SmallVector<unsigned, 8> getUsedWeightIndices(pim::PimCoreOp coreOp) {
return getUsedWeightIndices(coreOp.getBody().front());
}
SmallVector<Operation*> collectTopLevelCoreLikeOps(func::FuncOp funcOp) {
SmallVector<Operation*> coreLikeOps;
for (Operation& op : funcOp.getBody().front())
if (dyn_cast<pim::PimCoreOp>(&op) || dyn_cast<pim::PimCoreBatchOp>(&op))
coreLikeOps.push_back(&op);
return coreLikeOps;
}
} // namespace
llvm::DenseMap<size_t, llvm::DenseMap<mlir::Value, std::string>>
createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) {
ModuleOp moduleOp = funcOp->getParentOfType<ModuleOp>();
auto coreWeightsDirPath = outputDirPath + "/weights";
auto error = sys::fs::create_directory(coreWeightsDirPath);
assert(!error && "Error creating weights directory");
size_t indexFileName = 0;
int64_t xbarSize = crossbarSize.getValue();
llvm::DenseMap<size_t, llvm::DenseMap<mlir::Value, std::string>> mapCoreWeightToFileName;
llvm::DenseMap<memref::GlobalOp, std::string> mapGlobalOpToFileName;
SmallVector<Operation*> coreLikeOps = collectTopLevelCoreLikeOps(funcOp);
for (Operation* op : coreLikeOps) {
auto processCore = [&](pim::PimCoreOp coreOp) {
size_t coreId = static_cast<size_t>(coreOp.getCoreId());
for (unsigned index : getUsedWeightIndices(coreOp)) {
if (index >= coreOp.getWeights().size()) {
coreOp.emitWarning("Weight index " + std::to_string(index) + " is out of range");
assert(index < coreOp.getWeights().size() && "Weight index is out of range");
}
mlir::Value weight = coreOp.getWeights()[index];
auto weightView = resolveDenseWeightView(moduleOp, weight);
if (failed(weightView)) {
coreOp.emitWarning("Weight is not from a memref.get_global at index " + std::to_string(index));
assert(succeeded(weightView) && "Weight is not from a dense memref.global view");
}
if (mapCoreWeightToFileName[coreId].contains(weight))
continue;
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
auto globalOp = getGlobalOp ? lookupGlobalForGetGlobal(moduleOp, getGlobalOp) : memref::GlobalOp {};
if (globalOp && mapGlobalOpToFileName.contains(globalOp)) {
auto& fileName = mapGlobalOpToFileName[globalOp];
mapCoreWeightToFileName[coreId].insert({weight, fileName});
continue;
}
DenseElementsAttr denseAttr = weightView->denseAttr;
ArrayRef<int64_t> shape = weightView->shape;
assert(isMatrixShape(shape) && "Weight matrix must be 2-dimensional");
int64_t numRows = shape[0];
int64_t numCols = shape[1];
assert(numRows <= xbarSize && numCols <= xbarSize && "Weight dimensions must not exceed crossbar size");
size_t elementByteWidth = denseAttr.getElementType().getIntOrFloatBitWidth() / 8;
std::string newFileName = "crossbar_" + std::to_string(indexFileName++) + ".bin";
auto weightFilePath = (coreWeightsDirPath + "/" + newFileName).str();
std::error_code errorCode;
raw_fd_ostream weightFileStream(weightFilePath, errorCode, sys::fs::OF_None);
if (errorCode) {
errs() << "Error while opening weight file `" << weightFilePath << "`: " << errorCode.message() << '\n';
assert(errorCode);
}
uint64_t zero = 0;
for (int64_t row = 0; row < xbarSize; row++) {
for (int64_t col = 0; col < xbarSize; col++) {
if (row < numRows && col < numCols) {
int64_t elementIndex = weightView->offset + row * weightView->strides[0] + col * weightView->strides[1];
APInt bits = denseAttr.getValues<APFloat>()[elementIndex].bitcastToAPInt();
uint64_t word = bits.getZExtValue();
weightFileStream.write(reinterpret_cast<const char*>(&word), elementByteWidth);
}
else {
weightFileStream.write(reinterpret_cast<const char*>(&zero), elementByteWidth);
}
}
}
weightFileStream.close();
if (globalOp)
mapGlobalOpToFileName.insert({globalOp, newFileName});
mapCoreWeightToFileName[coreId].insert({weight, newFileName});
}
return success();
};
if (auto coreOp = dyn_cast<pim::PimCoreOp>(op)) {
(void) processCore(coreOp);
continue;
}
auto coreBatchOp = cast<pim::PimCoreBatchOp>(op);
for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane)
if (failed(withScalarCoreFromBatchLane(coreBatchOp, lane, processCore)))
return mapCoreWeightToFileName;
}
return mapCoreWeightToFileName;
}
} // namespace onnx_mlir
+16
View File
@@ -0,0 +1,16 @@
#pragma once
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/StringRef.h"
#include <string>
namespace onnx_mlir {
llvm::DenseMap<size_t, llvm::DenseMap<mlir::Value, std::string>>
createAndPopulateWeightFolder(mlir::func::FuncOp funcOp, llvm::StringRef outputDirPath);
} // namespace onnx_mlir
@@ -3,6 +3,11 @@ mlir_tablegen(ONNXToSpatial.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
add_public_tablegen_target(ONNXToSpatialIncGen)
add_pim_library(OMONNXToSpatial
ConversionPatterns.cpp
HostFoldability.cpp
HostLegality.cpp
PrePatterns.cpp
PostPatterns.cpp
Patterns/Math/Conv.cpp
Patterns/Math/Elementwise.cpp
Patterns/Math/Gemm.cpp
@@ -18,7 +23,9 @@ add_pim_library(OMONNXToSpatial
Patterns/Tensor/Reshape.cpp
Patterns/Tensor/Split.cpp
ONNXToSpatialPass.cpp
Common.cpp
Common/ComputeRegionBuilder.cpp
Common/ShapeTilingUtils.cpp
Common/WeightMaterialization.cpp
EXCLUDE_FROM_OM_LIBS
-279
View File
@@ -1,279 +0,0 @@
#pragma once
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Transforms/DialectConversion.h"
#include <cassert>
#include <type_traits>
#include <utility>
#include "llvm/ADT/SmallPtrSet.h"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
namespace onnx_mlir {
template <class ShapedType>
inline auto getImageWidth(const ShapedType& shapedType) {
return shapedType.getDimSize(2);
}
template <class ShapedType>
inline auto getImageHeight(const ShapedType& shapedType) {
return shapedType.getDimSize(3);
}
template <class ShapedType>
inline auto getImageChannel(const ShapedType& shapedType) {
return shapedType.getDimSize(1);
}
template <class ShapedType>
inline auto getImageN(const ShapedType& shapedType) {
return shapedType.getDimSize(0);
}
template <class ShapedType>
inline auto getKernelWidth(const ShapedType& shapedType) {
return shapedType.getDimSize(2);
}
template <class ShapedType>
inline auto getKernelHeight(const ShapedType& shapedType) {
return shapedType.getDimSize(3);
}
template <class ShapedType>
inline auto getFilterCount(const ShapedType& shapedType) {
return shapedType.getDimSize(0);
}
using HSliceId = size_t;
using CoreId = size_t;
template <class A, class B, class C = std::common_type_t<A, B>>
constexpr C ceilIntegerDivide(A a, B b) {
static_assert(std::is_integral_v<A>, "A must be an integer type");
static_assert(std::is_integral_v<B>, "B must be an integer type");
C ac = static_cast<C>(a);
C bc = static_cast<C>(b);
return 1 + (ac - 1) / bc;
}
template <class A, class B, class C = std::common_type_t<A, B>>
constexpr std::pair<C, C> ceilIntegerDivideWithRemainder(A a, B b) {
static_assert(std::is_integral_v<A>, "A must be an integer type");
static_assert(std::is_integral_v<B>, "B must be an integer type");
C ac = static_cast<C>(a);
C bc = static_cast<C>(b);
return {ceilIntegerDivide(ac, bc), ac % bc};
}
template <class T>
bool isVectorShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2 && (shape[0] == 1 || shape[1] == 1);
}
template <class T>
bool isMatrixShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2;
}
template <class T>
bool isHVectorShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2 && shape[0] == 1;
}
template <class T>
bool isVVectorShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2 && shape[1] == 1;
}
template <class T>
T getVectorLength(mlir::ArrayRef<T> shape) {
assert(isVectorShape(shape));
return shape[0] != 1 ? shape[0] : shape[1];
}
inline auto getTensorShape(mlir::Value tensor) {
return mlir::cast<mlir::RankedTensorType>(tensor.getType()).getShape();
}
inline bool isWeightLikeComputeOperand(mlir::Value value) {
auto rankedType = mlir::dyn_cast<mlir::RankedTensorType>(value.getType());
if (!rankedType || !isMatrixShape(rankedType.getShape()))
return false;
llvm::SmallPtrSet<mlir::Operation*, 8> visited;
while (auto* definingOp = value.getDefiningOp()) {
if (!visited.insert(definingOp).second)
return false;
if (hasWeightAlways(definingOp))
return true;
if (auto extractSliceOp = mlir::dyn_cast<mlir::tensor::ExtractSliceOp>(definingOp)) {
value = extractSliceOp.getSource();
continue;
}
if (auto expandShapeOp = mlir::dyn_cast<mlir::tensor::ExpandShapeOp>(definingOp)) {
value = expandShapeOp.getSrc();
continue;
}
if (auto collapseShapeOp = mlir::dyn_cast<mlir::tensor::CollapseShapeOp>(definingOp)) {
value = collapseShapeOp.getSrc();
continue;
}
if (auto transposeOp = mlir::dyn_cast<mlir::ONNXTransposeOp>(definingOp)) {
value = transposeOp.getData();
continue;
}
return false;
}
return false;
}
namespace detail {
inline mlir::ValueRange getBlockArgs(mlir::Block* block) { return mlir::ValueRange(block->getArguments()); }
template <typename Fn, size_t... Is>
decltype(auto) invokeWithBlockArgs(Fn&& fn, mlir::Block* block, std::index_sequence<Is...>) {
return std::forward<Fn>(fn)(block->getArgument(Is)...);
}
template <typename Fn, size_t... Is>
decltype(auto) invokeWithValues(Fn&& fn, mlir::ArrayRef<mlir::Value> values, std::index_sequence<Is...>) {
return std::forward<Fn>(fn)(values[Is]...);
}
template <size_t>
using ValueArg = mlir::Value;
template <typename Fn, typename Seq>
struct InvokeWithBlockArgsResult;
template <typename Fn, size_t... Is>
struct InvokeWithBlockArgsResult<Fn, std::index_sequence<Is...>> {
using type = std::invoke_result_t<Fn, ValueArg<Is>...>;
};
template <typename Fn, typename Seq>
using InvokeWithBlockArgsResultT = typename InvokeWithBlockArgsResult<Fn, Seq>::type;
template <typename Fn>
using InvokeWithValueRangeResultT = std::invoke_result_t<Fn, mlir::ValueRange>;
} // namespace detail
template <size_t NumInputs, typename RewriterT, typename BodyFn>
auto createSpatCompute(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
assert(inputs.size() == NumInputs && "NumInputs must match the number of input values");
auto computeOp = spatial::SpatWeightedCompute::create(rewriter, loc, resultTypes, weights, inputs);
auto* block = new mlir::Block();
for (mlir::Value input : inputs)
block->addArgument(input.getType(), loc);
computeOp.getBody().push_back(block);
rewriter.setInsertionPointToStart(block);
using BodyResult = detail::InvokeWithBlockArgsResultT<std::decay_t<BodyFn>, std::make_index_sequence<NumInputs>>;
if constexpr (std::is_same_v<BodyResult, mlir::LogicalResult>) {
auto bodyResult =
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
if (mlir::failed(bodyResult)) {
rewriter.setInsertionPointAfter(computeOp);
rewriter.eraseOp(computeOp);
return mlir::FailureOr<spatial::SpatWeightedCompute>(mlir::failure());
}
rewriter.setInsertionPointAfter(computeOp);
return mlir::FailureOr<spatial::SpatWeightedCompute>(computeOp);
}
else {
static_assert(std::is_same_v<BodyResult, void>, "createSpatCompute body must return void or mlir::LogicalResult");
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
rewriter.setInsertionPointAfter(computeOp);
return computeOp;
}
}
template <typename RewriterT, typename BodyFn>
auto createSpatCompute(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
auto computeOp = spatial::SpatWeightedCompute::create(rewriter, loc, resultTypes, weights, inputs);
auto* block = new mlir::Block();
for (mlir::Value input : inputs)
block->addArgument(input.getType(), loc);
computeOp.getBody().push_back(block);
rewriter.setInsertionPointToStart(block);
using BodyResult = detail::InvokeWithValueRangeResultT<std::decay_t<BodyFn>>;
if constexpr (std::is_same_v<BodyResult, mlir::LogicalResult>) {
auto bodyResult = std::forward<BodyFn>(body)(detail::getBlockArgs(block));
if (mlir::failed(bodyResult)) {
rewriter.setInsertionPointAfter(computeOp);
rewriter.eraseOp(computeOp);
return mlir::FailureOr<spatial::SpatWeightedCompute>(mlir::failure());
}
rewriter.setInsertionPointAfter(computeOp);
return mlir::FailureOr<spatial::SpatWeightedCompute>(computeOp);
}
else {
static_assert(std::is_same_v<BodyResult, void>, "createSpatCompute body must return void or mlir::LogicalResult");
std::forward<BodyFn>(body)(detail::getBlockArgs(block));
rewriter.setInsertionPointAfter(computeOp);
return computeOp;
}
}
llvm::SmallVector<mlir::Value> sliceTensor(const mlir::Value& tensorToSlice,
size_t axis,
int64_t sliceSize,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location loc);
llvm::SmallVector<mlir::Value> sliceVector(const mlir::Value& vectorToSlice,
int64_t sliceSize,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location loc);
llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>> sliceVectorPerCrossbarPerCore(
const mlir::Value& vectorToSlice, mlir::ConversionPatternRewriter& rewriter, mlir::Location loc);
llvm::DenseMap<HSliceId, llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>>>
tileMatrix(mlir::Value& matrixToTile,
int64_t hSliceSize,
int64_t vSliceSize,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location& loc);
mlir::tensor::SplatOp broadcastToVector(mlir::Value scalarToBroadcast,
int64_t length,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location loc);
mlir::Value sumTensors(mlir::ArrayRef<mlir::Value> tensors, mlir::ConversionPatternRewriter& rewriter);
}; // namespace onnx_mlir
@@ -0,0 +1,7 @@
#pragma once
#include "ComputeRegionBuilder.hpp"
#include "ShapeTilingUtils.hpp"
#include "WeightMaterialization.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -0,0 +1,39 @@
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVector.h"
#include "ComputeRegionBuilder.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
Value sumTensors(ArrayRef<Value> tensors, ConversionPatternRewriter& rewriter) {
if (tensors.size() == 1)
return tensors[0];
SmallVector<Value> tensors1 = {tensors.begin(), tensors.end()};
SmallVector<Value> tensors2;
tensors2.reserve(tensors.size() / 2);
auto* currTensors = &tensors1;
auto* nextTensors = &tensors2;
while (currTensors->size() > 1) {
for (size_t i = 0; i < currTensors->size() - 1; i += 2) {
Value a = (*currTensors)[i];
Value b = (*currTensors)[i + 1];
rewriter.setInsertionPointAfterValue(b);
auto addedValue = spatial::SpatVAddOp::create(rewriter, a.getLoc(), a.getType(), a, b);
nextTensors->push_back(addedValue);
}
if (currTensors->size() % 2 == 1)
nextTensors->push_back(currTensors->back());
std::swap(currTensors, nextTensors);
nextTensors->clear();
}
assert(currTensors->size() == 1 && "Expected a single input at this point.");
return (*currTensors)[0];
}
} // namespace onnx_mlir
@@ -0,0 +1,153 @@
#pragma once
#include "mlir/IR/Block.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Transforms/DialectConversion.h"
#include <cassert>
#include <cstddef>
#include <type_traits>
#include <utility>
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
namespace onnx_mlir {
namespace detail {
inline mlir::ValueRange getBlockArgs(mlir::Block* block) { return mlir::ValueRange(block->getArguments()); }
template <typename Fn, size_t... Is>
decltype(auto) invokeWithBlockArgs(Fn&& fn, mlir::Block* block, std::index_sequence<Is...>) {
return std::forward<Fn>(fn)(block->getArgument(Is)...);
}
template <typename Fn, size_t... Is>
decltype(auto) invokeWithValues(Fn&& fn, mlir::ArrayRef<mlir::Value> values, std::index_sequence<Is...>) {
return std::forward<Fn>(fn)(values[Is]...);
}
template <size_t>
using ValueArg = mlir::Value;
template <typename Fn, typename Seq>
struct InvokeWithBlockArgsResult;
template <typename Fn, size_t... Is>
struct InvokeWithBlockArgsResult<Fn, std::index_sequence<Is...>> {
using type = std::invoke_result_t<Fn, ValueArg<Is>...>;
};
template <typename Fn, typename Seq>
using InvokeWithBlockArgsResultT = typename InvokeWithBlockArgsResult<Fn, Seq>::type;
template <typename Fn>
using InvokeWithValueRangeResultT = std::invoke_result_t<Fn, mlir::ValueRange>;
} // namespace detail
template <typename RewriterT>
inline mlir::Value createSpatConcat(RewriterT& rewriter, mlir::Location loc, int64_t axis, mlir::ValueRange inputs) {
assert(!inputs.empty() && "spat.concat requires at least one input");
if (inputs.size() == 1)
return inputs.front();
auto firstType = mlir::cast<mlir::RankedTensorType>(inputs.front().getType());
auto outputShape = llvm::to_vector(firstType.getShape());
int64_t concatDimSize = 0;
bool concatDimDynamic = false;
for (mlir::Value input : inputs) {
auto inputType = mlir::cast<mlir::RankedTensorType>(input.getType());
assert(inputType.getRank() == firstType.getRank() && "spat.concat expects same-rank inputs");
if (mlir::ShapedType::isDynamic(inputType.getDimSize(axis)))
concatDimDynamic = true;
else
concatDimSize += inputType.getDimSize(axis);
}
outputShape[axis] = concatDimDynamic ? mlir::ShapedType::kDynamic : concatDimSize;
auto outputType = mlir::RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
return spatial::SpatConcatOp::create(rewriter, loc, outputType, rewriter.getI64IntegerAttr(axis), inputs).getOutput();
}
/// Builds a `spat.compute` with a fixed number of SSA inputs and erases it if
/// the body callback reports failure.
template <size_t NumInputs, typename RewriterT, typename BodyFn>
auto createSpatCompute(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
assert(inputs.size() == NumInputs && "NumInputs must match the number of input values");
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
auto* block = new mlir::Block();
for (mlir::Value input : inputs)
block->addArgument(input.getType(), loc);
computeOp.getBody().push_back(block);
rewriter.setInsertionPointToStart(block);
using BodyResult = detail::InvokeWithBlockArgsResultT<std::decay_t<BodyFn>, std::make_index_sequence<NumInputs>>;
if constexpr (std::is_same_v<BodyResult, void>) {
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
rewriter.setInsertionPointAfter(computeOp);
return computeOp;
}
else {
auto bodyResult =
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
if (mlir::failed(bodyResult)) {
rewriter.setInsertionPointAfter(computeOp);
rewriter.eraseOp(computeOp);
return mlir::FailureOr<spatial::SpatCompute>(mlir::failure());
}
rewriter.setInsertionPointAfter(computeOp);
return mlir::FailureOr<spatial::SpatCompute>(computeOp);
}
}
/// Builds a `spat.compute` whose body consumes the block arguments as a single
/// `ValueRange`, which is convenient for variadic reductions/concats.
template <typename RewriterT, typename BodyFn>
auto createSpatCompute(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
auto* block = new mlir::Block();
for (mlir::Value input : inputs)
block->addArgument(input.getType(), loc);
computeOp.getBody().push_back(block);
rewriter.setInsertionPointToStart(block);
using BodyResult = detail::InvokeWithValueRangeResultT<std::decay_t<BodyFn>>;
if constexpr (std::is_same_v<BodyResult, void>) {
std::forward<BodyFn>(body)(detail::getBlockArgs(block));
rewriter.setInsertionPointAfter(computeOp);
return computeOp;
}
else {
auto bodyResult = std::forward<BodyFn>(body)(detail::getBlockArgs(block));
if (mlir::failed(bodyResult)) {
rewriter.setInsertionPointAfter(computeOp);
rewriter.eraseOp(computeOp);
return mlir::FailureOr<spatial::SpatCompute>(mlir::failure());
}
rewriter.setInsertionPointAfter(computeOp);
return mlir::FailureOr<spatial::SpatCompute>(computeOp);
}
}
mlir::Value sumTensors(mlir::ArrayRef<mlir::Value> tensors, mlir::ConversionPatternRewriter& rewriter);
} // namespace onnx_mlir
@@ -1,24 +1,12 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/Casting.h"
#include <cassert>
#include <optional>
#include <utility>
#include "Common.hpp"
#include "ShapeTilingUtils.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
using namespace mlir;
@@ -44,10 +32,29 @@ SmallVector<Value> sliceTensor(
for (int64_t i = 0; i < numSlices; i++) {
offsets[axis] = rewriter.getIndexAttr(i * sliceSize);
if (i == numSlices - 1 && lastSliceSize != 0)
int64_t currentSliceSize = sliceSize;
if (i == numSlices - 1 && lastSliceSize != 0) {
currentSliceSize = lastSliceSize;
sizes[axis] = rewriter.getIndexAttr(lastSliceSize);
}
Value slice = tensor::ExtractSliceOp::create(rewriter, loc, tensorToSlice, offsets, sizes, strides);
SmallVector<int64_t> sliceShape(shape.begin(), shape.end());
sliceShape[axis] = currentSliceSize;
auto sliceType =
RankedTensorType::get(sliceShape, cast<RankedTensorType>(tensorToSlice.getType()).getElementType());
Value slice;
if (isHostFoldableValue(tensorToSlice)) {
slice = tensor::ExtractSliceOp::create(rewriter, loc, tensorToSlice, offsets, sizes, strides);
}
else {
auto sliceCompute =
createSpatCompute<1>(rewriter, loc, TypeRange {sliceType}, {}, ValueRange {tensorToSlice}, [&](Value input) {
Value computedSlice = tensor::ExtractSliceOp::create(rewriter, loc, input, offsets, sizes, strides);
spatial::SpatYieldOp::create(rewriter, loc, computedSlice);
});
slice = sliceCompute.getResult(0);
}
slices.push_back(slice);
}
@@ -93,45 +100,27 @@ DenseMap<HSliceId, DenseMap<CoreId, SmallVector<Value>>> tileMatrix(
return tiles;
}
tensor::SplatOp
broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatternRewriter& rewriter, Location loc) {
Value broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatternRewriter& rewriter, Location loc) {
auto oldType = cast<RankedTensorType>(scalarToBroadcast.getType());
Type elementType = oldType.getElementType();
int64_t shape[2] = {1, length};
Type type = oldType.cloneWith(ArrayRef(shape), elementType);
auto buildBroadcast = [&](Value input) -> Value {
auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0).getResult();
SmallVector<Value> index(oldType.getRank(), zero);
auto elementValue = tensor::ExtractOp::create(rewriter, loc, scalarToBroadcast, index).getResult();
auto elementValue = tensor::ExtractOp::create(rewriter, loc, input, index).getResult();
return tensor::SplatOp::create(rewriter, loc, type, elementValue);
};
if (isHostFoldableValue(scalarToBroadcast))
return buildBroadcast(scalarToBroadcast);
auto broadcastCompute =
createSpatCompute<1>(rewriter, loc, TypeRange {type}, {}, ValueRange {scalarToBroadcast}, [&](Value input) {
spatial::SpatYieldOp::create(rewriter, loc, buildBroadcast(input));
});
return broadcastCompute.getResult(0);
}
Value sumTensors(ArrayRef<Value> tensors, ConversionPatternRewriter& rewriter) {
if (tensors.size() == 1)
return tensors[0];
SmallVector<Value> tensors1 = {tensors.begin(), tensors.end()};
SmallVector<Value> tensors2;
tensors2.reserve(tensors.size() / 2);
auto* currTensors = &tensors1;
auto* nextTensors = &tensors2;
while (currTensors->size() > 1) {
for (size_t i = 0; i < currTensors->size() - 1; i += 2) {
Value a = (*currTensors)[i];
Value b = (*currTensors)[i + 1];
rewriter.setInsertionPointAfterValue(b);
auto addedValue = spatial::SpatVAddOp::create(rewriter, a.getLoc(), a.getType(), a, b);
nextTensors->push_back(addedValue);
}
if (currTensors->size() % 2 == 1)
nextTensors->push_back(currTensors->back());
std::swap(currTensors, nextTensors);
nextTensors->clear();
}
assert(currTensors->size() == 1 && "Expected a single input at this point.");
return (*currTensors)[0];
}
}; // namespace onnx_mlir
} // namespace onnx_mlir
@@ -0,0 +1,144 @@
#pragma once
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Value.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallVector.h"
#include <cassert>
#include <cstddef>
#include <type_traits>
#include <utility>
namespace onnx_mlir {
template <class ShapedType>
inline auto getImageWidth(const ShapedType& shapedType) {
return shapedType.getDimSize(2);
}
template <class ShapedType>
inline auto getImageHeight(const ShapedType& shapedType) {
return shapedType.getDimSize(3);
}
template <class ShapedType>
inline auto getImageChannel(const ShapedType& shapedType) {
return shapedType.getDimSize(1);
}
template <class ShapedType>
inline auto getImageN(const ShapedType& shapedType) {
return shapedType.getDimSize(0);
}
template <class ShapedType>
inline auto getKernelWidth(const ShapedType& shapedType) {
return shapedType.getDimSize(2);
}
template <class ShapedType>
inline auto getKernelHeight(const ShapedType& shapedType) {
return shapedType.getDimSize(3);
}
template <class ShapedType>
inline auto getFilterCount(const ShapedType& shapedType) {
return shapedType.getDimSize(0);
}
using HSliceId = size_t;
using CoreId = size_t;
template <class A, class B, class C = std::common_type_t<A, B>>
constexpr C ceilIntegerDivide(A a, B b) {
static_assert(std::is_integral_v<A>, "A must be an integer type");
static_assert(std::is_integral_v<B>, "B must be an integer type");
C ac = static_cast<C>(a);
C bc = static_cast<C>(b);
return 1 + (ac - 1) / bc;
}
template <class A, class B, class C = std::common_type_t<A, B>>
constexpr std::pair<C, C> ceilIntegerDivideWithRemainder(A a, B b) {
static_assert(std::is_integral_v<A>, "A must be an integer type");
static_assert(std::is_integral_v<B>, "B must be an integer type");
C ac = static_cast<C>(a);
C bc = static_cast<C>(b);
return {ceilIntegerDivide(ac, bc), ac % bc};
}
template <class T>
bool isVectorShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2 && (shape[0] == 1 || shape[1] == 1);
}
template <class T>
bool isMatrixShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2;
}
template <class T>
bool isHVectorShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2 && shape[0] == 1;
}
template <class T>
bool isVVectorShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2 && shape[1] == 1;
}
template <class T>
T getVectorLength(mlir::ArrayRef<T> shape) {
assert(isVectorShape(shape));
return shape[0] != 1 ? shape[0] : shape[1];
}
inline auto getTensorShape(mlir::Value tensor) {
return mlir::cast<mlir::RankedTensorType>(tensor.getType()).getShape();
}
inline bool haveSameStaticShape(mlir::Value lhs, mlir::Value rhs) {
auto lhsType = mlir::dyn_cast<mlir::RankedTensorType>(lhs.getType());
auto rhsType = mlir::dyn_cast<mlir::RankedTensorType>(rhs.getType());
return lhsType && rhsType && lhsType.hasStaticShape() && rhsType.hasStaticShape()
&& lhsType.getShape() == rhsType.getShape();
}
/// Slices a statically shaped tensor along one axis into contiguous pieces of
/// at most `sliceSize` elements.
llvm::SmallVector<mlir::Value> sliceTensor(const mlir::Value& tensorToSlice,
size_t axis,
int64_t sliceSize,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location loc);
llvm::SmallVector<mlir::Value> sliceVector(const mlir::Value& vectorToSlice,
int64_t sliceSize,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location loc);
/// Partitions one logical vector into per-core crossbar-sized slices using the
/// current PIM target geometry.
llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>> sliceVectorPerCrossbarPerCore(
const mlir::Value& vectorToSlice, mlir::ConversionPatternRewriter& rewriter, mlir::Location loc);
/// Tiles a matrix first across output columns and then across input rows so it
/// can be assigned to crossbars grouped by core.
llvm::DenseMap<HSliceId, llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>>>
tileMatrix(mlir::Value& matrixToTile,
int64_t hSliceSize,
int64_t vSliceSize,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location& loc);
mlir::Value broadcastToVector(mlir::Value scalarToBroadcast,
int64_t length,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location loc);
} // namespace onnx_mlir
@@ -0,0 +1,114 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
#include "ShapeTilingUtils.hpp"
#include "WeightMaterialization.hpp"
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
bool isWeightLikeComputeOperand(Value value) {
auto rankedType = dyn_cast<RankedTensorType>(value.getType());
if (!rankedType || !isMatrixShape(rankedType.getShape()))
return false;
llvm::SmallPtrSet<Operation*, 8> visited;
while (auto* definingOp = value.getDefiningOp()) {
if (!visited.insert(definingOp).second)
return false;
if (isa<arith::ConstantOp, ONNXConstantOp>(definingOp) || hasWeightAlways(definingOp))
return true;
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) {
value = extractSliceOp.getSource();
continue;
}
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(definingOp)) {
value = expandShapeOp.getSrc();
continue;
}
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(definingOp)) {
value = collapseShapeOp.getSrc();
continue;
}
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(definingOp)) {
value = transposeOp.getData();
continue;
}
return false;
}
return false;
}
FailureOr<Value> materializeWeightLikeValueInBlock(Value value, IRRewriter& rewriter, IRMapping& mapper) {
if (auto mapped = mapper.lookupOrNull(value))
return cast<Value>(mapped);
Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return failure();
if (isa<arith::ConstantOp, ONNXConstantOp>(definingOp)) {
auto tensorType = dyn_cast<RankedTensorType>(value.getType());
if (!tensorType || !tensorType.hasStaticShape())
return failure();
SmallVector<OpFoldResult> offsets(tensorType.getRank(), rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> sizes;
SmallVector<OpFoldResult> strides(tensorType.getRank(), rewriter.getIndexAttr(1));
sizes.reserve(tensorType.getRank());
for (int64_t dim : tensorType.getShape())
sizes.push_back(rewriter.getIndexAttr(dim));
auto referencedValue =
tensor::ExtractSliceOp::create(rewriter, value.getLoc(), tensorType, value, offsets, sizes, strides);
mapper.map(value, referencedValue.getResult());
return referencedValue.getResult();
}
if (!isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, ONNXTransposeOp>(definingOp))
return failure();
IRMapping localMapper;
for (Value operand : definingOp->getOperands()) {
if (auto mapped = mapper.lookupOrNull(operand)) {
localMapper.map(operand, cast<Value>(mapped));
continue;
}
if (isWeightLikeComputeOperand(operand)) {
auto clonedOperand = materializeWeightLikeValueInBlock(operand, rewriter, mapper);
if (failed(clonedOperand))
return failure();
localMapper.map(operand, *clonedOperand);
continue;
}
localMapper.map(operand, operand);
}
Operation* clonedOp = rewriter.clone(*definingOp, localMapper);
for (auto [oldResult, newResult] : llvm::zip(definingOp->getResults(), clonedOp->getResults()))
mapper.map(oldResult, newResult);
auto mapped = mapper.lookupOrNull(value);
if (!mapped)
return failure();
return cast<Value>(mapped);
}
} // namespace onnx_mlir
@@ -0,0 +1,18 @@
#pragma once
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
namespace onnx_mlir {
/// Returns true when a matrix-valued compute operand is ultimately backed by a
/// weight-marked constant/view chain and can be promoted into weights.
bool isWeightLikeComputeOperand(mlir::Value value);
/// Rebuilds the view/transpose chain of a promoted weight operand inside a new
/// compute body while reusing already-materialized intermediate values.
llvm::FailureOr<mlir::Value>
materializeWeightLikeValueInBlock(mlir::Value value, mlir::IRRewriter& rewriter, mlir::IRMapping& mapper);
} // namespace onnx_mlir
@@ -0,0 +1,32 @@
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc"
} // namespace
void populateConversionPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx) {
patterns.add<removeLRN>(ctx);
populateElementwisePatterns(patterns, ctx);
populateGemmPatterns(patterns, ctx);
populateConvPatterns(patterns, ctx);
populatePoolPatterns(patterns, ctx);
populateReduceMeanPatterns(patterns, ctx);
populateReluPatterns(patterns, ctx);
populateSigmoidPatterns(patterns, ctx);
populateSoftmaxPatterns(patterns, ctx);
populateConcatPatterns(patterns, ctx);
populateGatherPatterns(patterns, ctx);
populateResizePatterns(patterns, ctx);
populateReshapePatterns(patterns, ctx);
populateSplitPatterns(patterns, ctx);
}
} // namespace onnx_mlir
@@ -5,6 +5,8 @@
namespace onnx_mlir {
void populateConversionPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateConvPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateElementwisePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
@@ -0,0 +1,256 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static bool hasStaticUnitStrides(tensor::ExtractSliceOp extractSliceOp) {
return llvm::all_of(extractSliceOp.getStaticStrides(), [](int64_t stride) { return stride == 1; });
}
static bool hasConstantIndices(tensor::ExtractOp extractOp) {
return llvm::all_of(extractOp.getIndices(),
[](Value index) { return isa_and_nonnull<arith::ConstantIndexOp>(index.getDefiningOp()); });
}
static bool isStaticTensorResult(Operation* op) {
return llvm::all_of(op->getResultTypes(), [](Type type) {
auto shapedType = dyn_cast<ShapedType>(type);
return shapedType && shapedType.hasStaticShape();
});
}
static SmallVector<int64_t> computeRowMajorStrides(ArrayRef<int64_t> shape) {
SmallVector<int64_t> strides(shape.size(), 1);
for (int64_t dim = static_cast<int64_t>(shape.size()) - 2; dim >= 0; --dim)
strides[dim] = strides[dim + 1] * shape[dim + 1];
return strides;
}
static FailureOr<DenseElementsAttr> transposeDenseElements(DenseElementsAttr denseAttr, ArrayRef<int64_t> perms) {
auto tensorType = dyn_cast<RankedTensorType>(denseAttr.getType());
if (!tensorType)
return failure();
int64_t rank = tensorType.getRank();
if (static_cast<int64_t>(perms.size()) != rank)
return failure();
llvm::SmallBitVector seen(rank);
SmallVector<int64_t> transposedShape;
transposedShape.reserve(rank);
for (int64_t perm : perms) {
if (perm < 0 || perm >= rank || seen.test(perm))
return failure();
seen.set(perm);
transposedShape.push_back(tensorType.getShape()[perm]);
}
auto transposedType = RankedTensorType::get(transposedShape, tensorType.getElementType(), tensorType.getEncoding());
if (denseAttr.isSplat())
return DenseElementsAttr::get(transposedType, denseAttr.getSplatValue<Attribute>());
SmallVector<Attribute> originalValues(denseAttr.getValues<Attribute>());
SmallVector<Attribute> transposedValues(originalValues.size());
SmallVector<int64_t> originalStrides = computeRowMajorStrides(tensorType.getShape());
SmallVector<int64_t> transposedStrides = computeRowMajorStrides(transposedShape);
SmallVector<int64_t> originalIndices(rank);
for (auto [linearIndex, value] : llvm::enumerate(originalValues)) {
int64_t remaining = static_cast<int64_t>(linearIndex);
for (int64_t dim = 0; dim < rank; ++dim) {
originalIndices[dim] = remaining / originalStrides[dim];
remaining %= originalStrides[dim];
}
int64_t transposedLinearIndex = 0;
for (int64_t dim = 0; dim < rank; ++dim)
transposedLinearIndex += originalIndices[perms[dim]] * transposedStrides[dim];
transposedValues[transposedLinearIndex] = value;
}
return DenseElementsAttr::get(transposedType, transposedValues);
}
static FailureOr<DenseElementsAttr> reshapeDenseElements(DenseElementsAttr denseAttr, RankedTensorType resultType) {
auto sourceType = dyn_cast<RankedTensorType>(denseAttr.getType());
if (!sourceType || !resultType || sourceType.getNumElements() != resultType.getNumElements())
return failure();
if (denseAttr.isSplat())
return DenseElementsAttr::get(resultType, denseAttr.getSplatValue<Attribute>());
SmallVector<Attribute> values(denseAttr.getValues<Attribute>());
return DenseElementsAttr::get(resultType, values);
}
static FailureOr<DenseElementsAttr> extractSliceDenseElements(DenseElementsAttr denseAttr,
tensor::ExtractSliceOp extractSliceOp) {
auto sourceType = dyn_cast<RankedTensorType>(denseAttr.getType());
auto resultType = dyn_cast<RankedTensorType>(extractSliceOp.getType());
if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape())
return failure();
ArrayRef<int64_t> offsets = extractSliceOp.getStaticOffsets();
ArrayRef<int64_t> sizes = extractSliceOp.getStaticSizes();
ArrayRef<int64_t> strides = extractSliceOp.getStaticStrides();
if (llvm::any_of(offsets, [](int64_t value) { return ShapedType::isDynamic(value); })
|| llvm::any_of(sizes, [](int64_t value) { return ShapedType::isDynamic(value); })
|| llvm::any_of(strides, [](int64_t stride) { return ShapedType::isDynamic(stride) || stride != 1; }))
return failure();
if (denseAttr.isSplat())
return DenseElementsAttr::get(resultType, denseAttr.getSplatValue<Attribute>());
SmallVector<Attribute> sourceValues(denseAttr.getValues<Attribute>());
SmallVector<int64_t> sourceStrides = computeRowMajorStrides(sourceType.getShape());
SmallVector<int64_t> resultStrides = computeRowMajorStrides(resultType.getShape());
SmallVector<Attribute> resultValues;
resultValues.reserve(resultType.getNumElements());
for (int64_t linearIndex = 0; linearIndex < resultType.getNumElements(); ++linearIndex) {
int64_t remaining = linearIndex;
int64_t sourceLinearIndex = 0;
for (int64_t dim = 0; dim < resultType.getRank(); ++dim) {
const int64_t resultIndex = resultStrides.empty() ? 0 : remaining / resultStrides[dim];
remaining = resultStrides.empty() ? 0 : remaining % resultStrides[dim];
sourceLinearIndex += (offsets[dim] + resultIndex) * sourceStrides[dim];
}
resultValues.push_back(sourceValues[sourceLinearIndex]);
}
return DenseElementsAttr::get(resultType, resultValues);
}
static DenseElementsAttr getDirectDenseConstantAttr(Value value) {
if (auto constantOp = value.getDefiningOp<arith::ConstantOp>())
return dyn_cast<DenseElementsAttr>(constantOp.getValue());
if (auto constantOp = value.getDefiningOp<ONNXConstantOp>())
return dyn_cast_or_null<DenseElementsAttr>(constantOp.getValueAttr());
return nullptr;
}
static DenseElementsAttr getHostFoldableDenseElementsAttrImpl(Value value, llvm::SmallPtrSetImpl<Operation*>& visited) {
auto* definingOp = value.getDefiningOp();
if (!definingOp || !visited.insert(definingOp).second)
return nullptr;
// Rebuild dense attributes through view-only host-foldable chains so later
// lowering stages can still recognize grouped/sliced constants.
if (auto denseAttr = getDirectDenseConstantAttr(value))
return denseAttr;
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(definingOp)) {
auto inputAttr = getHostFoldableDenseElementsAttrImpl(transposeOp.getData(), visited);
if (!inputAttr)
return nullptr;
SmallVector<int64_t> perm;
perm.reserve(transposeOp.getPermAttr().size());
for (IntegerAttr attr : transposeOp.getPermAttr().getAsRange<IntegerAttr>())
perm.push_back(attr.getInt());
auto transposedAttr = transposeDenseElements(inputAttr, perm);
return succeeded(transposedAttr) ? *transposedAttr : nullptr;
}
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(definingOp)) {
auto inputAttr = getHostFoldableDenseElementsAttrImpl(collapseShapeOp.getSrc(), visited);
if (!inputAttr)
return nullptr;
auto reshapedAttr = reshapeDenseElements(inputAttr, cast<RankedTensorType>(collapseShapeOp.getType()));
return succeeded(reshapedAttr) ? *reshapedAttr : nullptr;
}
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(definingOp)) {
auto inputAttr = getHostFoldableDenseElementsAttrImpl(expandShapeOp.getSrc(), visited);
if (!inputAttr)
return nullptr;
auto reshapedAttr = reshapeDenseElements(inputAttr, cast<RankedTensorType>(expandShapeOp.getType()));
return succeeded(reshapedAttr) ? *reshapedAttr : nullptr;
}
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) {
auto inputAttr = getHostFoldableDenseElementsAttrImpl(extractSliceOp.getSource(), visited);
if (!inputAttr)
return nullptr;
auto slicedAttr = extractSliceDenseElements(inputAttr, extractSliceOp);
return succeeded(slicedAttr) ? *slicedAttr : nullptr;
}
return nullptr;
}
static bool isHostFoldableOpImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*>& visited) {
if (!op || !visited.insert(op).second)
return false;
if (isa<arith::ConstantOp, ONNXConstantOp, ONNXNoneOp>(op))
return true;
if (auto extractOp = dyn_cast<tensor::ExtractOp>(op))
return hasConstantIndices(extractOp) && isHostFoldableValue(extractOp.getTensor());
if (!isStaticTensorResult(op))
return false;
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(op))
return isHostFoldableValue(transposeOp.getData());
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(op))
return isHostFoldableValue(collapseShapeOp.getSrc());
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(op))
return isHostFoldableValue(expandShapeOp.getSrc());
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op))
return hasStaticUnitStrides(extractSliceOp) && isHostFoldableValue(extractSliceOp.getSource());
if (auto splatOp = dyn_cast<tensor::SplatOp>(op))
return isHostFoldableValue(splatOp.getInput());
if (auto extractRowsOp = dyn_cast<spatial::SpatExtractRowsOp>(op))
return isHostFoldableValue(extractRowsOp.getInput());
if (auto concatOp = dyn_cast<spatial::SpatConcatOp>(op))
return llvm::all_of(concatOp.getInputs(), isHostFoldableValue);
return false;
}
} // namespace
bool isHostFoldableValue(Value value) {
auto* definingOp = value.getDefiningOp();
if (!definingOp)
return false;
llvm::SmallPtrSet<Operation*, 8> visited;
return isHostFoldableOpImpl(definingOp, visited);
}
bool isHostFoldableOp(Operation* op) {
llvm::SmallPtrSet<Operation*, 8> visited;
return isHostFoldableOpImpl(op, visited);
}
DenseElementsAttr getHostFoldableDenseElementsAttr(Value value) {
llvm::SmallPtrSet<Operation*, 8> visited;
return getHostFoldableDenseElementsAttrImpl(value, visited);
}
} // namespace onnx_mlir
@@ -0,0 +1,15 @@
#pragma once
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Value.h"
namespace onnx_mlir {
bool isHostFoldableValue(mlir::Value value);
bool isHostFoldableOp(mlir::Operation* op);
mlir::DenseElementsAttr getHostFoldableDenseElementsAttr(mlir::Value value);
} // namespace onnx_mlir
@@ -0,0 +1,34 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostLegality.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
LogicalResult verifyONNXToSpatialHostLegality(func::FuncOp funcOp) {
pim::CappedDiagnosticReporter diagnostics;
for (Operation& op : funcOp.getFunctionBody().front()) {
if (isa<func::ReturnOp, spatial::SpatCompute, spatial::SpatComputeBatch>(&op))
continue;
if (isHostFoldableOp(&op))
continue;
diagnostics.report(&op, [](Operation* illegalOp) {
illegalOp->emitOpError("non-foldable top-level runtime op remains after ONNX-to-Spatial; lower it inside "
"spat.compute");
});
}
diagnostics.emitSuppressedSummary(funcOp, "ONNX-to-Spatial host legality failures");
return success(!diagnostics.hasFailure());
}
} // namespace onnx_mlir
@@ -0,0 +1,10 @@
#pragma once
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Support/LogicalResult.h"
namespace onnx_mlir {
mlir::LogicalResult verifyONNXToSpatialHostLegality(mlir::func::FuncOp funcOp);
} // namespace onnx_mlir
@@ -1,31 +1,23 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_os_ostream.h"
#include <fstream>
#include <iterator>
#include <utility>
#include "Common.hpp"
#include "Common/Common.hpp"
#include "Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostLegality.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.hpp"
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
#include "src/Compiler/CompilerOptions.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -33,12 +25,8 @@ using namespace mlir;
namespace onnx_mlir {
bool haveSameStaticShape(Value lhs, Value rhs);
namespace {
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc"
struct ONNXToSpatialPass : PassWrapper<ONNXToSpatialPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ONNXToSpatialPass)
StringRef getArgument() const override { return "convert-onnx-to-spatial"; }
@@ -48,35 +36,117 @@ struct ONNXToSpatialPass : PassWrapper<ONNXToSpatialPass, OperationPass<ModuleOp
ONNXToSpatialPass(const ONNXToSpatialPass& pass) {}
void runOnOperation() override;
private:
void annotateWeightsConstants(func::FuncOp funcOp) const;
void encapsulateGlobalInstruction(func::FuncOp funcOp);
void mergeTriviallyConnectedComputes(func::FuncOp funcOp);
LogicalResult promoteConstantInputsToWeights(func::FuncOp funcOp);
};
} // namespace
static void populateEmptyFunction(func::FuncOp funcOp) {
IRRewriter rewriter(funcOp.getContext());
IRMapping mapper;
SmallVector<spatial::SpatCompute> computes(funcOp.getOps<spatial::SpatCompute>());
if (!computes.empty())
return;
auto returnOp = cast<func::ReturnOp>(funcOp.getFunctionBody().front().getTerminator());
rewriter.setInsertionPoint(returnOp);
SmallVector<Type> sourceTypes;
SmallVector<Location> sourceLocs;
sourceTypes.reserve(funcOp.getNumArguments());
sourceLocs.reserve(funcOp.getNumArguments());
for (Value source : funcOp.getArguments()) {
sourceTypes.push_back(source.getType());
sourceLocs.push_back(source.getLoc());
}
auto newCompute = spatial::SpatCompute::create(
rewriter, returnOp.getLoc(), returnOp.getOperandTypes(), funcOp.getArguments(), {}, {});
auto* newBlock = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLocs);
for (auto [blockArg, computeArg] : llvm::zip(newBlock->getArguments(), newCompute.getOperands()))
mapper.map(computeArg, blockArg);
newCompute.getProperties().setOperandSegmentSizes({0, static_cast<int>(sourceTypes.size())});
rewriter.setInsertionPointToEnd(newBlock);
for (Operation& op : funcOp.getOps())
if (!isa<spatial::SpatCompute, func::ReturnOp>(&op))
rewriter.clone(op, mapper);
auto yield = spatial::SpatYieldOp::create(rewriter, funcOp.getLoc(), returnOp.getOperands());
for (size_t i = 0; i < yield.getNumOperands(); ++i)
yield.setOperand(i, mapper.lookupOrDefault(yield.getOperand(i)));
for (Operation& op : llvm::make_early_inc_range(funcOp.getOps()))
if (!isa<spatial::SpatCompute, func::ReturnOp>(&op)) {
op.dropAllUses();
rewriter.eraseOp(&op);
}
for (auto [index, computeResult] : llvm::enumerate(newCompute.getResults()))
returnOp.setOperand(index, computeResult);
}
static void wrapTopLevelRuntimeTransposes(func::FuncOp funcOp) {
IRRewriter rewriter(funcOp.getContext());
Block& entryBlock = funcOp.getFunctionBody().front();
for (Operation& op : llvm::make_early_inc_range(entryBlock)) {
auto transposeOp = dyn_cast<ONNXTransposeOp>(&op);
if (!transposeOp || isHostFoldableOp(transposeOp))
continue;
// Transpose stays globally legal because constant/view-only cases are
// allowed on the host. Any residual runtime transpose must be sunk into
// spat.compute before the host legality check.
auto resultType = transposeOp.getResult().getType();
rewriter.setInsertionPoint(transposeOp);
auto computeOp = createSpatCompute<1>(
rewriter, transposeOp.getLoc(), TypeRange {resultType}, {}, ValueRange {transposeOp.getData()}, [&](Value input) {
Value transposed =
ONNXTransposeOp::create(rewriter, transposeOp.getLoc(), resultType, input, transposeOp.getPermAttr());
spatial::SpatYieldOp::create(rewriter, transposeOp.getLoc(), transposed);
});
rewriter.replaceOp(transposeOp, computeOp.getResult(0));
}
}
void ONNXToSpatialPass::runOnOperation() {
ModuleOp moduleOp = getOperation();
MLIRContext* ctx = &getContext();
RewritePatternSet mergeActivationPatterns(ctx);
mergeActivationPatterns.add<onnxToArithConstant>(ctx);
mergeActivationPatterns.add<convAddToConvWithBiasLeft>(ctx);
mergeActivationPatterns.add<convAddToConvWithBiasRight>(ctx);
mergeActivationPatterns.add<matMulAddToGemm>(ctx);
mergeActivationPatterns.add<matMulToGemm>(ctx);
mergeActivationPatterns.add<removeFlattenSameShape>(ctx);
populateMatMulRewritePatterns(mergeActivationPatterns, ctx);
ConversionTarget preTarget(*ctx);
preTarget.addLegalDialect<spatial::SpatialDialect,
ONNXDialect,
tensor::TensorDialect,
arith::ArithDialect,
scf::SCFDialect>();
preTarget.addIllegalOp<ONNXConstantOp, ONNXFlattenOp>();
if (failed(applyPatternsGreedily(moduleOp, std::move(mergeActivationPatterns))))
llvm::dbgs() << "Failed to merge activation patterns, continuing...\n";
RewritePatternSet prePatterns(ctx);
populatePrePatterns(prePatterns, ctx);
if (failed(applyPartialConversion(moduleOp, preTarget, std::move(prePatterns)))) {
moduleOp.emitError("failed to apply ONNX-to-Spatial pre-rewrites");
signalPassFailure();
return;
}
IRRewriter rewriter(moduleOp);
auto entryFunc = getPimEntryFunc(moduleOp);
if (failed(entryFunc)) {
moduleOp.emitError("failed to locate the PIM entry function during ONNX-to-Spatial lowering");
signalPassFailure();
return;
}
RewritePatternSet matmulPatterns(ctx);
populateMatMulRewritePatterns(matmulPatterns, ctx);
walkAndApplyPatterns(moduleOp, std::move(matmulPatterns));
bool hasUnloweredMatMul = false;
moduleOp.walk([&](ONNXMatMulOp matmulOp) {
hasUnloweredMatMul = true;
matmulOp.emitOpError("remaining ONNX MatMul before the required ONNX-to-Spatial conversion");
});
if (hasUnloweredMatMul) {
moduleOp.emitError("failed to lower all ONNX MatMul ops before ONNX-to-Spatial conversion");
signalPassFailure();
return;
}
@@ -87,8 +157,7 @@ void ONNXToSpatialPass::runOnOperation() {
tensor::TensorDialect,
arith::ArithDialect,
scf::SCFDialect>();
target.addDynamicallyLegalOp<ONNXMatMulOp>(
[](ONNXMatMulOp op) { return cast<ShapedType>(op.getY().getType()).getRank() != 2; });
target.addIllegalOp<ONNXMatMulOp>();
target.addIllegalOp<ONNXAddOp>();
target.addIllegalOp<ONNXDivOp>();
target.addIllegalOp<ONNXMulOp>();
@@ -107,370 +176,70 @@ void ONNXToSpatialPass::runOnOperation() {
target.addIllegalOp<ONNXReduceMeanV13Op>();
target.addIllegalOp<ONNXSplitOp>();
RewritePatternSet patterns(ctx);
patterns.add<removeLRN>(ctx);
populateElementwisePatterns(patterns, ctx);
populateGemmPatterns(patterns, ctx);
populateConvPatterns(patterns, ctx);
populatePoolPatterns(patterns, ctx);
populateReduceMeanPatterns(patterns, ctx);
populateReluPatterns(patterns, ctx);
populateSigmoidPatterns(patterns, ctx);
populateSoftmaxPatterns(patterns, ctx);
populateConcatPatterns(patterns, ctx);
populateGatherPatterns(patterns, ctx);
populateResizePatterns(patterns, ctx);
populateReshapePatterns(patterns, ctx);
populateSplitPatterns(patterns, ctx);
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
RewritePatternSet conversionPatterns(ctx);
populateConversionPatterns(conversionPatterns, ctx);
if (failed(applyPartialConversion(moduleOp, target, std::move(conversionPatterns)))) {
moduleOp.emitError("failed to convert required ONNX ops to Spatial ops");
signalPassFailure();
return;
}
// Count the number of compute ops and check they do not exceed the core count
if (coresCount != -1) {
int computeOpsCount = 0;
for (auto& op : entryFunc->getFunctionBody().front().getOperations())
if (isa<spatial::SpatWeightedCompute>(op))
computeOpsCount++;
ConversionTarget earlyPostTarget(*ctx);
earlyPostTarget.addLegalDialect<spatial::SpatialDialect,
ONNXDialect,
tensor::TensorDialect,
arith::ArithDialect,
scf::SCFDialect>();
earlyPostTarget.addDynamicallyLegalOp<spatial::SpatComputeBatch>(
[](spatial::SpatComputeBatch batchOp) { return !requiresEarlyPostRewrite(batchOp); });
if (computeOpsCount > coresCount) {
llvm::dbgs() << "Number of compute ops exceeds the core count\n";
RewritePatternSet earlyPostPatterns(ctx);
populateEarlyPostPatterns(earlyPostPatterns, ctx);
if (failed(applyPartialConversion(*entryFunc, earlyPostTarget, std::move(earlyPostPatterns)))) {
moduleOp.emitError("failed to normalize single-lane spat.compute_batch ops before core assignment checks");
signalPassFailure();
return;
}
}
PassManager cleanupPM(ctx);
cleanupPM.addPass(createCanonicalizerPass());
if (failed(cleanupPM.run(moduleOp)))
llvm::dbgs() << "Failed to run canonicalization cleanup, continuing...\n";
moduleOp.emitWarning("failed to run ONNX-to-Spatial canonicalization cleanup; continuing");
annotateWeightsConstants(*entryFunc);
encapsulateGlobalInstruction(*entryFunc);
if (failed(promoteConstantInputsToWeights(*entryFunc))) {
ConversionTarget postTarget(*ctx);
postTarget.addLegalDialect<spatial::SpatialDialect,
ONNXDialect,
tensor::TensorDialect,
arith::ArithDialect,
scf::SCFDialect>();
postTarget.addDynamicallyLegalOp<spatial::SpatCompute>(
[](spatial::SpatCompute computeOp) { return !requiresPostRewrite(computeOp); });
postTarget.addDynamicallyLegalOp<spatial::SpatComputeBatch>(
[](spatial::SpatComputeBatch computeOp) { return !requiresPostRewrite(computeOp); });
RewritePatternSet postPatterns(ctx);
populatePostPatterns(postPatterns, ctx);
if (failed(applyPartialConversion(*entryFunc, postTarget, std::move(postPatterns)))) {
moduleOp.emitError("failed to normalize weight-like Spatial compute operands before Spatial-to-PIM lowering");
signalPassFailure();
return;
}
mergeTriviallyConnectedComputes(*entryFunc);
wrapTopLevelRuntimeTransposes(*entryFunc);
if (failed(verifyONNXToSpatialHostLegality(*entryFunc))) {
moduleOp.emitError("ONNX-to-Spatial host legality verification failed");
signalPassFailure();
return;
}
populateEmptyFunction(*entryFunc);
// Dump to file for debug
dumpModule(moduleOp, "spatial0");
}
template <typename T>
bool encapsulator(IRRewriter& rewriter, Location loc, Operation* inst, std::function<Value(T)> funcSource) {
if (T toRemoveOp = llvm::dyn_cast_if_present<T>(inst)) {
Value source = funcSource(toRemoveOp);
rewriter.setInsertionPointAfter(toRemoveOp);
if (isa_and_present<spatial::SpatWeightedCompute>(source.getDefiningOp())) {
auto newCompute = spatial::SpatWeightedCompute::create(rewriter, loc, inst->getResultTypes().front(), source);
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), {source.getType()}, {loc});
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) 1});
rewriter.setInsertionPointToEnd(BB);
IRMapping mapper;
mapper.map(source, BB->getArgument(0));
auto newInst = rewriter.clone(*inst, mapper);
spatial::SpatYieldOp::create(rewriter, loc, newInst->getResult(0));
inst->replaceAllUsesWith(newCompute);
inst->erase();
return true;
}
}
return false;
}
bool encapsulateConcat(IRRewriter& rewriter, Location loc, Operation* inst) {
if (auto toRemoveOp = llvm::dyn_cast_if_present<tensor::ConcatOp>(inst)) {
auto sources = toRemoveOp.getInputs();
rewriter.setInsertionPointAfter(toRemoveOp);
if (llvm::any_of(
sources, [](auto source) { return isa_and_present<spatial::SpatWeightedCompute>(source.getDefiningOp()); })) {
auto newCompute = spatial::SpatWeightedCompute::create(rewriter, loc, inst->getResultTypes().front(), sources);
SmallVector<Type> sourceTypes;
SmallVector<Location> sourceLoc;
for (auto source : sources) {
sourceTypes.push_back(source.getType());
sourceLoc.push_back(loc);
}
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLoc);
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) sources.size()});
rewriter.setInsertionPointToEnd(BB);
IRMapping mapper;
for (auto [source, bbArg] : llvm::zip(sources, BB->getArguments()))
mapper.map(source, bbArg);
auto newConcat = rewriter.clone(*inst, mapper);
spatial::SpatYieldOp::create(rewriter, loc, newConcat->getResult(0));
inst->replaceAllUsesWith(newCompute);
inst->erase();
return true;
}
}
return false;
}
static FailureOr<Value> materializeWeightLikeValueInBlock(Value value, IRRewriter& rewriter, IRMapping& mapper) {
if (auto mapped = mapper.lookupOrNull(value))
return cast<Value>(mapped);
Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return failure();
if (isa<arith::ConstantOp, ONNXConstantOp>(definingOp)) {
auto tensorType = dyn_cast<RankedTensorType>(value.getType());
if (!tensorType || !tensorType.hasStaticShape())
return failure();
SmallVector<OpFoldResult> offsets(tensorType.getRank(), rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> sizes;
SmallVector<OpFoldResult> strides(tensorType.getRank(), rewriter.getIndexAttr(1));
sizes.reserve(tensorType.getRank());
for (int64_t dim : tensorType.getShape())
sizes.push_back(rewriter.getIndexAttr(dim));
auto referencedValue =
tensor::ExtractSliceOp::create(rewriter, value.getLoc(), tensorType, value, offsets, sizes, strides);
mapper.map(value, referencedValue.getResult());
return referencedValue.getResult();
}
if (!isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, ONNXTransposeOp>(definingOp))
return failure();
IRMapping localMapper;
for (Value operand : definingOp->getOperands()) {
if (auto mapped = mapper.lookupOrNull(operand)) {
localMapper.map(operand, cast<Value>(mapped));
continue;
}
if (isWeightLikeComputeOperand(operand)) {
auto clonedOperand = materializeWeightLikeValueInBlock(operand, rewriter, mapper);
if (failed(clonedOperand))
return failure();
localMapper.map(operand, *clonedOperand);
continue;
}
localMapper.map(operand, operand);
}
Operation* clonedOp = rewriter.clone(*definingOp, localMapper);
for (auto [oldResult, newResult] : llvm::zip(definingOp->getResults(), clonedOp->getResults()))
mapper.map(oldResult, newResult);
auto mapped = mapper.lookupOrNull(value);
if (!mapped)
return failure();
return cast<Value>(mapped);
}
// TODO what we want to keep in global?
void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) {
Location loc = funcOp.getLoc();
IRRewriter rewriter(&getContext());
bool keep = true;
while (keep) {
keep = false;
for (auto& instruction : llvm::make_early_inc_range(funcOp.getOps())) {
keep |= encapsulator<tensor::ExtractSliceOp>(
rewriter, loc, &instruction, [](tensor::ExtractSliceOp extract) { return extract.getSource(); });
keep |= encapsulator<tensor::ExpandShapeOp>(
rewriter, loc, &instruction, [](tensor::ExpandShapeOp expand) { return expand.getSrc(); });
keep |= encapsulator<ONNXTransposeOp>(
rewriter, loc, &instruction, [](ONNXTransposeOp transpose) { return transpose.getData(); });
keep |= encapsulator<tensor::CollapseShapeOp>(
rewriter, loc, &instruction, [](tensor::CollapseShapeOp collapse) { return collapse.getSrc(); });
keep |= encapsulateConcat(rewriter, loc, &instruction);
}
}
}
void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
Location loc = funcOp.getLoc();
IRRewriter rewriter(&getContext());
SmallVector<spatial::SpatWeightedCompute> trivialComputes;
llvm::SmallSet<spatial::SpatWeightedCompute, 8> toErase;
for (auto compute : funcOp.getOps<spatial::SpatWeightedCompute>())
if (compute->hasOneUse()) {
auto user = dyn_cast<spatial::SpatWeightedCompute>(*compute->getUsers().begin());
if (user && user.getInputs().size() == 1)
trivialComputes.push_back(compute);
}
while (!trivialComputes.empty()) {
auto compute = trivialComputes.front();
if (compute.use_empty()) {
std::swap(trivialComputes.front(), trivialComputes.back());
trivialComputes.pop_back();
continue;
}
auto child = cast<spatial::SpatWeightedCompute>(*compute->getUsers().begin());
rewriter.setInsertionPointAfter(compute.getOperation());
auto newCompute =
spatial::SpatWeightedCompute::create(rewriter, loc, child.getResultTypes(), compute.getOperands());
newCompute.getProperties().setOperandSegmentSizes(
{static_cast<int>(compute.getWeights().size()), static_cast<int>(compute.getInputs().size())});
IRMapping mapper;
auto weightMutableIter = newCompute.getWeightsMutable();
for (auto weight : child.getWeights()) {
auto founded = llvm::find(newCompute.getWeights(), weight);
if (founded == newCompute.getWeights().end()) {
weightMutableIter.append(weight);
auto last = weightMutableIter.end();
last = std::prev(last, 1);
mapper.map(weight, last->get());
}
else {
mapper.map(weight, *founded);
}
}
compute.getBodyRegion().cloneInto(&newCompute.getBodyRegion(), mapper);
auto newTerminator = newCompute.getBody().front().getTerminator();
mapper.map(*child.getBody().front().getArguments().begin(), newTerminator->getOperand(0));
newTerminator->erase();
rewriter.setInsertionPoint(&newCompute.getBody().front(), newCompute.getBody().front().end());
for (auto& op : child.getBody().front()) {
auto newInst = rewriter.clone(op, mapper);
if (auto vmOp = llvm::dyn_cast<spatial::SpatWeightedMVMOp>(newInst)) {
auto oldIndex = vmOp.getWeightIndex();
auto newWeight = mapper.lookup(*std::next(child.getWeights().begin(), oldIndex));
auto newIndex = std::distance(newCompute.getWeights().begin(), llvm::find(newCompute.getWeights(), newWeight));
vmOp.setWeightIndex(newIndex);
}
if (auto vmOp = llvm::dyn_cast<spatial::SpatWeightedVMMOp>(newInst)) {
auto oldIndex = vmOp.getWeightIndex();
auto newWeight = mapper.lookup(*std::next(child.getWeights().begin(), oldIndex));
auto newIndex = std::distance(newCompute.getWeights().begin(), llvm::find(newCompute.getWeights(), newWeight));
vmOp.setWeightIndex(newIndex);
}
}
child.replaceAllUsesWith(newCompute);
toErase.insert(child);
std::swap(trivialComputes.front(), trivialComputes.back());
trivialComputes.pop_back();
toErase.insert(compute);
if (newCompute->hasOneUse()) {
auto user = dyn_cast<spatial::SpatWeightedCompute>(*newCompute->getUsers().begin());
if (user && user.getInputs().size() == 1)
trivialComputes.push_back(newCompute);
}
}
for (auto compute : toErase) {
compute.getResult(0).dropAllUses();
compute.erase();
}
}
void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const {
funcOp.walk([&](arith::ConstantOp constantOp) {
bool isAlwaysWeight =
llvm::all_of(constantOp->getUsers(), [](auto user) -> bool { return isa<spatial::SpatWeightedCompute>(user); });
if (isAlwaysWeight)
markWeightAlways(constantOp);
});
}
LogicalResult ONNXToSpatialPass::promoteConstantInputsToWeights(func::FuncOp funcOp) {
IRRewriter rewriter(&getContext());
SmallVector<spatial::SpatWeightedCompute> computes(funcOp.getOps<spatial::SpatWeightedCompute>());
for (auto compute : computes) {
SmallVector<bool> promoteInput(compute.getInputs().size(), false);
bool needsRewrite = false;
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
if (!isWeightLikeComputeOperand(input))
continue;
promoteInput[inputIdx] = true;
needsRewrite = true;
}
if (!needsRewrite)
continue;
rewriter.setInsertionPointAfter(compute);
SmallVector<Value> newWeights(compute.getWeights().begin(), compute.getWeights().end());
SmallVector<Value> newInputs;
SmallVector<Type> newInputTypes;
SmallVector<Location> newInputLocs;
newWeights.reserve(compute.getWeights().size() + compute.getInputs().size());
newInputs.reserve(compute.getInputs().size());
newInputTypes.reserve(compute.getInputs().size());
newInputLocs.reserve(compute.getInputs().size());
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
if (promoteInput[inputIdx]) {
newWeights.push_back(input);
continue;
}
newInputs.push_back(input);
newInputTypes.push_back(input.getType());
newInputLocs.push_back(input.getLoc());
}
auto newCompute =
spatial::SpatWeightedCompute::create(rewriter, compute.getLoc(), compute.getResultTypes(), newWeights, newInputs);
auto* newBlock =
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), newInputTypes, newInputLocs);
newCompute.getProperties().setOperandSegmentSizes(
{static_cast<int>(newWeights.size()), static_cast<int>(newInputs.size())});
rewriter.setInsertionPointToStart(newBlock);
IRMapping mapper;
auto& oldBlock = compute.getBody().front();
size_t newInputIdx = 0;
for (auto [oldInputIdx, oldArg] : llvm::enumerate(oldBlock.getArguments())) {
if (!promoteInput[oldInputIdx]) {
mapper.map(oldArg, newBlock->getArgument(newInputIdx++));
continue;
}
auto clonedValue = materializeWeightLikeValueInBlock(compute.getInputs()[oldInputIdx], rewriter, mapper);
if (failed(clonedValue))
return compute.emitError("failed to materialize promoted weight-like operand inside compute body");
mapper.map(oldArg, *clonedValue);
}
for (auto& op : oldBlock.without_terminator())
rewriter.clone(op, mapper);
auto oldYield = cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
SmallVector<Value> newYieldOperands;
newYieldOperands.reserve(oldYield.getOutputs().size());
for (Value operand : oldYield.getOutputs()) {
auto mapped = mapper.lookupOrNull(operand);
newYieldOperands.push_back(mapped ? cast<Value>(mapped) : operand);
}
spatial::SpatYieldOp::create(rewriter, oldYield.getLoc(), newYieldOperands);
compute.replaceAllUsesWith(newCompute);
compute.erase();
}
return success();
}
std::unique_ptr<Pass> createONNXToSpatialPass() { return std::make_unique<ONNXToSpatialPass>(); }
} // namespace onnx_mlir
@@ -7,11 +7,11 @@
#include "llvm/ADT/SmallVector.h"
#include <algorithm>
#include <cassert>
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -28,16 +28,6 @@ struct ConvToGemm : OpConversionPattern<ONNXConvOp> {
ConversionPatternRewriter& rewriter) const override;
};
static DenseElementsAttr getDenseConstantAttr(Value value) {
if (auto constantOp = value.getDefiningOp<arith::ConstantOp>())
return dyn_cast<DenseElementsAttr>(constantOp.getValue());
if (auto constantOp = value.getDefiningOp<ONNXConstantOp>())
return dyn_cast_or_null<DenseElementsAttr>(constantOp.getValueAttr());
return nullptr;
}
static int64_t getI64FromArrayAttr(ArrayAttr arr, size_t idx) { return cast<IntegerAttr>(arr[idx]).getInt(); }
static Value expandBiasIfNeeded(Value bias, ConversionPatternRewriter& rewriter, Location loc) {
@@ -147,10 +137,11 @@ static Value buildPackedBias(bool hasBias,
return arith::ConstantOp::create(rewriter, loc, packedBiasType, packedBiasAttr).getResult();
}
static Value createIm2colCompute(Value x,
static Value createIm2colRowComputes(Value x,
RankedTensorType xType,
RankedTensorType im2colType,
RankedTensorType rowType,
RankedTensorType im2colRowType,
RankedTensorType gemmInputRowsType,
int64_t batchSize,
int64_t numChannelsIn,
int64_t xHeight,
@@ -169,11 +160,14 @@ static Value createIm2colCompute(Value x,
int64_t patchSize,
int64_t numPatches,
int64_t numPatchesPerBatch,
int64_t packFactor,
ConversionPatternRewriter& rewriter,
Location loc) {
auto elemType = xType.getElementType();
constexpr size_t numInputs = 1;
auto im2colComputeOp = createSpatCompute<numInputs>(rewriter, loc, im2colType, {}, x, [&](Value xArg) {
const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor);
auto im2colComputeOp =
createSpatCompute<numInputs>(rewriter, loc, TypeRange {gemmInputRowsType}, {}, x, [&](Value xArg) {
Value paddedInput = xArg;
// Pad input with zeros if needed:
@@ -240,7 +234,7 @@ static Value createIm2colCompute(Value x,
Value row = tensor::CollapseShapeOp::create(rewriter,
loc,
rowType,
im2colRowType,
patch,
SmallVector<ReassociationIndices> {
{0},
@@ -256,28 +250,13 @@ static Value createIm2colCompute(Value x,
rewriter.setInsertionPointAfter(im2colLoop);
Value im2col = im2colLoop.getResult(0);
spatial::SpatYieldOp::create(rewriter, loc, im2col);
});
return im2colComputeOp.getResult(0);
}
static Value createPackedIm2colRows(Value im2col,
RankedTensorType im2colType,
Type elemType,
int64_t numPatches,
int64_t patchSize,
int64_t packFactor,
ConversionPatternRewriter& rewriter,
Location loc) {
if (packFactor == 1)
return im2col;
const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor);
Value gemmInputRows = im2col;
if (packFactor != 1) {
const int64_t paddedNumPatches = packedNumRows * packFactor;
auto groupedType = RankedTensorType::get({packedNumRows, packFactor, patchSize}, elemType);
auto packedType = RankedTensorType::get({packedNumRows, packFactor * patchSize}, elemType);
auto packedComputeOp = createSpatCompute<1>(rewriter, loc, packedType, {}, im2col, [&](Value im2colArg) {
Value paddedIm2col = createPaddedRows(im2colArg, im2colType, paddedNumPatches, rewriter, loc);
Value paddedIm2col = createPaddedRows(im2col, im2colType, paddedNumPatches, rewriter, loc);
Value groupedIm2col = tensor::ExpandShapeOp::create(rewriter,
loc,
groupedType,
@@ -286,7 +265,7 @@ static Value createPackedIm2colRows(Value im2col,
{0, 1},
{2}
});
Value packedIm2col = tensor::CollapseShapeOp::create(rewriter,
gemmInputRows = tensor::CollapseShapeOp::create(rewriter,
loc,
packedType,
groupedIm2col,
@@ -294,31 +273,39 @@ static Value createPackedIm2colRows(Value im2col,
{0},
{1, 2}
});
spatial::SpatYieldOp::create(rewriter, loc, packedIm2col);
}
spatial::SpatYieldOp::create(rewriter, loc, gemmInputRows);
});
return packedComputeOp.getResult(0);
return im2colComputeOp.getResult(0);
}
static Value createUnpackedOutput(Value packedOutput,
static Value createCollectedConvOutput(ValueRange gemmRows,
Type convType,
RankedTensorType gemmOutType,
RankedTensorType nhwcType,
RankedTensorType outType,
int64_t numPatches,
int64_t numChannelsOut,
int64_t packFactor,
ConversionPatternRewriter& rewriter,
Location loc) {
if (packFactor == 1)
return packedOutput;
const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor);
const int64_t paddedNumPatches = packedNumRows * packFactor;
auto collectComputeOp = createSpatCompute(rewriter, loc, convType, {}, gemmRows, [&](ValueRange gemmRowArgs) {
Value gemmOut;
if (packFactor == 1) {
gemmOut = createSpatConcat(rewriter, loc, /*axis=*/0, gemmRowArgs);
}
else {
auto expandedType = RankedTensorType::get({packedNumRows, packFactor, numChannelsOut}, outType.getElementType());
auto paddedType = RankedTensorType::get({paddedNumPatches, numChannelsOut}, outType.getElementType());
auto unpackComputeOp = createSpatCompute<1>(rewriter, loc, gemmOutType, {}, packedOutput, [&](Value packedOutputArg) {
Value packedOutput = createSpatConcat(rewriter, loc, /*axis=*/0, gemmRowArgs);
Value expandedOutput = tensor::ExpandShapeOp::create(rewriter,
loc,
expandedType,
packedOutputArg,
packedOutput,
SmallVector<ReassociationIndices> {
{0},
{1, 2}
@@ -332,29 +319,14 @@ static Value createUnpackedOutput(Value packedOutput,
{2}
});
Value unpackedOutput = paddedOutput;
gemmOut = paddedOutput;
if (paddedNumPatches != numPatches) {
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(numPatches), rewriter.getIndexAttr(numChannelsOut)};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
unpackedOutput =
tensor::ExtractSliceOp::create(rewriter, loc, gemmOutType, paddedOutput, offsets, sizes, strides);
gemmOut = tensor::ExtractSliceOp::create(rewriter, loc, gemmOutType, paddedOutput, offsets, sizes, strides);
}
}
spatial::SpatYieldOp::create(rewriter, loc, unpackedOutput);
});
return unpackComputeOp.getResult(0);
}
static Value createCollectedConvOutput(Value gemmOut,
Type convType,
RankedTensorType nhwcType,
RankedTensorType outType,
ConversionPatternRewriter& rewriter,
Location loc) {
auto collectComputeOp =
createSpatCompute(rewriter, loc, convType, {}, ValueRange {gemmOut}, [&](ValueRange gemmOutArgs) {
Value gemmOutArg = gemmOutArgs.front();
// Restore to NCHW layout:
// [numPatches, numChannelsOut]
@@ -363,7 +335,7 @@ static Value createCollectedConvOutput(Value gemmOut,
Value nhwcOut = tensor::ExpandShapeOp::create(rewriter,
loc,
nhwcType,
gemmOutArg,
gemmOut,
SmallVector<ReassociationIndices> {
{0, 1, 2},
{3}
@@ -374,6 +346,160 @@ static Value createCollectedConvOutput(Value gemmOut,
return collectComputeOp.getResult(0);
}
static Value lowerSingleConvGroup(Value x,
Value w,
Value b,
RankedTensorType xType,
RankedTensorType wType,
RankedTensorType outType,
int64_t padHeightBegin,
int64_t padHeightEnd,
int64_t padWidthBegin,
int64_t padWidthEnd,
int64_t strideHeight,
int64_t strideWidth,
int64_t dilationHeight,
int64_t dilationWidth,
ConversionPatternRewriter& rewriter,
Location loc) {
const int64_t batchSize = xType.getDimSize(0);
const int64_t numChannelsIn = xType.getDimSize(1);
const int64_t xHeight = xType.getDimSize(2);
const int64_t xWidth = xType.getDimSize(3);
const int64_t numChannelsOut = wType.getDimSize(0);
const int64_t wHeight = wType.getDimSize(2);
const int64_t wWidth = wType.getDimSize(3);
const int64_t outHeight = outType.getDimSize(2);
const int64_t outWidth = outType.getDimSize(3);
// im2col layout (flipped with respect to the standard, so filters sit in B = crossbar):
// A (im2col): [numPatches, patchSize] -- one row per output spatial position
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
// Gemm output: [numPatches, cOut]
const int64_t patchSize = numChannelsIn * wHeight * wWidth;
const int64_t numPatchesPerBatch = outHeight * outWidth;
const int64_t numPatches = batchSize * numPatchesPerBatch;
auto elemType = xType.getElementType();
auto im2colType = RankedTensorType::get({numPatches, patchSize}, elemType);
auto rowType = RankedTensorType::get({1, patchSize}, elemType);
auto wFlatType = RankedTensorType::get({numChannelsOut, patchSize}, wType.getElementType());
auto wTransType = RankedTensorType::get({patchSize, numChannelsOut}, wType.getElementType());
auto gemmOutType = RankedTensorType::get({numPatches, numChannelsOut}, outType.getElementType());
auto nhwcType = RankedTensorType::get({batchSize, outHeight, outWidth, numChannelsOut}, outType.getElementType());
const int64_t xbarSize = static_cast<int64_t>(crossbarSize.getValue());
const int64_t wMaxDim = std::max(patchSize, numChannelsOut);
const int64_t maxParallelPixels = std::max<int64_t>(1, xbarSize / wMaxDim);
auto wDenseAttr = getHostFoldableDenseElementsAttr(w);
// Prepare weight matrix W for crossbar storage:
// W: [numChannelsOut, numChannelsIn, wHeight, wWidth] -> [numChannelsOut, patchSize] -> [patchSize, numChannelsOut]
Value wFlat = tensor::CollapseShapeOp::create(rewriter,
loc,
wFlatType,
w,
SmallVector<ReassociationIndices> {
{0},
{1, 2, 3}
});
Value wTrans = ONNXTransposeOp::create(rewriter, loc, wTransType, wFlat, rewriter.getI64ArrayAttr({1, 0}));
// Pass bias through directly; Gemm handles rank-1 C canonicalization.
bool hasB = !isa<ONNXNoneOp>(b.getDefiningOp());
Value gemmBias = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
Value biasMatrix;
DenseElementsAttr biasDenseAttr;
if (hasB) {
gemmBias = b;
biasDenseAttr = getHostFoldableDenseElementsAttr(b);
biasMatrix = expandBiasIfNeeded(b, rewriter, loc);
}
const bool canPackWeightsAsConstants = static_cast<bool>(wDenseAttr);
const bool canPackBiasAsConstants = !hasB || static_cast<bool>(biasDenseAttr);
const int64_t effectiveMaxParallelPixels =
(canPackWeightsAsConstants && canPackBiasAsConstants) ? maxParallelPixels : 1;
// Keep the standard im2col view of convolution:
// A (im2col): [numPatches, patchSize] -- one row per output spatial position
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
// and optionally repack several old rows into one GEMM row to use the available crossbar size better.
//
// We want to process N pixels at the same time. Instead of doing N separate operations
// of (1 x patchSize) x (patchSize x cOut), we construct a block-diagonal weight matrix
// containing N copies of W^T and concatenate N im2col rows into one longer row:
// A_packed: [ceil(numPatches / N), N * patchSize]
// B_packed: [N * patchSize, N * cOut]
// Y_packed: [ceil(numPatches / N), N * cOut]
const int64_t packedNumRows = ceilIntegerDivide(numPatches, effectiveMaxParallelPixels);
auto gemmInputRowsType = RankedTensorType::get({packedNumRows, effectiveMaxParallelPixels * patchSize}, elemType);
auto gemmOutputRowsType =
RankedTensorType::get({packedNumRows, effectiveMaxParallelPixels * numChannelsOut}, outType.getElementType());
Value gemmInputRows = createIm2colRowComputes(x,
xType,
im2colType,
rowType,
gemmInputRowsType,
batchSize,
numChannelsIn,
xHeight,
xWidth,
wHeight,
wWidth,
padHeightBegin,
padHeightEnd,
padWidthBegin,
padWidthEnd,
strideHeight,
strideWidth,
dilationHeight,
dilationWidth,
outWidth,
patchSize,
numPatches,
numPatchesPerBatch,
effectiveMaxParallelPixels,
rewriter,
loc);
Value gemmB = buildPackedWeight(wDenseAttr,
wTrans,
wType,
numChannelsIn,
numChannelsOut,
wHeight,
wWidth,
patchSize,
effectiveMaxParallelPixels,
rewriter,
loc);
Value gemmC = buildPackedBias(
hasB, gemmBias, biasMatrix, biasDenseAttr, outType, numChannelsOut, effectiveMaxParallelPixels, rewriter, loc);
Value gemmRows = ONNXGemmOp::create(rewriter,
loc,
gemmOutputRowsType,
gemmInputRows,
gemmB,
gemmC,
rewriter.getF32FloatAttr(1.0f),
rewriter.getF32FloatAttr(1.0f),
rewriter.getBoolAttr(false),
rewriter.getBoolAttr(false))
.getY();
return createCollectedConvOutput(ValueRange {gemmRows},
outType,
gemmOutType,
nhwcType,
outType,
numPatches,
numChannelsOut,
effectiveMaxParallelPixels,
rewriter,
loc);
}
} // namespace
LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
@@ -388,11 +514,34 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
auto wType = cast<RankedTensorType>(w.getType());
auto outType = cast<RankedTensorType>(convOp.getY().getType());
assert("Only support static shapes" && xType.hasStaticShape() && wType.hasStaticShape() && outType.hasStaticShape());
assert("Only support 2D convolution" && xType.getRank() == 4);
// We need to understand what is group
assert("Only support group=1" && convOp.getGroup() == 1);
if (!xType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv input");
return failure();
}
if (!wType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv weight");
return failure();
}
if (!outType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv result");
return failure();
}
if (xType.getRank() != 4) {
pim::emitUnsupportedRankDiagnostic(convOp, "conv input", xType.getRank(), {4});
return failure();
}
if (wType.getRank() != 4) {
pim::emitUnsupportedRankDiagnostic(convOp, "conv weight", wType.getRank(), {4});
return failure();
}
if (outType.getRank() != 4) {
pim::emitUnsupportedRankDiagnostic(convOp, "conv result", outType.getRank(), {4});
return failure();
}
if (convOp.getGroup() < 1) {
convOp.emitOpError("requires group >= 1 for Spatial lowering");
return failure();
}
const int64_t batchSize = xType.getDimSize(0);
const int64_t numChannelsIn = xType.getDimSize(1);
@@ -403,12 +552,51 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
const int64_t wWidth = wType.getDimSize(3);
const int64_t outHeight = outType.getDimSize(2);
const int64_t outWidth = outType.getDimSize(3);
const int64_t group = convOp.getGroup();
const bool hasB = !isa<ONNXNoneOp>(b.getDefiningOp());
if (numChannelsIn % group != 0) {
convOp.emitOpError() << "requires input channels " << numChannelsIn << " to be divisible by group " << group
<< " for Spatial lowering";
return failure();
}
if (numChannelsOut % group != 0) {
convOp.emitOpError() << "requires output channels " << numChannelsOut << " to be divisible by group " << group
<< " for Spatial lowering";
return failure();
}
const int64_t numChannelsInPerGroup = numChannelsIn / group;
const int64_t numChannelsOutPerGroup = numChannelsOut / group;
if (wType.getDimSize(1) != numChannelsInPerGroup) {
convOp.emitOpError() << "requires grouped conv weight input channels " << wType.getDimSize(1)
<< " to match input channels per group " << numChannelsInPerGroup << " for Spatial lowering";
return failure();
}
if (wType.getDimSize(0) != numChannelsOut) {
convOp.emitOpError() << "requires weight output channels " << wType.getDimSize(0) << " to match result channels "
<< numChannelsOut << " for Spatial lowering";
return failure();
}
// Read optional conv attributes (ONNX defaults: stride=1, dilation=1, pad=0)
const auto stridesAttr = convOp.getStrides();
const auto dilationsAttr = convOp.getDilations();
const auto padsAttr = convOp.getPads();
if (stridesAttr && stridesAttr->size() != 2) {
convOp.emitOpError("requires exactly two stride values for Spatial lowering");
return failure();
}
if (dilationsAttr && dilationsAttr->size() != 2) {
convOp.emitOpError("requires exactly two dilation values for Spatial lowering");
return failure();
}
if (padsAttr && padsAttr->size() != 4) {
convOp.emitOpError("requires exactly four pad values for 2D Spatial lowering");
return failure();
}
const int64_t strideHeight = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 0) : 1;
const int64_t strideWidth = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 1) : 1;
const int64_t dilationHeight = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 0) : 1;
@@ -449,67 +637,21 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
padWidthBegin = totalPadW - padWidthEnd;
}
}
else if (autoPad != "NOTSET" && autoPad != "VALID") {
convOp.emitOpError() << "unsupported auto_pad value `" << autoPad << "` for Spatial lowering";
return failure();
}
// "NOTSET" or "VALID" -> all pads stay 0
}
// im2col layout (flipped with respect to the standard, so filters sit in B = crossbar):
// A (im2col): [numPatches, patchSize] -- one row per output spatial position
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
// Gemm output: [numPatches, cOut]
const int64_t patchSize = numChannelsIn * wHeight * wWidth;
const int64_t numPatchesPerBatch = outHeight * outWidth;
const int64_t numPatches = batchSize * numPatchesPerBatch;
auto elemType = xType.getElementType();
auto im2colType = RankedTensorType::get({numPatches, patchSize}, elemType);
auto rowType = RankedTensorType::get({1, patchSize}, elemType);
auto wFlatType = RankedTensorType::get({numChannelsOut, patchSize}, wType.getElementType());
auto wTransType = RankedTensorType::get({patchSize, numChannelsOut}, wType.getElementType());
auto gemmOutType = RankedTensorType::get({numPatches, numChannelsOut}, outType.getElementType());
auto nhwcType = RankedTensorType::get({batchSize, outHeight, outWidth, numChannelsOut}, outType.getElementType());
const int64_t xbarSize = static_cast<int64_t>(crossbarSize.getValue());
const int64_t wMaxDim = std::max(patchSize, numChannelsOut);
const int64_t maxParallelPixels = std::max<int64_t>(1, xbarSize / wMaxDim);
auto wDenseAttr = getDenseConstantAttr(w);
// Prepare weight matrix W for crossbar storage:
// W: [numChannelsOut, numChannelsIn, wHeight, wWidth] -> [numChannelsOut, patchSize] -> [patchSize, numChannelsOut]
Value wFlat = tensor::CollapseShapeOp::create(rewriter,
loc,
wFlatType,
if (group == 1) {
rewriter.replaceOp(convOp,
lowerSingleConvGroup(x,
w,
SmallVector<ReassociationIndices> {
{0},
{1, 2, 3}
});
Value wTrans = ONNXTransposeOp::create(rewriter, loc, wTransType, wFlat, rewriter.getI64ArrayAttr({1, 0}));
// Pass bias through directly; Gemm handles rank-1 C canonicalization.
bool hasB = !isa<ONNXNoneOp>(b.getDefiningOp());
Value gemmC = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
Value biasMatrix;
DenseElementsAttr biasDenseAttr;
if (hasB) {
gemmC = b;
biasDenseAttr = getDenseConstantAttr(b);
biasMatrix = expandBiasIfNeeded(b, rewriter, loc);
}
const bool canPackWeightsAsConstants = static_cast<bool>(wDenseAttr);
const bool canPackBiasAsConstants = !hasB || static_cast<bool>(biasDenseAttr);
const int64_t effectiveMaxParallelPixels =
(canPackWeightsAsConstants && canPackBiasAsConstants) ? maxParallelPixels : 1;
Value im2col = createIm2colCompute(x,
b,
xType,
im2colType,
rowType,
batchSize,
numChannelsIn,
xHeight,
xWidth,
wHeight,
wWidth,
wType,
outType,
padHeightBegin,
padHeightEnd,
padWidthBegin,
@@ -518,76 +660,74 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
strideWidth,
dilationHeight,
dilationWidth,
outWidth,
patchSize,
numPatches,
numPatchesPerBatch,
rewriter,
loc);
loc));
return success();
}
Value gemmOut;
if (effectiveMaxParallelPixels == 1) {
// Fallback to the plain im2col GEMM when a single crossbar cannot fit multiple pixels.
gemmOut = ONNXGemmOp::create(rewriter,
loc,
gemmOutType,
im2col,
wTrans,
gemmC,
rewriter.getF32FloatAttr(1.0f),
rewriter.getF32FloatAttr(1.0f),
rewriter.getBoolAttr(false),
rewriter.getBoolAttr(false))
.getY();
SmallVector<Value> xSlices = sliceTensor(x, /*axis=*/1, numChannelsInPerGroup, rewriter, loc);
SmallVector<Value> wSlices = sliceTensor(w, /*axis=*/0, numChannelsOutPerGroup, rewriter, loc);
SmallVector<Value> bSlices;
if (hasB) {
auto biasType = cast<RankedTensorType>(b.getType());
int64_t biasAxis = -1;
if (biasType.getRank() == 1)
biasAxis = 0;
else if (biasType.getRank() == 2)
biasAxis = biasType.getDimSize(0) != 1 ? 0 : 1;
else {
convOp.emitOpError() << "requires rank-1 or rank-2 bias for grouped convolution Spatial lowering, but got rank "
<< biasType.getRank();
return failure();
}
bSlices = sliceTensor(b, biasAxis, numChannelsOutPerGroup, rewriter, loc);
}
if (xSlices.size() != static_cast<size_t>(group) || wSlices.size() != static_cast<size_t>(group)
|| (hasB && bSlices.size() != static_cast<size_t>(group))) {
convOp.emitOpError("failed to partition grouped convolution operands for Spatial lowering");
return failure();
}
SmallVector<Value> groupResults;
groupResults.reserve(group);
auto groupOutType =
RankedTensorType::get({batchSize, numChannelsOutPerGroup, outHeight, outWidth}, outType.getElementType());
Value noBias = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
for (int64_t groupId = 0; groupId < group; groupId++) {
Value groupX = xSlices[groupId];
Value groupW = wSlices[groupId];
Value groupB = hasB ? bSlices[groupId] : noBias;
groupResults.push_back(lowerSingleConvGroup(groupX,
groupW,
groupB,
cast<RankedTensorType>(groupX.getType()),
cast<RankedTensorType>(groupW.getType()),
groupOutType,
padHeightBegin,
padHeightEnd,
padWidthBegin,
padWidthEnd,
strideHeight,
strideWidth,
dilationHeight,
dilationWidth,
rewriter,
loc));
}
Value result;
if (llvm::all_of(groupResults, isHostFoldableValue)) {
result = createSpatConcat(rewriter, loc, /*axis=*/1, groupResults);
}
else {
// Keep the standard im2col view of convolution:
// A (im2col): [numPatches, patchSize] -- one row per output spatial position
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
// but repack several old rows into one new row so we use the available crossbar size better.
//
// We want to process N spatial pixels at the exact same time. Instead of doing N separate
// operations of (1 x patchSize) x (patchSize x cOut), we construct a block-diagonal weight matrix
// containing N copies of W^T and concatenate N im2col rows into one longer row:
// A_packed: [ceil(numPatches / N), N * patchSize]
// B_packed: [N * patchSize, N * cOut]
// Y_packed: [ceil(numPatches / N), N * cOut]
// The downstream GemmToManyGemv pass still splits by row, but now there are fewer, longer rows.
const int64_t packedNumRows = ceilIntegerDivide(numPatches, effectiveMaxParallelPixels);
auto packedOutType =
RankedTensorType::get({packedNumRows, effectiveMaxParallelPixels * numChannelsOut}, outType.getElementType());
Value packedA = createPackedIm2colRows(
im2col, im2colType, elemType, numPatches, patchSize, effectiveMaxParallelPixels, rewriter, loc);
Value packedB = buildPackedWeight(wDenseAttr,
wTrans,
wType,
numChannelsIn,
numChannelsOut,
wHeight,
wWidth,
patchSize,
effectiveMaxParallelPixels,
rewriter,
loc);
Value packedC = buildPackedBias(
hasB, gemmC, biasMatrix, biasDenseAttr, outType, numChannelsOut, effectiveMaxParallelPixels, rewriter, loc);
Value packedOut = ONNXGemmOp::create(rewriter,
loc,
packedOutType,
packedA,
packedB,
packedC,
rewriter.getF32FloatAttr(1.0f),
rewriter.getF32FloatAttr(1.0f),
rewriter.getBoolAttr(false),
rewriter.getBoolAttr(false))
.getY();
gemmOut = createUnpackedOutput(
packedOut, gemmOutType, outType, numPatches, numChannelsOut, effectiveMaxParallelPixels, rewriter, loc);
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {outType}, {}, groupResults, [&](ValueRange args) {
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, /*axis=*/1, args));
});
result = concatCompute.getResult(0);
}
rewriter.replaceOp(convOp, createCollectedConvOutput(gemmOut, convOp.getType(), nhwcType, outType, rewriter, loc));
rewriter.replaceOp(convOp, result);
return success();
}
@@ -5,8 +5,9 @@
#include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -15,13 +16,6 @@ using namespace mlir;
namespace onnx_mlir {
namespace {
static SmallVector<int64_t> computeRowMajorStrides(ArrayRef<int64_t> shape) {
SmallVector<int64_t> strides(shape.size(), 1);
for (int64_t i = static_cast<int64_t>(shape.size()) - 2; i >= 0; --i)
strides[i] = strides[i + 1] * shape[i + 1];
return strides;
}
static DenseElementsAttr getDenseConstantAttr(Value value) {
if (auto constantOp = value.getDefiningOp<arith::ConstantOp>())
return dyn_cast<DenseElementsAttr>(constantOp.getValue());
@@ -1,16 +1,17 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Location.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVector.h"
#include <cassert>
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -49,6 +50,45 @@ materializeScaledConstantTensor(Value value, float factor, ConversionPatternRewr
return arith::ConstantOp::create(rewriter, loc, denseAttr.getType(), scaledAttr).getResult();
}
static Value transposeForSpatial(Value value,
RankedTensorType resultType,
ArrayRef<int64_t> permutation,
ConversionPatternRewriter& rewriter,
Location loc) {
if (isHostFoldableValue(value))
return ONNXTransposeOp::create(rewriter, loc, resultType, value, rewriter.getI64ArrayAttr(permutation));
auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {resultType}, {}, value, [&](Value input) {
Value transposed = ONNXTransposeOp::create(rewriter, loc, resultType, input, rewriter.getI64ArrayAttr(permutation));
spatial::SpatYieldOp::create(rewriter, loc, transposed);
});
return computeOp.getResult(0);
}
static Value
expandRankOneBias(Value value, RankedTensorType resultType, ConversionPatternRewriter& rewriter, Location loc) {
if (isHostFoldableValue(value))
return tensor::ExpandShapeOp::create(rewriter,
loc,
resultType,
value,
SmallVector<ReassociationIndices> {
{0, 1}
});
auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {resultType}, {}, value, [&](Value input) {
Value expanded = tensor::ExpandShapeOp::create(rewriter,
loc,
resultType,
input,
SmallVector<ReassociationIndices> {
{0, 1}
});
spatial::SpatYieldOp::create(rewriter, loc, expanded);
});
return computeOp.getResult(0);
}
struct GemmToManyGemv : OpConversionPattern<ONNXGemmOp> {
using OpConversionPattern::OpConversionPattern;
@@ -65,6 +105,72 @@ struct GemvToSpatialCompute : OpConversionPattern<ONNXGemmOp> {
ConversionPatternRewriter& rewriter) const override;
};
struct GemmToSpatialComputeBatch : OpConversionPattern<ONNXGemmOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(ONNXGemmOp gemmOp,
ONNXGemmOpAdaptor gemmOpAdaptor,
ConversionPatternRewriter& rewriter) const override;
};
static SmallVector<Value> materializeBatchRowSlices(Value matrix,
RankedTensorType matrixType,
ConversionPatternRewriter& rewriter,
Location loc) {
const int64_t numRows = matrixType.getDimSize(0);
auto rowType = RankedTensorType::get({1, matrixType.getDimSize(1)}, matrixType.getElementType());
SmallVector<Type> resultTypes(static_cast<size_t>(numRows), rowType);
if (isHostFoldableValue(matrix)) {
auto extractRowsOp = spatial::SpatExtractRowsOp::create(rewriter, loc, TypeRange(resultTypes), matrix);
return SmallVector<Value>(extractRowsOp->result_begin(), extractRowsOp->result_end());
}
auto buildRowSlices = [&](Value matrixArg) {
auto extractRowsOp = spatial::SpatExtractRowsOp::create(rewriter, loc, TypeRange(resultTypes), matrixArg);
return SmallVector<Value>(extractRowsOp->result_begin(), extractRowsOp->result_end());
};
auto cloneBatchInputChainIntoSliceCompute =
[&](Value rootInput, SmallVector<Operation*> chainOps, Value rootValue) -> SmallVector<Value> {
auto sliceCompute =
createSpatCompute<1>(rewriter, loc, TypeRange(resultTypes), {}, ValueRange {rootInput}, [&](Value input) {
Value transformedMatrix = input;
if (!chainOps.empty()) {
IRMapping mapper;
mapper.map(rootValue, input);
for (Operation* chainOp : chainOps)
rewriter.clone(*chainOp, mapper);
transformedMatrix = cast<Value>(mapper.lookup(matrix));
}
spatial::SpatYieldOp::create(rewriter, loc, buildRowSlices(transformedMatrix));
});
SmallVector<Value> rowSlices(sliceCompute->result_begin(), sliceCompute->result_end());
return rowSlices;
};
SmallVector<Operation*> chainOps;
Value rootValue = matrix;
while (Operation* definingOp = rootValue.getDefiningOp()) {
if (auto rootCompute = dyn_cast<spatial::SpatCompute>(definingOp)) {
SmallVector<Operation*> reversedChainOps(chainOps.rbegin(), chainOps.rend());
return cloneBatchInputChainIntoSliceCompute(
rootCompute.getResult(cast<OpResult>(rootValue).getResultNumber()), reversedChainOps, rootValue);
}
if (definingOp->getNumOperands() != 1)
break;
if (!isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, ONNXTransposeOp>(definingOp))
break;
chainOps.push_back(definingOp);
rootValue = definingOp->getOperand(0);
}
SmallVector<Operation*> reversedChainOps(chainOps.rbegin(), chainOps.rend());
return cloneBatchInputChainIntoSliceCompute(rootValue, reversedChainOps, rootValue);
}
} // namespace
LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
@@ -75,13 +181,23 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
Value b = gemmOpAdaptor.getB();
Value c = gemmOpAdaptor.getC();
assert("A should have been transposed already" && !gemmOpAdaptor.getTransA());
if (gemmOpAdaptor.getTransA()) {
gemmOp.emitOpError("requires transA=false before Gemm row decomposition");
return failure();
}
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
auto aType = cast<RankedTensorType>(a.getType());
auto outType = cast<RankedTensorType>(gemmOp.getY().getType());
assert("Only support static shapes" && aType.hasStaticShape() && outType.hasStaticShape());
if (!aType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input A");
return failure();
}
if (!outType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm result");
return failure();
}
const int64_t numOutRows = aType.getDimSize(0);
@@ -105,47 +221,43 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
// Expand rank-1 bias [N] to rank-2 [1, N] for uniform handling
if (cType.getRank() == 1) {
auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType());
c = tensor::ExpandShapeOp::create(rewriter,
loc,
expandedType,
c,
SmallVector<ReassociationIndices> {
{0, 1}
});
c = expandRankOneBias(c, expandedType, rewriter, loc);
cType = expandedType;
}
assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
if (!cType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm bias");
return failure();
}
if (cType.getRank() != 2) {
pim::emitUnsupportedRankDiagnostic(gemmOp, "Gemm bias", cType.getRank(), {1, 2});
return failure();
}
cHasNumOutRows = cType.getDimSize(0) == numOutRows;
}
auto outRowType = RankedTensorType::get({1, outType.getDimSize(1)}, outType.getElementType());
SmallVector<Value> aSlices = materializeBatchRowSlices(a, aType, rewriter, loc);
SmallVector<Value> cSlices;
if (hasC && cHasNumOutRows)
cSlices = materializeBatchRowSlices(c, cType, rewriter, loc);
SmallVector<Value> gemvOps;
gemvOps.reserve(numOutRows);
gemvOps.reserve(static_cast<size_t>(numOutRows));
for (int64_t rowIdx = 0; rowIdx < numOutRows; rowIdx++) {
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(aType.getDimSize(1))};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
auto aSliceType = RankedTensorType::get({1, aType.getDimSize(1)}, aType.getElementType());
auto aSlice = tensor::ExtractSliceOp::create(rewriter, loc, aSliceType, a, offsets, sizes, strides).getResult();
Value cSlice = c;
if (hasC) {
if (cHasNumOutRows) {
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(cType.getDimSize(1))};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
auto cSliceType = RankedTensorType::get({1, cType.getDimSize(1)}, cType.getElementType());
cSlice = tensor::ExtractSliceOp::create(rewriter, loc, cSliceType, c, offsets, sizes, strides).getResult();
if (cHasNumOutRows)
cSlice = cSlices[static_cast<size_t>(rowIdx)];
else if (!isVectorShape(getTensorShape(c))) {
gemmOp.emitOpError("requires Gemm bias C to be vector-like when shared across decomposed rows");
return failure();
}
else
assert("C should be a vector" && isVectorShape(getTensorShape(c)));
}
auto gemvOp = ONNXGemmOp::create(rewriter,
loc,
outRowType,
aSlice,
aSlices[static_cast<size_t>(rowIdx)],
b,
cSlice,
rewriter.getF32FloatAttr(1.0f),
@@ -156,8 +268,7 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
}
auto concatComputeOp = createSpatCompute(rewriter, loc, gemmOp.getType(), {}, gemvOps, [&](ValueRange gemvOpsArgs) {
auto concatOp = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, gemvOpsArgs);
spatial::SpatYieldOp::create(rewriter, loc, concatOp.getResult());
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, /*axis=*/0, gemvOpsArgs));
});
rewriter.replaceOp(gemmOp, concatComputeOp);
@@ -189,20 +300,31 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
// Expand rank-1 bias [N] to rank-2 [1, N] for uniform handling
if (cType.getRank() == 1) {
auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType());
c = tensor::ExpandShapeOp::create(rewriter,
gemmLoc,
expandedType,
c,
SmallVector<ReassociationIndices> {
{0, 1}
});
c = expandRankOneBias(c, expandedType, rewriter, gemmLoc);
cType = expandedType;
}
assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
if (!cType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm bias");
return failure();
}
if (cType.getRank() != 2) {
pim::emitUnsupportedRankDiagnostic(gemmOp, "Gemm bias", cType.getRank(), {1, 2});
return failure();
}
}
assert("Only support static shapes" && aType.hasStaticShape() && bType.hasStaticShape()
&& (!hasC || cType.hasStaticShape()) && outType.hasStaticShape());
if (!aType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input A");
return failure();
}
if (!bType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input B");
return failure();
}
if (!outType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm result");
return failure();
}
if (!isVectorShape(aType.getShape()) || (hasC && !isVectorShape(cType.getShape())))
// Not a gemv
@@ -210,13 +332,14 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
if (transA) {
auto aShape = aType.getShape();
auto transposedType = aType.cloneWith(ArrayRef({aShape[1], aShape[0]}), aType.getElementType());
a = ONNXTransposeOp::create(rewriter, gemmLoc, transposedType, a, rewriter.getI64ArrayAttr({1, 0}));
auto transposedType = RankedTensorType::get({aShape[1], aShape[0]}, aType.getElementType());
a = transposeForSpatial(a, transposedType, {1, 0}, rewriter, gemmLoc);
aType = cast<RankedTensorType>(a.getType());
}
if (transB) {
auto bShape = bType.getShape();
auto transposedType = bType.cloneWith(ArrayRef({bShape[1], bShape[0]}), bType.getElementType());
b = ONNXTransposeOp::create(rewriter, gemmLoc, transposedType, b, rewriter.getI64ArrayAttr({1, 0}));
auto transposedType = RankedTensorType::get({bShape[1], bShape[0]}, bType.getElementType());
b = transposeForSpatial(b, transposedType, {1, 0}, rewriter, gemmLoc);
bType = cast<RankedTensorType>(b.getType());
}
@@ -240,7 +363,6 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
auto [aNumHSlices, aLastHSliceSize] = ceilIntegerDivideWithRemainder(aType.getDimSize(1), crossbarSize.getValue());
auto [bNumHSlices, bLastHSliceSize] = ceilIntegerDivideWithRemainder(bType.getDimSize(1), crossbarSize.getValue());
auto bNumVSlices = aNumHSlices;
auto bLastVSliceSize = aLastHSliceSize;
auto cNumHSlices = bNumHSlices;
auto cLastHSliceSize = bLastHSliceSize;
auto outNumHSlices = cNumHSlices;
@@ -281,19 +403,25 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
weights.push_back(bTiles[outSliceId][coreId][aSliceId]);
auto computeOp = createSpatCompute(
rewriter, gemmLoc, currOutHSliceType, weights, aHSlices[coreId], [&](ValueRange aHSlicesArgs) {
rewriter, gemmLoc, currOutHSliceType, weights, aHSlices[coreId], [&](ValueRange aHSlicesArgs) -> LogicalResult {
SmallVector<Value> vmmOutputs;
vmmOutputs.reserve(aHSlicesArgs.size());
for (auto [aHSliceId, computeArg] : llvm::enumerate(aHSlicesArgs))
vmmOutputs.push_back(
spatial::SpatWeightedVMMOp::create(rewriter, gemmLoc, currOutHSliceType, aHSliceId, computeArg));
assert(!vmmOutputs.empty() && "vmmOutputs must be non-empty");
spatial::SpatVMMOp::create(rewriter, gemmLoc, currOutHSliceType, aHSliceId, computeArg));
if (vmmOutputs.empty()) {
gemmOp.emitOpError("requires at least one non-empty slice when lowering tiled Gemm to Spatial VMMs");
return failure();
}
Value partialVmmSum = sumTensors(vmmOutputs, rewriter);
spatial::SpatYieldOp::create(rewriter, gemmLoc, partialVmmSum);
return success();
});
if (failed(computeOp))
return failure();
partialResults.push_back(computeOp.getResult(0));
partialResults.push_back(computeOp->getResult(0));
}
if (hasC) {
@@ -313,8 +441,123 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
auto concatComputeOp =
createSpatCompute(rewriter, gemmLoc, gemmOp.getType(), {}, outHSlices, [&](ValueRange blockArgs) {
auto concatOp = tensor::ConcatOp::create(rewriter, gemmLoc, /*axis=*/1, blockArgs);
spatial::SpatYieldOp::create(rewriter, gemmLoc, concatOp.getResult());
spatial::SpatYieldOp::create(rewriter, gemmLoc, createSpatConcat(rewriter, gemmLoc, /*axis=*/1, blockArgs));
});
rewriter.replaceOp(gemmOp, concatComputeOp);
return success();
}
LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp,
ONNXGemmOpAdaptor gemmOpAdaptor,
ConversionPatternRewriter& rewriter) const {
Location loc = gemmOp.getLoc();
Value a = gemmOpAdaptor.getA();
Value b = gemmOpAdaptor.getB();
Value c = gemmOpAdaptor.getC();
if (gemmOpAdaptor.getTransA()) {
gemmOp.emitOpError("requires transA=false before batch Gemm lowering");
return failure();
}
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
auto aType = cast<RankedTensorType>(a.getType());
auto bType = cast<RankedTensorType>(b.getType());
auto outType = cast<RankedTensorType>(gemmOp.getY().getType());
if (!aType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input A");
return failure();
}
if (!bType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input B");
return failure();
}
if (!outType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm result");
return failure();
}
const int64_t numOutRows = aType.getDimSize(0);
if (numOutRows <= 1)
return failure();
// Only handle the single-tile case: K <= crossbarSize and N <= crossbarSize
if (aType.getDimSize(1) > static_cast<int64_t>(crossbarSize.getValue())
|| outType.getDimSize(1) > static_cast<int64_t>(crossbarSize.getValue()))
return failure();
auto scaledB = materializeScaledConstantTensor(b, gemmOpAdaptor.getAlpha().convertToFloat(), rewriter, loc);
if (failed(scaledB))
return failure();
b = *scaledB;
bType = cast<RankedTensorType>(b.getType());
if (gemmOpAdaptor.getTransB()) {
auto bShape = bType.getShape();
auto transposedType = RankedTensorType::get({bShape[1], bShape[0]}, bType.getElementType());
b = transposeForSpatial(b, transposedType, {1, 0}, rewriter, loc);
bType = cast<RankedTensorType>(b.getType());
}
(void) bType;
Value sharedBias;
if (hasC) {
auto scaledC = materializeScaledConstantTensor(c, gemmOpAdaptor.getBeta().convertToFloat(), rewriter, loc);
if (failed(scaledC))
return failure();
c = *scaledC;
auto cType = cast<RankedTensorType>(c.getType());
if (cType.getRank() == 1) {
auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType());
c = expandRankOneBias(c, expandedType, rewriter, loc);
cType = cast<RankedTensorType>(c.getType());
}
if (!cType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm bias");
return failure();
}
if (cType.getRank() != 2) {
pim::emitUnsupportedRankDiagnostic(gemmOp, "Gemm bias", cType.getRank(), {1, 2});
return failure();
}
// Row-specific bias can't share a single template body; fall through to GemmToManyGemv
if (cType.getDimSize(0) == numOutRows && numOutRows > 1)
return failure();
if (cType.getDimSize(0) == 1 && cType.getDimSize(1) == 1)
c = broadcastToVector(c, outType.getDimSize(1), rewriter, loc);
sharedBias = c;
}
SmallVector<Value> aSlices = materializeBatchRowSlices(a, aType, rewriter, loc);
auto aSliceType = cast<RankedTensorType>(aSlices.front().getType());
auto outRowType = RankedTensorType::get({1, outType.getDimSize(1)}, outType.getElementType());
SmallVector<Type> resultTypes(static_cast<size_t>(numOutRows), outRowType);
SmallVector<Value> weights(static_cast<size_t>(numOutRows), b);
auto batchOp = spatial::SpatComputeBatch::create(rewriter,
loc,
TypeRange(resultTypes),
rewriter.getI32IntegerAttr(static_cast<int32_t>(numOutRows)),
ValueRange(weights),
ValueRange(aSlices));
Block* body = rewriter.createBlock(
&batchOp.getBody(), batchOp.getBody().end(), TypeRange {aSliceType}, SmallVector<Location>(1, loc));
rewriter.setInsertionPointToEnd(body);
Value vmmResult = spatial::SpatVMMOp::create(rewriter, loc, outRowType, 0, body->getArgument(0)).getResult();
Value laneResult = vmmResult;
if (sharedBias)
laneResult = spatial::SpatVAddOp::create(rewriter, loc, outRowType, vmmResult, sharedBias).getResult();
spatial::SpatYieldOp::create(rewriter, loc, laneResult);
rewriter.setInsertionPointAfter(batchOp);
SmallVector<Value> laneResults(batchOp->result_begin(), batchOp->result_end());
auto concatComputeOp = createSpatCompute(rewriter, loc, gemmOp.getType(), {}, laneResults, [&](ValueRange args) {
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, /*axis=*/0, args));
});
rewriter.replaceOp(gemmOp, concatComputeOp);
@@ -322,6 +565,7 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
}
void populateGemmPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<GemmToSpatialComputeBatch>(ctx, PatternBenefit(2));
patterns.insert<GemmToManyGemv>(ctx);
patterns.insert<GemvToSpatialCompute>(ctx);
}
@@ -2,10 +2,15 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include <functional>
#include <numeric>
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -14,7 +19,191 @@ using namespace mlir;
namespace onnx_mlir {
namespace {
struct MatMulRank3ToGemm : OpRewritePattern<ONNXMatMulOp> {
static bool haveStaticPositiveShape(ArrayRef<int64_t> shape) {
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
}
static int64_t getStaticShapeElementCount(ArrayRef<int64_t> shape) {
return std::accumulate(shape.begin(), shape.end(), int64_t {1}, std::multiplies<int64_t> {});
}
static FailureOr<SmallVector<int64_t>> inferSupportedBatchShape(ArrayRef<int64_t> lhsBatchShape,
ArrayRef<int64_t> rhsBatchShape) {
if (lhsBatchShape.empty())
return SmallVector<int64_t>(rhsBatchShape.begin(), rhsBatchShape.end());
if (rhsBatchShape.empty())
return SmallVector<int64_t>(lhsBatchShape.begin(), lhsBatchShape.end());
if (!llvm::equal(lhsBatchShape, rhsBatchShape))
return failure();
return SmallVector<int64_t>(lhsBatchShape.begin(), lhsBatchShape.end());
}
static Value collapseBatchDims(Value value,
int64_t batchSize,
int64_t rows,
int64_t cols,
PatternRewriter& rewriter,
Location loc) {
auto type = cast<RankedTensorType>(value.getType());
if (type.getRank() == 2 || type.getRank() == 3)
return value;
auto collapsedType =
RankedTensorType::get({batchSize, rows, cols}, type.getElementType(), type.getEncoding());
SmallVector<ReassociationIndices> reassociation = {
ReassociationIndices {},
ReassociationIndices {static_cast<int64_t>(type.getRank() - 2)},
ReassociationIndices {static_cast<int64_t>(type.getRank() - 1)}
};
for (int64_t dim = 0; dim < type.getRank() - 2; ++dim)
reassociation.front().push_back(dim);
auto buildCollapsed = [&](Value input) -> Value {
return tensor::CollapseShapeOp::create(rewriter, loc, collapsedType, input, reassociation);
};
if (isHostFoldableValue(value))
return buildCollapsed(value);
auto collapseCompute =
createSpatCompute<1>(rewriter, loc, TypeRange {collapsedType}, {}, ValueRange {value}, [&](Value input) {
spatial::SpatYieldOp::create(rewriter, loc, buildCollapsed(input));
});
return collapseCompute.getResult(0);
}
static Value expandBatchDims(Value value,
RankedTensorType outputType,
size_t batchRank,
PatternRewriter& rewriter,
Location loc) {
if (cast<RankedTensorType>(value.getType()) == outputType)
return value;
SmallVector<ReassociationIndices> reassociation = {
ReassociationIndices {},
ReassociationIndices {static_cast<int64_t>(batchRank)},
ReassociationIndices {static_cast<int64_t>(batchRank + 1)}
};
for (size_t dim = 0; dim < batchRank; ++dim)
reassociation.front().push_back(static_cast<int64_t>(dim));
auto expandCompute =
createSpatCompute<1>(rewriter, loc, TypeRange {outputType}, {}, ValueRange {value}, [&](Value input) {
Value expanded = tensor::ExpandShapeOp::create(rewriter, loc, outputType, input, reassociation);
spatial::SpatYieldOp::create(rewriter, loc, expanded);
});
return expandCompute.getResult(0);
}
static Value extractBatchMatrix(Value value,
int64_t batchIndex,
int64_t batchSize,
int64_t rows,
int64_t cols,
PatternRewriter& rewriter,
Location loc) {
auto type = cast<RankedTensorType>(value.getType());
if (type.getRank() == 2)
return value;
auto sliceType = RankedTensorType::get({1, rows, cols}, type.getElementType());
SmallVector<OpFoldResult> offsets = {
rewriter.getIndexAttr(batchSize == 1 ? 0 : batchIndex), rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes = {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(rows), rewriter.getIndexAttr(cols)};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
auto matrixType = RankedTensorType::get({rows, cols}, type.getElementType());
auto buildMatrix = [&](Value input) -> Value {
Value slice = tensor::ExtractSliceOp::create(rewriter, loc, sliceType, input, offsets, sizes, strides);
return tensor::CollapseShapeOp::create(rewriter,
loc,
matrixType,
slice,
SmallVector<ReassociationIndices> {
{0, 1},
{2}
});
};
if (isHostFoldableValue(value))
return buildMatrix(value);
auto batchMatrixCompute =
createSpatCompute<1>(rewriter, loc, TypeRange {matrixType}, {}, ValueRange {value}, [&](Value input) {
spatial::SpatYieldOp::create(rewriter, loc, buildMatrix(input));
});
return batchMatrixCompute.getResult(0);
}
static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Location loc) {
auto type = cast<RankedTensorType>(value.getType());
auto shape = type.getShape();
RankedTensorType transposedType;
SmallVector<int64_t> perm;
if (type.getRank() == 2) {
transposedType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType());
perm = {1, 0};
}
else {
transposedType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType());
perm = {0, 2, 1};
}
auto buildTranspose = [&](Value input) -> Value {
return ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm));
};
if (isHostFoldableValue(value))
return buildTranspose(value);
auto transposeCompute =
createSpatCompute<1>(rewriter, loc, TypeRange {transposedType}, {}, ValueRange {value}, [&](Value input) {
spatial::SpatYieldOp::create(rewriter, loc, buildTranspose(input));
});
return transposeCompute.getResult(0);
}
static Value transposeLastTwoDimsInCompute(Value value, PatternRewriter& rewriter, Location loc) {
auto type = cast<RankedTensorType>(value.getType());
auto shape = type.getShape();
RankedTensorType transposedType;
SmallVector<int64_t> perm;
if (type.getRank() == 2) {
transposedType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType());
perm = {1, 0};
}
else {
transposedType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType());
perm = {0, 2, 1};
}
auto transposeCompute = createSpatCompute<1>(rewriter, loc, transposedType, {}, ValueRange {value}, [&](Value input) {
Value transposed = ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm));
spatial::SpatYieldOp::create(rewriter, loc, transposed);
});
return transposeCompute.getResult(0);
}
static Value concatValues(ValueRange inputs, int64_t axis, PatternRewriter& rewriter, Location loc) {
auto firstType = cast<RankedTensorType>(inputs.front().getType());
SmallVector<int64_t> outputShape(firstType.getShape().begin(), firstType.getShape().end());
int64_t concatDimSize = 0;
for (Value input : inputs)
concatDimSize += cast<RankedTensorType>(input.getType()).getDimSize(axis);
outputShape[axis] = concatDimSize;
auto resultType = RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
if (llvm::all_of(inputs, isHostFoldableValue))
return createSpatConcat(rewriter, loc, axis, inputs);
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {resultType}, {}, inputs, [&](ValueRange args) {
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, axis, args));
});
return concatCompute.getResult(0);
}
struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ONNXMatMulOp matmulOp, PatternRewriter& rewriter) const override {
@@ -24,80 +213,129 @@ struct MatMulRank3ToGemm : OpRewritePattern<ONNXMatMulOp> {
if (!lhsType || !rhsType || !outType || !lhsType.hasStaticShape() || !rhsType.hasStaticShape()
|| !outType.hasStaticShape())
return failure();
if (lhsType.getRank() != 2 || rhsType.getRank() != 3 || outType.getRank() != 3)
if (lhsType.getRank() < 2 || rhsType.getRank() < 2 || outType.getRank() < 2)
return failure();
if (!haveStaticPositiveShape(lhsType.getShape()) || !haveStaticPositiveShape(rhsType.getShape())
|| !haveStaticPositiveShape(outType.getShape()))
return failure();
const int64_t batch = rhsType.getDimSize(0);
const int64_t k = rhsType.getDimSize(1);
const int64_t n = rhsType.getDimSize(2);
const int64_t m = lhsType.getDimSize(0);
if (lhsType.getDimSize(1) != k || outType.getDimSize(0) != batch || outType.getDimSize(1) != m
|| outType.getDimSize(2) != n)
SmallVector<int64_t> lhsBatchShape(lhsType.getShape().begin(), lhsType.getShape().end() - 2);
SmallVector<int64_t> rhsBatchShape(rhsType.getShape().begin(), rhsType.getShape().end() - 2);
auto batchShape = inferSupportedBatchShape(lhsBatchShape, rhsBatchShape);
if (failed(batchShape))
return failure();
const int64_t lhsBatch = lhsBatchShape.empty() ? 1 : getStaticShapeElementCount(lhsBatchShape);
const int64_t rhsBatch = rhsBatchShape.empty() ? 1 : getStaticShapeElementCount(rhsBatchShape);
const int64_t batch = batchShape->empty() ? 1 : getStaticShapeElementCount(*batchShape);
const int64_t m = lhsType.getDimSize(lhsType.getRank() - 2);
const int64_t k = lhsType.getDimSize(lhsType.getRank() - 1);
const int64_t rhsK = rhsType.getDimSize(rhsType.getRank() - 2);
const int64_t n = rhsType.getDimSize(rhsType.getRank() - 1);
if (k != rhsK)
return failure();
if (outType.getRank() == 2) {
if (batch != 1 || outType.getDimSize(0) != m || outType.getDimSize(1) != n)
return failure();
}
else {
SmallVector<int64_t> outBatchShape(outType.getShape().begin(), outType.getShape().end() - 2);
if (!llvm::equal(outBatchShape, *batchShape) || outType.getDimSize(outType.getRank() - 2) != m
|| outType.getDimSize(outType.getRank() - 1) != n)
return failure();
}
Location loc = matmulOp.getLoc();
auto lhsTransposedType = RankedTensorType::get({k, m}, lhsType.getElementType());
auto rhsSliceType = RankedTensorType::get({1, k, 1}, rhsType.getElementType());
auto rhsRowType = RankedTensorType::get({1, k}, rhsType.getElementType());
auto gemmRowType = RankedTensorType::get({1, m}, outType.getElementType());
auto gemmOutType = RankedTensorType::get({batch * n, m}, outType.getElementType());
auto gemmExpandedType = RankedTensorType::get({batch, n, m}, outType.getElementType());
bool useTransposedForm = isHostFoldableValue(matmulOp.getA()) && !isHostFoldableValue(matmulOp.getB());
Value lhsTransposed =
ONNXTransposeOp::create(rewriter, loc, lhsTransposedType, matmulOp.getA(), rewriter.getI64ArrayAttr({1, 0}));
Value lhs = collapseBatchDims(matmulOp.getA(), lhsBatch, m, k, rewriter, loc);
Value rhs = collapseBatchDims(matmulOp.getB(), rhsBatch, k, n, rewriter, loc);
int64_t lhsBatchForGemm = lhsBatch;
int64_t rhsBatchForGemm = rhsBatch;
int64_t gemmM = m;
int64_t gemmK = k;
int64_t gemmN = n;
if (useTransposedForm) {
lhs = transposeLastTwoDimsInCompute(matmulOp.getB(), rewriter, loc);
lhsBatchForGemm = rhsBatch;
rhs = transposeLastTwoDims(matmulOp.getA(), rewriter, loc);
rhsBatchForGemm = lhsBatch;
gemmM = n;
gemmN = m;
}
auto gemmType = RankedTensorType::get({gemmM, gemmN}, outType.getElementType());
auto batchedOutType = RankedTensorType::get({1, m, n}, outType.getElementType());
Value none = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
SmallVector<Value> gemmRows;
gemmRows.reserve(batch * n);
for (int64_t batchIdx = 0; batchIdx < batch; batchIdx++) {
for (int64_t colIdx = 0; colIdx < n; colIdx++) {
SmallVector<OpFoldResult> offsets = {
rewriter.getIndexAttr(batchIdx), rewriter.getIndexAttr(0), rewriter.getIndexAttr(colIdx)};
SmallVector<OpFoldResult> sizes = {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(k), rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> strides = {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
Value rhsSlice =
tensor::ExtractSliceOp::create(rewriter, loc, rhsSliceType, matmulOp.getB(), offsets, sizes, strides);
Value rhsRow = tensor::CollapseShapeOp::create(rewriter,
if (outType.getRank() == 2) {
Value lhsMatrix = extractBatchMatrix(lhs, /*batchIndex=*/0, lhsBatchForGemm, gemmM, gemmK, rewriter, loc);
Value rhsMatrix = extractBatchMatrix(rhs, /*batchIndex=*/0, rhsBatchForGemm, gemmK, gemmN, rewriter, loc);
Value gemmResult = ONNXGemmOp::create(rewriter,
loc,
rhsRowType,
rhsSlice,
SmallVector<ReassociationIndices> {
{0},
{1, 2}
});
auto gemmOp = ONNXGemmOp::create(rewriter,
loc,
gemmRowType,
rhsRow,
lhsTransposed,
gemmType,
lhsMatrix,
rhsMatrix,
none,
rewriter.getF32FloatAttr(1.0f),
rewriter.getF32FloatAttr(1.0f),
rewriter.getBoolAttr(false),
rewriter.getBoolAttr(false));
gemmRows.push_back(gemmOp.getY());
}
}
auto concatComputeOp = createSpatCompute(rewriter, loc, gemmOutType, {}, gemmRows, [&](ValueRange gemmRowsArgs) {
auto concatOp = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, gemmRowsArgs);
spatial::SpatYieldOp::create(rewriter, loc, concatOp.getResult());
rewriter.getBoolAttr(false))
.getY();
if (useTransposedForm) {
auto transposeCompute =
createSpatCompute<1>(rewriter, loc, TypeRange {outType}, {}, gemmResult, [&](Value input) {
Value transposed = ONNXTransposeOp::create(rewriter, loc, outType, input, rewriter.getI64ArrayAttr({1, 0}));
spatial::SpatYieldOp::create(rewriter, loc, transposed);
});
gemmResult = transposeCompute.getResult(0);
}
rewriter.replaceOp(matmulOp, gemmResult);
return success();
}
Value gemmOut = concatComputeOp.getResult(0);
Value gemmExpanded = tensor::ExpandShapeOp::create(rewriter,
SmallVector<Value> batchResults;
batchResults.reserve(batch);
for (int64_t batchIdx = 0; batchIdx < batch; batchIdx++) {
Value lhsMatrix = extractBatchMatrix(lhs, batchIdx, lhsBatchForGemm, gemmM, gemmK, rewriter, loc);
Value rhsMatrix = extractBatchMatrix(rhs, batchIdx, rhsBatchForGemm, gemmK, gemmN, rewriter, loc);
Value gemmResult = ONNXGemmOp::create(rewriter,
loc,
gemmExpandedType,
gemmOut,
gemmType,
lhsMatrix,
rhsMatrix,
none,
rewriter.getF32FloatAttr(1.0f),
rewriter.getF32FloatAttr(1.0f),
rewriter.getBoolAttr(false),
rewriter.getBoolAttr(false))
.getY();
auto batchResultCompute =
createSpatCompute<1>(rewriter, loc, TypeRange {batchedOutType}, {}, gemmResult, [&](Value input) {
Value resultMatrix = input;
if (useTransposedForm) {
resultMatrix = ONNXTransposeOp::create(rewriter,
loc,
RankedTensorType::get({m, n}, outType.getElementType()),
input,
rewriter.getI64ArrayAttr({1, 0}));
}
Value expanded = tensor::ExpandShapeOp::create(rewriter,
loc,
batchedOutType,
resultMatrix,
SmallVector<ReassociationIndices> {
{0, 1},
{2}
});
Value result = ONNXTransposeOp::create(rewriter, loc, outType, gemmExpanded, rewriter.getI64ArrayAttr({0, 2, 1}));
spatial::SpatYieldOp::create(rewriter, loc, expanded);
});
batchResults.push_back(batchResultCompute.getResult(0));
}
Value result = concatValues(batchResults, /*axis=*/0, rewriter, loc);
result = expandBatchDims(result, outType, batchShape->size(), rewriter, loc);
rewriter.replaceOp(matmulOp, result);
return success();
}
@@ -106,7 +344,7 @@ struct MatMulRank3ToGemm : OpRewritePattern<ONNXMatMulOp> {
} // namespace
void populateMatMulRewritePatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<MatMulRank3ToGemm>(ctx);
patterns.insert<MatMulToGemm>(ctx);
}
} // namespace onnx_mlir
@@ -5,8 +5,9 @@
#include <algorithm>
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -81,6 +82,24 @@ createAverageCompute(Value input, RankedTensorType resultType, ConversionPattern
return computeOp.getResult(0);
}
static Value concatValues(ValueRange inputs, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
auto firstType = cast<RankedTensorType>(inputs.front().getType());
SmallVector<int64_t> outputShape(firstType.getShape().begin(), firstType.getShape().end());
int64_t concatDimSize = 0;
for (Value input : inputs)
concatDimSize += cast<RankedTensorType>(input.getType()).getDimSize(axis);
outputShape[axis] = concatDimSize;
auto resultType = RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
if (llvm::all_of(inputs, isHostFoldableValue))
return createSpatConcat(rewriter, loc, axis, inputs);
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {resultType}, {}, inputs, [&](ValueRange args) {
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, axis, args));
});
return concatCompute.getResult(0);
}
static Value buildReduceMeanKeepdims(Value input,
ArrayRef<bool> reducedAxes,
int64_t axis,
@@ -100,8 +119,7 @@ static Value buildReduceMeanKeepdims(Value input,
for (Value slice : slices)
reducedSlices.push_back(buildReduceMeanKeepdims(slice, reducedAxes, axis + 1, leafType, rewriter, loc));
return reducedSlices.size() == 1 ? reducedSlices.front()
: tensor::ConcatOp::create(rewriter, loc, axis, reducedSlices).getResult();
return concatValues(reducedSlices, axis, rewriter, loc);
}
static Value squeezeReducedAxes(Value keepdimsValue,
@@ -116,9 +134,16 @@ static Value squeezeReducedAxes(Value keepdimsValue,
return tensor::FromElementsOp::create(rewriter, loc, resultType, ValueRange {element});
}
return tensor::CollapseShapeOp::create(
rewriter, loc, resultType, keepdimsValue, buildCollapseReassociation(reducedAxes))
.getResult();
auto reassociation = buildCollapseReassociation(reducedAxes);
if (isHostFoldableValue(keepdimsValue))
return tensor::CollapseShapeOp::create(rewriter, loc, resultType, keepdimsValue, reassociation).getResult();
auto squeezeCompute =
createSpatCompute<1>(rewriter, loc, TypeRange {resultType}, {}, ValueRange {keepdimsValue}, [&](Value input) {
Value collapsed = tensor::CollapseShapeOp::create(rewriter, loc, resultType, input, reassociation);
spatial::SpatYieldOp::create(rewriter, loc, collapsed);
});
return squeezeCompute.getResult(0);
}
struct ReduceMeanToSpatialCompute : OpConversionPattern<ONNXReduceMeanV13Op> {
@@ -1,18 +1,20 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/SmallVector.h"
#include <algorithm>
#include <cassert>
#include <optional>
#include <type_traits>
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -31,13 +33,6 @@ static int64_t getOptionalI64(std::optional<ArrayAttrT> arrayAttr, size_t index,
return arrayAttr ? getI64(*arrayAttr, index) : defaultValue;
}
static Value concatAlongAxis(ConversionPatternRewriter& rewriter, Location loc, int64_t axis, ArrayRef<Value> values) {
assert(!values.empty() && "Expected at least one value to concatenate.");
if (values.size() == 1)
return values.front();
return tensor::ConcatOp::create(rewriter, loc, axis, values);
}
static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Location loc, Value tile) {
auto tileType = cast<RankedTensorType>(tile.getType());
Value empty = tensor::EmptyOp::create(rewriter, loc, tileType.getShape(), tileType.getElementType());
@@ -52,27 +47,126 @@ static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Loca
return tensor::InsertSliceOp::create(rewriter, loc, tile, empty, offsets, sizes, strides);
}
template <typename ReduceOp>
static Value reduceWindowValues(ConversionPatternRewriter& rewriter, Location loc, ArrayRef<Value> windowValues) {
assert(!windowValues.empty() && "Expected at least one pool window value.");
static Value
createPoolFillElement(ConversionPatternRewriter& rewriter, Location loc, Type elementType, bool useMinimumValue) {
if (!useMinimumValue)
return arith::ConstantOp::create(rewriter, loc, elementType, rewriter.getZeroAttr(elementType));
Value reduced = windowValues.front();
for (Value value : windowValues.drop_front())
reduced = ReduceOp::create(rewriter, loc, reduced.getType(), reduced, value);
return reduced;
if (auto floatType = dyn_cast<FloatType>(elementType)) {
auto minValue = llvm::APFloat::getInf(floatType.getFloatSemantics(), /*Negative=*/true);
return arith::ConstantOp::create(rewriter, loc, elementType, rewriter.getFloatAttr(floatType, minValue));
}
if (auto integerType = dyn_cast<IntegerType>(elementType)) {
auto minValue = llvm::APInt::getSignedMinValue(integerType.getWidth());
return arith::ConstantOp::create(rewriter, loc, elementType, rewriter.getIntegerAttr(integerType, minValue));
}
llvm_unreachable("unsupported pool element type");
}
static Value
scaleAverageWindow(ConversionPatternRewriter& rewriter, Location loc, Value reducedWindow, int64_t divisor) {
assert(divisor > 0 && "AveragePool divisor must be positive.");
if (divisor == 1)
return reducedWindow;
static Value createPoolFillTensor(ConversionPatternRewriter& rewriter,
Location loc,
RankedTensorType tensorType,
bool useMinimumValue) {
auto fillElement = createPoolFillElement(rewriter, loc, tensorType.getElementType(), useMinimumValue);
return tensor::SplatOp::create(rewriter, loc, tensorType, fillElement);
}
auto tileType = cast<RankedTensorType>(reducedWindow.getType());
double scale = 1.0 / static_cast<double>(divisor);
auto scaleAttr = DenseElementsAttr::get(tileType, rewriter.getFloatAttr(tileType.getElementType(), scale));
Value scaleTensor = arith::ConstantOp::create(rewriter, loc, tileType, scaleAttr);
return spatial::SpatVMulOp::create(rewriter, loc, tileType, reducedWindow, scaleTensor);
template <typename PoolOp>
static Value createPaddedPoolInput(ConversionPatternRewriter& rewriter,
Location loc,
PoolOp poolOp,
Value input,
RankedTensorType inputType,
int64_t padTop,
int64_t padLeft,
int64_t padBottom,
int64_t padRight) {
if (padTop == 0 && padLeft == 0 && padBottom == 0 && padRight == 0)
return input;
auto paddedType = RankedTensorType::get({inputType.getDimSize(0),
inputType.getDimSize(1),
inputType.getDimSize(2) + padTop + padBottom,
inputType.getDimSize(3) + padLeft + padRight},
inputType.getElementType(),
inputType.getEncoding());
SmallVector<OpFoldResult> lowPads = {
rewriter.getIndexAttr(0), rewriter.getIndexAttr(0), rewriter.getIndexAttr(padTop), rewriter.getIndexAttr(padLeft)};
SmallVector<OpFoldResult> highPads = {rewriter.getIndexAttr(0),
rewriter.getIndexAttr(0),
rewriter.getIndexAttr(padBottom),
rewriter.getIndexAttr(padRight)};
auto padOp = tensor::PadOp::create(rewriter, loc, paddedType, input, lowPads, highPads);
auto* padBlock = new Block();
for (int index = 0; index < paddedType.getRank(); ++index)
padBlock->addArgument(rewriter.getIndexType(), loc);
padOp.getRegion().push_back(padBlock);
rewriter.setInsertionPointToStart(padBlock);
Value padValue =
createPoolFillElement(rewriter, loc, inputType.getElementType(), std::is_same_v<PoolOp, ONNXMaxPoolSingleOutOp>);
tensor::YieldOp::create(rewriter, loc, padValue);
rewriter.setInsertionPointAfter(padOp);
return padOp.getResult();
}
static FailureOr<Value> createAverageScaleTensor(ConversionPatternRewriter& rewriter,
Location loc,
Operation* op,
RankedTensorType outType,
int64_t channels,
int64_t inputHeight,
int64_t inputWidth,
int64_t outputHeight,
int64_t outputWidth,
int64_t kernelHeight,
int64_t kernelWidth,
int64_t strideHeight,
int64_t strideWidth,
int64_t dilationHeight,
int64_t dilationWidth,
int64_t padTop,
int64_t padLeft,
bool countIncludePad) {
auto elemType = dyn_cast<FloatType>(outType.getElementType());
if (!elemType) {
op->emitOpError("AveragePool lowering requires a floating-point element type");
return failure();
}
auto scaleType = RankedTensorType::get({1, channels, outputHeight, outputWidth}, elemType, outType.getEncoding());
SmallVector<Attribute> scaleValues;
scaleValues.reserve(static_cast<size_t>(channels * outputHeight * outputWidth));
for (int64_t channel = 0; channel < channels; ++channel) {
(void) channel;
for (int64_t outH = 0; outH < outputHeight; ++outH) {
for (int64_t outW = 0; outW < outputWidth; ++outW) {
int64_t validCount = 0;
for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) {
const int64_t inH = outH * strideHeight + kernelH * dilationHeight - padTop;
if (inH < 0 || inH >= inputHeight)
continue;
for (int64_t kernelW = 0; kernelW < kernelWidth; ++kernelW) {
const int64_t inW = outW * strideWidth + kernelW * dilationWidth - padLeft;
if (inW < 0 || inW >= inputWidth)
continue;
++validCount;
}
}
const int64_t divisor = countIncludePad ? kernelHeight * kernelWidth : validCount;
if (divisor <= 0) {
op->emitOpError("AveragePool divisor must be positive");
return failure();
}
scaleValues.push_back(rewriter.getFloatAttr(elemType, 1.0 / static_cast<double>(divisor)));
}
}
}
auto scaleAttr = DenseElementsAttr::get(scaleType, scaleValues);
return arith::ConstantOp::create(rewriter, loc, scaleType, scaleAttr).getResult();
}
template <typename PoolOp>
@@ -150,89 +244,133 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
}
}
(void) padBottom;
(void) padRight;
const int64_t xbarSize = static_cast<int64_t>(crossbarSize.getValue());
const int64_t channelTileCount = (channels + xbarSize - 1) / xbarSize;
const int64_t outputPatchCount = batchSize * outputHeight * outputWidth;
const bool countIncludePad = [&]() {
if constexpr (std::is_same_v<PoolOp, ONNXAveragePoolOp>)
return poolOp.getCountIncludePad() == 1;
return true;
}();
Value averageScaleTensor;
if constexpr (std::is_same_v<PoolOp, ONNXAveragePoolOp>) {
auto maybeAverageScaleTensor = createAverageScaleTensor(rewriter,
loc,
poolOp,
outType,
channels,
inputHeight,
inputWidth,
outputHeight,
outputWidth,
kernelHeight,
kernelWidth,
strideHeight,
strideWidth,
dilationHeight,
dilationWidth,
padTop,
padLeft,
countIncludePad);
if (failed(maybeAverageScaleTensor))
return failure();
averageScaleTensor = *maybeAverageScaleTensor;
}
constexpr size_t numInputs = 1;
auto computeOp =
createSpatCompute<numInputs>(rewriter, loc, outType, {}, ValueRange {x}, [&](Value xArg) -> LogicalResult {
SmallVector<Value> batchResults;
batchResults.reserve(batchSize);
Value paddedInput =
createPaddedPoolInput(rewriter, loc, poolOp, xArg, xType, padTop, padLeft, padBottom, padRight);
Value pooledOutputInit = tensor::EmptyOp::create(rewriter, loc, outType.getShape(), outType.getElementType());
for (int64_t batch = 0; batch < batchSize; ++batch) {
SmallVector<Value> rows;
rows.reserve(outputHeight);
Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
Value c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
Value cOutputPatchCount = arith::ConstantIndexOp::create(rewriter, loc, outputPatchCount);
Value cOutputPixelsPerBatch = arith::ConstantIndexOp::create(rewriter, loc, outputHeight * outputWidth);
Value cOutputWidth = arith::ConstantIndexOp::create(rewriter, loc, outputWidth);
Value cStrideHeight = arith::ConstantIndexOp::create(rewriter, loc, strideHeight);
Value cStrideWidth = arith::ConstantIndexOp::create(rewriter, loc, strideWidth);
for (int64_t outH = 0; outH < outputHeight; ++outH) {
SmallVector<Value> rowPixels;
rowPixels.reserve(outputWidth);
auto outputLoop = scf::ForOp::create(rewriter, loc, c0, cOutputPatchCount, c1, ValueRange {pooledOutputInit});
rewriter.setInsertionPointToStart(outputLoop.getBody());
for (int64_t outW = 0; outW < outputWidth; ++outW) {
SmallVector<Value> outputChannelTiles;
outputChannelTiles.reserve(channelTileCount);
Value outputPatchIndex = outputLoop.getInductionVar();
Value pooledOutputAcc = outputLoop.getRegionIterArgs().front();
Value batchIndex = arith::DivUIOp::create(rewriter, loc, outputPatchIndex, cOutputPixelsPerBatch);
Value batchPatchIndex = arith::RemUIOp::create(rewriter, loc, outputPatchIndex, cOutputPixelsPerBatch);
Value outHeightIndex = arith::DivUIOp::create(rewriter, loc, batchPatchIndex, cOutputWidth);
Value outWidthIndex = arith::RemUIOp::create(rewriter, loc, batchPatchIndex, cOutputWidth);
Value windowBaseH = arith::MulIOp::create(rewriter, loc, outHeightIndex, cStrideHeight);
Value windowBaseW = arith::MulIOp::create(rewriter, loc, outWidthIndex, cStrideWidth);
Value updatedOutput = pooledOutputAcc;
for (int64_t channelTile = 0; channelTile < channelTileCount; ++channelTile) {
const int64_t tileChannels = std::min<int64_t>(xbarSize, channels - channelTile * xbarSize);
auto tileType = RankedTensorType::get({1, tileChannels, 1, 1}, outType.getElementType());
Value reducedWindow =
createPoolFillTensor(rewriter, loc, tileType, std::is_same_v<PoolOp, ONNXMaxPoolSingleOutOp>);
SmallVector<Value> windowValues;
windowValues.reserve(kernelHeight * kernelWidth);
for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) {
const int64_t inH = outH * strideHeight + kernelH * dilationHeight - padTop;
if (inH < 0 || inH >= inputHeight)
continue;
Value paddedInH = windowBaseH;
if (kernelH * dilationHeight != 0) {
Value kernelHOffset = arith::ConstantIndexOp::create(rewriter, loc, kernelH * dilationHeight);
paddedInH = arith::AddIOp::create(rewriter, loc, paddedInH, kernelHOffset);
}
for (int64_t kernelW = 0; kernelW < kernelWidth; ++kernelW) {
const int64_t inW = outW * strideWidth + kernelW * dilationWidth - padLeft;
if (inW < 0 || inW >= inputWidth)
continue;
Value paddedInW = windowBaseW;
if (kernelW * dilationWidth != 0) {
Value kernelWOffset = arith::ConstantIndexOp::create(rewriter, loc, kernelW * dilationWidth);
paddedInW = arith::AddIOp::create(rewriter, loc, paddedInW, kernelWOffset);
}
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(batch),
rewriter.getIndexAttr(channelTile * xbarSize),
rewriter.getIndexAttr(inH),
rewriter.getIndexAttr(inW)};
SmallVector<OpFoldResult> offsets = {
batchIndex, rewriter.getIndexAttr(channelTile * xbarSize), paddedInH, paddedInW};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(tileChannels),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1),
SmallVector<OpFoldResult> strides = {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
Value windowValue =
tensor::ExtractSliceOp::create(rewriter, loc, tileType, paddedInput, offsets, sizes, strides);
windowValue = materializeContiguousTile(rewriter, loc, windowValue);
reducedWindow = ReduceOp::create(rewriter, loc, tileType, reducedWindow, windowValue);
}
}
if constexpr (std::is_same_v<PoolOp, ONNXAveragePoolOp>) {
SmallVector<OpFoldResult> scaleOffsets = {
rewriter.getIndexAttr(0), rewriter.getIndexAttr(channelTile * xbarSize), outHeightIndex, outWidthIndex};
SmallVector<OpFoldResult> scaleSizes = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(tileChannels),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1)};
Value windowValue =
tensor::ExtractSliceOp::create(rewriter, loc, tileType, xArg, offsets, sizes, strides);
windowValue = materializeContiguousTile(rewriter, loc, windowValue);
windowValues.push_back(windowValue);
}
SmallVector<OpFoldResult> scaleStrides = {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
Value scaleSlice = tensor::ExtractSliceOp::create(
rewriter, loc, tileType, averageScaleTensor, scaleOffsets, scaleSizes, scaleStrides);
scaleSlice = materializeContiguousTile(rewriter, loc, scaleSlice);
reducedWindow = spatial::SpatVMulOp::create(rewriter, loc, tileType, reducedWindow, scaleSlice);
}
if (windowValues.empty())
return rewriter.notifyMatchFailure(poolOp, "pool window resolved to zero valid elements.");
Value reducedWindow = reduceWindowValues<ReduceOp>(rewriter, loc, windowValues);
if constexpr (std::is_same_v<PoolOp, ONNXAveragePoolOp>) {
const bool countIncludePad = poolOp.getCountIncludePad() == 1;
const int64_t divisor =
countIncludePad ? kernelHeight * kernelWidth : static_cast<int64_t>(windowValues.size());
reducedWindow = scaleAverageWindow(rewriter, loc, reducedWindow, divisor);
SmallVector<OpFoldResult> outputOffsets = {
batchIndex, rewriter.getIndexAttr(channelTile * xbarSize), outHeightIndex, outWidthIndex};
SmallVector<OpFoldResult> outputSizes = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(tileChannels),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> outputStrides = {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
updatedOutput = tensor::InsertSliceOp::create(
rewriter, loc, reducedWindow, updatedOutput, outputOffsets, outputSizes, outputStrides);
}
outputChannelTiles.push_back(reducedWindow);
}
scf::YieldOp::create(rewriter, loc, updatedOutput);
rowPixels.push_back(concatAlongAxis(rewriter, loc, /*axis=*/1, outputChannelTiles));
}
rows.push_back(concatAlongAxis(rewriter, loc, /*axis=*/3, rowPixels));
}
batchResults.push_back(concatAlongAxis(rewriter, loc, /*axis=*/2, rows));
}
Value pooledOutput = concatAlongAxis(rewriter, loc, /*axis=*/0, batchResults);
spatial::SpatYieldOp::create(rewriter, loc, pooledOutput);
rewriter.setInsertionPointAfter(outputLoop);
spatial::SpatYieldOp::create(rewriter, loc, outputLoop.getResult(0));
return success();
});
if (failed(computeOp))
@@ -1,6 +1,6 @@
#include "mlir/Transforms/DialectConversion.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -1,6 +1,6 @@
#include "mlir/Transforms/DialectConversion.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -1,8 +1,10 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/DialectConversion.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -21,36 +23,83 @@ static SmallVector<int64_t> permuteShape(ArrayRef<int64_t> shape, ArrayRef<int64
return permutedShape;
}
static Value createSoftmaxCompute(Value input, ConversionPatternRewriter& rewriter, Location loc) {
static Value buildLoopSoftmaxSlice(Value input,
Value accumulator,
RankedTensorType inputType,
ArrayRef<Value> outerIndices,
ConversionPatternRewriter& rewriter,
Location loc) {
int64_t rank = inputType.getRank();
SmallVector<int64_t> sliceShape(static_cast<size_t>(rank - 1), 1);
sliceShape.push_back(inputType.getDimSize(rank - 1));
auto sliceType = RankedTensorType::get(sliceShape, inputType.getElementType(), inputType.getEncoding());
SmallVector<OpFoldResult> offsets;
SmallVector<OpFoldResult> sizes;
SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
offsets.reserve(rank);
sizes.reserve(rank);
for (Value outerIndex : outerIndices) {
offsets.push_back(outerIndex);
sizes.push_back(rewriter.getIndexAttr(1));
}
offsets.push_back(rewriter.getIndexAttr(0));
sizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(rank - 1)));
Value inputSlice = tensor::ExtractSliceOp::create(rewriter, loc, sliceType, input, offsets, sizes, strides);
Value softmaxSlice = spatial::SpatSoftmaxOp::create(rewriter, loc, sliceType, inputSlice).getResult();
return tensor::InsertSliceOp::create(rewriter, loc, softmaxSlice, accumulator, offsets, sizes, strides);
}
static Value buildLoopSoftmaxNest(Value input,
Value accumulator,
RankedTensorType inputType,
int64_t axis,
SmallVectorImpl<Value>& outerIndices,
ConversionPatternRewriter& rewriter,
Location loc) {
if (axis == inputType.getRank() - 1)
return buildLoopSoftmaxSlice(input, accumulator, inputType, outerIndices, rewriter, loc);
Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
Value c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
Value cUpper = arith::ConstantIndexOp::create(rewriter, loc, inputType.getDimSize(axis));
auto loop = scf::ForOp::create(rewriter, loc, c0, cUpper, c1, ValueRange {accumulator});
rewriter.setInsertionPointToStart(loop.getBody());
Value loopIndex = loop.getInductionVar();
Value loopAccumulator = loop.getRegionIterArgs().front();
outerIndices.push_back(loopIndex);
Value updatedAccumulator =
buildLoopSoftmaxNest(input, loopAccumulator, inputType, axis + 1, outerIndices, rewriter, loc);
outerIndices.pop_back();
scf::YieldOp::create(rewriter, loc, updatedAccumulator);
rewriter.setInsertionPointAfter(loop);
return loop.getResult(0);
}
static Value createLoopSoftmaxCompute(Value input, ConversionPatternRewriter& rewriter, Location loc) {
auto inputType = cast<RankedTensorType>(input.getType());
constexpr size_t numInputs = 1;
auto computeOp =
createSpatCompute<numInputs>(rewriter, loc, TypeRange {inputType}, {}, ValueRange {input}, [&](Value x) {
auto softmaxOp = spatial::SpatSoftmaxOp::create(rewriter, loc, inputType, x);
spatial::SpatYieldOp::create(rewriter, loc, softmaxOp.getResult());
if (inputType.getRank() == 1) {
Value softmax = spatial::SpatSoftmaxOp::create(rewriter, loc, inputType, x).getResult();
spatial::SpatYieldOp::create(rewriter, loc, softmax);
return;
}
Value outputInit = tensor::EmptyOp::create(rewriter, loc, inputType.getShape(), inputType.getElementType());
SmallVector<Value> outerIndices;
Value result = buildLoopSoftmaxNest(x, outputInit, inputType, /*axis=*/0, outerIndices, rewriter, loc);
spatial::SpatYieldOp::create(rewriter, loc, result);
});
return computeOp.getResult(0);
}
static Value
buildSoftmax(Value input, int64_t softmaxAxis, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
auto inputType = cast<RankedTensorType>(input.getType());
if (axis == inputType.getRank())
return createSoftmaxCompute(input, rewriter, loc);
if (axis == softmaxAxis)
return buildSoftmax(input, softmaxAxis, axis + 1, rewriter, loc);
SmallVector<Value> slices = sliceTensor(input, axis, /*sliceSize=*/1, rewriter, loc);
SmallVector<Value> rebuiltSlices;
rebuiltSlices.reserve(slices.size());
for (Value slice : slices)
rebuiltSlices.push_back(buildSoftmax(slice, softmaxAxis, axis + 1, rewriter, loc));
return rebuiltSlices.size() == 1 ? rebuiltSlices.front()
: tensor::ConcatOp::create(rewriter, loc, axis, rebuiltSlices).getResult();
}
struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
using OpConversionPattern::OpConversionPattern;
@@ -68,7 +117,7 @@ struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
Value input = adaptor.getInput();
Value result;
if (axis == inputType.getRank() - 1) {
result = buildSoftmax(input, axis, /*axis=*/0, rewriter, softmaxOp.getLoc());
result = createLoopSoftmaxCompute(input, rewriter, softmaxOp.getLoc());
}
else {
SmallVector<int64_t> permutation;
@@ -91,10 +140,14 @@ struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
spatial::SpatYieldOp::create(rewriter, softmaxOp.getLoc(), transposed);
});
Value transposedInput = preTransposeCompute.getResult(0);
Value transposedResult = buildSoftmax(
transposedInput, /*softmaxAxis=*/inputType.getRank() - 1, /*axis=*/0, rewriter, softmaxOp.getLoc());
result = ONNXTransposeOp::create(
rewriter, softmaxOp.getLoc(), inputType, transposedResult, rewriter.getI64ArrayAttr(inversePermutation));
Value transposedResult = createLoopSoftmaxCompute(transposedInput, rewriter, softmaxOp.getLoc());
auto postTransposeCompute =
createSpatCompute<1>(rewriter, softmaxOp.getLoc(), TypeRange {inputType}, {}, transposedResult, [&](Value x) {
Value transposed = ONNXTransposeOp::create(
rewriter, softmaxOp.getLoc(), inputType, x, rewriter.getI64ArrayAttr(inversePermutation));
spatial::SpatYieldOp::create(rewriter, softmaxOp.getLoc(), transposed);
});
result = postTransposeCompute.getResult(0);
}
rewriter.replaceOp(softmaxOp, result);
@@ -1,7 +1,10 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/PatternMatch.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/ComputeRegionBuilder.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
@@ -17,7 +20,17 @@ struct Concat : public OpConversionPattern<ONNXConcatOp> {
auto inputs = adaptor.getInputs();
int64_t axis = adaptor.getAxis();
rewriter.replaceOpWithNewOp<tensor::ConcatOp>(maxpoolOp, axis, inputs);
if (llvm::all_of(inputs, isHostFoldableValue)) {
rewriter.replaceOp(maxpoolOp, createSpatConcat(rewriter, maxpoolOp.getLoc(), axis, inputs));
return success();
}
auto computeOp = createSpatCompute(
rewriter, maxpoolOp.getLoc(), TypeRange {maxpoolOp.getResult().getType()}, {}, inputs, [&](ValueRange args) {
spatial::SpatYieldOp::create(
rewriter, maxpoolOp.getLoc(), createSpatConcat(rewriter, maxpoolOp.getLoc(), axis, args));
});
rewriter.replaceOp(maxpoolOp, computeOp.getResults());
return success();
}
@@ -5,8 +5,8 @@
#include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -49,7 +49,7 @@ static Value concatGatherSlices(Value data,
}
if (slices.empty())
return {};
return slices.size() == 1 ? slices.front() : tensor::ConcatOp::create(rewriter, loc, axis, slices).getResult();
return createSpatConcat(rewriter, loc, axis, slices);
}
static Value addLeadingGatherDim(Value value, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
@@ -130,9 +130,7 @@ struct Gather : OpConversionPattern<ONNXGatherOp> {
return failure();
rows.push_back(addLeadingGatherDim(gatheredRow, axis, rewriter, loc));
}
result = rows.size() == 1
? rows.front()
: tensor::ConcatOp::create(rewriter, loc, /*axis=*/axis, rows).getResult();
result = createSpatConcat(rewriter, loc, /*axis=*/axis, rows);
}
else {
return failure();
@@ -3,7 +3,10 @@
#include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
@@ -77,6 +80,22 @@ static bool inferExpandReassociation(ArrayRef<int64_t> sourceShape,
return sourceIdx == sourceShape.size() && resultIdx == resultShape.size();
}
static SmallVector<ReassociationIndices> getCollapseTo1DReassociation(size_t rank) {
SmallVector<ReassociationIndices> reassociation(1);
reassociation.front().reserve(rank);
for (size_t dim = 0; dim < rank; ++dim)
reassociation.front().push_back(static_cast<int64_t>(dim));
return reassociation;
}
static SmallVector<ReassociationIndices> getExpandFrom1DReassociation(size_t rank) {
SmallVector<ReassociationIndices> reassociation(1);
reassociation.front().reserve(rank);
for (size_t dim = 0; dim < rank; ++dim)
reassociation.front().push_back(static_cast<int64_t>(dim));
return reassociation;
}
struct Reshape : OpConversionPattern<ONNXReshapeOp> {
using OpConversionPattern::OpConversionPattern;
@@ -95,18 +114,50 @@ struct Reshape : OpConversionPattern<ONNXReshapeOp> {
return success();
}
SmallVector<ReassociationIndices> reassociation;
if (sourceType.getRank() > resultType.getRank()
&& inferCollapseReassociation(sourceType.getShape(), resultType.getShape(), reassociation)) {
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(reshapeOp, resultType, adaptor.getData(), reassociation);
auto replaceWithReshape = [&](auto buildReshape) -> LogicalResult {
if (isHostFoldableValue(adaptor.getData())) {
rewriter.replaceOp(reshapeOp, buildReshape(adaptor.getData()));
return success();
}
if (sourceType.getRank() < resultType.getRank()
&& inferExpandReassociation(sourceType.getShape(), resultType.getShape(), reassociation)) {
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(reshapeOp, resultType, adaptor.getData(), reassociation);
auto computeOp = createSpatCompute<1>(
rewriter, reshapeOp.getLoc(), TypeRange {resultType}, {}, adaptor.getData(), [&](Value data) {
Value reshaped = buildReshape(data);
spatial::SpatYieldOp::create(rewriter, reshapeOp.getLoc(), reshaped);
});
rewriter.replaceOp(reshapeOp, computeOp.getResults());
return success();
};
SmallVector<ReassociationIndices> reassociation;
if (sourceType.getRank() > resultType.getRank()
&& inferCollapseReassociation(sourceType.getShape(), resultType.getShape(), reassociation))
return replaceWithReshape([&](Value data) {
return tensor::CollapseShapeOp::create(rewriter, reshapeOp.getLoc(), resultType, data, reassociation);
});
if (sourceType.getRank() < resultType.getRank()
&& inferExpandReassociation(sourceType.getShape(), resultType.getShape(), reassociation))
return replaceWithReshape([&](Value data) {
return tensor::ExpandShapeOp::create(rewriter, reshapeOp.getLoc(), resultType, data, reassociation);
});
if (sourceType.getNumElements() != resultType.getNumElements())
return failure();
return replaceWithReshape([&](Value data) -> Value {
Value reshaped = data;
if (sourceType.getRank() != 1) {
auto flatType = RankedTensorType::get({sourceType.getNumElements()}, sourceType.getElementType());
reshaped = tensor::CollapseShapeOp::create(
rewriter, reshapeOp.getLoc(), flatType, reshaped, getCollapseTo1DReassociation(sourceType.getRank()));
}
if (resultType.getRank() == 1)
return reshaped;
return tensor::ExpandShapeOp::create(
rewriter, reshapeOp.getLoc(), resultType, reshaped, getExpandFrom1DReassociation(resultType.getRank()))
.getResult();
});
return failure();
}

Some files were not shown because too many files have changed in this diff Show More