42 Commits

Author SHA1 Message Date
NiccoloN 41de3cb150 add memory coalescing pass
Validate Operations / validate-operations (push) Has been cancelled
better reports
refactor for more code-reuse and patter usage
fixes
2026-05-12 18:17:00 +02:00
NiccoloN 4f3570520c add pim.vmm verifier and fix vmm lowering
reuse code for subviews
2026-05-12 15:13:50 +02:00
NiccoloN 628dc630a4 compact syntax for spatial tensor ops
Validate Operations / validate-operations (push) Has been cancelled
better IR compaction after dcp merge
remove pim.mvm op
better memory report
2026-05-12 13:35:25 +02:00
NiccoloN 80a7298552 fix pool lowering
Validate Operations / validate-operations (push) Has been cancelled
better reports (dcp merge and memory)
2026-05-12 12:32:23 +02:00
ilgeco 8ad504fcdf Yolo splitted at conv boundary
Validate Operations / validate-operations (push) Has been cancelled
2026-05-12 11:33:15 +02:00
ilgeco e6f442c5d2 Merge branch 'main' of chef.heaplab.deib.polimi.it:nnicolosi/Raptor
Validate Operations / validate-operations (push) Has been cancelled
2026-05-12 10:43:01 +02:00
ilgeco f6b97b3813 Fix report memory 2026-05-12 10:42:38 +02:00
ilgeco 26317ea7d0 Shorter Memory Reporty 2026-05-12 10:38:35 +02:00
NiccoloN 909c4acfdd huge refactor for high RewritePatterns usage and less ad-hoc cpp code
Validate Operations / validate-operations (push) Has been cancelled
remove Spatial many ops in favor of tensor ops like in pim
2026-05-12 10:35:44 +02:00
ilgeco feaff820e1 pim-sim TraceTime + faer
Validate Operations / validate-operations (push) Has been cancelled
2026-05-11 18:19:30 +02:00
NiccoloN 1e279ae9bb minor fix
Validate Operations / validate-operations (push) Has been cancelled
2026-05-11 16:01:42 +02:00
NiccoloN 57f0cca8c0 remove duplicated code
Validate Operations / validate-operations (push) Has been cancelled
quieter validation scripts (with optional verbose flag)
2026-05-11 15:52:26 +02:00
NiccoloN 5ff364027b big cleanup: remove remaining pim many operations, simplify bufferization logic
Validate Operations / validate-operations (push) Has been cancelled
2026-05-11 14:38:13 +02:00
NiccoloN b1272d2283 fast pim bufferization using tensors
Validate Operations / validate-operations (push) Successful in 24m29s
2026-05-08 14:21:45 +02:00
NiccoloN 58e6587697 Merge remote-tracking branch 'origin/main' 2026-05-08 13:12:47 +02:00
NiccoloN f6c8cc4aa5 sightly better bufferization
minor fixes
2026-05-07 17:53:47 +02:00
ilgeco 566630b99a Removed SpatNopPattern
Validate Operations / validate-operations (push) Successful in 22m36s
2026-05-07 17:03:35 +02:00
ilgeco 74931ad75b Single Concat Fix 2026-05-07 16:47:01 +02:00
NiccoloN f2fe147961 compact pim IR
Validate Operations / validate-operations (push) Successful in 22m15s
2026-05-06 17:16:51 +02:00
ilgeco 7bb58e80de Merge branch 'main' of chef.heaplab.deib.polimi.it:nnicolosi/Raptor into main
Validate Operations / validate-operations (push) Successful in 24m25s
2026-05-06 12:25:29 +02:00
NiccoloN b2dc9c38b6 better spatial IR compaction with better custom syntax, scf.for and
Validate Operations / validate-operations (push) Has been cancelled
spat.map
2026-05-06 12:21:58 +02:00
ilgeco 3cb6a1abc5 Memory report 2026-05-06 10:47:04 +02:00
NiccoloN 285773fa55 rework actually broken dcp merge + compute re-batching (still to refine) 2026-05-04 19:30:40 +02:00
NiccoloN bdacb9871d fix dcp merge bug
Validate Operations / validate-operations (push) Failing after 15m54s
2026-05-04 15:58:14 +02:00
NiccoloN 5b9bb0c191 refactor spatial ops
Validate Operations / validate-operations (push) Successful in 24m55s
2026-05-04 14:19:30 +02:00
NiccoloN f789954ad7 Refactor ONNXToSpatial Common and diagnostics 2026-05-04 13:42:43 +02:00
ilgeco b6ba1e4fea Fix DCPTest using old constructor
Validate Operations / validate-operations (push) Successful in 24m15s
2026-05-04 10:58:51 +02:00
NiccoloN 717ad160cd Refactor PIM/Common (splitting in files, adding helpers, adding brief
Validate Operations / validate-operations (push) Failing after 18m36s
docs)
2026-05-04 09:20:43 +02:00
NiccoloN 905fa9f9a7 Merge remote changes
Validate Operations / validate-operations (push) Failing after 18m42s
2026-05-03 23:09:32 +02:00
NiccoloN 62b0a6e19d merge remote changes 2026-05-03 22:30:46 +02:00
NiccoloN b605585b1f compact spatial IR through different new operations and dedicated syntax
fast spatial node merging with batch operations
2026-05-03 14:14:14 +02:00
ilgeco 08b0fcd850 Parallel bufferization
Validate Operations / validate-operations (push) Successful in 21m49s
2026-04-30 11:48:17 +02:00
ilgeco 9dccc2c701 Translate global constant to symble 2026-04-28 12:42:01 +02:00
ilgeco 5c839e62c1 Func Input converted to symbol 2026-04-27 13:48:03 +02:00
NiccoloN 15e8edb9c4 better spat computes merging
Validate Operations / validate-operations (push) Successful in 21m14s
2026-04-25 19:24:09 +02:00
ilgeco 951baca106 Merge Node update fix comparison bug
Validate Operations / validate-operations (push) Successful in 20m21s
2026-04-23 19:52:16 +02:00
ilgeco fc5bccb487 Merge Node update status file
Validate Operations / validate-operations (push) Has started running
2026-04-23 19:42:56 +02:00
ilgeco 49dea15b95 DCP Merge status
Validate Operations / validate-operations (push) Successful in 22m29s
2026-04-23 18:40:33 +02:00
NiccoloN 5545b0f672 fix MatMul pattern non-contiguous extract_slices
Validate Operations / validate-operations (push) Successful in 22m31s
2026-04-23 14:44:30 +02:00
NiccoloN cff929a083 fix sigmoid implementation stability in pim-simulator
Validate Operations / validate-operations (push) Successful in 23m4s
2026-04-23 10:34:29 +02:00
NiccoloN 89b3501aa8 fix weightAlways attribute in spatial 2026-04-23 10:04:47 +02:00
NiccoloN 412ca957f6 multiple-output spat computes
Validate Operations / validate-operations (push) Successful in 22m38s
2026-04-23 09:28:57 +02:00
193 changed files with 17471 additions and 4307 deletions
+12
View File
@@ -1,5 +1,17 @@
.zed
.idea
**/.vscode
.claude
.codex
AGENTS.md
CMakeUserPresets.json
build
build_release
cmake-build-debug
cmake-build-release
compile.sh
**/__*
+158
View File
@@ -1,5 +1,163 @@
# Raptor
Raptor is a domain-specific MLIR compiler for neural networks (ONNX format)
targeting in-memory computing / processing-in-memory (PIM) architectures.
It progressively lowers ONNX-MLIR through a set of MLIR dialects down to
target-specific artifacts (currently JSON code for the `pimsim-nn` simulator).
## Overview
PIM architectures perform most of the computation directly in memory.
Raptor's first supported target is `pimsim-nn`, which simulates a chip with:
- a shared host memory,
- a number of cores that do most of the computation directly in their memory
(vector ops, vmm/mvm on ReRAM crossbars),
- no branching instructions (branchless architecture) and no hardware loop
support — any repeated work (e.g. convolutions) must be unrolled into
explicit per-iteration instructions.
Because of this, the amount of emitted instructions explodes quickly and the
compiler must optimize aggressively at every stage to keep compilation
tractable.
A second target, `PulPim`, is planned for an accelerator with RISC-V cores
each carrying its own in-memory computing unit and crossbars. It will live in
a dedicated dialect (future work).
### Targets and simulators
`pimsim-nn` (under `backend-simulators/pim/pimsim-nn`) is used for
**performance** estimates (latency, energy), but does not functionally execute
the JSON code it consumes. To validate the numerical correctness of the JSON
code produced by Raptor (or, for comparison, by the `pimcomp` compiler), we use
a Rust simulator we maintain in-tree at
`backend-simulators/pim/pim-simulator`.
## Compilation pipeline
The PIM-related sources live under `src/PIM` and the tests under `test/PIM`.
When working on this codebase, most changes should stay confined to those
trees (you only need to look outside, e.g. at `onnx-mlir` or `llvm`, for
framework-level details).
High-level lowering flow:
```
ONNX-MLIR ──► Spatial ──► Pim (tensor) ──► Pim (bufferized) ──► PIM JSON
```
1. **ONNX → Spatial** (`src/PIM/Conversion/ONNXToSpatial`).
Lowers ONNX ops into the `spat` dialect (`src/PIM/Dialect/Spatial`).
Spatial models a high-level spatial in-memory accelerator: vmm/mvm
operations are accelerated by storing a constant RHS matrix into a
crossbar. Crossbars cannot be re-programmed during execution, have a
limited fixed size, and there is a limited number of them per core.
Conversion patterns are split by op family under
`Conversion/ONNXToSpatial/Patterns/{Math,NN,Tensor}` (Conv, Gemm, MatMul,
Elementwise, ReduceMean, Pool, Relu, Sigmoid, Softmax, Concat, Gather,
Reshape, Resize, Split).
2. **Spatial → Pim** (`src/PIM/Conversion/SpatialToPim`).
Lowers Spatial to the `pim` dialect (`src/PIM/Dialect/Pim`), which
materializes PIM cores (`pim.core`), inter-core communication
(`pim.send` / `pim.receive`), halts, and crossbar-level operations.
3. **Merge compute nodes** (`src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes`).
A DCP-inspired heuristic (Dynamic Critical Path — see the original
scheduling paper by Kwok & Ahmad,
[DCP-eScience2007](https://clouds.cis.unimelb.edu.au/papers/DCP-eScience2007.pdf))
that coarsens the virtual node graph and decides how to group compute
nodes onto cores. Our implementation is only DCP-*inspired*: it is a
heuristic with different assumptions from the paper (different cost
model, constraints from crossbar capacity / core resources, and a
windowed coarsening loop instead of full-graph reprioritization). The
`dcp-critical-window-size` option controls how many lowest-slack virtual
nodes each coarsening iteration considers (0 = legacy full-graph
analysis). Related sources: `DCPGraph/DCPAnalysis.cpp`, `Graph.cpp/.hpp`,
`MergeComputeNodesPass.cpp`.
4. **Bufferization** (`src/PIM/Dialect/Pim/Transforms/Bufferization`).
Converts tensor-semantics PIM IR into memref-semantics PIM IR using the
standard MLIR `BufferizableOpInterface` machinery
(`OpBufferizationInterfaces.*`, `PimBufferization.td`).
5. **Static memory coalescing** (`src/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing`).
Conservatively reuses same-typed local memref allocations inside PIM cores
after bufferization and before code generation.
6. **PIM code generation** (`src/PIM/Pass/PimCodegen`):
- `HostConstantFolding` — folds host-side constants.
- `MaterializeHostConstantsPass` — materializes the remaining host
constants for emission.
- `VerificationPass` — checks invariants before emission.
- `EmitPimJsonPass` — emits the final PIM JSON consumed by `pimsim-nn`
and `pim-simulator`.
Supporting pieces:
- `src/PIM/Compiler` — PIM-specific compiler options (crossbar size/count,
core count, DCP window, experimental conv impl, concat error handling, …)
and `PimCodeGen` entry points.
- `src/PIM/Common` — shared utilities (`PimCommon`, `LabeledList`).
- `src/PIM/Pass` — auxiliary passes (`MessagePass`, `CountInstructionPass`)
and the `PIMPasses.h` registry used by `PimAccelerator`.
- `src/PIM/PimAccelerator.{cpp,hpp}` — accelerator entry point: registers
dialects, passes, and plugs Raptor into the ONNX-MLIR driver.
## Key compiler options
Pass these on the `onnx-mlir` command line when compiling for PIM:
- `--maccel=PIM` — select the PIM accelerator.
- `--EmitSpatial` / `--EmitPim` / `--EmitPimBufferized` / `--EmitPimCodegen`
— stop the pipeline at the requested stage (default: `EmitPimCodegen`).
- `--pim-only-codegen` — assume the input is already bufferized PIM IR and
run only the codegen tail.
- `--crossbar-size=<N>` / `--crossbar-count=<N>` — crossbar dimensions and
per-core count.
- `--core-count=<N>` — number of cores (`-1` picks the minimum).
- `--dcp-critical-window-size=<N>` — DCP coarsening window (0 = legacy).
- `--use-experimental-conv-impl` — alternative convolution lowering.
- `--ignore-concat-error` — soft-fail corner case in `ConcatOp`.
## Validation
Functional validation lives in `validation/` and drives the Rust
`pim-simulator` to compare Raptor's output against a reference.
Per-operation validation (from `validation/`):
```
validate.py \
--raptor-path ../cmake-build-release/Release/bin/onnx-mlir \
--onnx-include-dir ../onnx-mlir/include
```
End-to-end network validation (example: first 4 layers of YOLOv11n):
```
validate.py \
--raptor-path ../cmake-build-release/Release/bin/onnx-mlir \
--onnx-include-dir ../onnx-mlir/include \
--operations-dir ./networks/yolo11n/depth_04 \
--crossbar-size 2048 --crossbar-count 256 --core-count 1000
```
Available networks under `validation/networks/`: `vgg16`, `yolo11n`.
Available operations under `validation/operations/`: `add`, `conv`, `div`,
`gather`, `gemm`, `gemv`, `mul`, `pool`, `reduce_mean`, `relu`, `resize`,
`sigmoid`, `softmax`, `split`.
## Rebuilding
Release build (fast):
```
cmake --build /home/nico/raptor/raptor/cmake-build-release --target onnx-mlir -j 30
```
A slower debug build is also available — configure it the same way but with
`-DCMAKE_BUILD_TYPE=Debug` (see installation instructions below).
## Build
### Protobuf
File diff suppressed because it is too large Load Diff
@@ -13,8 +13,9 @@ name = "pimcore"
path = "src/lib/pimcore.rs"
[features]
default = ["tracing"]
default = []
tracing = []
profile_time = ["dep:plotly", "dep:comfy-table", "dep:statrs"]
@@ -27,3 +28,9 @@ hex = "0"
paste = "1"
serde = { version = "1", features = ["derive"] }
serde_json = "1"
statrs = {version="0.16", optional=true}
comfy-table = {version="7.1", optional=true}
plotly = {version="0.8", optional=true}
rayon = "1.12.0"
faer = "0.24.0"
faer-traits = "0.24.0"
@@ -1,14 +1,19 @@
use crate::{
cpu::{CPU, crossbar}, instruction_set::{
cpu::{CPU, crossbar},
instruction_set::{
Instruction, InstructionData, InstructionStatus, InstructionType, VectorBitWith,
helper::add_all,
}, memory_manager::{
},
memory_manager::{
MemoryStorable,
type_traits::{FromFloat, UpcastDestTraits, UpcastSlice},
}, tracing::TRACER, utility::{add_offset_r1, add_offset_r2, add_offset_rd}
},
tracing::TRACER,
utility::{add_offset_r1, add_offset_r2, add_offset_rd},
};
use aligned_vec::{AVec, ConstAlign};
use anyhow::{Context, Result, ensure};
use rayon::prelude::*;
use paste::paste;
use std::{borrow::Cow, cell::OnceCell, collections::HashMap};
@@ -76,8 +81,7 @@ pub fn functor_to_name(functor: usize) -> &'static str {
///////////////////////////////////////////////////////////////
/////////////////Scalar/register Instructions//////////////////
///////////////////////////////////////////////////////////////
pub fn sldi(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
{
pub fn sldi(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
TRACER.lock().unwrap().pre_sldi(cores, data);
let (core_indx, rd, imm) = data.get_core_rd_imm();
let core = cores.core(core_indx);
@@ -229,25 +233,30 @@ where
[F]: UpcastSlice<T> + UpcastSlice<M>,
[M]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
M: UpcastDestTraits<M> + MemoryStorable + FromFloat,
// Add faer::ComplexField HERE, directly bounding M for this function only
M: UpcastDestTraits<M> + MemoryStorable + FromFloat + faer_traits::ComplexField,
F: UpcastDestTraits<F> + MemoryStorable,
{
TRACER.lock().unwrap().pre_mvm::<F,M,T>(cores, data);
TRACER.lock().unwrap().pre_mvm::<F, M, T>(cores, data);
let (core_indx, rd, r1, mbiw, relu, group) = data.get_core_rd_r1_mbiw_immrelu_immgroup();
let group: usize = group.try_into().context("group can not be negative")?;
let core = cores.core(core_indx);
let r1_val = core.register(r1);
let rd_val = core.register(rd);
let (memory, crossbars) = core.get_memory_crossbar();
let crossbar = crossbars.get_mut(group).unwrap();
let crossbar_stored_bytes = crossbar.stored_bytes();
let crossbar_byte_width = crossbar.width();
//Fix this
let crossbar_elem_width = crossbar_byte_width / size_of::<M>();
ensure!(
crossbar_byte_width & size_of::<M>() == 0,
crossbar_byte_width % size_of::<M>() == 0,
"M not divisor of the crosbbar size"
);
let crossbar_height = crossbar.height();
let crossbar_byte_size = crossbar_byte_width * crossbar_height;
@@ -257,19 +266,29 @@ where
let load = loads[0];
let vec: Cow<[M]> = load.up();
let matrix = crossbar.load::<M>(crossbar_byte_size)?[0];
let mut res = Vec::with_capacity(crossbar_elem_width);
let mut partial :AVec<M, _> = AVec::<M, ConstAlign<64>>::with_capacity(64, vec.len());
partial.resize(vec.len(), M::from_f32(0.0));
for x in 0..crossbar_elem_width {
partial[0] = vec[0] * matrix[x];
for y in 1..crossbar_height {
partial[y] = vec[y] * matrix[y * crossbar_elem_width + x];
}
// --- FAER IMPLEMENTATION ---
// 1. Explicitly create a Matrix Reference (MatRef)
let matrix_view = faer::mat::MatRef::from_row_major_slice(
matrix.as_ref(),
crossbar_height,
crossbar_elem_width,
);
// 2. Explicitly create a Column Vector Reference (ColRef)
// Using `ColRef` here guarantees we don't accidentally get a RowRef (Fixes E0277)
let vec_view = faer::col::ColRef::from_slice(vec.as_ref());
let res_col: faer::col::Col<M> = matrix_view.transpose() * vec_view;
// 4. Convert back to standard Rust Vec
// try_as_slice() returns an Option<&[M]>.
// We can safely unwrap() because a freshly allocated, owned Col is ALWAYS contiguous!
let mut res: Vec<M> = (0..crossbar_elem_width).map(|i| res_col[i]).collect();
// --- END FAER ---
let mut acc = add_all(partial.as_slice());
res.push(acc);
}
if relu != 0 {
res.iter_mut().for_each(|x| {
if *x < M::from_f32(0.0) {
@@ -277,13 +296,16 @@ where
}
});
}
ensure!(
res.len() == crossbar_elem_width,
"mvm generate a vector bigger thant it's requested elements"
"mvm generate a vector bigger thant it's requested elements"
);
let res_up: Cow<[T]> = res.as_slice().up();
core.execute_store(rd_val, res_up.as_ref());
TRACER.lock().unwrap().post_mvm::<F,M,T>(cores, data);
TRACER.lock().unwrap().post_mvm::<F, M, T>(cores, data);
Ok(InstructionStatus::Completed)
}
@@ -317,7 +339,7 @@ where
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
TRACER.lock().unwrap().pre_vvadd::<F,T>(cores, data);
TRACER.lock().unwrap().pre_vvadd::<F, T>(cores, data);
let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) =
data.get_core_rd_r1_r2_immlen_offset();
let core = cores.core(core_indx);
@@ -345,7 +367,7 @@ where
);
let res_up: Cow<[T]> = res.as_slice().up();
core.execute_store(rd_val, res_up.as_ref());
TRACER.lock().unwrap().post_vvadd::<F,T>(cores, data);
TRACER.lock().unwrap().post_vvadd::<F, T>(cores, data);
Ok(InstructionStatus::Completed)
}
@@ -359,7 +381,7 @@ where
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
TRACER.lock().unwrap().pre_vvsub::<F,T>(cores, data);
TRACER.lock().unwrap().pre_vvsub::<F, T>(cores, data);
let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) =
data.get_core_rd_r1_r2_immlen_offset();
let core = cores.core(core_indx);
@@ -400,7 +422,7 @@ where
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
TRACER.lock().unwrap().pre_vvmul::<F,T>(cores, data);
TRACER.lock().unwrap().pre_vvmul::<F, T>(cores, data);
let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) =
data.get_core_rd_r1_r2_immlen_offset();
let core = cores.core(core_indx);
@@ -440,7 +462,7 @@ where
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
TRACER.lock().unwrap().pre_vvdmul::<F,T>(cores, data);
TRACER.lock().unwrap().pre_vvdmul::<F, T>(cores, data);
let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) =
data.get_core_rd_r1_r2_immlen_offset();
let core = cores.core(core_indx);
@@ -476,7 +498,7 @@ where
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
TRACER.lock().unwrap().pre_vvmax::<F,T>(cores, data);
TRACER.lock().unwrap().pre_vvmax::<F, T>(cores, data);
let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) =
data.get_core_rd_r1_r2_immlen_offset();
let core = cores.core(core_indx);
@@ -525,7 +547,7 @@ where
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
TRACER.lock().unwrap().pre_vavg::<F,T>(cores, data);
TRACER.lock().unwrap().pre_vavg::<F, T>(cores, data);
let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) =
data.get_core_rd_r1_r2_immlen_offset();
let core = cores.core(core_indx);
@@ -533,7 +555,10 @@ where
let r2_val = r2;
ensure!(r2_val == 1, "Stride different than 1 not supported");
let rd_val = core.register(rd);
ensure!(offset_select == 1, "Offset select cannot be different from 1");
ensure!(
offset_select == 1,
"Offset select cannot be different from 1"
);
let r1_val = add_offset_r1(r1_val, offset_select, offset_value);
let loads = core.reserve_load(r1_val, imm_len)?.execute_load::<F>()?;
let load1 = loads[0];
@@ -555,7 +580,7 @@ where
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
{
TRACER.lock().unwrap().pre_vrelu::<F,T>(cores, data);
TRACER.lock().unwrap().pre_vrelu::<F, T>(cores, data);
let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) =
data.get_core_rd_r1_r2_immlen_offset();
let core = cores.core(core_indx);
@@ -585,7 +610,7 @@ where
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
{
TRACER.lock().unwrap().pre_vtanh::<F,T>(cores, data);
TRACER.lock().unwrap().pre_vtanh::<F, T>(cores, data);
let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) =
data.get_core_rd_r1_r2_immlen_offset();
let core = cores.core(core_indx);
@@ -613,7 +638,7 @@ where
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
{
TRACER.lock().unwrap().pre_vsigm::<F,T>(cores, data);
TRACER.lock().unwrap().pre_vsigm::<F, T>(cores, data);
let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) =
data.get_core_rd_r1_r2_immlen_offset();
let core = cores.core(core_indx);
@@ -633,13 +658,16 @@ pub fn vsoftmax(cores: &mut CPU, data: InstructionData) -> Result<InstructionSta
panic!("You are calling a placeholder, the real call is the generic version");
}
pub(super) fn vsoftmax_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
pub(super) fn vsoftmax_impl<F, T>(
cores: &mut CPU,
data: InstructionData,
) -> Result<InstructionStatus>
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
{
TRACER.lock().unwrap().pre_vsoftmax::<F,T>(cores, data);
TRACER.lock().unwrap().pre_vsoftmax::<F, T>(cores, data);
let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) =
data.get_core_rd_r1_r2_immlen_offset();
let core = cores.core(core_indx);
@@ -656,16 +684,15 @@ where
.reduce(|a, b| if a > b { a } else { b })
.unwrap();
let exp_values: Vec<F> = load1.iter().map(|&a| (a - max_val).exp()).collect();
let sum = exp_values
.iter()
.copied()
.reduce(|a, b| a + b)
.unwrap();
ensure!(sum > 0.0.into(), "vsoftmax normalization sum must be positive");
let sum = exp_values.iter().copied().reduce(|a, b| a + b).unwrap();
ensure!(
sum > 0.0.into(),
"vsoftmax normalization sum must be positive"
);
let res: Vec<F> = exp_values.iter().map(|&a| a / sum).collect();
let res_up: Cow<[T]> = res.as_slice().up();
core.execute_store(rd_val, res_up.as_ref());
TRACER.lock().unwrap().post_vsoftmax::<F,T>(cores, data);
TRACER.lock().unwrap().post_vsoftmax::<F, T>(cores, data);
Ok(InstructionStatus::Completed)
}
@@ -749,12 +776,10 @@ pub fn lmv(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
}
pub fn send(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
TRACER.lock().unwrap().pre_send(cores, data);
Ok(InstructionStatus::Sending(data))
}
pub fn recv(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
TRACER.lock().unwrap().pre_recv(cores, data);
Ok(InstructionStatus::Reciving(data))
}
@@ -15,9 +15,9 @@ use crate::{
};
pub fn json_to_executor<'a>(
pub fn json_to_executor<'a, 'b>(
config: Value,
mut cores: impl Iterator<Item = &'a Value>,
mut cores: impl Iterator<Item = &'b Value>,
crossbars : Vec<Vec<&'a Crossbar>>
) -> Executable<'a> {
let cell_precision = config.get("cell_precision").unwrap().as_i64().unwrap() as i32;
@@ -55,15 +55,23 @@ pub trait HasSigm {
impl HasSigm for f32 {
fn sigm(self) -> Self {
let ex = self.exp();
ex / (1.0 + ex)
if self >= 0.0 {
1.0 / (1.0 + (-self).exp())
} else {
let ex = self.exp();
ex / (1.0 + ex)
}
}
}
impl HasSigm for f64 {
fn sigm(self) -> Self {
let ex = self.exp();
ex / (1.0 + ex)
if self >= 0.0 {
1.0 / (1.0 + (-self).exp())
} else {
let ex = self.exp();
ex / (1.0 + ex)
}
}
}
@@ -169,6 +169,9 @@ impl<'a> Executable<'a> {
}
}
print_status(cores_instructions);
#[cfg(feature = "profile_time")]
TRACER.lock().unwrap().report();
}
pub fn cpu(&self) -> &CPU<'a> {
@@ -58,6 +58,20 @@ where 'a : 'b
&& sender.internal_core == receiver.external_core
&& receiver.internal_core == sender.external_core
{
{
let sender = &mut core_instructions[sender.internal_core];
let pc = sender.program_counter;
let inst = sender.instructions.get(pc).unwrap();
let data = inst.data;
TRACER.lock().unwrap().pre_send(cpu, data);
}
{
let recv = &mut core_instructions[receiver.internal_core];
let pc = recv.program_counter;
let inst = recv.instructions.get(pc).unwrap();
let data = inst.data;
TRACER.lock().unwrap().pre_recv(cpu, data);
}
let [sender_core, reciver_core] =
cpu.get_multiple_cores([sender.internal_core, receiver.internal_core]);
let memory = sender_core
@@ -13,7 +13,7 @@ use crate::{
};
use std::io::Write;
#[cfg(not(feature = "tracing"))]
#[cfg(not(any(feature = "tracing", feature = "profile_time")))]
impl Trace {
///////////////////////////////////////////////////////////////
/////////////////Scalar/register Instructions//////////////////
@@ -1,52 +1,32 @@
mod tracing_isa;
mod disable;
mod pretty_print;
use std::{fs::File, path::{ PathBuf}};
#[cfg(feature = "profile_time")]
mod profile;
#[cfg(feature = "profile_time")]
use profile::Trace;
#[cfg(feature = "tracing")]
mod trace;
#[cfg(feature = "tracing")]
use trace::Trace;
use crate::Executable;
#[cfg(not(any(feature = "tracing", feature = "profile_time")))]
use std::path::PathBuf;
use std::sync::{LazyLock, Mutex};
use crate::Executable;
#[cfg(not(any(feature = "tracing", feature = "profile_time")))]
pub struct Trace {}
#[cfg(feature = "tracing")]
pub struct Trace {
out_files : Vec<File>
}
#[cfg(feature = "tracing")]
#[cfg(not(any(feature = "tracing", feature = "profile_time")))]
impl Trace {
fn new() -> Self {
Self { out_files : Vec::new()}
Self {}
}
pub fn init(&mut self, num_core : usize , mut path : PathBuf) {
path.pop();
for i in 0..num_core {
path.push(format!("TraceCore{}", i));
let file = File::create(&path).expect("Can not create file");
self.out_files.push(file);
path.pop();
}
}
}
#[cfg(not(feature = "tracing"))]
pub struct Trace {
pub fn init(&mut self, num_core: usize, path: PathBuf) {}
}
#[cfg(not(feature = "tracing"))]
impl Trace {
fn new() -> Self {
Self { }
}
pub fn init(&mut self, num_core : usize, path : PathBuf ) {
}
}
pub static TRACER: LazyLock<Mutex<Trace>> = LazyLock::new(|| { Trace::new().into()});
pub static TRACER: LazyLock<Mutex<Trace>> = LazyLock::new(|| Trace::new().into());
@@ -0,0 +1,73 @@
use std::{collections::HashMap, path::PathBuf, time::Instant};
use crate::tracing::profile::profile_analysis::{
analyze_timings, generate_interactive_report, print_textual_report,
};
pub mod profile_analysis;
pub mod profile_isa;
pub struct Trace {
instruction_times: HashMap<String, Vec<(u128,u128)>>,
core_start_time: HashMap<usize, Option<Instant>>,
start_time: Instant,
}
impl Trace {
pub fn new() -> Self {
let mut instruction_times = HashMap::new();
instruction_times.insert("sldi".to_string(), Vec::with_capacity(20000));
instruction_times.insert("sld".to_string(), Vec::with_capacity(20000));
instruction_times.insert("sadd".to_string(), Vec::with_capacity(20000));
instruction_times.insert("ssub".to_string(), Vec::with_capacity(20000));
instruction_times.insert("smul".to_string(), Vec::with_capacity(20000));
instruction_times.insert("saddi".to_string(), Vec::with_capacity(20000));
instruction_times.insert("smuli".to_string(), Vec::with_capacity(20000));
instruction_times.insert("setbw".to_string(), Vec::with_capacity(20000));
instruction_times.insert("mvmul".to_string(), Vec::with_capacity(20000));
instruction_times.insert("vvadd".to_string(), Vec::with_capacity(20000));
instruction_times.insert("vvsub".to_string(), Vec::with_capacity(20000));
instruction_times.insert("vvmul".to_string(), Vec::with_capacity(20000));
instruction_times.insert("vvdmul".to_string(), Vec::with_capacity(20000));
instruction_times.insert("vvmax".to_string(), Vec::with_capacity(20000));
instruction_times.insert("vvsll".to_string(), Vec::with_capacity(20000));
instruction_times.insert("vvsra".to_string(), Vec::with_capacity(20000));
instruction_times.insert("vavg".to_string(), Vec::with_capacity(20000));
instruction_times.insert("vrelu".to_string(), Vec::with_capacity(20000));
instruction_times.insert("vtanh".to_string(), Vec::with_capacity(20000));
instruction_times.insert("vsigm".to_string(), Vec::with_capacity(20000));
instruction_times.insert("vsoftmax".to_string(), Vec::with_capacity(20000));
instruction_times.insert("vmv".to_string(), Vec::with_capacity(20000));
instruction_times.insert("vrsu".to_string(), Vec::with_capacity(20000));
instruction_times.insert("vrsl".to_string(), Vec::with_capacity(20000));
instruction_times.insert("ld".to_string(), Vec::with_capacity(20000));
instruction_times.insert("st".to_string(), Vec::with_capacity(20000));
instruction_times.insert("lldi".to_string(), Vec::with_capacity(20000));
instruction_times.insert("lmv".to_string(), Vec::with_capacity(20000));
instruction_times.insert("send".to_string(), Vec::with_capacity(20000));
instruction_times.insert("recv".to_string(), Vec::with_capacity(20000));
instruction_times.insert("wait".to_string(), Vec::with_capacity(20000));
instruction_times.insert("sync".to_string(), Vec::with_capacity(20000));
Self {
instruction_times,
core_start_time: HashMap::new(),
start_time: Instant::now()
}
}
pub fn init(&mut self, num_core: usize, path: PathBuf) {
for i in 0..num_core {
self.core_start_time.insert(i, None);
}
}
pub fn report(&self) {
let res = analyze_timings(&self.instruction_times);
print_textual_report(&res);
generate_interactive_report(
&self.instruction_times,
&["mvmul", "recv"],
"/tmp/report.html",
);
}
}
@@ -0,0 +1,192 @@
use comfy_table::{Cell, Table, modifiers::UTF8_ROUND_CORNERS, presets::UTF8_FULL};
use statrs::statistics::{Data, Distribution, Max, Min, OrderStatistics};
use std::collections::HashMap;
#[derive(Debug)]
pub struct InstructionStats {
pub name: String,
pub count: usize,
pub total_time: u128,
pub min: f64,
pub max: f64,
pub mean: f64,
pub median: f64,
pub std_dev: f64,
pub cv: f64,
pub p95: f64,
pub p99: f64,
pub skewness: f64,
pub kurtosis: f64,
}
fn format_time(ns: f64) -> String {
if ns.is_nan() {
return "NaN".to_string();
}
if ns >= 1_000_000_000.0 {
format!("{:.2} s", ns / 1_000_000_000.0)
} else if ns >= 1_000_000.0 {
format!("{:.2} ms", ns / 1_000_000.0)
} else if ns >= 1_000.0 {
format!("{:.2} µs", ns / 1_000.0)
} else {
format!("{:.2} ns", ns)
}
}
fn calculate_skewness_kurtosis(times: &[f64], mean: f64, std_dev: f64) -> (f64, f64) {
let n = times.len() as f64;
if n < 4.0 || std_dev == 0.0 {
return (f64::NAN, f64::NAN);
}
let mut sum_m3 = 0.0;
let mut sum_m4 = 0.0;
for &x in times {
let deviation = x - mean;
sum_m3 += deviation.powi(3);
sum_m4 += deviation.powi(4);
}
let m3 = sum_m3 / n;
let m4 = sum_m4 / n;
let skewness = m3 / std_dev.powi(3);
let kurtosis = (m4 / std_dev.powi(4)) - 3.0;
(skewness, kurtosis)
}
pub fn analyze_timings(timings: &HashMap<String, Vec<(u128, u128)>>) -> Vec<InstructionStats> {
let mut results = Vec::new();
for (instruction, times) in timings {
let count = times.len();
if count == 0 {
continue;
}
// Extract ONLY the duration (the second element of the tuple) for stats
let durations: Vec<u128> = times.iter().map(|&(_, duration)| duration).collect();
let total_time: u128 = durations.iter().sum();
let f64_times: Vec<f64> = durations.iter().map(|&t| t as f64).collect();
let mut data = Data::new(f64_times.clone());
let mean = data.mean().unwrap_or(0.0);
let std_dev = data.std_dev().unwrap_or(0.0);
let cv = if mean > 0.0 { std_dev / mean } else { 0.0 };
let (skewness, kurtosis) = calculate_skewness_kurtosis(&f64_times, mean, std_dev);
results.push(InstructionStats {
name: instruction.clone(),
count,
total_time,
min: data.min(),
max: data.max(),
mean,
median: data.median(),
std_dev,
cv,
p95: data.percentile(95),
p99: data.percentile(99),
skewness,
kurtosis,
});
}
results.sort_by(|a, b| b.mean.partial_cmp(&a.mean).unwrap());
results
}
pub fn print_textual_report(stats: &[InstructionStats]) {
let mut table = Table::new();
table
.load_preset(UTF8_FULL)
.apply_modifier(UTF8_ROUND_CORNERS)
.set_header(vec![
"Instruction",
"Count",
"Total Time",
"Mean",
"Median",
"Min",
"Max",
"P95",
"P99",
"StdDev",
"CV",
"Skewness",
"Kurtosis",
]);
for stat in stats {
table.add_row(vec![
Cell::new(&stat.name),
Cell::new(stat.count.to_string()),
Cell::new(format_time(stat.total_time as f64)), // Cast u128 to f64 for formatting
Cell::new(format_time(stat.mean)),
Cell::new(format_time(stat.median)),
Cell::new(format_time(stat.min)),
Cell::new(format_time(stat.max)),
Cell::new(format_time(stat.p95)),
Cell::new(format_time(stat.p99)),
Cell::new(format_time(stat.std_dev)),
Cell::new(format!("{:.3}", stat.cv)),
Cell::new(format!("{:.2}", stat.skewness)),
Cell::new(format!("{:.2}", stat.kurtosis)),
]);
}
println!("{table}");
}
pub fn generate_interactive_report(
timings: &HashMap<String, Vec<(u128, u128)>>,
instructions_to_plot: &[&str], // <-- NEW: Only plot these
file_path: &str,
) {
use plotly::common::{Mode, Marker, Line};
use plotly::layout::{Axis, Layout};
use plotly::{Plot, Scatter};
use std::collections::HashMap;
let mut plot = Plot::new();
for &instruction_name in instructions_to_plot {
// Only proceed if the instruction exists in our timings map
if let Some(times) = timings.get(instruction_name) {
let x_axis: Vec<f64> = times.iter().map(|&(ts, _)| ts as f64).collect();
let y_axis: Vec<f64> = times.iter().map(|&(_, dur)| dur as f64).collect();
let text_array: Vec<String> = times.iter()
.map(|&(_, dur)| format_time(dur as f64))
.collect();
let trace = Scatter::new(x_axis, y_axis)
.name(instruction_name)
.mode(Mode::LinesMarkers)
.marker(Marker::new().size(4).opacity(0.6))
.line(Line::new().width(1.0))
.text_array(text_array)
.hover_info(plotly::common::HoverInfo::All);
plot.add_trace(trace);
}
}
let layout = Layout::new()
.title(plotly::common::Title::new("Simulator Timeline: Top Offenders"))
.x_axis(Axis::new().title(plotly::common::Title::new("Absolute Time (ns)")))
.y_axis(Axis::new().title(plotly::common::Title::new("Execution Duration")));
plot.set_layout(layout);
plot.write_html(file_path);
println!("🌐 Interactive timeline saved to {}", file_path);
}
@@ -0,0 +1,364 @@
use crate::{
cpu::CPU,
instruction_set::instruction_data::InstructionData,
memory_manager::{
MemoryStorable,
type_traits::{FromFloat, UpcastDestTraits, UpcastSlice},
},
tracing::Trace,
utility::{add_offset_r1, add_offset_rd},
};
use std::io::Write;
use std::time::Instant;
#[cfg(feature = "profile_time")]
impl Trace {
///////////////////////////////////////////////////////////////
/////////////////Scalar/register Instructions//////////////////
///////////////////////////////////////////////////////////////
fn pre_impl(&mut self, cores: &mut CPU, data: InstructionData) {
let (core_indx, rd, imm) = data.get_core_rd_imm();
let core_indx = core_indx as usize;
if self.core_start_time.get(&core_indx).unwrap().is_none() {
self.core_start_time.insert(core_indx, Some(Instant::now()));
}
}
fn post_impl(&mut self, cores: &mut CPU, data: InstructionData, name: &'static str) {
let (core_indx, rd, imm) = data.get_core_rd_imm();
let core_indx = core_indx as usize;
let Self {
instruction_times,
core_start_time,
start_time,
} = self;
let now = Instant::now();
instruction_times
.get_mut(name)
.unwrap()
.push((now.duration_since(*start_time).as_nanos(), now.duration_since(core_start_time[&core_indx].unwrap()).as_nanos()));
self.core_start_time.insert(core_indx, None);
}
pub fn pre_sldi(&mut self, cores: &mut CPU, data: InstructionData) {
self.pre_impl(cores, data);
}
pub fn post_sldi(&mut self, cores: &mut CPU, data: InstructionData) {
self.post_impl(cores, data, "sldi");
}
pub fn pre_sld(&mut self, cores: &mut CPU, data: InstructionData) {
self.pre_impl(cores, data);
}
pub fn post_sld(&mut self, cores: &mut CPU, data: InstructionData) {
self.post_impl(cores, data, "sld");
}
pub fn pre_sadd(&mut self, cores: &mut CPU, data: InstructionData) {
self.pre_impl(cores, data);
}
pub fn post_sadd(&mut self, cores: &mut CPU, data: InstructionData) {
self.post_impl(cores, data, "sadd");
}
pub fn pre_ssub(&mut self, cores: &mut CPU, data: InstructionData) {
self.pre_impl(cores, data);
}
pub fn post_ssub(&mut self, cores: &mut CPU, data: InstructionData) {
self.post_impl(cores, data, "ssub");
}
pub fn pre_smul(&mut self, cores: &mut CPU, data: InstructionData) {
self.pre_impl(cores, data);
}
pub fn post_smul(&mut self, cores: &mut CPU, data: InstructionData) {
self.post_impl(cores, data, "smul");
}
pub fn pre_saddi(&mut self, cores: &mut CPU, data: InstructionData) {
self.pre_impl(cores, data);
}
pub fn post_saddi(&mut self, cores: &mut CPU, data: InstructionData) {
self.post_impl(cores, data, "saddi");
}
pub fn pre_smuli(&mut self, cores: &mut CPU, data: InstructionData) {
self.pre_impl(cores, data);
}
pub fn post_smuli(&mut self, cores: &mut CPU, data: InstructionData) {
self.post_impl(cores, data, "smuli");
}
/////////////////////////////////////////////////////////////////
///////////////////Matrix/vector Instructions////////////////////
/////////////////////////////////////////////////////////////////
pub fn pre_setbw(&mut self, cores: &mut CPU, data: InstructionData) {
self.pre_impl(cores, data);
}
pub fn post_setbw(&mut self, cores: &mut CPU, data: InstructionData) {
self.post_impl(cores, data, "setbw");
}
pub fn pre_mvm<F, M, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T> + UpcastSlice<M>,
[M]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
M: UpcastDestTraits<M> + MemoryStorable + FromFloat,
F: UpcastDestTraits<F> + MemoryStorable,
{
self.pre_impl(cores, data);
}
pub fn post_mvm<F, M, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T> + UpcastSlice<M>,
[M]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
M: UpcastDestTraits<M> + MemoryStorable + FromFloat,
F: UpcastDestTraits<F> + MemoryStorable,
{
self.post_impl(cores, data, "mvmul");
}
pub fn pre_vvadd<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
self.pre_impl(cores, data);
}
pub fn post_vvadd<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
self.post_impl(cores, data, "vvadd");
}
pub fn pre_vvsub<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
self.pre_impl(cores, data);
}
pub fn post_vvsub<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
self.post_impl(cores, data, "vvsub");
}
pub fn pre_vvmul<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
self.pre_impl(cores, data);
}
pub fn post_vvmul<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
self.post_impl(cores, data, "vvmul");
}
pub fn pre_vvdmul<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
self.pre_impl(cores, data);
}
pub fn post_vvdmul<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
self.post_impl(cores, data, "vvdmul");
}
pub fn pre_vvmax<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
self.pre_impl(cores, data);
}
pub fn post_vvmax<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
self.post_impl(cores, data, "vvmax");
}
pub fn pre_vavg<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
self.pre_impl(cores, data);
}
pub fn post_vavg<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
self.post_impl(cores, data, "vavg");
}
pub fn pre_vrelu<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
{
self.pre_impl(cores, data);
}
pub fn post_vrelu<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
{
self.post_impl(cores, data, "vrelu");
}
pub fn pre_vtanh<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
{
self.pre_impl(cores, data);
}
pub fn post_vtanh<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
{
self.post_impl(cores, data, "vtanh");
}
pub fn pre_vsigm<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
{
self.pre_impl(cores, data);
}
pub fn post_vsigm<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
{
self.post_impl(cores, data, "vsigm");
}
pub fn pre_vsoftmax<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
{
self.pre_impl(cores, data);
}
pub fn post_vsoftmax<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
{
self.post_impl(cores, data, "vsoftmax");
}
/////////////////////////////////////////////////////////////////
/////Communication/synchronization Instructions/////////////////
/////////////////////////////////////////////////////////////////
pub fn pre_ld(&mut self, cores: &mut CPU, data: InstructionData) {
self.pre_impl(cores, data);
}
pub fn post_ld(&mut self, cores: &mut CPU, data: InstructionData) {
self.post_impl(cores, data, "ld");
}
pub fn pre_st(&mut self, cores: &mut CPU, data: InstructionData) {
self.pre_impl(cores, data);
}
pub fn post_st(&mut self, cores: &mut CPU, data: InstructionData) {
self.post_impl(cores, data, "st");
}
pub fn pre_lldi(&mut self, cores: &mut CPU, data: InstructionData) {
self.pre_impl(cores, data);
}
pub fn post_lldi(&mut self, cores: &mut CPU, data: InstructionData) {
self.post_impl(cores, data, "lldi");
}
pub fn pre_lmv(&mut self, cores: &mut CPU, data: InstructionData) {
self.pre_impl(cores, data);
}
pub fn post_lmv(&mut self, cores: &mut CPU, data: InstructionData) {
self.post_impl(cores, data, "lmv");
}
pub fn pre_send(&mut self, cores: &mut CPU, data: InstructionData) {
self.pre_impl(cores, data);
}
pub fn post_send(&mut self, cores: &mut CPU, data: InstructionData) {
self.post_impl(cores, data, "send");
}
pub fn pre_recv(&mut self, cores: &mut CPU, data: InstructionData) {
self.pre_impl(cores, data);
}
pub fn post_recv(&mut self, cores: &mut CPU, data: InstructionData) {
self.post_impl(cores, data, "recv");
}
}
@@ -0,0 +1,28 @@
use std::{fs::File, path::PathBuf};
pub mod pretty_print;
pub mod tracing_isa;
pub struct Trace {
out_files: Vec<File>,
}
impl Trace {
pub fn new() -> Self {
Self {
out_files: Vec::new(),
}
}
pub fn init(&mut self, num_core: usize, mut path: PathBuf) {
path.pop();
for i in 0..num_core {
path.push(format!("TraceCore{}", i));
let file = File::create(&path).expect("Can not create file");
self.out_files.push(file);
path.pop();
}
}
}
@@ -1,4 +1,4 @@
use crate::tracing::pretty_print;
use crate::{tracing::trace::pretty_print, utility::add_offset_r2};
use std::fs::File;
use crate::{
@@ -13,7 +13,6 @@ use crate::{
};
use std::io::Write;
#[cfg(feature = "tracing")]
impl Trace {
///////////////////////////////////////////////////////////////
/////////////////Scalar/register Instructions//////////////////
@@ -284,7 +283,6 @@ impl Trace {
M: UpcastDestTraits<M> + MemoryStorable + FromFloat,
F: UpcastDestTraits<F> + MemoryStorable,
{
use crate::tracing::pretty_print;
let (core_indx, rd, r1, mbiw, relu, group) = data.get_core_rd_r1_mbiw_immrelu_immgroup();
let file: &mut File = self
@@ -358,8 +356,6 @@ impl Trace {
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
use crate::{tracing::pretty_print, utility::add_offset_r2};
let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) =
data.get_core_rd_r1_r2_immlen_offset();
let file: &mut File = self
@@ -990,8 +986,6 @@ impl Trace {
/////////////////////////////////////////////////////////////////
pub fn ld_impl(&mut self, cores: &mut CPU, data: InstructionData, prefix: &'static str) {
use crate::tracing::pretty_print;
let (core, rd, r1, _, imm_len, offset_select, offset_value) =
data.get_core_rd_r1_r2_immlen_offset();
let file: &mut File = self
@@ -1044,8 +1038,6 @@ impl Trace {
}
pub fn st_impl(&mut self, cores: &mut CPU, data: InstructionData, prefix: &'static str) {
use crate::tracing::pretty_print;
let (core, rd, r1, _, imm_len, offset_select, offset_value) =
data.get_core_rd_r1_r2_immlen_offset();
let file: &mut File = self
@@ -1138,7 +1130,6 @@ impl Trace {
}
fn lmv_impl (&mut self, cores: &mut CPU, data: InstructionData, prefix: &'static str) {
use crate::tracing::pretty_print;
let (core, rd, r1, _, imm_len, offset_select, offset_value) =
data.get_core_rd_r1_r2_immlen_offset();
+1
View File
@@ -68,5 +68,6 @@ add_pim_library(OMPIMAccel
OMSpatialToPim
OMPimCommon
OMPimBufferization
OMPimStaticMemoryCoalescing
MLIRTensorInferTypeOpInterfaceImpl
)
+10 -1
View File
@@ -1,5 +1,14 @@
add_pim_library(OMPimCommon
PimCommon.cpp
IR/AddressAnalysis.cpp
IR/CoreBlockUtils.cpp
IR/EntryPointUtils.cpp
IR/ShapeUtils.cpp
IR/SubviewUtils.cpp
IR/WeightUtils.cpp
Support/DebugDump.cpp
Support/Diagnostics.cpp
Support/FileSystemUtils.cpp
Support/ReportUtils.cpp
EXCLUDE_FROM_OM_LIBS
+259
View File
@@ -0,0 +1,259 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
namespace onnx_mlir {
mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp) {
if (!moduleOp || !getGlobalOp)
return {};
return moduleOp.lookupSymbol<mlir::memref::GlobalOp>(getGlobalOp.getName());
}
namespace {
mlir::Value resolveAlias(mlir::Value value, const StaticValueKnowledge* knowledge) {
if (!knowledge)
return value;
auto iter = knowledge->aliases.find(value);
while (iter != knowledge->aliases.end()) {
value = iter->second;
iter = knowledge->aliases.find(value);
}
return value;
}
mlir::Value resolveLoopCarriedAliasImpl(mlir::Value value, const StaticValueKnowledge* knowledge) {
value = resolveAlias(value, knowledge);
if (mlir::isa<mlir::BlockArgument>(value))
return value;
mlir::Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return value;
if (auto dpsDefiningOp = mlir::dyn_cast<mlir::DestinationStyleOpInterface>(definingOp)) {
if (auto result = mlir::dyn_cast<mlir::OpResult>(value))
if (mlir::OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(result))
return resolveLoopCarriedAliasImpl(tiedOperand->get(), knowledge);
}
if (auto castOp = mlir::dyn_cast<mlir::memref::CastOp>(definingOp))
return resolveLoopCarriedAliasImpl(castOp.getSource(), knowledge);
if (auto collapseOp = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(definingOp))
return resolveLoopCarriedAliasImpl(collapseOp.getSrc(), knowledge);
if (auto expandOp = mlir::dyn_cast<mlir::memref::ExpandShapeOp>(definingOp))
return resolveLoopCarriedAliasImpl(expandOp.getSrc(), knowledge);
return value;
}
llvm::FailureOr<int64_t> resolveOpFoldResult(mlir::OpFoldResult ofr, const StaticValueKnowledge* knowledge);
llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticValueKnowledge* knowledge) {
value = resolveAlias(value, knowledge);
if (knowledge) {
auto iter = knowledge->indexValues.find(value);
if (iter != knowledge->indexValues.end())
return iter->second;
}
auto constantOp = value.getDefiningOp<mlir::arith::ConstantOp>();
if (constantOp) {
if (auto integerAttr = mlir::dyn_cast<mlir::IntegerAttr>(constantOp.getValue()))
return integerAttr.getInt();
}
mlir::Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return mlir::failure();
if (auto indexCastOp = mlir::dyn_cast<mlir::arith::IndexCastOp>(definingOp))
return resolveIndexValueImpl(indexCastOp.getIn(), knowledge);
if (auto addOp = mlir::dyn_cast<mlir::arith::AddIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(addOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(addOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs))
return mlir::failure();
return *lhs + *rhs;
}
if (auto subOp = mlir::dyn_cast<mlir::arith::SubIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(subOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(subOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs))
return mlir::failure();
return *lhs - *rhs;
}
if (auto mulOp = mlir::dyn_cast<mlir::arith::MulIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(mulOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(mulOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs))
return mlir::failure();
return *lhs * *rhs;
}
if (auto divOp = mlir::dyn_cast<mlir::arith::DivUIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(divOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(divOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs) || *rhs == 0)
return mlir::failure();
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) / static_cast<uint64_t>(*rhs));
}
if (auto remOp = mlir::dyn_cast<mlir::arith::RemUIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(remOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(remOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs) || *rhs == 0)
return mlir::failure();
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) % static_cast<uint64_t>(*rhs));
}
return mlir::failure();
}
llvm::FailureOr<int64_t> resolveOpFoldResult(mlir::OpFoldResult ofr, const StaticValueKnowledge* knowledge) {
if (auto attr = mlir::dyn_cast<mlir::Attribute>(ofr)) {
auto integerAttr = mlir::dyn_cast<mlir::IntegerAttr>(attr);
if (!integerAttr)
return mlir::failure();
return integerAttr.getInt();
}
return resolveIndexValueImpl(mlir::cast<mlir::Value>(ofr), knowledge);
}
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddressImpl(mlir::Value value,
const StaticValueKnowledge* knowledge) {
int64_t byteOffset = 0;
value = resolveAlias(value, knowledge);
while (true) {
if (mlir::isa<mlir::BlockArgument>(value))
return ResolvedContiguousAddress {value, byteOffset};
mlir::Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return mlir::failure();
if (auto dpsDefiningOp = mlir::dyn_cast<mlir::DestinationStyleOpInterface>(definingOp)) {
mlir::OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(mlir::dyn_cast<mlir::OpResult>(value));
if (!tiedOperand)
return mlir::failure();
value = resolveAlias(tiedOperand->get(), knowledge);
continue;
}
if (auto forOp = mlir::dyn_cast<mlir::scf::ForOp>(definingOp)) {
auto result = mlir::dyn_cast<mlir::OpResult>(value);
if (!result)
return mlir::failure();
auto yieldOp = mlir::cast<mlir::scf::YieldOp>(forOp.getBody()->getTerminator());
mlir::Value yieldedValue = resolveLoopCarriedAliasImpl(yieldOp.getOperand(result.getResultNumber()), knowledge);
if (auto blockArgument = mlir::dyn_cast<mlir::BlockArgument>(yieldedValue)) {
if (blockArgument.getOwner() == forOp.getBody() && blockArgument.getArgNumber() > 0
&& static_cast<unsigned>(blockArgument.getArgNumber() - 1) < forOp.getInitArgs().size()) {
value = resolveAlias(forOp.getInitArgs()[blockArgument.getArgNumber() - 1], knowledge);
continue;
}
}
value = yieldedValue;
continue;
}
if (auto subviewOp = mlir::dyn_cast<mlir::memref::SubViewOp>(definingOp)) {
auto sourceType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getSource().getType());
auto subviewType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getType());
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
return mlir::failure();
llvm::SmallVector<int64_t> offsets;
llvm::SmallVector<int64_t> sizes;
llvm::SmallVector<int64_t> strides;
offsets.reserve(subviewOp.getMixedOffsets().size());
sizes.reserve(subviewOp.getMixedSizes().size());
strides.reserve(subviewOp.getMixedStrides().size());
for (mlir::OpFoldResult offset : subviewOp.getMixedOffsets()) {
auto resolvedOffset = resolveOpFoldResult(offset, knowledge);
if (failed(resolvedOffset))
return mlir::failure();
offsets.push_back(*resolvedOffset);
}
for (mlir::OpFoldResult size : subviewOp.getMixedSizes()) {
auto resolvedSize = resolveOpFoldResult(size, knowledge);
if (failed(resolvedSize))
return mlir::failure();
sizes.push_back(*resolvedSize);
}
for (mlir::OpFoldResult stride : subviewOp.getMixedStrides()) {
auto resolvedStride = resolveOpFoldResult(stride, knowledge);
if (failed(resolvedStride))
return mlir::failure();
strides.push_back(*resolvedStride);
}
if (!isMemoryContiguous(sourceType.getShape(), offsets, sizes, strides))
return mlir::failure();
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
byteOffset += linearizeIndex(offsets, sourceStrides) * subviewType.getElementTypeBitWidth() / 8;
value = resolveAlias(subviewOp.getSource(), knowledge);
continue;
}
if (auto castOp = mlir::dyn_cast<mlir::memref::CastOp>(definingOp)) {
value = resolveAlias(castOp.getSource(), knowledge);
continue;
}
if (auto collapseOp = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(definingOp)) {
value = resolveAlias(collapseOp.getSrc(), knowledge);
continue;
}
if (auto expandOp = mlir::dyn_cast<mlir::memref::ExpandShapeOp>(definingOp)) {
value = resolveAlias(expandOp.getSrc(), knowledge);
continue;
}
if (mlir::isa<mlir::memref::AllocOp, mlir::memref::GetGlobalOp>(definingOp))
return ResolvedContiguousAddress {value, byteOffset};
return mlir::failure();
}
}
} // namespace
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value) { return resolveIndexValueImpl(value, nullptr); }
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge) {
return resolveIndexValueImpl(value, &knowledge);
}
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value) {
return resolveContiguousAddressImpl(value, nullptr);
}
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value,
const StaticValueKnowledge& knowledge) {
return resolveContiguousAddressImpl(value, &knowledge);
}
mlir::Value resolveLoopCarriedAlias(mlir::Value value, const StaticValueKnowledge& knowledge) {
return resolveLoopCarriedAliasImpl(value, &knowledge);
}
} // namespace onnx_mlir
+43
View File
@@ -0,0 +1,43 @@
#pragma once
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/DenseMap.h"
namespace onnx_mlir {
/// Describes a value as a base addressable object plus a statically known
/// byte offset after peeling aliases, casts, and contiguous subviews.
struct ResolvedContiguousAddress {
mlir::Value base;
int64_t byteOffset = 0;
};
/// Records compile-time facts used when interpreting address arithmetic and
/// loop-carried aliases inside PIM regions.
struct StaticValueKnowledge {
llvm::DenseMap<mlir::Value, int64_t> indexValues;
llvm::DenseMap<mlir::Value, mlir::Value> aliases;
StaticValueKnowledge() {}
};
mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp);
/// Resolves a value to contiguous backing storage when that storage can be
/// proven statically from aliases, DPS ties, casts, and subviews.
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value);
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value,
const StaticValueKnowledge& knowledge);
/// Statically evaluates index-like SSA values, including simple integer
/// arithmetic and loop facts recorded in `knowledge`.
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value);
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge);
/// Follows alias, view, and DPS chains to recover the backing value of a
/// loop-carried memref/result.
mlir::Value resolveLoopCarriedAlias(mlir::Value value, const StaticValueKnowledge& knowledge);
} // namespace onnx_mlir
+745
View File
@@ -0,0 +1,745 @@
#ifndef ONNX_MLIR_PIM_COMPACT_ASM_UTILS_HPP
#define ONNX_MLIR_PIM_COMPACT_ASM_UTILS_HPP
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/LogicalResult.h"
namespace onnx_mlir {
namespace compact_asm {
using namespace mlir;
enum class ListDelimiter {
Square,
Paren
};
inline ParseResult parseOpenDelimiter(OpAsmParser& parser, ListDelimiter delimiter) {
if (delimiter == ListDelimiter::Square)
return parser.parseLSquare();
return parser.parseLParen();
}
inline ParseResult parseOptionalCloseDelimiter(OpAsmParser& parser, ListDelimiter delimiter) {
if (delimiter == ListDelimiter::Square)
return parser.parseOptionalRSquare();
return parser.parseOptionalRParen();
}
template <typename StreamT>
inline void printOpenDelimiter(StreamT& stream, ListDelimiter delimiter) {
stream << (delimiter == ListDelimiter::Square ? "[" : "(");
}
template <typename StreamT>
inline void printCloseDelimiter(StreamT& stream, ListDelimiter delimiter) {
stream << (delimiter == ListDelimiter::Square ? "]" : ")");
}
template <typename EntryT, typename ParseEntryFn>
inline ParseResult parseCompressedRepeatedList(OpAsmParser& parser,
ListDelimiter delimiter,
SmallVectorImpl<EntryT>& entries,
ParseEntryFn parseEntry) {
if (parseOpenDelimiter(parser, delimiter))
return failure();
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
return success();
while (true) {
EntryT entry;
if (parseEntry(entry))
return failure();
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t index = 0; index < repeatCount; ++index)
entries.push_back(entry);
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
return success();
if (parser.parseComma())
return failure();
}
}
template <typename IntT>
inline ParseResult
parseCompressedIntegerEntries(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl<IntT>& values) {
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
return success();
while (true) {
if (succeeded(parser.parseOptionalLParen())) {
SmallVector<IntT> subgroup;
if (parseCompressedIntegerEntries(parser, ListDelimiter::Paren, subgroup))
return failure();
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
llvm::append_range(values, subgroup);
}
else {
int64_t first = 0;
if (parser.parseInteger(first))
return failure();
if (succeeded(parser.parseOptionalKeyword("to"))) {
int64_t last = 0;
if (parser.parseInteger(last) || last < first)
return parser.emitError(parser.getCurrentLocation(), "invalid ascending range");
int64_t step = 1;
if (succeeded(parser.parseOptionalKeyword("by"))) {
if (parser.parseInteger(step) || step <= 0)
return parser.emitError(parser.getCurrentLocation(), "step after 'by' must be positive");
}
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
if ((last - first) % step != 0) {
return parser.emitError(parser.getCurrentLocation(),
"range end must be reachable from start using the given step");
}
for (int64_t value = first; value <= last; value += step)
for (int64_t index = 0; index < repeatCount; ++index)
values.push_back(static_cast<IntT>(value));
}
else {
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t index = 0; index < repeatCount; ++index)
values.push_back(static_cast<IntT>(first));
}
}
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
return success();
if (parser.parseComma())
return failure();
}
}
template <typename IntT>
inline ParseResult
parseCompressedIntegerSequence(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl<IntT>& values) {
if (parseOpenDelimiter(parser, delimiter))
return failure();
return parseCompressedIntegerEntries(parser, delimiter, values);
}
template <typename RangeT, typename PrintEntryFn>
inline void printCompressedEqualRuns(OpAsmPrinter& printer, RangeT entries, PrintEntryFn printEntry) {
for (size_t index = 0; index < entries.size();) {
size_t runEnd = index + 1;
while (runEnd < entries.size() && entries[runEnd] == entries[index])
++runEnd;
if (index != 0)
printer << ", ";
printEntry(entries[index]);
size_t runLength = runEnd - index;
if (runLength > 1)
printer << " x" << runLength;
index = runEnd;
}
}
template <typename StreamT, typename IntT>
inline void printCompressedIntegerEntries(StreamT& stream, ArrayRef<IntT> values) {
struct FlatCompression {
enum class Kind {
Single,
EqualRun,
Progression
};
Kind kind = Kind::Single;
size_t covered = 1;
size_t repeatCount = 1;
size_t progressionValueCount = 1;
int64_t step = 1;
IntT firstValue {};
IntT lastValue {};
};
auto computeFlatCompression = [&](size_t start) {
FlatCompression compression;
compression.firstValue = values[start];
compression.lastValue = values[start];
auto findEqualRunEnd = [&](size_t runStart) {
size_t runEnd = runStart + 1;
while (runEnd < values.size() && values[runEnd] == values[runStart])
++runEnd;
return runEnd;
};
size_t firstRunEnd = findEqualRunEnd(start);
compression.repeatCount = firstRunEnd - start;
size_t progressionEnd = firstRunEnd;
int64_t step = 0;
IntT lastValue = values[start];
if (firstRunEnd < values.size()) {
size_t secondRunEnd = findEqualRunEnd(firstRunEnd);
step = static_cast<int64_t>(values[firstRunEnd]) - static_cast<int64_t>(values[start]);
if (step > 0 && secondRunEnd - firstRunEnd == compression.repeatCount) {
progressionEnd = secondRunEnd;
lastValue = values[firstRunEnd];
size_t currentRunStart = secondRunEnd;
while (currentRunStart < values.size()) {
size_t currentRunEnd = findEqualRunEnd(currentRunStart);
if (currentRunEnd - currentRunStart != compression.repeatCount)
break;
if (static_cast<int64_t>(values[currentRunStart]) != static_cast<int64_t>(lastValue) + step)
break;
lastValue = values[currentRunStart];
progressionEnd = currentRunEnd;
currentRunStart = currentRunEnd;
}
}
else {
step = 0;
}
}
compression.covered = 1;
if (progressionEnd > firstRunEnd) {
size_t progressionValueCount = (progressionEnd - start) / compression.repeatCount;
if (progressionValueCount >= 3) {
compression.kind = FlatCompression::Kind::Progression;
compression.covered = progressionEnd - start;
compression.progressionValueCount = progressionValueCount;
compression.step = step;
compression.lastValue = lastValue;
return compression;
}
}
if (compression.repeatCount > 1) {
compression.kind = FlatCompression::Kind::EqualRun;
compression.covered = compression.repeatCount;
return compression;
}
return compression;
};
auto findRepeatedSublist = [&](size_t start) {
size_t bestLength = 0;
size_t bestRepeatCount = 1;
size_t remaining = values.size() - start;
for (size_t length = 2; length * 2 <= remaining; ++length) {
size_t repeatCount = 1;
ArrayRef<IntT> candidate = values.slice(start, length);
while (start + (repeatCount + 1) * length <= values.size()
&& llvm::equal(candidate, values.slice(start + repeatCount * length, length))) {
++repeatCount;
}
if (repeatCount <= 1)
continue;
size_t covered = length * repeatCount;
size_t bestCovered = bestLength * bestRepeatCount;
if (covered > bestCovered || (covered == bestCovered && length < bestLength)) {
bestLength = length;
bestRepeatCount = repeatCount;
}
}
return std::pair(bestLength, bestRepeatCount);
};
for (size_t index = 0; index < values.size();) {
if (index != 0)
stream << ", ";
FlatCompression flat = computeFlatCompression(index);
auto [sublistLength, sublistRepeatCount] = findRepeatedSublist(index);
size_t repeatedSublistCoverage = sublistLength * sublistRepeatCount;
if (sublistRepeatCount > 1 && sublistLength > 1 && repeatedSublistCoverage > flat.covered) {
printOpenDelimiter(stream, ListDelimiter::Paren);
printCompressedIntegerEntries(stream, values.slice(index, sublistLength));
printCloseDelimiter(stream, ListDelimiter::Paren);
stream << " x" << sublistRepeatCount;
index += repeatedSublistCoverage;
continue;
}
switch (flat.kind) {
case FlatCompression::Kind::Progression:
stream << flat.firstValue << " to " << flat.lastValue;
if (flat.step != 1)
stream << " by " << flat.step;
if (flat.repeatCount > 1)
stream << " x" << flat.repeatCount;
index += flat.covered;
break;
case FlatCompression::Kind::EqualRun:
stream << flat.firstValue << " x" << flat.repeatCount;
index += flat.covered;
break;
case FlatCompression::Kind::Single:
stream << flat.firstValue;
index += flat.covered;
break;
}
}
}
template <typename StreamT, typename IntT>
inline void printCompressedIntegerSequence(StreamT& stream, ArrayRef<IntT> values, ListDelimiter delimiter) {
printOpenDelimiter(stream, delimiter);
printCompressedIntegerEntries(stream, values);
printCloseDelimiter(stream, delimiter);
}
template <typename IntT>
inline ParseResult parseCompressedIntegerList(OpAsmParser& parser, SmallVectorImpl<IntT>& values) {
return parseCompressedIntegerSequence(parser, ListDelimiter::Square, values);
}
template <typename IntT>
inline void printCompressedIntegerList(OpAsmPrinter& printer, ArrayRef<IntT> values) {
printCompressedIntegerSequence(printer, values, ListDelimiter::Square);
}
inline void printCompressedValueSequence(OpAsmPrinter& printer, ValueRange values) {
for (size_t index = 0; index < values.size();) {
size_t equalRunEnd = index + 1;
while (equalRunEnd < values.size() && values[equalRunEnd] == values[index])
++equalRunEnd;
if (index != 0)
printer << ", ";
if (equalRunEnd - index > 1) {
printer.printOperand(values[index]);
printer << " x" << (equalRunEnd - index);
index = equalRunEnd;
continue;
}
size_t rangeEnd = index + 1;
if (auto firstResult = dyn_cast<OpResult>(values[index])) {
while (rangeEnd < values.size()) {
auto nextResult = dyn_cast<OpResult>(values[rangeEnd]);
if (!nextResult || nextResult.getOwner() != firstResult.getOwner()
|| nextResult.getResultNumber() != firstResult.getResultNumber() + (rangeEnd - index)) {
break;
}
++rangeEnd;
}
}
else if (auto firstArg = dyn_cast<BlockArgument>(values[index])) {
while (rangeEnd < values.size()) {
auto nextArg = dyn_cast<BlockArgument>(values[rangeEnd]);
if (!nextArg || nextArg.getOwner() != firstArg.getOwner()
|| nextArg.getArgNumber() != firstArg.getArgNumber() + (rangeEnd - index)) {
break;
}
++rangeEnd;
}
}
printer.printOperand(values[index]);
if (rangeEnd - index >= 3) {
printer << " to ";
printer.printOperand(values[rangeEnd - 1]);
}
else if (rangeEnd - index == 2) {
printer << ", ";
printer.printOperand(values[index + 1]);
}
index = rangeEnd;
}
}
inline void printCompressedTypeSequence(OpAsmPrinter& printer, TypeRange types) {
printCompressedEqualRuns(printer, types, [&](Type type) { printer.printType(type); });
}
inline void printCompressedValueList(OpAsmPrinter& printer, ValueRange values, ListDelimiter delimiter) {
printOpenDelimiter(printer, delimiter);
printCompressedValueSequence(printer, values);
printCloseDelimiter(printer, delimiter);
}
inline void printCompressedTypeList(OpAsmPrinter& printer, TypeRange types, ListDelimiter delimiter) {
printOpenDelimiter(printer, delimiter);
printCompressedTypeSequence(printer, types);
printCloseDelimiter(printer, delimiter);
}
inline ParseResult parseCompressedTypeSequence(OpAsmParser& parser, SmallVectorImpl<Type>& types, bool allowEmpty) {
Type firstType;
OptionalParseResult firstTypeResult = parser.parseOptionalType(firstType);
if (!firstTypeResult.has_value()) {
if (allowEmpty)
return success();
return parser.emitError(parser.getCurrentLocation(), "expected type");
}
if (failed(*firstTypeResult))
return failure();
auto appendType = [&](Type type) -> ParseResult {
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t index = 0; index < repeatCount; ++index)
types.push_back(type);
return success();
};
if (appendType(firstType))
return failure();
while (succeeded(parser.parseOptionalComma())) {
Type nextType;
if (parser.parseType(nextType) || appendType(nextType))
return failure();
}
return success();
}
inline ParseResult parseCompressedOperandEntryWithFirst(OpAsmParser& parser,
OpAsmParser::UnresolvedOperand firstOperand,
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
if (succeeded(parser.parseOptionalKeyword("to"))) {
OpAsmParser::UnresolvedOperand lastOperand;
if (parser.parseOperand(lastOperand))
return failure();
if (firstOperand.name != lastOperand.name || firstOperand.number > lastOperand.number)
return parser.emitError(parser.getCurrentLocation(), "invalid operand range");
for (unsigned number = firstOperand.number; number <= lastOperand.number; ++number)
operands.push_back({firstOperand.location, firstOperand.name, number});
return success();
}
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t index = 0; index < repeatCount; ++index)
operands.push_back(firstOperand);
return success();
}
inline ParseResult parseOneCompressedOperandEntry(OpAsmParser& parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
OpAsmParser::UnresolvedOperand firstOperand;
if (parser.parseOperand(firstOperand))
return failure();
return parseCompressedOperandEntryWithFirst(parser, firstOperand, operands);
}
inline ParseResult parseCompressedOperandList(OpAsmParser& parser,
ListDelimiter delimiter,
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
if (parseOpenDelimiter(parser, delimiter))
return failure();
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
return success();
while (true) {
if (parseOneCompressedOperandEntry(parser, operands))
return failure();
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
return success();
if (parser.parseComma())
return failure();
}
}
inline ParseResult parseCompressedOperandSequence(OpAsmParser& parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
if (parseOneCompressedOperandEntry(parser, operands))
return failure();
while (succeeded(parser.parseOptionalComma()))
if (parseOneCompressedOperandEntry(parser, operands))
return failure();
return success();
}
inline ParseResult parseCompressedTypeList(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl<Type>& types) {
if (parseOpenDelimiter(parser, delimiter))
return failure();
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
return success();
if (parseCompressedTypeSequence(parser, types, /*allowEmpty=*/false))
return failure();
return parseOptionalCloseDelimiter(parser, delimiter);
}
inline bool hasRepeatedTuple(ValueRange values, size_t tupleSize) {
if (tupleSize == 0 || values.empty() || values.size() % tupleSize != 0)
return false;
SmallVector<Value> valueVec(values.begin(), values.end());
ArrayRef<Value> tuple(valueVec.data(), tupleSize);
for (size_t index = tupleSize; index < values.size(); index += tupleSize)
if (!llvm::equal(tuple, ArrayRef<Value>(valueVec).slice(index, tupleSize)))
return false;
return true;
}
inline bool hasRepeatedTuple(TypeRange types, size_t tupleSize) {
if (tupleSize == 0 || types.empty() || types.size() % tupleSize != 0)
return false;
SmallVector<Type> typeVec(types.begin(), types.end());
ArrayRef<Type> tuple(typeVec.data(), tupleSize);
for (size_t index = tupleSize; index < types.size(); index += tupleSize)
if (!llvm::equal(tuple, ArrayRef<Type>(typeVec).slice(index, tupleSize)))
return false;
return true;
}
inline void printValueTupleRun(OpAsmPrinter& printer, ValueRange values, size_t tupleSize, ListDelimiter delimiter) {
printOpenDelimiter(printer, delimiter);
printOpenDelimiter(printer, ListDelimiter::Paren);
for (size_t index = 0; index < tupleSize; ++index) {
if (index != 0)
printer << ", ";
printer.printOperand(values[index]);
}
printCloseDelimiter(printer, ListDelimiter::Paren);
printer << " x" << (values.size() / tupleSize);
printCloseDelimiter(printer, delimiter);
}
inline void printTypeTupleRun(OpAsmPrinter& printer, TypeRange types, size_t tupleSize, ListDelimiter delimiter) {
printOpenDelimiter(printer, delimiter);
printOpenDelimiter(printer, ListDelimiter::Paren);
for (size_t index = 0; index < tupleSize; ++index) {
if (index != 0)
printer << ", ";
printer.printType(types[index]);
}
printCloseDelimiter(printer, ListDelimiter::Paren);
printer << " x" << (types.size() / tupleSize);
printCloseDelimiter(printer, delimiter);
}
inline ParseResult parseCompressedOrTupleOperandList(OpAsmParser& parser,
ListDelimiter delimiter,
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
if (parseOpenDelimiter(parser, delimiter))
return failure();
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
return success();
if (succeeded(parser.parseOptionalLParen())) {
SmallVector<OpAsmParser::UnresolvedOperand> tupleOperands;
if (parseCompressedOperandSequence(parser, tupleOperands) || parser.parseRParen())
return failure();
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
llvm::append_range(operands, tupleOperands);
while (succeeded(parser.parseOptionalComma())) {
if (parser.parseLParen())
return failure();
tupleOperands.clear();
if (parseCompressedOperandSequence(parser, tupleOperands) || parser.parseRParen())
return failure();
repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
llvm::append_range(operands, tupleOperands);
}
return parseOptionalCloseDelimiter(parser, delimiter);
}
while (true) {
if (parseOneCompressedOperandEntry(parser, operands))
return failure();
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
return success();
if (parser.parseComma())
return failure();
}
}
inline ParseResult
parseCompressedOrTupleTypeList(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl<Type>& types) {
if (parseOpenDelimiter(parser, delimiter))
return failure();
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
return success();
if (succeeded(parser.parseOptionalLParen())) {
SmallVector<Type> tupleTypes;
if (parseCompressedTypeSequence(parser, tupleTypes, /*allowEmpty=*/false) || parser.parseRParen())
return failure();
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
llvm::append_range(types, tupleTypes);
while (succeeded(parser.parseOptionalComma())) {
if (parser.parseLParen())
return failure();
tupleTypes.clear();
if (parseCompressedTypeSequence(parser, tupleTypes, /*allowEmpty=*/false) || parser.parseRParen())
return failure();
repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
llvm::append_range(types, tupleTypes);
}
return parseOptionalCloseDelimiter(parser, delimiter);
}
while (true) {
Type type;
if (parser.parseType(type))
return failure();
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
types.push_back(type);
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
return success();
if (parser.parseComma())
return failure();
}
}
inline void printArgumentBindings(OpAsmPrinter& printer, Block& block, ValueRange operands) {
if (block.getNumArguments() == 0) {
printer << "() = ()";
return;
}
if (block.getNumArguments() == 1) {
printer.printOperand(block.getArgument(0));
printer << " = ";
printCompressedValueList(printer, operands, ListDelimiter::Paren);
return;
}
printCompressedValueList(printer, ValueRange(block.getArguments()), ListDelimiter::Paren);
printer << " = ";
printCompressedValueList(printer, operands, ListDelimiter::Paren);
}
inline ParseResult parseCompressedArgumentEntryWithFirst(OpAsmParser& parser,
OpAsmParser::Argument firstArgument,
SmallVectorImpl<OpAsmParser::Argument>& arguments) {
if (succeeded(parser.parseOptionalKeyword("to"))) {
OpAsmParser::Argument lastArgument;
if (parser.parseArgument(lastArgument))
return failure();
if (firstArgument.ssaName.name != lastArgument.ssaName.name
|| firstArgument.ssaName.number > lastArgument.ssaName.number) {
return parser.emitError(parser.getCurrentLocation(), "invalid argument range");
}
for (unsigned number = firstArgument.ssaName.number; number <= lastArgument.ssaName.number; ++number) {
OpAsmParser::Argument argument;
argument.ssaName = {firstArgument.ssaName.location, firstArgument.ssaName.name, number};
arguments.push_back(argument);
}
return success();
}
arguments.push_back(firstArgument);
return success();
}
inline ParseResult parseOneCompressedArgumentEntry(OpAsmParser& parser,
SmallVectorImpl<OpAsmParser::Argument>& arguments) {
OpAsmParser::Argument firstArgument;
if (parser.parseArgument(firstArgument))
return failure();
return parseCompressedArgumentEntryWithFirst(parser, firstArgument, arguments);
}
inline void applyArgumentTypes(ArrayRef<Type> inputTypes, SmallVectorImpl<OpAsmParser::Argument>& arguments) {
for (auto [argument, inputType] : llvm::zip_equal(arguments, inputTypes))
argument.type = inputType;
}
inline ParseResult parseArgumentBindings(OpAsmParser& parser,
SmallVectorImpl<OpAsmParser::Argument>& arguments,
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
if (succeeded(parser.parseOptionalLParen())) {
if (succeeded(parser.parseOptionalRParen())) {
if (parser.parseEqual() || parseCompressedOperandList(parser, ListDelimiter::Paren, operands))
return failure();
return success();
}
OpAsmParser::Argument firstArgument;
if (parser.parseArgument(firstArgument) || parseCompressedArgumentEntryWithFirst(parser, firstArgument, arguments))
return failure();
while (succeeded(parser.parseOptionalComma()))
if (parseOneCompressedArgumentEntry(parser, arguments))
return failure();
if (parser.parseRParen() || parser.parseEqual()
|| parseCompressedOperandList(parser, ListDelimiter::Paren, operands)) {
return failure();
}
return success();
}
OpAsmParser::Argument argument;
if (parser.parseArgument(argument) || parser.parseEqual()
|| parseCompressedOperandList(parser, ListDelimiter::Paren, operands)) {
return failure();
}
arguments.push_back(argument);
return success();
}
} // namespace compact_asm
} // namespace onnx_mlir
#endif
+67
View File
@@ -0,0 +1,67 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
namespace onnx_mlir {
bool isCoreStaticAddressOp(mlir::Operation* op) {
return mlir::isa<mlir::arith::ConstantOp,
mlir::arith::AddIOp,
mlir::arith::SubIOp,
mlir::arith::MulIOp,
mlir::arith::DivUIOp,
mlir::arith::RemUIOp,
mlir::arith::IndexCastOp,
mlir::memref::AllocOp,
mlir::memref::SubViewOp,
mlir::memref::CastOp,
mlir::memref::CollapseShapeOp,
mlir::memref::ExpandShapeOp>(op);
}
mlir::LogicalResult
walkPimCoreBlock(mlir::Block& block,
const StaticValueKnowledge& knowledge,
llvm::function_ref<mlir::LogicalResult(mlir::Operation&, const StaticValueKnowledge&)> callback) {
bool hasFailure = false;
for (mlir::Operation& op : block) {
if (mlir::isa<pim::PimHaltOp, mlir::scf::YieldOp>(op) || isCoreStaticAddressOp(&op))
continue;
if (auto forOp = mlir::dyn_cast<mlir::scf::ForOp>(op)) {
mlir::Block& loopBody = forOp.getRegion().front();
auto lowerBound = resolveIndexValue(forOp.getLowerBound(), knowledge);
auto upperBound = resolveIndexValue(forOp.getUpperBound(), knowledge);
auto step = resolveIndexValue(forOp.getStep(), knowledge);
if (failed(lowerBound) || failed(upperBound) || failed(step) || *step <= 0) {
forOp.emitOpError("requires statically evaluable scf.for bounds for PIM codegen");
hasFailure = true;
continue;
}
llvm::SmallVector<mlir::Value> iterValues(forOp.getInitArgs().begin(), forOp.getInitArgs().end());
for (int64_t inductionValue = *lowerBound; inductionValue < *upperBound; inductionValue += *step) {
StaticValueKnowledge loopKnowledge = knowledge;
loopKnowledge.indexValues[forOp.getInductionVar()] = inductionValue;
for (auto [iterArg, iterValue] : llvm::zip_equal(forOp.getRegionIterArgs(), iterValues))
loopKnowledge.aliases[iterArg] = iterValue;
if (failed(walkPimCoreBlock(loopBody, loopKnowledge, callback)))
hasFailure = true;
auto yieldOp = mlir::cast<mlir::scf::YieldOp>(loopBody.getTerminator());
for (auto [index, yieldedValue] : llvm::enumerate(yieldOp.getOperands()))
iterValues[index] = resolveLoopCarriedAlias(yieldedValue, loopKnowledge);
}
continue;
}
if (failed(callback(op, knowledge)))
hasFailure = true;
}
return mlir::success(!hasFailure);
}
} // namespace onnx_mlir
+24
View File
@@ -0,0 +1,24 @@
#pragma once
#include "mlir/IR/Block.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/STLFunctionalExtras.h"
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
namespace onnx_mlir {
/// Returns true for ops in a `pim.core` body that only participate in static
/// address or index computation and therefore do not emit PIM instructions.
bool isCoreStaticAddressOp(mlir::Operation* op);
/// Walks a `pim.core` body, statically unrolling nested `scf.for` loops when
/// their bounds are known and invoking `callback` only on instruction-emitting
/// operations.
mlir::LogicalResult
walkPimCoreBlock(mlir::Block& block,
const StaticValueKnowledge& knowledge,
llvm::function_ref<mlir::LogicalResult(mlir::Operation&, const StaticValueKnowledge&)> callback);
} // namespace onnx_mlir
+45
View File
@@ -0,0 +1,45 @@
#include "src/Accelerators/PIM/Common/IR/EntryPointUtils.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
namespace onnx_mlir {
llvm::FailureOr<mlir::func::FuncOp> getPimEntryFunc(mlir::ModuleOp moduleOp) {
if (!moduleOp)
return mlir::failure();
llvm::SmallVector<mlir::ONNXEntryPointOp> entryPoints(moduleOp.getOps<mlir::ONNXEntryPointOp>());
if (entryPoints.size() > 1) {
moduleOp.emitError("PIM pipeline requires a single ONNX entry point, but found ") << entryPoints.size();
return mlir::failure();
}
if (!entryPoints.empty()) {
auto entryPointAttr =
entryPoints.front()->getAttrOfType<mlir::SymbolRefAttr>(mlir::ONNXEntryPointOp::getEntryPointFuncAttrName());
if (!entryPointAttr) {
entryPoints.front().emitOpError("is missing the entry point function attribute");
return mlir::failure();
}
auto entryFunc = moduleOp.lookupSymbol<mlir::func::FuncOp>(entryPointAttr.getLeafReference().getValue());
if (!entryFunc) {
entryPoints.front().emitOpError("references an unknown entry function ")
<< entryPointAttr.getLeafReference().getValue();
return mlir::failure();
}
return entryFunc;
}
if (auto mainGraphFunc = moduleOp.lookupSymbol<mlir::func::FuncOp>("main_graph"))
return mainGraphFunc;
llvm::SmallVector<mlir::func::FuncOp> nonExternalFuncs;
for (auto funcOp : moduleOp.getOps<mlir::func::FuncOp>())
if (!funcOp.isExternal())
nonExternalFuncs.push_back(funcOp);
if (nonExternalFuncs.size() == 1)
return nonExternalFuncs.front();
moduleOp.emitError("could not resolve a unique PIM entry function");
return mlir::failure();
}
} // namespace onnx_mlir
+13
View File
@@ -0,0 +1,13 @@
#pragma once
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
namespace onnx_mlir {
/// Resolves the function the PIM pipeline should treat as its entry point.
/// Prefers ONNX entry-point metadata, then `main_graph`, then the only
/// non-external function if the module is otherwise unambiguous.
llvm::FailureOr<mlir::func::FuncOp> getPimEntryFunc(mlir::ModuleOp moduleOp);
} // namespace onnx_mlir
+89
View File
@@ -0,0 +1,89 @@
#include "llvm/ADT/STLExtras.h"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
namespace onnx_mlir {
llvm::SmallVector<int64_t> computeRowMajorStrides(llvm::ArrayRef<int64_t> shape) {
llvm::SmallVector<int64_t> strides(shape.size(), 1);
for (int64_t dim = static_cast<int64_t>(shape.size()) - 2; dim >= 0; --dim)
strides[dim] = strides[dim + 1] * shape[dim + 1];
return strides;
}
llvm::SmallVector<int64_t>
delinearizeIndex(int64_t linearIndex, llvm::ArrayRef<int64_t> shape, llvm::ArrayRef<int64_t> strides) {
llvm::SmallVector<int64_t> indices(shape.size(), 0);
for (auto [dim, stride] : llvm::enumerate(strides)) {
indices[dim] = linearIndex / stride;
linearIndex %= stride;
}
return indices;
}
int64_t linearizeIndex(llvm::ArrayRef<int64_t> indices, llvm::ArrayRef<int64_t> strides) {
int64_t linearIndex = 0;
for (auto [index, stride] : llvm::zip_equal(indices, strides))
linearIndex += index * stride;
return linearIndex;
}
int64_t getNumElements(llvm::ArrayRef<int64_t> shape) {
int64_t numElements = 1;
for (int64_t dim : shape)
numElements *= dim;
return numElements;
}
bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
llvm::ArrayRef<int64_t> offsets,
llvm::ArrayRef<int64_t> sizes,
llvm::ArrayRef<int64_t> strides) {
if (std::any_of(strides.begin(), strides.end(), [](int64_t stride) -> bool { return stride != 1; }))
return false;
auto offsetsAndSizesAndShape = llvm::zip_equal(llvm::make_range(offsets.rbegin(), offsets.rend()),
llvm::make_range(sizes.rbegin(), sizes.rend()),
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
auto firstNonZeroOffset = std::find_if(
offsetsAndSizesAndShape.begin(), offsetsAndSizesAndShape.end(), [&](auto offsetAndSizeAndShape) -> bool {
auto [offset, _size, _dimension] = offsetAndSizeAndShape;
return offset != 0;
});
if (firstNonZeroOffset != offsetsAndSizesAndShape.end()) {
auto [offset, size, dimension] = *firstNonZeroOffset;
if (size > dimension - offset)
return false;
++firstNonZeroOffset;
if (std::any_of(firstNonZeroOffset, offsetsAndSizesAndShape.end(), [](auto offsetAndSizeAndShape) -> bool {
auto [_offset, size, _dimension] = offsetAndSizeAndShape;
return size != 1;
}))
return false;
}
auto sizesAndShape = llvm::zip_equal(llvm::make_range(sizes.rbegin(), sizes.rend()),
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
auto firstDifferentSize = std::find_if(sizesAndShape.begin(), sizesAndShape.end(), [&](auto sizeAndShape) -> bool {
auto [size, dimension] = sizeAndShape;
return size != dimension;
});
if (firstDifferentSize != sizesAndShape.end()) {
++firstDifferentSize;
if (std::any_of(firstDifferentSize, sizesAndShape.end(), [](auto sizeAndShape) -> bool {
auto [size, _dimension] = sizeAndShape;
return size != 1;
}))
return false;
}
return true;
}
} // namespace onnx_mlir
+22
View File
@@ -0,0 +1,22 @@
#pragma once
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
namespace onnx_mlir {
llvm::SmallVector<int64_t> computeRowMajorStrides(llvm::ArrayRef<int64_t> shape);
llvm::SmallVector<int64_t>
delinearizeIndex(int64_t linearIndex, llvm::ArrayRef<int64_t> shape, llvm::ArrayRef<int64_t> strides);
int64_t linearizeIndex(llvm::ArrayRef<int64_t> indices, llvm::ArrayRef<int64_t> strides);
int64_t getNumElements(llvm::ArrayRef<int64_t> shape);
bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
llvm::ArrayRef<int64_t> offsets,
llvm::ArrayRef<int64_t> sizes,
llvm::ArrayRef<int64_t> strides);
} // namespace onnx_mlir
+85
View File
@@ -0,0 +1,85 @@
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
using namespace mlir;
namespace onnx_mlir {
Value stripMemRefCasts(Value value) {
while (auto castOp = value.getDefiningOp<memref::CastOp>())
value = castOp.getSource();
return value;
}
Value stripMemRefViewOps(Value value) {
while (true) {
if (auto castOp = value.getDefiningOp<memref::CastOp>()) {
value = castOp.getSource();
continue;
}
if (auto collapseOp = value.getDefiningOp<memref::CollapseShapeOp>()) {
value = collapseOp.getSrc();
continue;
}
if (auto expandOp = value.getDefiningOp<memref::ExpandShapeOp>()) {
value = expandOp.getSrc();
continue;
}
return value;
}
}
bool hasAllStaticSubviewParts(memref::SubViewOp subview) {
return llvm::all_of(subview.getStaticOffsets(), [](int64_t value) { return !ShapedType::isDynamic(value); })
&& llvm::all_of(subview.getStaticSizes(), [](int64_t value) { return !ShapedType::isDynamic(value); })
&& llvm::all_of(subview.getStaticStrides(), [](int64_t value) { return !ShapedType::isDynamic(value); });
}
FailureOr<StaticSubviewInfo> getStaticSubviewInfo(Value value) {
value = stripMemRefViewOps(value);
auto subviewOp = value.getDefiningOp<memref::SubViewOp>();
if (!subviewOp)
return failure();
auto source = stripMemRefCasts(subviewOp.getSource());
auto sourceType = dyn_cast<MemRefType>(source.getType());
auto subviewType = dyn_cast<MemRefType>(subviewOp.getType());
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
return failure();
StaticSubviewInfo info;
info.source = source;
info.sourceShape.assign(sourceType.getShape().begin(), sourceType.getShape().end());
SmallVector<OpFoldResult> mixedOffsets = subviewOp.getMixedOffsets();
info.offsets.assign(mixedOffsets.begin(), mixedOffsets.end());
for (OpFoldResult size : subviewOp.getMixedSizes()) {
auto staticSize = getConstantIntValue(size);
if (!staticSize)
return failure();
info.sizes.push_back(*staticSize);
}
for (OpFoldResult stride : subviewOp.getMixedStrides()) {
auto staticStride = getConstantIntValue(stride);
if (!staticStride)
return failure();
info.strides.push_back(*staticStride);
}
return info;
}
FailureOr<SmallVector<int64_t>> getStaticSubviewOffsets(const StaticSubviewInfo& info) {
SmallVector<int64_t> staticOffsets;
staticOffsets.reserve(info.offsets.size());
for (OpFoldResult offset : info.offsets) {
auto staticOffset = getConstantIntValue(offset);
if (!staticOffset)
return failure();
staticOffsets.push_back(*staticOffset);
}
return staticOffsets;
}
} // namespace onnx_mlir
+30
View File
@@ -0,0 +1,30 @@
#pragma once
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/LogicalResult.h"
namespace onnx_mlir {
struct StaticSubviewInfo {
mlir::Value source;
llvm::SmallVector<int64_t> sourceShape;
llvm::SmallVector<mlir::OpFoldResult> offsets;
llvm::SmallVector<int64_t> sizes;
llvm::SmallVector<int64_t> strides;
};
mlir::Value stripMemRefCasts(mlir::Value value);
mlir::Value stripMemRefViewOps(mlir::Value value);
bool hasAllStaticSubviewParts(mlir::memref::SubViewOp subview);
llvm::FailureOr<StaticSubviewInfo> getStaticSubviewInfo(mlir::Value value);
/// Returns the offsets in `info` as int64_t, failing if any offset is dynamic.
llvm::FailureOr<llvm::SmallVector<int64_t>> getStaticSubviewOffsets(const StaticSubviewInfo& info);
} // namespace onnx_mlir
+108
View File
@@ -0,0 +1,108 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallSet.h"
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
namespace onnx_mlir {
bool hasWeightAlways(mlir::Operation* op) { return op && op->getAttr(PimWeightAlwaysAttrName) != nullptr; }
void markWeightAlways(mlir::Operation* op) {
assert(op && "expected valid op");
op->setAttr(PimWeightAlwaysAttrName, mlir::UnitAttr::get(op->getContext()));
}
namespace {
template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy>
bool hasMvmVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) {
bool found = false;
parentOp.walk([&](mlir::Operation* op) {
if (auto mvmOp = mlir::dyn_cast<MVMOpTy>(op))
found |= mvmOp.getWeightIndex() == weightIndex;
else if (auto vmmOp = mlir::dyn_cast<VMMOpTy>(op))
found |= vmmOp.getWeightIndex() == weightIndex;
});
return found;
}
template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy>
void walkMvmVmmWeightUses(ParentOpTy parentOp, llvm::function_ref<void(mlir::OpOperand&)> callback) {
auto weights = parentOp.getWeights();
llvm::SmallSet<unsigned, 8> visited;
auto walkWeightIndex = [&](unsigned weightIndex) {
if (weightIndex < weights.size() && visited.insert(weightIndex).second)
callback(parentOp->getOpOperand(weightIndex));
};
parentOp.walk([&](MVMOpTy op) { walkWeightIndex(op.getWeightIndex()); });
parentOp.walk([&](VMMOpTy op) { walkWeightIndex(op.getWeightIndex()); });
}
} // namespace
bool isSpatialMvmVmmWeightUse(mlir::OpOperand& use) {
mlir::Operation* user = use.getOwner();
unsigned operandIndex = use.getOperandNumber();
auto computeOp = mlir::dyn_cast<spatial::SpatCompute>(user);
if (!computeOp || operandIndex >= computeOp.getWeights().size())
return false;
return hasMvmVmmWeightUse<spatial::SpatMVMOp, spatial::SpatVMMOp>(computeOp, operandIndex);
}
bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value) {
llvm::SmallPtrSet<mlir::Value, 8> visited;
auto walkUses = [&](mlir::Value currentValue, auto& self) -> bool {
if (!visited.insert(currentValue).second)
return true;
if (currentValue.use_empty())
return false;
return llvm::all_of(currentValue.getUses(), [&](mlir::OpOperand& use) {
if (isSpatialMvmVmmWeightUse(use))
return true;
mlir::Operation* user = use.getOwner();
if (auto extractSliceOp = mlir::dyn_cast<mlir::tensor::ExtractSliceOp>(user))
return extractSliceOp.getSource() == currentValue && self(extractSliceOp.getResult(), self);
if (auto expandShapeOp = mlir::dyn_cast<mlir::tensor::ExpandShapeOp>(user))
return expandShapeOp.getSrc() == currentValue && self(expandShapeOp.getResult(), self);
if (auto collapseShapeOp = mlir::dyn_cast<mlir::tensor::CollapseShapeOp>(user))
return collapseShapeOp.getSrc() == currentValue && self(collapseShapeOp.getResult(), self);
if (auto transposeOp = mlir::dyn_cast<mlir::ONNXTransposeOp>(user))
return transposeOp.getData() == currentValue && self(transposeOp.getResult(), self);
return false;
});
};
return walkUses(value, walkUses);
}
void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir::OpOperand&)> callback) {
assert(root && "expected valid root op");
root->walk([&](pim::PimCoreOp coreOp) {
coreOp.walk([&](pim::PimVMMOp vmmOp) {
auto weights = coreOp.getWeights();
unsigned weightIndex = vmmOp.getWeightIndex();
if (weightIndex < weights.size())
callback(coreOp->getOpOperand(weightIndex));
});
});
root->walk([&](pim::PimCoreBatchOp coreBatchOp) {
auto weights = coreBatchOp.getWeights();
for (auto weight : weights)
for (mlir::OpOperand& use : weight.getUses())
if (use.getOwner() == coreBatchOp.getOperation())
callback(use);
});
}
} // namespace onnx_mlir
+29
View File
@@ -0,0 +1,29 @@
#pragma once
#include "mlir/IR/Operation.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/STLFunctionalExtras.h"
#include "llvm/ADT/StringRef.h"
inline constexpr llvm::StringRef PimWeightAlwaysAttrName = "weightAlways";
namespace onnx_mlir {
bool hasWeightAlways(mlir::Operation* op);
/// Tags an op as producing a value that should stay materialized as a reusable
/// weight across later PIM lowering/codegen stages.
void markWeightAlways(mlir::Operation* op);
bool isSpatialMvmVmmWeightUse(mlir::OpOperand& use);
/// Returns true when a value flows only into Spatial weighted MVM/VMM operands,
/// allowing later passes to preserve it as a dedicated weight-like object.
bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value);
/// Visits weight operands consumed by Pim core ops/core batches so downstream
/// passes can identify globals that must remain weight-backed.
void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir::OpOperand&)> callback);
} // namespace onnx_mlir
-546
View File
@@ -1,546 +0,0 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "llvm/Support/raw_os_ostream.h"
#include <filesystem>
#include <fstream>
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Compiler/CompilerOptions.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
std::string getOutputDir() {
if (outputBaseName.empty() || outputBaseName == "-")
return {};
size_t lastSlash = outputBaseName.find_last_of('/');
if (lastSlash == std::string::npos)
return ".";
return outputBaseName.substr(0, lastSlash);
}
void createDirectory(const std::string& directory) {
std::error_code errorCode;
std::filesystem::create_directories(directory, errorCode);
assert(!errorCode && ("Failed to create directory: " + errorCode.message()).data());
}
void dumpModule(ModuleOp moduleOp, const std::string& name) {
std::string outputDir = getOutputDir();
if (outputDir.empty())
return;
std::string dialectsDir = outputDir + "/dialects";
createDirectory(dialectsDir);
std::fstream file(dialectsDir + "/" + name + ".mlir", std::ios::out);
llvm::raw_os_ostream os(file);
os << *moduleOp;
os.flush();
file.close();
}
FailureOr<func::FuncOp> getPimEntryFunc(ModuleOp moduleOp) {
if (!moduleOp)
return failure();
SmallVector<ONNXEntryPointOp> entryPoints(moduleOp.getOps<ONNXEntryPointOp>());
if (entryPoints.size() > 1) {
moduleOp.emitError("PIM pipeline requires a single ONNX entry point, but found ") << entryPoints.size();
return failure();
}
if (!entryPoints.empty()) {
auto entryPointAttr =
entryPoints.front()->getAttrOfType<SymbolRefAttr>(ONNXEntryPointOp::getEntryPointFuncAttrName());
if (!entryPointAttr) {
entryPoints.front().emitOpError("is missing the entry point function attribute");
return failure();
}
auto entryFunc = moduleOp.lookupSymbol<func::FuncOp>(entryPointAttr.getLeafReference().getValue());
if (!entryFunc) {
entryPoints.front().emitOpError("references an unknown entry function ")
<< entryPointAttr.getLeafReference().getValue();
return failure();
}
return entryFunc;
}
if (auto mainGraphFunc = moduleOp.lookupSymbol<func::FuncOp>("main_graph"))
return mainGraphFunc;
SmallVector<func::FuncOp> nonExternalFuncs;
for (auto funcOp : moduleOp.getOps<func::FuncOp>())
if (!funcOp.isExternal())
nonExternalFuncs.push_back(funcOp);
if (nonExternalFuncs.size() == 1)
return nonExternalFuncs.front();
moduleOp.emitError("could not resolve a unique PIM entry function");
return failure();
}
bool hasWeightAlways(Operation* op) { return op && op->getAttr(PimWeightAlwaysAttrName) != nullptr; }
void markWeightAlways(Operation* op) {
assert(op && "expected valid op");
op->setAttr(PimWeightAlwaysAttrName, UnitAttr::get(op->getContext()));
}
memref::GlobalOp lookupGlobalForGetGlobal(ModuleOp moduleOp, memref::GetGlobalOp getGlobalOp) {
if (!moduleOp || !getGlobalOp)
return {};
return moduleOp.lookupSymbol<memref::GlobalOp>(getGlobalOp.getName());
}
FailureOr<Operation*> getOtherEndOfChannel(Operation* op, bool opIsReceive, RewriterBase& rewriter) {
auto channelNewOp = op->getOperand(0).getDefiningOp<spatial::SpatChannelNewOp>();
if (!channelNewOp) {
op->emitError("User of Channel must have the first operand created by ChannelNewOp.");
return failure();
}
// channelNewOp should have two users: `op` and a
// `ChannelSendOp`/`ChannelReceiveOp`
auto channelUsers = channelNewOp->getUsers();
auto usersIterator = channelUsers.begin();
auto firstUser = *usersIterator;
usersIterator++;
if (usersIterator == channelUsers.end()) {
op->emitError("Operand generated by ChannelNewOp must have two users, "
"only one found.");
channelNewOp->dump();
op->dump();
channelNewOp->getParentOp()->dump();
return failure();
}
auto secondUser = *usersIterator;
usersIterator++;
if (usersIterator != channelUsers.end()) {
op->emitError("Operand generated by ChannelNewOp must have two users, "
"more than two found.");
return failure();
}
Operation* notOpUser;
if (firstUser == op) {
notOpUser = secondUser;
}
else if (secondUser == op) {
notOpUser = firstUser;
}
else {
op->emitError("Operand generated by ChannelNewOp must have two users, "
"and one of them must be me, but"
"none of them is actually me.");
return failure();
}
if (opIsReceive) {
if (!isa<spatial::SpatChannelSendOp>(notOpUser)) {
op->emitError("Operand generated by ChannelNewOp has two user, one is "
"me, the other is not a ChannelSendOp.");
return failure();
}
return notOpUser;
}
else {
if (!isa<spatial::SpatChannelReceiveOp>(notOpUser)) {
op->emitError("Operand generated by ChannelNewOp has two user, one is "
"me, the other is not a ChannelReceiveOp.");
return failure();
}
return notOpUser;
}
}
SmallVector<int64_t> computeRowMajorStrides(ArrayRef<int64_t> shape) {
SmallVector<int64_t> strides(shape.size(), 1);
for (int64_t dim = static_cast<int64_t>(shape.size()) - 2; dim >= 0; --dim)
strides[dim] = strides[dim + 1] * shape[dim + 1];
return strides;
}
SmallVector<int64_t> delinearizeIndex(int64_t linearIndex, ArrayRef<int64_t> shape, ArrayRef<int64_t> strides) {
SmallVector<int64_t> indices(shape.size(), 0);
for (auto [dim, stride] : llvm::enumerate(strides)) {
indices[dim] = linearIndex / stride;
linearIndex %= stride;
}
return indices;
}
int64_t linearizeIndex(ArrayRef<int64_t> indices, ArrayRef<int64_t> strides) {
int64_t linearIndex = 0;
for (auto [index, stride] : llvm::zip_equal(indices, strides))
linearIndex += index * stride;
return linearIndex;
}
int64_t getNumElements(ArrayRef<int64_t> shape) {
int64_t numElements = 1;
for (int64_t dim : shape)
numElements *= dim;
return numElements;
}
bool isMemoryContiguous(ArrayRef<int64_t> srcShape,
ArrayRef<int64_t> offsets,
ArrayRef<int64_t> sizes,
ArrayRef<int64_t> strides) {
if (std::any_of(strides.begin(), strides.end(), [](int64_t stride) -> bool { return stride != 1; }))
return false;
auto offsetsAndSizesAndShape = llvm::zip_equal(llvm::make_range(offsets.rbegin(), offsets.rend()),
llvm::make_range(sizes.rbegin(), sizes.rend()),
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
auto firstNonZeroOffset = std::find_if(
offsetsAndSizesAndShape.begin(), offsetsAndSizesAndShape.end(), [&](auto offsetAndSizeAndShape) -> bool {
auto [offset, _size, _dimension] = offsetAndSizeAndShape;
return offset != 0;
});
if (firstNonZeroOffset != offsetsAndSizesAndShape.end()) {
auto [offset, size, dimension] = *firstNonZeroOffset;
if (size > dimension - offset)
return false;
++firstNonZeroOffset;
if (std::any_of(firstNonZeroOffset, offsetsAndSizesAndShape.end(), [](auto offsetAndSizeAndShape) -> bool {
auto [_offset, size, _dimension] = offsetAndSizeAndShape;
return size != 1;
}))
return false;
}
auto sizesAndShape = llvm::zip_equal(llvm::make_range(sizes.rbegin(), sizes.rend()),
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
auto firstDifferentSize = std::find_if(sizesAndShape.begin(), sizesAndShape.end(), [&](auto sizeAndShape) -> bool {
auto [size, dimension] = sizeAndShape;
return size != dimension;
});
if (firstDifferentSize != sizesAndShape.end()) {
++firstDifferentSize;
if (std::any_of(firstDifferentSize, sizesAndShape.end(), [](auto sizeAndShape) -> bool {
auto [size, _dimension] = sizeAndShape;
return size != 1;
}))
return false;
}
return true;
}
static Value resolveAlias(Value value, const StaticValueKnowledge* knowledge) {
if (!knowledge)
return value;
auto iter = knowledge->aliases.find(value);
while (iter != knowledge->aliases.end()) {
value = iter->second;
iter = knowledge->aliases.find(value);
}
return value;
}
// Walks through view-like ops and DPS tied operands to find the "underlying" memref value
// behind an scf.for iter-arg. Used both when resolving a contiguous address inside a loop
// and when propagating yielded values across iterations during static unrolling.
static Value resolveLoopCarriedAliasImpl(Value value, const StaticValueKnowledge* knowledge) {
value = resolveAlias(value, knowledge);
if (auto blockArgument = dyn_cast<BlockArgument>(value))
return value;
Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return value;
if (auto dpsDefiningOp = dyn_cast<DestinationStyleOpInterface>(definingOp)) {
if (auto result = dyn_cast<OpResult>(value))
if (OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(result))
return resolveLoopCarriedAliasImpl(tiedOperand->get(), knowledge);
}
if (auto castOp = dyn_cast<memref::CastOp>(definingOp))
return resolveLoopCarriedAliasImpl(castOp.getSource(), knowledge);
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp))
return resolveLoopCarriedAliasImpl(collapseOp.getSrc(), knowledge);
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp))
return resolveLoopCarriedAliasImpl(expandOp.getSrc(), knowledge);
return value;
}
static FailureOr<int64_t> resolveOpFoldResult(OpFoldResult ofr, const StaticValueKnowledge* knowledge);
static FailureOr<int64_t> resolveIndexValueImpl(Value value, const StaticValueKnowledge* knowledge) {
value = resolveAlias(value, knowledge);
if (knowledge) {
auto iter = knowledge->indexValues.find(value);
if (iter != knowledge->indexValues.end())
return iter->second;
}
auto constantOp = value.getDefiningOp<arith::ConstantOp>();
if (constantOp) {
if (auto integerAttr = dyn_cast<IntegerAttr>(constantOp.getValue()))
return integerAttr.getInt();
}
Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return failure();
if (auto indexCastOp = dyn_cast<arith::IndexCastOp>(definingOp))
return resolveIndexValueImpl(indexCastOp.getIn(), knowledge);
if (auto addOp = dyn_cast<arith::AddIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(addOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(addOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs))
return failure();
return *lhs + *rhs;
}
if (auto subOp = dyn_cast<arith::SubIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(subOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(subOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs))
return failure();
return *lhs - *rhs;
}
if (auto mulOp = dyn_cast<arith::MulIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(mulOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(mulOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs))
return failure();
return *lhs * *rhs;
}
if (auto divOp = dyn_cast<arith::DivUIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(divOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(divOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs) || *rhs == 0)
return failure();
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) / static_cast<uint64_t>(*rhs));
}
if (auto remOp = dyn_cast<arith::RemUIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(remOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(remOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs) || *rhs == 0)
return failure();
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) % static_cast<uint64_t>(*rhs));
}
return failure();
}
static FailureOr<int64_t> resolveOpFoldResult(OpFoldResult ofr, const StaticValueKnowledge* knowledge) {
if (auto attr = dyn_cast<Attribute>(ofr)) {
auto integerAttr = dyn_cast<IntegerAttr>(attr);
if (!integerAttr)
return failure();
return integerAttr.getInt();
}
return resolveIndexValueImpl(cast<Value>(ofr), knowledge);
}
static FailureOr<ResolvedContiguousAddress> resolveContiguousAddressImpl(Value value,
const StaticValueKnowledge* knowledge) {
int64_t byteOffset = 0;
value = resolveAlias(value, knowledge);
while (true) {
if (isa<BlockArgument>(value))
return ResolvedContiguousAddress {value, byteOffset};
Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return failure();
if (auto dpsDefiningOp = dyn_cast<DestinationStyleOpInterface>(definingOp)) {
OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(dyn_cast<OpResult>(value));
if (!tiedOperand)
return failure();
value = resolveAlias(tiedOperand->get(), knowledge);
continue;
}
if (auto forOp = dyn_cast<scf::ForOp>(definingOp)) {
auto result = dyn_cast<OpResult>(value);
if (!result)
return failure();
// Trace the loop carry back to its underlying memref, then if that memref is the
// loop's own iter-arg we know the base comes from the corresponding init arg
// (every iteration yields the same backing memory in the DPS sense).
auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
Value yieldedValue = resolveLoopCarriedAliasImpl(yieldOp.getOperand(result.getResultNumber()), knowledge);
if (auto blockArgument = dyn_cast<BlockArgument>(yieldedValue)) {
if (blockArgument.getOwner() == forOp.getBody() && blockArgument.getArgNumber() > 0
&& static_cast<unsigned>(blockArgument.getArgNumber() - 1) < forOp.getInitArgs().size()) {
value = resolveAlias(forOp.getInitArgs()[blockArgument.getArgNumber() - 1], knowledge);
continue;
}
}
value = yieldedValue;
continue;
}
if (auto subviewOp = dyn_cast<memref::SubViewOp>(definingOp)) {
auto sourceType = dyn_cast<MemRefType>(subviewOp.getSource().getType());
auto subviewType = dyn_cast<MemRefType>(subviewOp.getType());
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
return failure();
SmallVector<int64_t> offsets;
SmallVector<int64_t> sizes;
SmallVector<int64_t> strides;
offsets.reserve(subviewOp.getMixedOffsets().size());
sizes.reserve(subviewOp.getMixedSizes().size());
strides.reserve(subviewOp.getMixedStrides().size());
for (OpFoldResult offset : subviewOp.getMixedOffsets()) {
auto resolvedOffset = resolveOpFoldResult(offset, knowledge);
if (failed(resolvedOffset))
return failure();
offsets.push_back(*resolvedOffset);
}
for (OpFoldResult size : subviewOp.getMixedSizes()) {
auto resolvedSize = resolveOpFoldResult(size, knowledge);
if (failed(resolvedSize))
return failure();
sizes.push_back(*resolvedSize);
}
for (OpFoldResult stride : subviewOp.getMixedStrides()) {
auto resolvedStride = resolveOpFoldResult(stride, knowledge);
if (failed(resolvedStride))
return failure();
strides.push_back(*resolvedStride);
}
if (!isMemoryContiguous(sourceType.getShape(), offsets, sizes, strides))
return failure();
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
byteOffset += linearizeIndex(offsets, sourceStrides) * subviewType.getElementTypeBitWidth() / 8;
value = resolveAlias(subviewOp.getSource(), knowledge);
continue;
}
if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
value = resolveAlias(castOp.getSource(), knowledge);
continue;
}
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp)) {
value = resolveAlias(collapseOp.getSrc(), knowledge);
continue;
}
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp)) {
value = resolveAlias(expandOp.getSrc(), knowledge);
continue;
}
if (isa<memref::AllocOp, memref::GetGlobalOp>(definingOp))
return ResolvedContiguousAddress {value, byteOffset};
return failure();
}
}
FailureOr<int64_t> resolveIndexValue(Value value) { return resolveIndexValueImpl(value, nullptr); }
FailureOr<int64_t> resolveIndexValue(Value value, const StaticValueKnowledge& knowledge) {
return resolveIndexValueImpl(value, &knowledge);
}
FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(Value value) {
return resolveContiguousAddressImpl(value, nullptr);
}
FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(Value value, const StaticValueKnowledge& knowledge) {
return resolveContiguousAddressImpl(value, &knowledge);
}
Value resolveLoopCarriedAlias(Value value, const StaticValueKnowledge& knowledge) {
return resolveLoopCarriedAliasImpl(value, &knowledge);
}
bool isCoreStaticAddressOp(Operation* op) {
return isa<arith::ConstantOp,
arith::AddIOp,
arith::SubIOp,
arith::MulIOp,
arith::DivUIOp,
arith::RemUIOp,
arith::IndexCastOp,
memref::AllocOp,
memref::SubViewOp,
memref::CastOp,
memref::CollapseShapeOp,
memref::ExpandShapeOp>(op);
}
LogicalResult walkPimCoreBlock(Block& block,
const StaticValueKnowledge& knowledge,
llvm::function_ref<LogicalResult(Operation&, const StaticValueKnowledge&)> callback) {
bool hasFailure = false;
for (Operation& op : block) {
if (isa<pim::PimHaltOp, scf::YieldOp>(op) || isCoreStaticAddressOp(&op))
continue;
if (auto forOp = dyn_cast<scf::ForOp>(op)) {
Block& loopBody = forOp.getRegion().front();
auto lowerBound = resolveIndexValue(forOp.getLowerBound(), knowledge);
auto upperBound = resolveIndexValue(forOp.getUpperBound(), knowledge);
auto step = resolveIndexValue(forOp.getStep(), knowledge);
if (failed(lowerBound) || failed(upperBound) || failed(step) || *step <= 0) {
forOp.emitOpError("requires statically evaluable scf.for bounds for PIM codegen");
hasFailure = true;
continue;
}
SmallVector<Value> iterValues(forOp.getInitArgs().begin(), forOp.getInitArgs().end());
for (int64_t inductionValue = *lowerBound; inductionValue < *upperBound; inductionValue += *step) {
StaticValueKnowledge loopKnowledge = knowledge;
loopKnowledge.indexValues[forOp.getInductionVar()] = inductionValue;
for (auto [iterArg, iterValue] : llvm::zip_equal(forOp.getRegionIterArgs(), iterValues))
loopKnowledge.aliases[iterArg] = iterValue;
if (failed(walkPimCoreBlock(loopBody, loopKnowledge, callback)))
hasFailure = true;
auto yieldOp = cast<scf::YieldOp>(loopBody.getTerminator());
for (auto [index, yieldedValue] : llvm::enumerate(yieldOp.getOperands()))
iterValues[index] = resolveLoopCarriedAlias(yieldedValue, loopKnowledge);
}
continue;
}
if (failed(callback(op, knowledge)))
hasFailure = true;
}
return success(!hasFailure);
}
} // namespace onnx_mlir
+10 -70
View File
@@ -7,82 +7,22 @@
#include "mlir/IR/Value.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLFunctionalExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/EntryPointUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
#include "src/Accelerators/PIM/Common/Support/DebugDump.hpp"
#include "src/Accelerators/PIM/Common/Support/FileSystemUtils.hpp"
#include "src/Compiler/CompilerOptions.hpp"
inline constexpr llvm::StringRef PimWeightAlwaysAttrName = "weightAlways";
namespace onnx_mlir {
struct ResolvedContiguousAddress {
mlir::Value base;
int64_t byteOffset = 0;
};
struct StaticValueKnowledge {
llvm::DenseMap<mlir::Value, int64_t> indexValues;
llvm::DenseMap<mlir::Value, mlir::Value> aliases;
StaticValueKnowledge() {}
};
std::string getOutputDir();
void createDirectory(const std::string& directory);
void dumpModule(mlir::ModuleOp moduleOp, const std::string& name);
llvm::FailureOr<mlir::func::FuncOp> getPimEntryFunc(mlir::ModuleOp moduleOp);
bool hasWeightAlways(mlir::Operation* op);
void markWeightAlways(mlir::Operation* op);
mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp);
llvm::FailureOr<mlir::Operation*>
getOtherEndOfChannel(mlir::Operation* op, bool opIsReceive, mlir::RewriterBase& rewriter);
llvm::SmallVector<int64_t> computeRowMajorStrides(llvm::ArrayRef<int64_t> shape);
llvm::SmallVector<int64_t>
delinearizeIndex(int64_t linearIndex, llvm::ArrayRef<int64_t> shape, llvm::ArrayRef<int64_t> strides);
int64_t linearizeIndex(llvm::ArrayRef<int64_t> indices, llvm::ArrayRef<int64_t> strides);
int64_t getNumElements(llvm::ArrayRef<int64_t> shape);
bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
llvm::ArrayRef<int64_t> offsets,
llvm::ArrayRef<int64_t> sizes,
llvm::ArrayRef<int64_t> strides);
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value);
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value,
const StaticValueKnowledge& knowledge);
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value);
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge);
/// Follows alias and view/DPS chains using `knowledge` to find the value an scf.for
/// iter-arg is ultimately backed by. Used when interpreting scf.for loop carries.
mlir::Value resolveLoopCarriedAlias(mlir::Value value, const StaticValueKnowledge& knowledge);
/// Returns true for ops inside a pim.core body that do not emit any PIM instruction and
/// only contribute to static addressing or index computations (arith integer math,
/// memref view ops, memref.alloc, arith.constant).
bool isCoreStaticAddressOp(mlir::Operation* op);
/// Walks `block` (the body of a pim.core region or an scf.for nested in it), statically
/// unrolling any scf.for with resolvable bounds using `knowledge`. For each remaining op
/// that is not skipped (pim.halt, scf.yield, or isCoreStaticAddressOp), `callback` is
/// invoked with the op and the in-scope knowledge. The walker keeps going after a callback
/// failure so callers can collect multiple diagnostics, but propagates the overall result.
mlir::LogicalResult
walkPimCoreBlock(mlir::Block& block,
const StaticValueKnowledge& knowledge,
llvm::function_ref<mlir::LogicalResult(mlir::Operation&, const StaticValueKnowledge&)> callback);
inline constexpr llvm::StringLiteral kCoreIdAttrName = "coreId";
inline constexpr llvm::StringLiteral kCoreIdsAttrName = "coreIds";
} // namespace onnx_mlir
+27
View File
@@ -0,0 +1,27 @@
#include "llvm/Support/raw_os_ostream.h"
#include <fstream>
#include "src/Accelerators/PIM/Common/Support/DebugDump.hpp"
#include "src/Accelerators/PIM/Common/Support/FileSystemUtils.hpp"
namespace onnx_mlir {
void dumpModule(mlir::ModuleOp moduleOp, const std::string& name) {
std::string outputDir = getOutputDir();
if (outputDir.empty())
return;
std::string dialectsDir = outputDir + "/dialects";
createDirectory(dialectsDir);
std::fstream file(dialectsDir + "/" + name + ".mlir", std::ios::out);
llvm::raw_os_ostream os(file);
mlir::OpPrintingFlags flags;
flags.elideLargeElementsAttrs();
moduleOp.print(os, flags);
os.flush();
file.close();
}
} // namespace onnx_mlir
+13
View File
@@ -0,0 +1,13 @@
#pragma once
#include "mlir/IR/BuiltinOps.h"
#include <string>
namespace onnx_mlir {
/// Emits a MLIR snapshot under the current compiler output
/// directory for pass-level debugging.
void dumpModule(mlir::ModuleOp moduleOp, const std::string& name);
} // namespace onnx_mlir
+41
View File
@@ -0,0 +1,41 @@
#include "llvm/ADT/STLExtras.h"
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
namespace onnx_mlir::pim {
mlir::InFlightDiagnostic emitUnsupportedStaticShapeDiagnostic(mlir::Operation* op, llvm::StringRef valueDescription) {
return op->emitOpError() << "requires statically shaped " << valueDescription;
}
mlir::InFlightDiagnostic emitUnsupportedRankDiagnostic(mlir::Operation* op,
llvm::StringRef valueDescription,
int64_t actualRank,
llvm::ArrayRef<int64_t> supportedRanks) {
auto diag = op->emitOpError() << "has unsupported rank " << actualRank << " for " << valueDescription;
if (supportedRanks.empty())
return diag;
diag << "; supported rank";
if (supportedRanks.size() != 1)
diag << 's';
diag << ' ';
llvm::interleaveComma(supportedRanks, diag, [&](int64_t rank) { diag << rank; });
return diag;
}
mlir::InFlightDiagnostic
emitMissingSymbolDiagnostic(mlir::Operation* op, llvm::StringRef symbolKind, llvm::StringRef symbolName) {
return op->emitOpError() << "references missing " << symbolKind << " `" << symbolName << "`";
}
mlir::LogicalResult emitFileSystemError(mlir::Location loc,
llvm::StringRef action,
llvm::StringRef path,
const std::error_code& errorCode) {
mlir::emitError(loc) << "failed to " << action << " `" << path << "`: " << errorCode.message();
return mlir::failure();
}
} // namespace onnx_mlir::pim
+38
View File
@@ -0,0 +1,38 @@
#pragma once
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Operation.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/StringRef.h"
#include <system_error>
namespace onnx_mlir::pim {
/// Emits a consistent diagnostic for target paths that require static shapes.
mlir::InFlightDiagnostic emitUnsupportedStaticShapeDiagnostic(mlir::Operation* op, llvm::StringRef valueDescription);
/// Emits a consistent diagnostic for unsupported ranks while listing the ranks
/// accepted by the current lowering/codegen path.
mlir::InFlightDiagnostic emitUnsupportedRankDiagnostic(mlir::Operation* op,
llvm::StringRef valueDescription,
int64_t actualRank,
llvm::ArrayRef<int64_t> supportedRanks);
/// Emits a consistent diagnostic for missing symbol/global references.
mlir::InFlightDiagnostic
emitMissingSymbolDiagnostic(mlir::Operation* op, llvm::StringRef symbolKind, llvm::StringRef symbolName);
/// Converts a filesystem error into an MLIR failure diagnostic anchored at
/// the relevant IR location.
mlir::LogicalResult
emitFileSystemError(mlir::Location loc, llvm::StringRef action, llvm::StringRef path, const std::error_code& errorCode);
template <typename T>
mlir::LogicalResult failureOrToLogicalResult(const llvm::FailureOr<T>& value) {
return mlir::success(succeeded(value));
}
} // namespace onnx_mlir::pim
@@ -0,0 +1,24 @@
#include <filesystem>
#include "src/Accelerators/PIM/Common/Support/FileSystemUtils.hpp"
#include "src/Compiler/CompilerOptions.hpp"
namespace onnx_mlir {
std::string getOutputDir() {
if (outputBaseName.empty() || outputBaseName == "-")
return {};
size_t lastSlash = outputBaseName.find_last_of('/');
if (lastSlash == std::string::npos)
return ".";
return outputBaseName.substr(0, lastSlash);
}
void createDirectory(const std::string& directory) {
std::error_code errorCode;
std::filesystem::create_directories(directory, errorCode);
assert(!errorCode && ("Failed to create directory: " + errorCode.message()).data());
}
} // namespace onnx_mlir
@@ -0,0 +1,13 @@
#pragma once
#include <string>
namespace onnx_mlir {
/// Returns the directory that should hold PIM artifacts/debug dumps for the
/// current compiler invocation.
std::string getOutputDir();
void createDirectory(const std::string& directory);
} // namespace onnx_mlir
+63
View File
@@ -0,0 +1,63 @@
#include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp"
#include "llvm/Support/Format.h"
#include "src/Accelerators/PIM/Common/Support/FileSystemUtils.hpp"
namespace onnx_mlir {
std::fstream openReportFile(const std::string& name) {
std::string outputDir = getOutputDir();
if (outputDir.empty())
return {};
std::string reportsDir = outputDir + "/reports";
createDirectory(reportsDir);
return std::fstream(reportsDir + "/" + name + ".txt", std::ios::out);
}
std::string formatReportMemory(uint64_t bytes) {
const char* units[] = {"B", "KB", "MB", "GB", "TB", "PB", "EB"};
int i = 0;
double size = static_cast<double>(bytes);
while (size >= 1024 && i < 6) {
size /= 1024;
i++;
}
std::string out;
llvm::raw_string_ostream rss(out);
rss << llvm::format("%.2f ", size) << units[i];
return rss.str();
}
void printReportFlatFields(llvm::raw_ostream& os, llvm::ArrayRef<ReportField> fields) {
for (const ReportField& field : fields)
os << "\t" << field.label << ": " << field.value << "\n";
}
void printReportFieldBlock(llvm::raw_ostream& os, llvm::StringRef title, llvm::ArrayRef<ReportField> fields) {
os << "\t" << title << ":\n";
for (const ReportField& field : fields)
os << "\t " << field.label << ": " << field.value << "\n";
}
void printReportTotalsBlock(llvm::raw_ostream& os, llvm::ArrayRef<ReportField> fields) {
os << "Totals:\n";
for (const ReportField& field : fields)
os << "\t" << field.label << ": " << field.value << "\n";
}
void printReportPerCoreAndTotalFields(llvm::raw_ostream& os,
llvm::ArrayRef<ReportField> perCoreFields,
llvm::ArrayRef<ReportField> totalFields) {
printReportFieldBlock(os, "Per core", perCoreFields);
printReportFieldBlock(os, "Total", totalFields);
}
void printReportEntrySeparator(llvm::raw_ostream& os, bool hasNextEntry) {
if (hasNextEntry)
os << "\n";
}
} // namespace onnx_mlir
+48
View File
@@ -0,0 +1,48 @@
#pragma once
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/Support/raw_ostream.h"
#include <cstdint>
#include <fstream>
#include <limits>
#include <string>
namespace onnx_mlir {
std::fstream openReportFile(const std::string& name);
std::string formatReportMemory(uint64_t bytes);
struct ReportField {
std::string label;
std::string value;
};
void printReportFlatFields(llvm::raw_ostream& os, llvm::ArrayRef<ReportField> fields);
void printReportFieldBlock(llvm::raw_ostream& os, llvm::StringRef title, llvm::ArrayRef<ReportField> fields);
void printReportTotalsBlock(llvm::raw_ostream& os, llvm::ArrayRef<ReportField> fields);
void printReportPerCoreAndTotalFields(llvm::raw_ostream& os,
llvm::ArrayRef<ReportField> perCoreFields,
llvm::ArrayRef<ReportField> totalFields);
void printReportEntrySeparator(llvm::raw_ostream& os, bool hasNextEntry);
template <typename EntryTy>
int32_t getFirstReportCoreId(const EntryTy& entry) {
if (entry.coreIds.empty())
return std::numeric_limits<int32_t>::max();
return entry.coreIds.front();
}
template <typename EntryRange>
void sortReportEntriesByFirstCore(EntryRange& entries) {
llvm::stable_sort(entries, [](const auto& lhs, const auto& rhs) {
int32_t lhsFirstCore = getFirstReportCoreId(lhs);
int32_t rhsFirstCore = getFirstReportCoreId(rhs);
if (lhsFirstCore != rhsFirstCore)
return lhsFirstCore < rhsFirstCore;
return lhs.id < rhs.id;
});
}
} // namespace onnx_mlir
+4
View File
@@ -15,7 +15,10 @@ add_pim_library(OMPimCompilerOptions
add_pim_library(OMPimCompilerUtils
PimCompilerUtils.cpp
PimArtifactWriter.cpp
PimBatchEmission.cpp
PimCodeGen.cpp
PimWeightEmitter.cpp
EXCLUDE_FROM_OM_LIBS
@@ -26,6 +29,7 @@ add_pim_library(OMPimCompilerUtils
OMPimCompilerOptions
OMPimCommon
OMPimBufferization
OMPimStaticMemoryCoalescing
OMPimPasses
OMONNXToSpatial
OMSpatialToPim
+123
View File
@@ -0,0 +1,123 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/raw_ostream.h"
#include <algorithm>
#include <cassert>
#include <cstring>
#include <vector>
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
#include "src/Accelerators/PIM/Compiler/PimArtifactWriter.hpp"
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
using namespace llvm;
using namespace mlir;
namespace onnx_mlir {
OnnxMlirCompilerErrorCodes writeHostCoreJson(StringRef outputDirPath) {
std::error_code errorCode;
std::string outputHostCorePath = outputDirPath.str() + "/core_0.json";
raw_fd_ostream hostFileStream(outputHostCorePath, errorCode);
if (errorCode) {
errs() << "Error while opening host core file `" << outputHostCorePath << "`: " << errorCode.message() << '\n';
return InvalidOutputFileAccess;
}
// The host core json contains two no-op-like instructions to satisfy pimsim-nn.
hostFileStream << "[{\"imm\":0,\"op\":\"sldi\",\"rd\":0},{\"imm\":0,\"op\":\"sldi\",\"rd\":0}]";
hostFileStream.close();
return CompilerSuccess;
}
OnnxMlirCompilerErrorCodes
writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory& memory, StringRef outputDirPath) {
auto memoryFilePath = (outputDirPath + "/memory.bin").str();
std::error_code errorCode;
raw_fd_ostream memoryFileStream(memoryFilePath, errorCode, sys::fs::OF_None);
if (errorCode) {
errs() << "Error while opening memory file " << memoryFilePath << ": " << errorCode.message() << '\n';
return InvalidOutputFileAccess;
}
std::vector<char> memoryBuffer(memory.hostMem.getFirstAvailableAddress(), 0);
SmallPtrSet<Operation*, 16> writtenGlobals;
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
if (hasWeightAlways(getGlobalOp))
return;
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp)
return;
if (!writtenGlobals.insert(globalOp.getOperation()).second)
return;
auto initialValue = globalOp.getInitialValue();
if (!initialValue)
return;
auto denseAttr = dyn_cast<DenseElementsAttr>(*initialValue);
if (!denseAttr)
return;
MemEntry memEntry = memory.hostMem.getMemEntry(getGlobalOp.getResult());
ArrayRef<char> rawData = denseAttr.getRawData();
char* dst = memoryBuffer.data() + memEntry.address;
if (denseAttr.isSplat()) {
size_t elementSize = rawData.size();
assert(elementSize * getGlobalOp.getType().getNumElements() == memEntry.size && "Data size mismatch");
for (size_t offset = 0; offset < memEntry.size; offset += elementSize)
std::memcpy(dst + offset, rawData.data(), std::min(elementSize, memEntry.size - offset));
}
else {
assert(rawData.size() == memEntry.size && "Data size mismatch");
std::memcpy(dst, rawData.data(), rawData.size());
}
});
memoryFileStream.write(memoryBuffer.data(), memoryBuffer.size());
memoryFileStream.close();
return CompilerSuccess;
}
OnnxMlirCompilerErrorCodes writeConfigJson(func::FuncOp funcOp,
PimAcceleratorMemory& memory,
size_t maxCoreId,
json::Object xbarsPerArrayGroup,
StringRef outputDirPath) {
json::Object configJson;
configJson["core_cnt"] = maxCoreId + 1;
configJson["adc_count"] = 16;
configJson["cell_precision"] = 2;
configJson["xbar_array_count"] = crossbarCountInCore.getValue();
configJson["xbar_size"] = {crossbarSize.getValue(), crossbarSize.getValue()};
configJson["array_group_map"] = std::move(xbarsPerArrayGroup);
json::Array inputsAddresses;
for (BlockArgument input : funcOp.getArguments())
inputsAddresses.push_back(memory.getValueAddress(input));
configJson["inputs_addresses"] = std::move(inputsAddresses);
json::Array outputsAddresses;
for (func::ReturnOp returnOp : funcOp.getOps<func::ReturnOp>())
for (mlir::Value output : returnOp.getOperands())
outputsAddresses.push_back(memory.getValueAddress(output));
configJson["outputs_addresses"] = std::move(outputsAddresses);
auto configPath = (outputDirPath + "/config.json").str();
std::error_code errorCode;
raw_fd_ostream jsonOS(configPath, errorCode);
if (errorCode) {
errs() << "Error while opening config file: " << errorCode.message() << '\n';
return InvalidOutputFileAccess;
}
jsonOS << json::Value(std::move(configJson)) << '\n';
jsonOS.close();
return CompilerSuccess;
}
} // namespace onnx_mlir
+26
View File
@@ -0,0 +1,26 @@
#pragma once
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/JSON.h"
#include "onnx-mlir/Compiler/OMCompilerTypes.h"
namespace onnx_mlir {
class PimAcceleratorMemory;
OnnxMlirCompilerErrorCodes writeHostCoreJson(llvm::StringRef outputDirPath);
OnnxMlirCompilerErrorCodes writeMemoryBinary(mlir::ModuleOp moduleOp,
mlir::func::FuncOp funcOp,
PimAcceleratorMemory& memory,
llvm::StringRef outputDirPath);
OnnxMlirCompilerErrorCodes writeConfigJson(mlir::func::FuncOp funcOp,
PimAcceleratorMemory& memory,
size_t maxCoreId,
llvm::json::Object xbarsPerArrayGroup,
llvm::StringRef outputDirPath);
} // namespace onnx_mlir
+136
View File
@@ -0,0 +1,136 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) {
auto coreIdsAttr = coreBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName);
assert(coreIdsAttr && "pim.core_batch requires coreIds array attribute");
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
}
static SmallVector<int32_t> getLaneChunkCoreIds(ArrayRef<int32_t> coreIds, size_t laneCount, unsigned lane) {
SmallVector<int32_t> laneCoreIds;
laneCoreIds.reserve(coreIds.size() / laneCount);
for (size_t chunkIndex = 0; chunkIndex < coreIds.size() / laneCount; ++chunkIndex)
laneCoreIds.push_back(coreIds[chunkIndex * laneCount + lane]);
return laneCoreIds;
}
static void scalarizeBatchOpsInCore(pim::PimCoreOp scalarCore, size_t laneCount, unsigned lane) {
IRRewriter rewriter(scalarCore.getContext());
SmallVector<Operation*> batchOps;
scalarCore.walk([&](Operation* op) {
if (isa<pim::PimSendBatchOp,
pim::PimSendTensorBatchOp,
pim::PimReceiveBatchOp,
pim::PimReceiveTensorBatchOp,
pim::PimMemCopyHostToDevBatchOp>(op)) {
batchOps.push_back(op);
}
});
for (Operation* op : batchOps) {
rewriter.setInsertionPoint(op);
if (auto sendBatchOp = dyn_cast<pim::PimSendBatchOp>(op)) {
pim::PimSendOp::create(rewriter,
sendBatchOp.getLoc(),
sendBatchOp.getInput(),
sendBatchOp.getSizeAttr(),
rewriter.getI32IntegerAttr(sendBatchOp.getTargetCoreIds()[lane]));
rewriter.eraseOp(op);
continue;
}
if (auto sendTensorBatchOp = dyn_cast<pim::PimSendTensorBatchOp>(op)) {
pim::PimSendTensorOp::create(
rewriter,
sendTensorBatchOp.getLoc(),
sendTensorBatchOp.getInput(),
rewriter.getDenseI32ArrayAttr(getLaneChunkCoreIds(sendTensorBatchOp.getTargetCoreIds(), laneCount, lane)));
rewriter.eraseOp(op);
continue;
}
if (auto receiveBatchOp = dyn_cast<pim::PimReceiveBatchOp>(op)) {
auto scalarReceive =
pim::PimReceiveOp::create(rewriter,
receiveBatchOp.getLoc(),
receiveBatchOp.getOutput().getType(),
receiveBatchOp.getOutputBuffer(),
receiveBatchOp.getSizeAttr(),
rewriter.getI32IntegerAttr(receiveBatchOp.getSourceCoreIds()[lane]));
rewriter.replaceOp(op, scalarReceive->getResults());
continue;
}
if (auto receiveTensorBatchOp = dyn_cast<pim::PimReceiveTensorBatchOp>(op)) {
auto scalarReceive = pim::PimReceiveTensorOp::create(
rewriter,
receiveTensorBatchOp.getLoc(),
receiveTensorBatchOp.getOutput().getType(),
receiveTensorBatchOp.getOutputBuffer(),
rewriter.getDenseI32ArrayAttr(getLaneChunkCoreIds(receiveTensorBatchOp.getSourceCoreIds(), laneCount, lane)));
rewriter.replaceOp(op, scalarReceive->getResults());
continue;
}
auto memcpBatchOp = cast<pim::PimMemCopyHostToDevBatchOp>(op);
auto scalarCopy = pim::PimMemCopyHostToDevOp::create(rewriter,
memcpBatchOp.getLoc(),
memcpBatchOp.getOutput().getType(),
memcpBatchOp.getDeviceTarget(),
memcpBatchOp.getHostSource(),
memcpBatchOp.getDeviceTargetOffsetAttr(),
memcpBatchOp.getHostSourceOffsetAttr(),
memcpBatchOp.getSizeAttr());
rewriter.replaceOp(op, scalarCopy->getResults());
}
}
} // namespace
LogicalResult withScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp,
unsigned lane,
llvm::function_ref<LogicalResult(pim::PimCoreOp)> callback) {
OwningOpRef<ModuleOp> scratchModule = ModuleOp::create(coreBatchOp.getLoc());
OpBuilder builder(scratchModule->getContext());
builder.setInsertionPointToStart(scratchModule->getBody());
size_t laneCount = static_cast<size_t>(coreBatchOp.getLaneCount());
size_t weightsPerLane = coreBatchOp.getWeights().size() / laneCount;
SmallVector<Value> laneWeights;
laneWeights.reserve(weightsPerLane);
for (size_t weightIndex = 0; weightIndex < weightsPerLane; ++weightIndex)
laneWeights.push_back(coreBatchOp.getWeights()[lane * weightsPerLane + weightIndex]);
auto coreIds = getBatchCoreIds(coreBatchOp);
auto scalarCore = pim::PimCoreOp::create(
builder, coreBatchOp.getLoc(), ValueRange(laneWeights), builder.getI32IntegerAttr(coreIds[lane]));
Block* block = builder.createBlock(&scalarCore.getBody(), scalarCore.getBody().end());
IRMapping mapper;
if (coreBatchOp.getBody().front().getNumArguments() == 1)
mapper.map(coreBatchOp.getBody().front().getArgument(0), coreBatchOp.getInputs()[lane]);
builder.setInsertionPointToEnd(block);
for (Operation& op : coreBatchOp.getBody().front()) {
Operation* cloned = builder.clone(op, mapper);
for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults()))
mapper.map(originalResult, clonedResult);
}
if (block->empty() || !isa<pim::PimHaltOp>(block->back()))
pim::PimHaltOp::create(builder, coreBatchOp.getLoc());
scalarizeBatchOpsInCore(scalarCore, laneCount, lane);
return callback(scalarCore);
}
} // namespace onnx_mlir
+13
View File
@@ -0,0 +1,13 @@
#pragma once
#include "llvm/ADT/STLFunctionalExtras.h"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
namespace onnx_mlir {
mlir::LogicalResult withScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp,
unsigned lane,
llvm::function_ref<mlir::LogicalResult(pim::PimCoreOp)> callback);
} // namespace onnx_mlir
+432 -343
View File
@@ -1,30 +1,49 @@
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/Verifier.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/Format.h"
#include "llvm/Support/JSON.h"
#include "llvm/Support/raw_ostream.h"
#include <absl/types/compare.h>
#include <algorithm>
#include <cassert>
#include <cmath>
#include <cstdint>
#include <fstream>
#include <string>
#include <utility>
#include "Common/IR/CompactAsmUtils.hpp"
#include "Common/PimCommon.hpp"
#include "Conversion/ONNXToSpatial/Common.hpp"
#include "Common/Support/ReportUtils.hpp"
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Compiler/PimArtifactWriter.hpp"
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Compiler/PimWeightEmitter.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace llvm;
using namespace mlir;
using namespace onnx_mlir;
using namespace onnx_mlir::compact_asm;
static size_t getValueSizeInBytes(mlir::Value value) {
auto type = cast<ShapedType>(value.getType());
return type.getNumElements() * type.getElementTypeBitWidth() / 8;
}
MemEntry* PimMemory::gatherMemEntry(mlir::Value value) {
auto type = cast<ShapedType>(value.getType());
@@ -47,15 +66,29 @@ void PimMemory::allocateMemoryForValue(mlir::Value value, MemEntry& memEntry) {
if (size_t remainder = firstAvailableAddress % minAlignment)
firstAvailableAddress += minAlignment - remainder;
ownedMemEntriesMap[value] = memEntry;
globalMemEntriesMap[value] = memEntry;
}
void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) {
SmallDenseMap<memref::GlobalOp, mlir::Value, 8> globalConstants;
SmallVector<std::pair<mlir::Value, mlir::Value>, 16> globalAliases;
SmallVector<mlir::Value> args;
for (mlir::Value arg : funcOp.getArguments()) {
gatherMemEntry(arg);
args.push_back(arg);
}
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
if (!hasWeightAlways(getGlobalOp)) {
auto globalMemrefOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (globalMemrefOp.getName().starts_with("arg")) {
StringRef indexStr = globalMemrefOp.getName().substr(4);
int index = 0;
llvm::to_integer(indexStr, index, 10);
globalAliases.push_back({getGlobalOp.getResult(), args[index]});
}
auto [iter, inserted] = globalConstants.try_emplace(globalMemrefOp, getGlobalOp.getResult());
if (inserted)
gatherMemEntry(getGlobalOp.getResult());
@@ -64,9 +97,6 @@ void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) {
}
});
for (mlir::Value arg : funcOp.getArguments())
gatherMemEntry(arg);
funcOp.walk([&](memref::AllocOp allocOp) {
if (!allocOp->getParentOfType<pim::PimCoreOp>())
gatherMemEntry(allocOp.getResult());
@@ -84,6 +114,64 @@ void PimMemory::allocateCore(Operation* op) {
allocateGatheredMemory();
}
static void printHostMemoryReportRow(raw_ostream& os, const MemoryReportRow& row) {
llvm::SmallVector<ReportField, 2> fields = {
{"Number of globals", std::to_string(row.numGlobal)},
{"Global memory", formatReportMemory(row.sizeGlobal)}};
printReportFlatFields(os, fields);
}
static void printCoreMemoryReportRow(raw_ostream& os, const MemoryReportEntry& entry) {
llvm::SmallVector<ReportField, 2> fields = {
{"Number of allocas", std::to_string(entry.row.numAlloca)},
{"Allocated memory", formatReportMemory(entry.row.sizeAlloca)}};
printReportFlatFields(os, fields);
}
static void printBatchMemoryReportRow(raw_ostream& os, const MemoryReportEntry& entry) {
llvm::SmallVector<ReportField, 2> perCoreFields = {
{"Number of allocas", std::to_string(entry.row.numAlloca)},
{"Allocated memory", formatReportMemory(entry.row.sizeAlloca)}};
llvm::SmallVector<ReportField, 2> totalFields = {
{"Number of allocas", std::to_string(entry.totalAllocaCount)},
{"Batch memory", formatReportMemory(entry.totalAllocaBytes)}};
printReportPerCoreAndTotalFields(os, perCoreFields, totalFields);
}
static MemoryReportRow addMemoryReportRows(const MemoryReportRow& lhs, const MemoryReportRow& rhs) {
MemoryReportRow result = lhs;
result.numAlloca += rhs.numAlloca;
result.sizeAlloca += rhs.sizeAlloca;
result.numGlobal += rhs.numGlobal;
result.sizeGlobal += rhs.sizeGlobal;
return result;
}
MemoryReportRow PimMemory::getReportRow() const {
MemoryReportRow row;
for (auto& [val, memEntry] : ownedMemEntriesMap) {
if (auto op = val.getDefiningOp()) {
if (isa<memref::AllocOp>(op)) {
row.numAlloca++;
row.sizeAlloca += memEntry.size;
}
if (isa<memref::GetGlobalOp>(op)) {
row.numGlobal++;
row.sizeGlobal += memEntry.size;
}
}
}
return row;
}
void PimMemory::remove(mlir::Value val) {
if (auto removeIter = ownedMemEntriesMap.find(val); removeIter != ownedMemEntriesMap.end())
ownedMemEntriesMap.erase(removeIter);
if (auto removeIter = globalMemEntriesMap.find(val); removeIter != globalMemEntriesMap.end())
globalMemEntriesMap.erase(removeIter);
}
MemEntry PimMemory::getMemEntry(mlir::Value value) const {
auto iter = globalMemEntriesMap.find(value);
assert("Missing memEntry for value" && iter != globalMemEntriesMap.end());
@@ -124,6 +212,106 @@ size_t PimAcceleratorMemory::getValueAddress(mlir::Value value, const StaticValu
return iter->second.address + resolvedAddress->byteOffset;
}
void PimAcceleratorMemory::reportHost() { hostReportRow = hostMem.getReportRow(); }
void PimAcceleratorMemory::recordCoreReport(size_t coreId, const MemoryReportRow& row) {
reportEntries.push_back({MemoryReportEntry::Kind::Core,
coreId,
{static_cast<int32_t>(coreId)},
row,
row.numAlloca,
row.sizeAlloca});
}
void PimAcceleratorMemory::recordBatchReport(uint64_t batchId,
ArrayRef<int32_t> coreIds,
const MemoryReportRow& perCoreRow,
uint64_t totalAllocaCount,
uint64_t totalAllocaBytes) {
MemoryReportEntry entry;
entry.kind = MemoryReportEntry::Kind::Batch;
entry.id = batchId;
llvm::append_range(entry.coreIds, coreIds);
entry.row = perCoreRow;
entry.totalAllocaCount = totalAllocaCount;
entry.totalAllocaBytes = totalAllocaBytes;
reportEntries.push_back(std::move(entry));
}
void PimAcceleratorMemory::flushReport() {
if (!fileReport.is_open())
return;
llvm::raw_os_ostream os(fileReport);
uint64_t totalGlobalMemory = hostReportRow.has_value() ? hostReportRow->sizeGlobal : 0;
uint64_t totalCoresMemory = 0;
for (const MemoryReportEntry& entry : reportEntries)
totalCoresMemory += entry.totalAllocaBytes;
llvm::SmallVector<ReportField, 2> totalFields = {
{"Global memory", formatReportMemory(totalGlobalMemory)},
{"Cores memory", formatReportMemory(totalCoresMemory)}};
printReportTotalsBlock(os, totalFields);
if (hostReportRow.has_value()) {
os << "\nHost:\n";
printHostMemoryReportRow(os, *hostReportRow);
}
if (!reportEntries.empty()) {
if (hostReportRow.has_value())
os << "\n";
sortReportEntriesByFirstCore(reportEntries);
for (size_t index = 0; index < reportEntries.size();) {
size_t runEnd = index + 1;
while (runEnd < reportEntries.size() && reportEntries[runEnd].kind == reportEntries[index].kind
&& reportEntries[runEnd].row == reportEntries[index].row
&& reportEntries[runEnd].totalAllocaCount == reportEntries[index].totalAllocaCount
&& reportEntries[runEnd].totalAllocaBytes == reportEntries[index].totalAllocaBytes) {
++runEnd;
}
if (reportEntries[index].kind == MemoryReportEntry::Kind::Batch) {
os << "Batch ";
for (size_t batchIndex = index; batchIndex < runEnd; ++batchIndex) {
if (batchIndex != index)
os << ",\n ";
os << reportEntries[batchIndex].id << " (cores ";
printCompressedIntegerEntries(os, ArrayRef<int32_t>(reportEntries[batchIndex].coreIds));
os << ")";
}
}
else {
llvm::SmallVector<int32_t, 8> coreIds;
for (size_t coreIndex = index; coreIndex < runEnd; ++coreIndex)
coreIds.push_back(reportEntries[coreIndex].coreIds.front());
os << "Core ";
printCompressedIntegerEntries(os, ArrayRef<int32_t>(coreIds));
}
os << ":\n";
if (reportEntries[index].kind == MemoryReportEntry::Kind::Batch)
printBatchMemoryReportRow(os, reportEntries[index]);
else
printCoreMemoryReportRow(os, reportEntries[index]);
printReportEntrySeparator(os, runEnd < reportEntries.size());
index = runEnd;
}
}
os.flush();
fileReport.close();
}
void PimAcceleratorMemory::clean(mlir::Operation* op) {
for (auto value : op->getResults()) {
hostMem.remove(value);
for (auto& device : deviceMem)
device.second.remove(value);
}
}
json::Object PimCodeGen::createEmptyOffset() {
json::Object offset;
offset["offset_select"] = 0;
@@ -131,6 +319,12 @@ json::Object PimCodeGen::createEmptyOffset() {
return offset;
}
size_t PimCodeGen::remapCoreId(size_t coreId) const {
auto it = emittedCoreIds.find(coreId);
assert(it != emittedCoreIds.end() && "Missing emitted core id remapping");
return it->second;
}
static json::Object createRs1OnlyOffset() {
json::Object offset;
offset["offset_select"] = 1;
@@ -190,7 +384,7 @@ void PimCodeGen::emitCommunicationOp(StringRef opName, size_t bufferAddr, size_t
json::Object json;
json["op"] = opName;
json["rd"] = 0;
json["core"] = coreId;
json["core"] = remapCoreId(coreId);
json["size"] = size;
json["offset"] = createEmptyOffset();
emitInstruction(std::move(json));
@@ -242,10 +436,62 @@ void PimCodeGen::codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValue
"recv", addressOf(receiveOp.getOutputBuffer(), knowledge), receiveOp.getSourceCoreId(), receiveOp.getSize());
}
void PimCodeGen::codeGenReceiveTensorOp(pim::PimReceiveTensorOp receiveTensorOp,
const StaticValueKnowledge& knowledge) const {
size_t outputAddr = addressOf(receiveTensorOp.getOutputBuffer(), knowledge);
size_t chunkSize = getValueSizeInBytes(receiveTensorOp.getOutputBuffer()) / receiveTensorOp.getSourceCoreIds().size();
for (auto [chunkIndex, sourceCoreId] : llvm::enumerate(receiveTensorOp.getSourceCoreIds()))
emitCommunicationOp("recv", outputAddr + chunkIndex * chunkSize, sourceCoreId, chunkSize);
}
void PimCodeGen::codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const {
emitCommunicationOp("send", addressOf(sendOp.getInput(), knowledge), sendOp.getTargetCoreId(), sendOp.getSize());
}
void PimCodeGen::codeGenSendTensorOp(pim::PimSendTensorOp sendTensorOp, const StaticValueKnowledge& knowledge) const {
size_t inputAddr = addressOf(sendTensorOp.getInput(), knowledge);
size_t chunkSize = getValueSizeInBytes(sendTensorOp.getInput()) / sendTensorOp.getTargetCoreIds().size();
for (auto [chunkIndex, targetCoreId] : llvm::enumerate(sendTensorOp.getTargetCoreIds()))
emitCommunicationOp("send", inputAddr + chunkIndex * chunkSize, targetCoreId, chunkSize);
}
void PimCodeGen::codeGenConcatOp(pim::PimConcatOp concatOp, const StaticValueKnowledge& knowledge) const {
auto outputType = cast<ShapedType>(concatOp.getOutputBuffer().getType());
assert(outputType.hasStaticShape() && "concat codegen requires static output shape");
int64_t axis = concatOp.getAxis();
ArrayRef<int64_t> outputShape = outputType.getShape();
size_t elementSize = outputType.getElementTypeBitWidth() / 8;
size_t outputAddr = addressOf(concatOp.getOutputBuffer(), knowledge);
size_t outerCount = 1;
for (int64_t dim = 0; dim < axis; ++dim)
outerCount *= static_cast<size_t>(outputShape[dim]);
size_t innerCount = 1;
for (size_t dim = static_cast<size_t>(axis) + 1; dim < outputShape.size(); ++dim)
innerCount *= static_cast<size_t>(outputShape[dim]);
size_t outputConcatDim = static_cast<size_t>(outputShape[axis]);
size_t concatOffset = 0;
for (mlir::Value input : concatOp.getInputs()) {
auto inputType = cast<ShapedType>(input.getType());
assert(inputType.hasStaticShape() && "concat codegen requires static input shapes");
size_t inputConcatDim = static_cast<size_t>(inputType.getDimSize(axis));
size_t blockSizeInBytes = inputConcatDim * innerCount * elementSize;
size_t inputAddr = addressOf(input, knowledge);
for (size_t outerIndex = 0; outerIndex < outerCount; ++outerIndex) {
size_t dstOffset = (outerIndex * outputConcatDim + concatOffset) * innerCount * elementSize;
size_t srcOffset = outerIndex * inputConcatDim * innerCount * elementSize;
emitMemCopyOp("lmv", outputAddr, dstOffset, inputAddr, srcOffset, blockSizeInBytes, "len");
}
concatOffset += inputConcatDim;
}
}
template <typename MVMTy>
void PimCodeGen::codeGenMVMLikeOp(size_t mvmId,
MVMTy mvmLikeOp,
@@ -256,11 +502,6 @@ void PimCodeGen::codeGenMVMLikeOp(size_t mvmId,
// TODO: save weights somewhere (if transposeMatrix=true, transpose the weight matrix)
}
static size_t getValueSizeInBytes(mlir::Value value) {
auto type = cast<ShapedType>(value.getType());
return type.getNumElements() * type.getElementTypeBitWidth() / 8;
}
void PimCodeGen::codeGenVVAddOp(pim::PimVVAddOp vvaddOp, const StaticValueKnowledge& knowledge) const {
auto outputBufferAddr = addressOf(vvaddOp.getOutputBuffer(), knowledge);
auto lhsAddr = addressOf(vvaddOp.getLhs(), knowledge);
@@ -412,6 +653,8 @@ void PimCodeGen::codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp, const StaticVa
emitInstruction(std::move(json));
}
void PimCodeGen::codeGetGlobalOp(memref::GetGlobalOp getGlobalOp, const StaticValueKnowledge& knowledge) const {}
void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp, const StaticValueKnowledge& knowledge) const {
auto srcAddr = addressOf(transposeOp.getInput(), knowledge);
auto dstAddr = addressOf(transposeOp.getOutputBuffer(), knowledge);
@@ -474,67 +717,59 @@ std::string getMemorySizeAsString(size_t size) {
return std::to_string(size) + " Bytes";
}
static SmallVector<unsigned, 8> getUsedWeightIndices(pim::PimCoreOp coreOp) {
static SmallVector<unsigned, 8> getUsedWeightIndices(Block& block) {
SmallVector<unsigned, 8> indices;
auto addIndex = [&](unsigned weightIndex) {
if (!llvm::is_contained(indices, weightIndex))
indices.push_back(weightIndex);
};
coreOp.walk([&](pim::PimMVMOp mvmOp) { addIndex(mvmOp.getWeightIndex()); });
coreOp.walk([&](pim::PimVMMOp vmmOp) { addIndex(vmmOp.getWeightIndex()); });
block.walk([&](pim::PimVMMOp vmmOp) { addIndex(vmmOp.getWeightIndex()); });
llvm::sort(indices);
return indices;
}
/// Write global constant data into a binary memory image at their allocated addresses.
static OnnxMlirCompilerErrorCodes
writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory& memory, StringRef outputDirPath) {
auto memoryFilePath = (outputDirPath + "/memory.bin").str();
std::error_code errorCode;
raw_fd_ostream memoryFileStream(memoryFilePath, errorCode, sys::fs::OF_None);
if (errorCode) {
errs() << "Error while opening memory file " << memoryFilePath << ": " << errorCode.message() << '\n';
return InvalidOutputFileAccess;
}
static SmallVector<unsigned, 8> getUsedWeightIndices(pim::PimCoreOp coreOp) {
return getUsedWeightIndices(coreOp.getBody().front());
}
std::vector<char> memoryBuffer(memory.hostMem.getFirstAvailableAddress(), 0);
static SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) {
auto coreIdsAttr = coreBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName);
assert(coreIdsAttr && "pim.core_batch requires coreIds array attribute");
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
}
SmallPtrSet<Operation*, 16> writtenGlobals;
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
if (hasWeightAlways(getGlobalOp))
return;
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp)
return;
if (!writtenGlobals.insert(globalOp.getOperation()).second)
return;
auto initialValue = globalOp.getInitialValue();
if (!initialValue)
return;
auto denseAttr = dyn_cast<DenseElementsAttr>(*initialValue);
if (!denseAttr)
static SmallVector<Operation*> collectTopLevelCoreLikeOps(func::FuncOp funcOp) {
SmallVector<Operation*> coreLikeOps;
for (Operation& op : funcOp.getBody().front())
if (dyn_cast<pim::PimCoreOp>(&op) || dyn_cast<pim::PimCoreBatchOp>(&op))
coreLikeOps.push_back(&op);
return coreLikeOps;
}
static void aliasMaterializedHostGlobals(ModuleOp moduleOp,
func::FuncOp funcOp,
pim::PimCoreOp coreOp,
PimAcceleratorMemory& memory) {
coreOp.walk([&](memref::GetGlobalOp getGlobalOp) {
if (hasWeightAlways(getGlobalOp) || memory.memEntriesMap.contains(getGlobalOp.getResult()))
return;
MemEntry memEntry = memory.hostMem.getMemEntry(getGlobalOp.getResult());
ArrayRef<char> rawData = denseAttr.getRawData();
char* dst = memoryBuffer.data() + memEntry.address;
auto targetGlobal = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!targetGlobal)
return;
if (denseAttr.isSplat()) {
size_t elementSize = rawData.size();
assert(elementSize * getGlobalOp.getType().getNumElements() == memEntry.size && "Data size mismatch");
for (size_t offset = 0; offset < memEntry.size; offset += elementSize)
std::memcpy(dst + offset, rawData.data(), std::min(elementSize, memEntry.size - offset));
}
else {
assert(rawData.size() == memEntry.size && "Data size mismatch");
std::memcpy(dst, rawData.data(), rawData.size());
}
mlir::Value aliasedValue;
funcOp.walk([&](memref::GetGlobalOp candidate) {
if (aliasedValue || candidate == getGlobalOp || !memory.memEntriesMap.contains(candidate.getResult()))
return;
if (lookupGlobalForGetGlobal(moduleOp, candidate) == targetGlobal)
aliasedValue = candidate.getResult();
});
if (aliasedValue)
memory.memEntriesMap[getGlobalOp.getResult()] = memory.memEntriesMap[aliasedValue];
});
memoryFileStream.write(memoryBuffer.data(), memoryBuffer.size());
memoryFileStream.close();
return CompilerSuccess;
}
/// Dispatch all operations in a core region to the appropriate code generator.
@@ -553,12 +788,16 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
coreCodeGen.codeGenLmvOp(lmvOp, knowledge);
else if (auto receiveOp = dyn_cast<pim::PimReceiveOp>(op))
coreCodeGen.codeGenReceiveOp(receiveOp, knowledge);
else if (auto receiveTensorOp = dyn_cast<pim::PimReceiveTensorOp>(op))
coreCodeGen.codeGenReceiveTensorOp(receiveTensorOp, knowledge);
else if (auto sendOp = dyn_cast<pim::PimSendOp>(op))
coreCodeGen.codeGenSendOp(sendOp, knowledge);
else if (auto sendTensorOp = dyn_cast<pim::PimSendTensorOp>(op))
coreCodeGen.codeGenSendTensorOp(sendTensorOp, knowledge);
else if (auto concatOp = dyn_cast<pim::PimConcatOp>(op))
coreCodeGen.codeGenConcatOp(concatOp, knowledge);
else if (auto vmmOp = dyn_cast<pim::PimVMMOp>(op))
coreCodeGen.codeGenMVMLikeOp<pim::PimVMMOp>(vmmOp.getWeightIndex(), vmmOp, true, knowledge);
else if (auto mvmOp = dyn_cast<pim::PimMVMOp>(op))
coreCodeGen.codeGenMVMLikeOp<pim::PimMVMOp>(mvmOp.getWeightIndex(), mvmOp, false, knowledge);
else if (auto transposeOp = dyn_cast<pim::PimTransposeOp>(op))
coreCodeGen.codeGenTransposeOp(transposeOp, knowledge);
else if (auto vvaddOp = dyn_cast<pim::PimVVAddOp>(op))
@@ -581,9 +820,15 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
coreCodeGen.codeGenVSigmOp(vsigmOp, knowledge);
else if (auto vsoftmaxOp = dyn_cast<pim::PimVSoftmaxOp>(op))
coreCodeGen.codeGenVSoftmaxOp(vsoftmaxOp, knowledge);
else if (auto getGlobalOp = dyn_cast<memref::GetGlobalOp>(op))
coreCodeGen.codeGetGlobalOp(getGlobalOp, knowledge);
else {
op.emitError("Unsupported codegen for this operation");
op.dump();
InFlightDiagnostic diag = op.emitError()
<< "unsupported codegen for op '" << op.getName().getStringRef() << "'";
if (auto coreOp = op.getParentOfType<pim::PimCoreOp>())
diag << " inside pim.core " << coreOp.getCoreId();
else if (auto coreBatchOp = op.getParentOfType<pim::PimCoreBatchOp>())
diag << " inside pim.core_batch with laneCount " << coreBatchOp.getLaneCount();
return failure();
}
processedOperations++;
@@ -592,225 +837,6 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
return failed(result) ? -1 : static_cast<int64_t>(processedOperations);
}
/// Write crossbar weight matrices as padded binary files for a single core.
static OnnxMlirCompilerErrorCodes writeCrossbarWeights(ModuleOp moduleOp,
pim::PimCoreOp coreOp,
StringRef coreWeightsDirPath,
json::Array& xbarsPerGroup) {
int64_t xbarSize = crossbarSize.getValue();
std::error_code errorCode;
size_t weightIndex = 0;
for (auto weight : coreOp.getWeights()) {
xbarsPerGroup.push_back(weightIndex);
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
if (!getGlobalOp) {
coreOp.emitWarning("Weight is not from a memref.get_global at index " + std::to_string(weightIndex));
weightIndex++;
continue;
}
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp) {
coreOp.emitWarning("Could not find memref.global for weight at index " + std::to_string(weightIndex));
weightIndex++;
continue;
}
auto initialValue = globalOp.getInitialValue();
if (!initialValue) {
coreOp.emitWarning("memref.global has no initial value at index " + std::to_string(weightIndex));
weightIndex++;
continue;
}
auto denseAttr = dyn_cast<DenseElementsAttr>(*initialValue);
if (!denseAttr) {
coreOp.emitWarning("memref.global initial value is not dense at index " + std::to_string(weightIndex));
weightIndex++;
continue;
}
auto type = denseAttr.getType();
auto shape = type.getShape();
assert(isMatrixShape(shape) && "Weight matrix must be 2-dimensional");
int64_t numRows = shape[0];
int64_t numCols = shape[1];
assert(numRows <= xbarSize && numCols <= xbarSize && "Weight dimensions must not exceed crossbar size");
size_t elementByteWidth = type.getElementType().getIntOrFloatBitWidth() / 8;
auto weightFilePath = (coreWeightsDirPath + "/crossbar_" + std::to_string(weightIndex) + ".bin").str();
raw_fd_ostream weightFileStream(weightFilePath, errorCode, sys::fs::OF_None);
if (errorCode) {
errs() << "Error while opening weight file `" << weightFilePath << "`: " << errorCode.message() << '\n';
return InvalidOutputFileAccess;
}
uint64_t zero = 0;
for (int64_t row = 0; row < xbarSize; row++) {
for (int64_t col = 0; col < xbarSize; col++) {
if (row < numRows && col < numCols) {
int64_t index = row * numCols + col;
APInt bits = denseAttr.getValues<APFloat>()[index].bitcastToAPInt();
uint64_t word = bits.getZExtValue();
weightFileStream.write(reinterpret_cast<const char*>(&word), elementByteWidth);
}
else {
weightFileStream.write(reinterpret_cast<const char*>(&zero), elementByteWidth);
}
}
}
weightFileStream.close();
weightIndex++;
}
return CompilerSuccess;
}
llvm::DenseMap<pim::PimCoreOp, llvm::DenseMap<mlir::Value, std::string>>
createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) {
ModuleOp moduleOp = funcOp->getParentOfType<ModuleOp>();
auto coreWeightsDirPath = outputDirPath + "/weights";
auto error = sys::fs::create_directory(coreWeightsDirPath);
assert(!error && "Error creating weights directory");
size_t indexFileName = 0;
int64_t xbarSize = crossbarSize.getValue();
llvm::DenseMap<pim::PimCoreOp, llvm::DenseMap<mlir::Value, std::string>> mapCoreWeightToFileName;
llvm::DenseMap<memref::GlobalOp, std::string> mapGlobalOpToFileName;
for (pim::PimCoreOp coreOp : funcOp.getOps<pim::PimCoreOp>()) {
for (unsigned index : getUsedWeightIndices(coreOp)) {
if (index >= coreOp.getWeights().size()) {
coreOp.emitWarning("Weight index " + std::to_string(index) + " is out of range");
assert(index < coreOp.getWeights().size() && "Weight index is out of range");
}
mlir::Value weight = coreOp.getWeights()[index];
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
if (!getGlobalOp) {
coreOp.emitWarning("Weight is not from a memref.get_global at index " + std::to_string(index));
assert(!getGlobalOp && "Weight is not from a memref.get_global");
}
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp) {
coreOp.emitWarning("Could not find memref.global for weight at index " + std::to_string(index));
assert(!globalOp && "Could not find memref.global");
}
auto initialValue = globalOp.getInitialValue();
if (!initialValue) {
coreOp.emitWarning("memref.global has no initial value at index " + std::to_string(index));
assert(!initialValue && "memref.global has no initial value");
}
auto denseAttr = dyn_cast<DenseElementsAttr>(*initialValue);
if (!denseAttr) {
coreOp.emitWarning("memref.global initial value is not dense at index " + std::to_string(index));
assert(!denseAttr && "memref.global initial value is not dense");
}
if (mapGlobalOpToFileName.contains(globalOp)) {
auto& fileName = mapGlobalOpToFileName[globalOp];
std::pair<mlir::Value, std::string> weightToFile = {weight, fileName};
mapCoreWeightToFileName[coreOp].insert(weightToFile);
continue;
}
auto type = denseAttr.getType();
auto shape = type.getShape();
assert(isMatrixShape(shape) && "Weight matrix must be 2-dimensional");
int64_t numRows = shape[0];
int64_t numCols = shape[1];
assert(numRows <= xbarSize && numCols <= xbarSize && "Weight dimensions must not exceed crossbar size");
size_t elementByteWidth = type.getElementType().getIntOrFloatBitWidth() / 8;
std::string newFileName = "crossbar_" + std::to_string(indexFileName++) + ".bin";
auto weightFilePath = (coreWeightsDirPath + "/" + newFileName).str();
std::error_code errorCode;
raw_fd_ostream weightFileStream(weightFilePath, errorCode, sys::fs::OF_None);
if (errorCode) {
errs() << "Error while opening weight file `" << weightFilePath << "`: " << errorCode.message() << '\n';
assert(errorCode);
}
uint64_t zero = 0;
for (int64_t row = 0; row < xbarSize; row++) {
for (int64_t col = 0; col < xbarSize; col++) {
if (row < numRows && col < numCols) {
int64_t index = row * numCols + col;
APInt bits = denseAttr.getValues<APFloat>()[index].bitcastToAPInt();
uint64_t word = bits.getZExtValue();
weightFileStream.write(reinterpret_cast<const char*>(&word), elementByteWidth);
}
else {
weightFileStream.write(reinterpret_cast<const char*>(&zero), elementByteWidth);
}
}
}
weightFileStream.close();
mapGlobalOpToFileName.insert({globalOp, newFileName});
mapCoreWeightToFileName[coreOp].insert({weight, newFileName});
}
}
return mapCoreWeightToFileName;
}
/// Write the top-level PIM configuration JSON (core count, crossbar config, I/O addresses).
static OnnxMlirCompilerErrorCodes writeConfigJson(func::FuncOp funcOp,
PimAcceleratorMemory& memory,
size_t coreCount,
json::Object xbarsPerArrayGroup,
StringRef outputDirPath) {
json::Object configJson;
// +1 because pimsim-nn also considers the host as a core
configJson["core_cnt"] = coreCount + 1;
// TODO: Should this be based on the floating point type used in the model?
// The 2 following values determine the bitwidth of the vectors' elements: bitwidth = adc_count * cell_precision
// Number of ADC for MVM units
configJson["adc_count"] = 16;
// The bit precision of each ADC
configJson["cell_precision"] = 2;
// Crossbar configuration
configJson["xbar_array_count"] = crossbarCountInCore.getValue();
configJson["xbar_size"] = {crossbarSize.getValue(), crossbarSize.getValue()};
configJson["array_group_map"] = std::move(xbarsPerArrayGroup);
// Memory layout of inputs and outputs
json::Array inputsAddresses;
for (BlockArgument input : funcOp.getArguments())
inputsAddresses.push_back(memory.getValueAddress(input));
configJson["inputs_addresses"] = std::move(inputsAddresses);
json::Array outputsAddresses;
for (func::ReturnOp returnOp : funcOp.getOps<func::ReturnOp>())
for (mlir::Value output : returnOp.getOperands())
outputsAddresses.push_back(memory.getValueAddress(output));
configJson["outputs_addresses"] = std::move(outputsAddresses);
auto configPath = (outputDirPath + "/config.json").str();
std::error_code errorCode;
raw_fd_ostream jsonOS(configPath, errorCode);
if (errorCode) {
errs() << "Error while opening config file: " << errorCode.message() << '\n';
return InvalidOutputFileAccess;
}
jsonOS << json::Value(std::move(configJson)) << '\n';
jsonOS.close();
return CompilerSuccess;
}
OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::string& outputDirPath) {
if (!outputDirPath.empty()) {
if (auto error = sys::fs::create_directory(outputDirPath)) {
@@ -826,85 +852,148 @@ OnnxMlirCompilerErrorCodes onnx_mlir::compileToPimJson(ModuleOp& moduleOp, std::
PimAcceleratorMemory memory;
memory.hostMem.allocateHost(moduleOp, funcOp);
memory.reportHost();
if (auto err = writeMemoryBinary(moduleOp, funcOp, memory, outputDirPath))
return err;
// Write empty host core file
std::error_code errorCode;
auto outputHostCorePath = outputDirPath + "/core_0.json";
raw_fd_ostream hostFileStream(outputHostCorePath, errorCode);
if (errorCode) {
errs() << "Error while opening host core file `" << outputHostCorePath << "`: " << errorCode.message() << '\n';
return InvalidOutputFileAccess;
}
// The host core json contains 2 random instructions, just to make pimsim-nn happy
hostFileStream << "[{\"imm\":0,\"op\":\"sldi\",\"rd\":0},{\"imm\":0,\"op\":\"sldi\",\"rd\":0}]";
hostFileStream.close();
if (auto err = writeHostCoreJson(outputDirPath))
return err;
// For each core, specify the number of crossbar per array group.
// This implementation always assigns one crossbar per group.
json::Object xbarsPerArrayGroup;
size_t coreCount = 0;
size_t maxCoreId = 0;
uint64_t nextBatchReportId = 0;
// Create Weight Folder
auto mapCoreWeightToFileName = createAndPopulateWeightFolder(funcOp, outputDirPath);
for (auto coreOp : funcOp.getOps<pim::PimCoreOp>()) {
auto coreId = coreOp.getCoreId();
coreCount++;
SmallVector<Operation*> coreLikeOps = collectTopLevelCoreLikeOps(funcOp);
llvm::DenseMap<size_t, size_t> emittedCoreIds;
size_t nextEmittedCoreId = 1;
std::error_code errorCode;
auto outputCorePath = outputDirPath + "/core_" + std::to_string(coreId) + ".json";
raw_fd_ostream coreFileStream(outputCorePath, errorCode);
if (errorCode) {
errs() << "Error while opening core file `" << outputCorePath << "`: " << errorCode.message() << '\n';
return InvalidOutputFileAccess;
}
coreFileStream << '[';
PimCodeGen coreCodeGen(memory, coreFileStream);
memory.getOrCreateDeviceMem(coreId).allocateCore(coreOp);
int64_t processedOperations = codeGenCoreOps(coreOp.getBody().front(), coreCodeGen);
if (processedOperations < 0)
return CompilerFailure;
assert(processedOperations > 0);
// Remove trailing comma, close JSON array
coreFileStream.seek(coreFileStream.tell() - 1);
coreFileStream << ']';
coreFileStream.close();
// Write crossbar weights for this core
auto coreWeightsDirPath = outputDirPath + "/core_" + std::to_string(coreId);
if (auto error = sys::fs::create_directory(coreWeightsDirPath)) {
errs() << "Error creating core directory: " << coreWeightsDirPath << ": " << error.message() << '\n';
return InvalidOutputFileAccess;
for (Operation* op : coreLikeOps) {
if (auto coreOp = dyn_cast<pim::PimCoreOp>(op)) {
size_t originalCoreId = static_cast<size_t>(coreOp.getCoreId());
if (!emittedCoreIds.contains(originalCoreId))
emittedCoreIds[originalCoreId] = nextEmittedCoreId++;
continue;
}
auto& mapWeightToFile = mapCoreWeightToFileName[coreOp];
json::Array xbarsPerGroup;
for (unsigned index : getUsedWeightIndices(coreOp)) {
if (index >= coreOp.getWeights().size()) {
coreOp.emitWarning("Weight index " + std::to_string(index) + " is out of range");
assert(index < coreOp.getWeights().size() && "Weight index is out of range");
}
mlir::Value weight = coreOp.getWeights()[index];
xbarsPerGroup.push_back(index);
assert(mapWeightToFile.contains(weight) && "Weight was not materialized into a file!!");
auto& fileName = mapWeightToFile[weight];
if (auto error = sys::fs::create_link(outputDirPath + "/weights/" + fileName,
coreWeightsDirPath + "/crossbar_" + std::to_string(index) + ".bin")) {
errs() << "Error creating link file: " << (outputDirPath + "/weights/" + fileName) << " to "
<< (coreWeightsDirPath + "/crossbar_" + std::to_string(index) + ".bin") << "\nError:" << error.message()
<< '\n';
return InvalidOutputFileAccess;
}
auto coreBatchOp = cast<pim::PimCoreBatchOp>(op);
auto batchCoreIds = getBatchCoreIds(coreBatchOp);
for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane) {
size_t originalCoreId = static_cast<size_t>(batchCoreIds[lane]);
if (!emittedCoreIds.contains(originalCoreId))
emittedCoreIds[originalCoreId] = nextEmittedCoreId++;
}
xbarsPerArrayGroup["core" + std::to_string(coreId)] = std::move(xbarsPerGroup);
}
return writeConfigJson(funcOp, memory, coreCount, std::move(xbarsPerArrayGroup), outputDirPath);
for (Operation* op : coreLikeOps) {
auto emitCore = [&](pim::PimCoreOp coreOp,
bool temporaryCore,
MemoryReportRow* reportRow = nullptr) -> OnnxMlirCompilerErrorCodes {
size_t originalCoreId = static_cast<size_t>(coreOp.getCoreId());
size_t coreId = emittedCoreIds.lookup(originalCoreId);
maxCoreId = std::max(maxCoreId, coreId);
std::error_code errorCode;
auto outputCorePath = outputDirPath + "/core_" + std::to_string(coreId) + ".json";
raw_fd_ostream coreFileStream(outputCorePath, errorCode);
if (errorCode) {
errs() << "Error while opening core file `" << outputCorePath << "`: " << errorCode.message() << '\n';
return InvalidOutputFileAccess;
}
coreFileStream << '[';
PimCodeGen coreCodeGen(memory, coreFileStream, emittedCoreIds);
aliasMaterializedHostGlobals(moduleOp, funcOp, coreOp, memory);
auto& deviceMemory = memory.getOrCreateDeviceMem(coreId);
deviceMemory.allocateCore(coreOp);
int64_t processedOperations = codeGenCoreOps(coreOp.getBody().front(), coreCodeGen);
if (processedOperations < 0)
return CompilerFailure;
assert(processedOperations > 0);
if (reportRow)
*reportRow = deviceMemory.getReportRow();
coreFileStream.seek(coreFileStream.tell() - 1);
coreFileStream << ']';
coreFileStream.close();
auto coreWeightsDirPath = outputDirPath + "/core_" + std::to_string(coreId);
if (auto error = sys::fs::create_directory(coreWeightsDirPath)) {
errs() << "Error creating core directory: " << coreWeightsDirPath << ": " << error.message() << '\n';
return InvalidOutputFileAccess;
}
auto& mapWeightToFile = mapCoreWeightToFileName[originalCoreId];
json::Array xbarsPerGroup;
for (unsigned index : getUsedWeightIndices(coreOp)) {
if (index >= coreOp.getWeights().size()) {
coreOp.emitWarning("Weight index " + std::to_string(index) + " is out of range");
assert(index < coreOp.getWeights().size() && "Weight index is out of range");
}
mlir::Value weight = coreOp.getWeights()[index];
xbarsPerGroup.push_back(index);
assert(mapWeightToFile.contains(weight) && "Weight was not materialized into a file!!");
auto& fileName = mapWeightToFile[weight];
if (auto error = sys::fs::create_link(outputDirPath + "/weights/" + fileName,
coreWeightsDirPath + "/crossbar_" + std::to_string(index) + ".bin")) {
errs() << "Error creating link file: " << (outputDirPath + "/weights/" + fileName) << " to "
<< (coreWeightsDirPath + "/crossbar_" + std::to_string(index) + ".bin")
<< "\nError:" << error.message() << '\n';
return InvalidOutputFileAccess;
}
}
xbarsPerArrayGroup["core" + std::to_string(coreId)] = std::move(xbarsPerGroup);
if (temporaryCore)
coreOp.walk([&memory](Operation* op) { memory.clean(op); });
return CompilerSuccess;
};
if (auto coreOp = dyn_cast<pim::PimCoreOp>(op)) {
MemoryReportRow coreRow;
if (auto err = emitCore(coreOp, false, &coreRow))
return err;
memory.recordCoreReport(emittedCoreIds.lookup(static_cast<size_t>(coreOp.getCoreId())), coreRow);
continue;
}
auto coreBatchOp = cast<pim::PimCoreBatchOp>(op);
auto batchCoreIds = getBatchCoreIds(coreBatchOp);
SmallVector<int32_t> reportedCoreIds;
reportedCoreIds.reserve(batchCoreIds.size());
MemoryReportRow batchRow;
std::optional<MemoryReportRow> batchPerCoreRow;
for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane) {
OnnxMlirCompilerErrorCodes laneResult = CompilerSuccess;
if (failed(withScalarCoreFromBatchLane(coreBatchOp, lane, [&](pim::PimCoreOp coreOp) {
size_t originalCoreId = static_cast<size_t>(batchCoreIds[lane]);
size_t coreId = emittedCoreIds.lookup(originalCoreId);
reportedCoreIds.push_back(static_cast<int32_t>(coreId));
MemoryReportRow laneRow;
laneResult = emitCore(coreOp, true, &laneRow);
if (laneResult == CompilerSuccess) {
if (!batchPerCoreRow.has_value())
batchPerCoreRow = laneRow;
batchRow = addMemoryReportRows(batchRow, laneRow);
}
return laneResult == CompilerSuccess ? success() : failure();
})))
return laneResult == CompilerSuccess ? CompilerFailure : laneResult;
}
memory.recordBatchReport(nextBatchReportId++,
reportedCoreIds,
batchPerCoreRow.value_or(MemoryReportRow {}),
batchRow.numAlloca,
batchRow.sizeAlloca);
}
memory.flushReport();
return writeConfigJson(funcOp, memory, maxCoreId, std::move(xbarsPerArrayGroup), outputDirPath);
}
+60 -5
View File
@@ -1,10 +1,18 @@
#pragma once
#include "mlir/IR/Operation.h"
#include "llvm-project/clang/include/clang/Basic/LLVM.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/JSON.h"
#include "llvm/Support/raw_os_ostream.h"
#include <fstream>
#include <optional>
#include "onnx-mlir/Compiler/OMCompilerTypes.h"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
namespace onnx_mlir {
@@ -14,12 +22,37 @@ struct MemEntry {
size_t size;
};
struct MemoryReportRow {
uint64_t numAlloca = 0;
uint64_t sizeAlloca = 0;
uint64_t numGlobal = 0;
uint64_t sizeGlobal = 0;
bool operator==(const MemoryReportRow& other) const {
return numAlloca == other.numAlloca && sizeAlloca == other.sizeAlloca && numGlobal == other.numGlobal
&& sizeGlobal == other.sizeGlobal;
}
};
struct MemoryReportEntry {
enum class Kind {
Core,
Batch
};
Kind kind = Kind::Core;
uint64_t id = 0;
llvm::SmallVector<int32_t, 8> coreIds;
MemoryReportRow row;
uint64_t totalAllocaCount = 0;
uint64_t totalAllocaBytes = 0;
};
class PimMemory {
llvm::SmallVector<std::pair<MemEntry, mlir::Value>, 32> memEntries;
llvm::SmallDenseMap<mlir::Value, MemEntry, 32>& globalMemEntriesMap;
llvm::SmallDenseMap<mlir::Value, MemEntry, 32> ownedMemEntriesMap;
size_t maxSize = 0; // 0 for unbounded memory
size_t startAddress = 0;
size_t minAlignment = 4;
size_t firstAvailableAddress = 0;
@@ -33,6 +66,8 @@ public:
void allocateHost(mlir::ModuleOp moduleOp, mlir::func::FuncOp funcOp);
void allocateCore(mlir::Operation* op);
MemoryReportRow getReportRow() const;
void remove(mlir::Value val);
size_t getFirstAvailableAddress() const { return firstAvailableAddress; }
MemEntry getMemEntry(mlir::Value value) const;
@@ -45,23 +80,37 @@ public:
private:
llvm::SmallDenseMap<size_t, PimMemory> deviceMem;
std::fstream fileReport;
std::optional<MemoryReportRow> hostReportRow;
llvm::SmallVector<MemoryReportEntry, 32> reportEntries;
public:
PimAcceleratorMemory()
: hostMem(memEntriesMap) {}
: hostMem(memEntriesMap), fileReport(openReportFile("memory_report")) {}
PimMemory& getOrCreateDeviceMem(size_t id);
size_t getValueAddress(mlir::Value value, const StaticValueKnowledge& knowledge = {}) const;
void reportHost();
void recordCoreReport(size_t coreId, const MemoryReportRow& row);
void recordBatchReport(uint64_t batchId,
llvm::ArrayRef<int32_t> coreIds,
const MemoryReportRow& perCoreRow,
uint64_t totalAllocaCount,
uint64_t totalAllocaBytes);
void flushReport();
void clean(mlir::Operation* op);
};
class PimCodeGen {
PimAcceleratorMemory& memory;
llvm::raw_fd_ostream& coreFileStream;
const llvm::DenseMap<size_t, size_t>& emittedCoreIds;
size_t addressOf(mlir::Value value, const StaticValueKnowledge& knowledge) const {
return memory.getValueAddress(value, knowledge);
}
size_t remapCoreId(size_t coreId) const;
static llvm::json::Object createEmptyOffset();
void emitInstruction(llvm::json::Object instruction) const;
@@ -83,15 +132,20 @@ class PimCodeGen {
void emitMvmOp(size_t groupId, size_t rdAddr, size_t rdOffset, size_t rs1Addr, size_t rs1Offset) const;
public:
PimCodeGen(PimAcceleratorMemory& memory, llvm::raw_fd_ostream& coreJson)
: memory(memory), coreFileStream(coreJson) {}
PimCodeGen(PimAcceleratorMemory& memory,
llvm::raw_fd_ostream& coreJson,
const llvm::DenseMap<size_t, size_t>& emittedCoreIds)
: memory(memory), coreFileStream(coreJson), emittedCoreIds(emittedCoreIds) {}
void codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp, const StaticValueKnowledge& knowledge) const;
void codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp, const StaticValueKnowledge& knowledge) const;
void codeGenLmvOp(pim::PimMemCopyOp lmvOp, const StaticValueKnowledge& knowledge) const;
void codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValueKnowledge& knowledge) const;
void codeGenReceiveTensorOp(pim::PimReceiveTensorOp receiveTensorOp, const StaticValueKnowledge& knowledge) const;
void codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const;
void codeGenSendTensorOp(pim::PimSendTensorOp sendTensorOp, const StaticValueKnowledge& knowledge) const;
void codeGenConcatOp(pim::PimConcatOp concatOp, const StaticValueKnowledge& knowledge) const;
template <typename MVMTy>
void codeGenMVMLikeOp(size_t mvmId, MVMTy mvmLikeOp, bool transposeMatrix, const StaticValueKnowledge& knowledge);
@@ -106,6 +160,7 @@ public:
void codeGenVTanhOp(pim::PimVTanhOp vtanhOp, const StaticValueKnowledge& knowledge) const;
void codeGenVSigmOp(pim::PimVSigmOp vsigmOp, const StaticValueKnowledge& knowledge) const;
void codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp, const StaticValueKnowledge& knowledge) const;
void codeGetGlobalOp(mlir::memref::GetGlobalOp getGlobalOp, const StaticValueKnowledge& knowledge) const;
void codeGenTransposeOp(pim::PimTransposeOp transposeOp, const StaticValueKnowledge& knowledge) const;
};
+2 -15
View File
@@ -1,16 +1,3 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/
//===------------------------- PimCompilerOptions.cpp --------------------===//
//
// Copyright 2022 The IBM Research Authors.
//
// =============================================================================
//
// Compiler Options for PIM
//
//===----------------------------------------------------------------------===//
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#define DEBUG_TYPE "PimCompilerOptions"
@@ -41,7 +28,7 @@ llvm::cl::opt<size_t>
crossbarSize("crossbar-size", llvm::cl::desc("Width and heigth of a single crossbar"), llvm::cl::init(2));
llvm::cl::opt<size_t>
crossbarCountInCore("crossbar-count", llvm::cl::desc("Number of crossbars in each core"), llvm::cl::init(2));
crossbarCountInCore("crossbar-count", llvm::cl::desc("Number of crossbars in each core"), llvm::cl::init(256));
llvm::cl::opt<long> coresCount("core-count",
llvm::cl::desc("Number of cores in the chip. `-1` to use the minimum amount of cores."),
@@ -51,7 +38,7 @@ llvm::cl::opt<size_t> dcpCriticalWindowSize(
"dcp-critical-window-size",
llvm::cl::desc("Number of lowest-slack virtual nodes considered by each DCP coarsening iteration. "
"Use 0 to run the legacy full-graph DCP analysis."),
llvm::cl::init(1024));
llvm::cl::init(4000));
llvm::cl::opt<bool>
ignoreConcatError("ignore-concat-error",
+1
View File
@@ -41,6 +41,7 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
if (pimEmissionTarget >= EmitPimBufferized) {
pm.addPass(createPimBufferizationPass());
pm.addPass(createPimStaticMemoryCoalescingPass());
// pm.addPass(createCountInstructionPass());
pm.addPass(createMessagePass("Pim bufferized"));
}
+209
View File
@@ -0,0 +1,209 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/raw_ostream.h"
#include <cassert>
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Compiler/PimWeightEmitter.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace llvm;
using namespace mlir;
namespace onnx_mlir {
namespace {
struct DenseWeightView {
DenseElementsAttr denseAttr;
SmallVector<int64_t> shape;
SmallVector<int64_t> strides;
int64_t offset = 0;
};
FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value weight) {
SmallVector<memref::SubViewOp> subviews;
mlir::Value current = weight;
memref::GetGlobalOp getGlobalOp;
while (true) {
Operation* defOp = current.getDefiningOp();
if (!defOp)
return failure();
if ((getGlobalOp = dyn_cast<memref::GetGlobalOp>(defOp)))
break;
if (auto subview = dyn_cast<memref::SubViewOp>(defOp)) {
if (!hasAllStaticSubviewParts(subview))
return failure();
subviews.push_back(subview);
current = subview.getSource();
continue;
}
if (auto cast = dyn_cast<memref::CastOp>(defOp)) {
current = cast.getSource();
continue;
}
return failure();
}
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp || !globalOp.getInitialValue())
return failure();
auto denseAttr = dyn_cast<DenseElementsAttr>(*globalOp.getInitialValue());
if (!denseAttr)
return failure();
DenseWeightView view;
view.denseAttr = denseAttr;
view.shape.assign(denseAttr.getType().getShape().begin(), denseAttr.getType().getShape().end());
view.strides = computeRowMajorStrides(view.shape);
for (memref::SubViewOp subview : llvm::reverse(subviews)) {
SmallVector<int64_t> nextStrides;
nextStrides.reserve(subview.getStaticStrides().size());
for (auto [offset, stride, sourceStride] :
llvm::zip_equal(subview.getStaticOffsets(), subview.getStaticStrides(), view.strides)) {
view.offset += offset * sourceStride;
nextStrides.push_back(stride * sourceStride);
}
view.shape.assign(subview.getStaticSizes().begin(), subview.getStaticSizes().end());
view.strides = std::move(nextStrides);
}
return view;
}
SmallVector<unsigned, 8> getUsedWeightIndices(Block& block) {
SmallVector<unsigned, 8> indices;
auto addIndex = [&](unsigned weightIndex) {
if (!llvm::is_contained(indices, weightIndex))
indices.push_back(weightIndex);
};
block.walk([&](pim::PimVMMOp vmmOp) { addIndex(vmmOp.getWeightIndex()); });
llvm::sort(indices);
return indices;
}
SmallVector<unsigned, 8> getUsedWeightIndices(pim::PimCoreOp coreOp) {
return getUsedWeightIndices(coreOp.getBody().front());
}
SmallVector<Operation*> collectTopLevelCoreLikeOps(func::FuncOp funcOp) {
SmallVector<Operation*> coreLikeOps;
for (Operation& op : funcOp.getBody().front())
if (dyn_cast<pim::PimCoreOp>(&op) || dyn_cast<pim::PimCoreBatchOp>(&op))
coreLikeOps.push_back(&op);
return coreLikeOps;
}
} // namespace
llvm::DenseMap<size_t, llvm::DenseMap<mlir::Value, std::string>>
createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) {
ModuleOp moduleOp = funcOp->getParentOfType<ModuleOp>();
auto coreWeightsDirPath = outputDirPath + "/weights";
auto error = sys::fs::create_directory(coreWeightsDirPath);
assert(!error && "Error creating weights directory");
size_t indexFileName = 0;
int64_t xbarSize = crossbarSize.getValue();
llvm::DenseMap<size_t, llvm::DenseMap<mlir::Value, std::string>> mapCoreWeightToFileName;
llvm::DenseMap<memref::GlobalOp, std::string> mapGlobalOpToFileName;
SmallVector<Operation*> coreLikeOps = collectTopLevelCoreLikeOps(funcOp);
for (Operation* op : coreLikeOps) {
auto processCore = [&](pim::PimCoreOp coreOp) {
size_t coreId = static_cast<size_t>(coreOp.getCoreId());
for (unsigned index : getUsedWeightIndices(coreOp)) {
if (index >= coreOp.getWeights().size()) {
coreOp.emitWarning("Weight index " + std::to_string(index) + " is out of range");
assert(index < coreOp.getWeights().size() && "Weight index is out of range");
}
mlir::Value weight = coreOp.getWeights()[index];
auto weightView = resolveDenseWeightView(moduleOp, weight);
if (failed(weightView)) {
coreOp.emitWarning("Weight is not from a memref.get_global at index " + std::to_string(index));
assert(succeeded(weightView) && "Weight is not from a dense memref.global view");
}
if (mapCoreWeightToFileName[coreId].contains(weight))
continue;
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
auto globalOp = getGlobalOp ? lookupGlobalForGetGlobal(moduleOp, getGlobalOp) : memref::GlobalOp {};
if (globalOp && mapGlobalOpToFileName.contains(globalOp)) {
auto& fileName = mapGlobalOpToFileName[globalOp];
mapCoreWeightToFileName[coreId].insert({weight, fileName});
continue;
}
DenseElementsAttr denseAttr = weightView->denseAttr;
ArrayRef<int64_t> shape = weightView->shape;
assert(isMatrixShape(shape) && "Weight matrix must be 2-dimensional");
int64_t numRows = shape[0];
int64_t numCols = shape[1];
assert(numRows <= xbarSize && numCols <= xbarSize && "Weight dimensions must not exceed crossbar size");
size_t elementByteWidth = denseAttr.getElementType().getIntOrFloatBitWidth() / 8;
std::string newFileName = "crossbar_" + std::to_string(indexFileName++) + ".bin";
auto weightFilePath = (coreWeightsDirPath + "/" + newFileName).str();
std::error_code errorCode;
raw_fd_ostream weightFileStream(weightFilePath, errorCode, sys::fs::OF_None);
if (errorCode) {
errs() << "Error while opening weight file `" << weightFilePath << "`: " << errorCode.message() << '\n';
assert(errorCode);
}
uint64_t zero = 0;
for (int64_t row = 0; row < xbarSize; row++) {
for (int64_t col = 0; col < xbarSize; col++) {
if (row < numRows && col < numCols) {
int64_t elementIndex = weightView->offset + row * weightView->strides[0] + col * weightView->strides[1];
APInt bits = denseAttr.getValues<APFloat>()[elementIndex].bitcastToAPInt();
uint64_t word = bits.getZExtValue();
weightFileStream.write(reinterpret_cast<const char*>(&word), elementByteWidth);
}
else {
weightFileStream.write(reinterpret_cast<const char*>(&zero), elementByteWidth);
}
}
}
weightFileStream.close();
if (globalOp)
mapGlobalOpToFileName.insert({globalOp, newFileName});
mapCoreWeightToFileName[coreId].insert({weight, newFileName});
}
return success();
};
if (auto coreOp = dyn_cast<pim::PimCoreOp>(op)) {
(void) processCore(coreOp);
continue;
}
auto coreBatchOp = cast<pim::PimCoreBatchOp>(op);
for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane)
if (failed(withScalarCoreFromBatchLane(coreBatchOp, lane, processCore)))
return mapCoreWeightToFileName;
}
return mapCoreWeightToFileName;
}
} // namespace onnx_mlir
+16
View File
@@ -0,0 +1,16 @@
#pragma once
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/StringRef.h"
#include <string>
namespace onnx_mlir {
llvm::DenseMap<size_t, llvm::DenseMap<mlir::Value, std::string>>
createAndPopulateWeightFolder(mlir::func::FuncOp funcOp, llvm::StringRef outputDirPath);
} // namespace onnx_mlir
@@ -3,6 +3,11 @@ mlir_tablegen(ONNXToSpatial.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
add_public_tablegen_target(ONNXToSpatialIncGen)
add_pim_library(OMONNXToSpatial
ConversionPatterns.cpp
HostFoldability.cpp
HostLegality.cpp
PrePatterns.cpp
PostPatterns.cpp
Patterns/Math/Conv.cpp
Patterns/Math/Elementwise.cpp
Patterns/Math/Gemm.cpp
@@ -18,7 +23,9 @@ add_pim_library(OMONNXToSpatial
Patterns/Tensor/Reshape.cpp
Patterns/Tensor/Split.cpp
ONNXToSpatialPass.cpp
Common.cpp
Common/ComputeRegionBuilder.cpp
Common/ShapeTilingUtils.cpp
Common/WeightMaterialization.cpp
EXCLUDE_FROM_OM_LIBS
-279
View File
@@ -1,279 +0,0 @@
#pragma once
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Transforms/DialectConversion.h"
#include <cassert>
#include <type_traits>
#include <utility>
#include "llvm/ADT/SmallPtrSet.h"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
namespace onnx_mlir {
template <class ShapedType>
inline auto getImageWidth(const ShapedType& shapedType) {
return shapedType.getDimSize(2);
}
template <class ShapedType>
inline auto getImageHeight(const ShapedType& shapedType) {
return shapedType.getDimSize(3);
}
template <class ShapedType>
inline auto getImageChannel(const ShapedType& shapedType) {
return shapedType.getDimSize(1);
}
template <class ShapedType>
inline auto getImageN(const ShapedType& shapedType) {
return shapedType.getDimSize(0);
}
template <class ShapedType>
inline auto getKernelWidth(const ShapedType& shapedType) {
return shapedType.getDimSize(2);
}
template <class ShapedType>
inline auto getKernelHeight(const ShapedType& shapedType) {
return shapedType.getDimSize(3);
}
template <class ShapedType>
inline auto getFilterCount(const ShapedType& shapedType) {
return shapedType.getDimSize(0);
}
using HSliceId = size_t;
using CoreId = size_t;
template <class A, class B, class C = std::common_type_t<A, B>>
constexpr C ceilIntegerDivide(A a, B b) {
static_assert(std::is_integral_v<A>, "A must be an integer type");
static_assert(std::is_integral_v<B>, "B must be an integer type");
C ac = static_cast<C>(a);
C bc = static_cast<C>(b);
return 1 + (ac - 1) / bc;
}
template <class A, class B, class C = std::common_type_t<A, B>>
constexpr std::pair<C, C> ceilIntegerDivideWithRemainder(A a, B b) {
static_assert(std::is_integral_v<A>, "A must be an integer type");
static_assert(std::is_integral_v<B>, "B must be an integer type");
C ac = static_cast<C>(a);
C bc = static_cast<C>(b);
return {ceilIntegerDivide(ac, bc), ac % bc};
}
template <class T>
bool isVectorShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2 && (shape[0] == 1 || shape[1] == 1);
}
template <class T>
bool isMatrixShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2;
}
template <class T>
bool isHVectorShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2 && shape[0] == 1;
}
template <class T>
bool isVVectorShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2 && shape[1] == 1;
}
template <class T>
T getVectorLength(mlir::ArrayRef<T> shape) {
assert(isVectorShape(shape));
return shape[0] != 1 ? shape[0] : shape[1];
}
inline auto getTensorShape(mlir::Value tensor) {
return mlir::cast<mlir::RankedTensorType>(tensor.getType()).getShape();
}
inline bool isWeightLikeComputeOperand(mlir::Value value) {
auto rankedType = mlir::dyn_cast<mlir::RankedTensorType>(value.getType());
if (!rankedType || !isMatrixShape(rankedType.getShape()))
return false;
llvm::SmallPtrSet<mlir::Operation*, 8> visited;
while (auto* definingOp = value.getDefiningOp()) {
if (!visited.insert(definingOp).second)
return false;
if (hasWeightAlways(definingOp))
return true;
if (auto extractSliceOp = mlir::dyn_cast<mlir::tensor::ExtractSliceOp>(definingOp)) {
value = extractSliceOp.getSource();
continue;
}
if (auto expandShapeOp = mlir::dyn_cast<mlir::tensor::ExpandShapeOp>(definingOp)) {
value = expandShapeOp.getSrc();
continue;
}
if (auto collapseShapeOp = mlir::dyn_cast<mlir::tensor::CollapseShapeOp>(definingOp)) {
value = collapseShapeOp.getSrc();
continue;
}
if (auto transposeOp = mlir::dyn_cast<mlir::ONNXTransposeOp>(definingOp)) {
value = transposeOp.getData();
continue;
}
return false;
}
return false;
}
namespace detail {
inline mlir::ValueRange getBlockArgs(mlir::Block* block) { return mlir::ValueRange(block->getArguments()); }
template <typename Fn, size_t... Is>
decltype(auto) invokeWithBlockArgs(Fn&& fn, mlir::Block* block, std::index_sequence<Is...>) {
return std::forward<Fn>(fn)(block->getArgument(Is)...);
}
template <typename Fn, size_t... Is>
decltype(auto) invokeWithValues(Fn&& fn, mlir::ArrayRef<mlir::Value> values, std::index_sequence<Is...>) {
return std::forward<Fn>(fn)(values[Is]...);
}
template <size_t>
using ValueArg = mlir::Value;
template <typename Fn, typename Seq>
struct InvokeWithBlockArgsResult;
template <typename Fn, size_t... Is>
struct InvokeWithBlockArgsResult<Fn, std::index_sequence<Is...>> {
using type = std::invoke_result_t<Fn, ValueArg<Is>...>;
};
template <typename Fn, typename Seq>
using InvokeWithBlockArgsResultT = typename InvokeWithBlockArgsResult<Fn, Seq>::type;
template <typename Fn>
using InvokeWithValueRangeResultT = std::invoke_result_t<Fn, mlir::ValueRange>;
} // namespace detail
template <size_t NumInputs, typename RewriterT, typename BodyFn>
auto createSpatCompute(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
assert(inputs.size() == NumInputs && "NumInputs must match the number of input values");
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
auto* block = new mlir::Block();
for (mlir::Value input : inputs)
block->addArgument(input.getType(), loc);
computeOp.getBody().push_back(block);
rewriter.setInsertionPointToStart(block);
using BodyResult = detail::InvokeWithBlockArgsResultT<std::decay_t<BodyFn>, std::make_index_sequence<NumInputs>>;
if constexpr (std::is_same_v<BodyResult, mlir::LogicalResult>) {
auto bodyResult =
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
if (mlir::failed(bodyResult)) {
rewriter.setInsertionPointAfter(computeOp);
rewriter.eraseOp(computeOp);
return mlir::FailureOr<spatial::SpatCompute>(mlir::failure());
}
rewriter.setInsertionPointAfter(computeOp);
return mlir::FailureOr<spatial::SpatCompute>(computeOp);
}
else {
static_assert(std::is_same_v<BodyResult, void>, "createSpatCompute body must return void or mlir::LogicalResult");
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
rewriter.setInsertionPointAfter(computeOp);
return computeOp;
}
}
template <typename RewriterT, typename BodyFn>
auto createSpatCompute(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
auto* block = new mlir::Block();
for (mlir::Value input : inputs)
block->addArgument(input.getType(), loc);
computeOp.getBody().push_back(block);
rewriter.setInsertionPointToStart(block);
using BodyResult = detail::InvokeWithValueRangeResultT<std::decay_t<BodyFn>>;
if constexpr (std::is_same_v<BodyResult, mlir::LogicalResult>) {
auto bodyResult = std::forward<BodyFn>(body)(detail::getBlockArgs(block));
if (mlir::failed(bodyResult)) {
rewriter.setInsertionPointAfter(computeOp);
rewriter.eraseOp(computeOp);
return mlir::FailureOr<spatial::SpatCompute>(mlir::failure());
}
rewriter.setInsertionPointAfter(computeOp);
return mlir::FailureOr<spatial::SpatCompute>(computeOp);
}
else {
static_assert(std::is_same_v<BodyResult, void>, "createSpatCompute body must return void or mlir::LogicalResult");
std::forward<BodyFn>(body)(detail::getBlockArgs(block));
rewriter.setInsertionPointAfter(computeOp);
return computeOp;
}
}
llvm::SmallVector<mlir::Value> sliceTensor(const mlir::Value& tensorToSlice,
size_t axis,
int64_t sliceSize,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location loc);
llvm::SmallVector<mlir::Value> sliceVector(const mlir::Value& vectorToSlice,
int64_t sliceSize,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location loc);
llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>> sliceVectorPerCrossbarPerCore(
const mlir::Value& vectorToSlice, mlir::ConversionPatternRewriter& rewriter, mlir::Location loc);
llvm::DenseMap<HSliceId, llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>>>
tileMatrix(mlir::Value& matrixToTile,
int64_t hSliceSize,
int64_t vSliceSize,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location& loc);
mlir::tensor::SplatOp broadcastToVector(mlir::Value scalarToBroadcast,
int64_t length,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location loc);
mlir::Value sumTensors(mlir::ArrayRef<mlir::Value> tensors, mlir::ConversionPatternRewriter& rewriter);
}; // namespace onnx_mlir
@@ -0,0 +1,7 @@
#pragma once
#include "ComputeRegionBuilder.hpp"
#include "ShapeTilingUtils.hpp"
#include "WeightMaterialization.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -0,0 +1,39 @@
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVector.h"
#include "ComputeRegionBuilder.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
Value sumTensors(ArrayRef<Value> tensors, ConversionPatternRewriter& rewriter) {
if (tensors.size() == 1)
return tensors[0];
SmallVector<Value> tensors1 = {tensors.begin(), tensors.end()};
SmallVector<Value> tensors2;
tensors2.reserve(tensors.size() / 2);
auto* currTensors = &tensors1;
auto* nextTensors = &tensors2;
while (currTensors->size() > 1) {
for (size_t i = 0; i < currTensors->size() - 1; i += 2) {
Value a = (*currTensors)[i];
Value b = (*currTensors)[i + 1];
rewriter.setInsertionPointAfterValue(b);
auto addedValue = spatial::SpatVAddOp::create(rewriter, a.getLoc(), a.getType(), a, b);
nextTensors->push_back(addedValue);
}
if (currTensors->size() % 2 == 1)
nextTensors->push_back(currTensors->back());
std::swap(currTensors, nextTensors);
nextTensors->clear();
}
assert(currTensors->size() == 1 && "Expected a single input at this point.");
return (*currTensors)[0];
}
} // namespace onnx_mlir
@@ -0,0 +1,153 @@
#pragma once
#include "mlir/IR/Block.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Transforms/DialectConversion.h"
#include <cassert>
#include <cstddef>
#include <type_traits>
#include <utility>
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
namespace onnx_mlir {
namespace detail {
inline mlir::ValueRange getBlockArgs(mlir::Block* block) { return mlir::ValueRange(block->getArguments()); }
template <typename Fn, size_t... Is>
decltype(auto) invokeWithBlockArgs(Fn&& fn, mlir::Block* block, std::index_sequence<Is...>) {
return std::forward<Fn>(fn)(block->getArgument(Is)...);
}
template <typename Fn, size_t... Is>
decltype(auto) invokeWithValues(Fn&& fn, mlir::ArrayRef<mlir::Value> values, std::index_sequence<Is...>) {
return std::forward<Fn>(fn)(values[Is]...);
}
template <size_t>
using ValueArg = mlir::Value;
template <typename Fn, typename Seq>
struct InvokeWithBlockArgsResult;
template <typename Fn, size_t... Is>
struct InvokeWithBlockArgsResult<Fn, std::index_sequence<Is...>> {
using type = std::invoke_result_t<Fn, ValueArg<Is>...>;
};
template <typename Fn, typename Seq>
using InvokeWithBlockArgsResultT = typename InvokeWithBlockArgsResult<Fn, Seq>::type;
template <typename Fn>
using InvokeWithValueRangeResultT = std::invoke_result_t<Fn, mlir::ValueRange>;
} // namespace detail
template <typename RewriterT>
inline mlir::Value createSpatConcat(RewriterT& rewriter, mlir::Location loc, int64_t axis, mlir::ValueRange inputs) {
assert(!inputs.empty() && "spat.concat requires at least one input");
if (inputs.size() == 1)
return inputs.front();
auto firstType = mlir::cast<mlir::RankedTensorType>(inputs.front().getType());
auto outputShape = llvm::to_vector(firstType.getShape());
int64_t concatDimSize = 0;
bool concatDimDynamic = false;
for (mlir::Value input : inputs) {
auto inputType = mlir::cast<mlir::RankedTensorType>(input.getType());
assert(inputType.getRank() == firstType.getRank() && "spat.concat expects same-rank inputs");
if (mlir::ShapedType::isDynamic(inputType.getDimSize(axis)))
concatDimDynamic = true;
else
concatDimSize += inputType.getDimSize(axis);
}
outputShape[axis] = concatDimDynamic ? mlir::ShapedType::kDynamic : concatDimSize;
auto outputType = mlir::RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
return spatial::SpatConcatOp::create(rewriter, loc, outputType, rewriter.getI64IntegerAttr(axis), inputs).getOutput();
}
/// Builds a `spat.compute` with a fixed number of SSA inputs and erases it if
/// the body callback reports failure.
template <size_t NumInputs, typename RewriterT, typename BodyFn>
auto createSpatCompute(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
assert(inputs.size() == NumInputs && "NumInputs must match the number of input values");
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
auto* block = new mlir::Block();
for (mlir::Value input : inputs)
block->addArgument(input.getType(), loc);
computeOp.getBody().push_back(block);
rewriter.setInsertionPointToStart(block);
using BodyResult = detail::InvokeWithBlockArgsResultT<std::decay_t<BodyFn>, std::make_index_sequence<NumInputs>>;
if constexpr (std::is_same_v<BodyResult, void>) {
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
rewriter.setInsertionPointAfter(computeOp);
return computeOp;
}
else {
auto bodyResult =
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
if (mlir::failed(bodyResult)) {
rewriter.setInsertionPointAfter(computeOp);
rewriter.eraseOp(computeOp);
return mlir::FailureOr<spatial::SpatCompute>(mlir::failure());
}
rewriter.setInsertionPointAfter(computeOp);
return mlir::FailureOr<spatial::SpatCompute>(computeOp);
}
}
/// Builds a `spat.compute` whose body consumes the block arguments as a single
/// `ValueRange`, which is convenient for variadic reductions/concats.
template <typename RewriterT, typename BodyFn>
auto createSpatCompute(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
auto* block = new mlir::Block();
for (mlir::Value input : inputs)
block->addArgument(input.getType(), loc);
computeOp.getBody().push_back(block);
rewriter.setInsertionPointToStart(block);
using BodyResult = detail::InvokeWithValueRangeResultT<std::decay_t<BodyFn>>;
if constexpr (std::is_same_v<BodyResult, void>) {
std::forward<BodyFn>(body)(detail::getBlockArgs(block));
rewriter.setInsertionPointAfter(computeOp);
return computeOp;
}
else {
auto bodyResult = std::forward<BodyFn>(body)(detail::getBlockArgs(block));
if (mlir::failed(bodyResult)) {
rewriter.setInsertionPointAfter(computeOp);
rewriter.eraseOp(computeOp);
return mlir::FailureOr<spatial::SpatCompute>(mlir::failure());
}
rewriter.setInsertionPointAfter(computeOp);
return mlir::FailureOr<spatial::SpatCompute>(computeOp);
}
}
mlir::Value sumTensors(mlir::ArrayRef<mlir::Value> tensors, mlir::ConversionPatternRewriter& rewriter);
} // namespace onnx_mlir
@@ -1,24 +1,12 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/Casting.h"
#include <cassert>
#include <optional>
#include <utility>
#include "Common.hpp"
#include "ShapeTilingUtils.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
using namespace mlir;
@@ -44,10 +32,29 @@ SmallVector<Value> sliceTensor(
for (int64_t i = 0; i < numSlices; i++) {
offsets[axis] = rewriter.getIndexAttr(i * sliceSize);
if (i == numSlices - 1 && lastSliceSize != 0)
int64_t currentSliceSize = sliceSize;
if (i == numSlices - 1 && lastSliceSize != 0) {
currentSliceSize = lastSliceSize;
sizes[axis] = rewriter.getIndexAttr(lastSliceSize);
}
Value slice = tensor::ExtractSliceOp::create(rewriter, loc, tensorToSlice, offsets, sizes, strides);
SmallVector<int64_t> sliceShape(shape.begin(), shape.end());
sliceShape[axis] = currentSliceSize;
auto sliceType =
RankedTensorType::get(sliceShape, cast<RankedTensorType>(tensorToSlice.getType()).getElementType());
Value slice;
if (isHostFoldableValue(tensorToSlice)) {
slice = tensor::ExtractSliceOp::create(rewriter, loc, tensorToSlice, offsets, sizes, strides);
}
else {
auto sliceCompute =
createSpatCompute<1>(rewriter, loc, TypeRange {sliceType}, {}, ValueRange {tensorToSlice}, [&](Value input) {
Value computedSlice = tensor::ExtractSliceOp::create(rewriter, loc, input, offsets, sizes, strides);
spatial::SpatYieldOp::create(rewriter, loc, computedSlice);
});
slice = sliceCompute.getResult(0);
}
slices.push_back(slice);
}
@@ -107,31 +114,4 @@ broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatternRewr
return tensor::SplatOp::create(rewriter, loc, type, elementValue);
}
Value sumTensors(ArrayRef<Value> tensors, ConversionPatternRewriter& rewriter) {
if (tensors.size() == 1)
return tensors[0];
SmallVector<Value> tensors1 = {tensors.begin(), tensors.end()};
SmallVector<Value> tensors2;
tensors2.reserve(tensors.size() / 2);
auto* currTensors = &tensors1;
auto* nextTensors = &tensors2;
while (currTensors->size() > 1) {
for (size_t i = 0; i < currTensors->size() - 1; i += 2) {
Value a = (*currTensors)[i];
Value b = (*currTensors)[i + 1];
rewriter.setInsertionPointAfterValue(b);
auto addedValue = spatial::SpatVAddOp::create(rewriter, a.getLoc(), a.getType(), a, b);
nextTensors->push_back(addedValue);
}
if (currTensors->size() % 2 == 1)
nextTensors->push_back(currTensors->back());
std::swap(currTensors, nextTensors);
nextTensors->clear();
}
assert(currTensors->size() == 1 && "Expected a single input at this point.");
return (*currTensors)[0];
}
}; // namespace onnx_mlir
} // namespace onnx_mlir
@@ -0,0 +1,144 @@
#pragma once
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Value.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallVector.h"
#include <cassert>
#include <cstddef>
#include <type_traits>
#include <utility>
namespace onnx_mlir {
template <class ShapedType>
inline auto getImageWidth(const ShapedType& shapedType) {
return shapedType.getDimSize(2);
}
template <class ShapedType>
inline auto getImageHeight(const ShapedType& shapedType) {
return shapedType.getDimSize(3);
}
template <class ShapedType>
inline auto getImageChannel(const ShapedType& shapedType) {
return shapedType.getDimSize(1);
}
template <class ShapedType>
inline auto getImageN(const ShapedType& shapedType) {
return shapedType.getDimSize(0);
}
template <class ShapedType>
inline auto getKernelWidth(const ShapedType& shapedType) {
return shapedType.getDimSize(2);
}
template <class ShapedType>
inline auto getKernelHeight(const ShapedType& shapedType) {
return shapedType.getDimSize(3);
}
template <class ShapedType>
inline auto getFilterCount(const ShapedType& shapedType) {
return shapedType.getDimSize(0);
}
using HSliceId = size_t;
using CoreId = size_t;
template <class A, class B, class C = std::common_type_t<A, B>>
constexpr C ceilIntegerDivide(A a, B b) {
static_assert(std::is_integral_v<A>, "A must be an integer type");
static_assert(std::is_integral_v<B>, "B must be an integer type");
C ac = static_cast<C>(a);
C bc = static_cast<C>(b);
return 1 + (ac - 1) / bc;
}
template <class A, class B, class C = std::common_type_t<A, B>>
constexpr std::pair<C, C> ceilIntegerDivideWithRemainder(A a, B b) {
static_assert(std::is_integral_v<A>, "A must be an integer type");
static_assert(std::is_integral_v<B>, "B must be an integer type");
C ac = static_cast<C>(a);
C bc = static_cast<C>(b);
return {ceilIntegerDivide(ac, bc), ac % bc};
}
template <class T>
bool isVectorShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2 && (shape[0] == 1 || shape[1] == 1);
}
template <class T>
bool isMatrixShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2;
}
template <class T>
bool isHVectorShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2 && shape[0] == 1;
}
template <class T>
bool isVVectorShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2 && shape[1] == 1;
}
template <class T>
T getVectorLength(mlir::ArrayRef<T> shape) {
assert(isVectorShape(shape));
return shape[0] != 1 ? shape[0] : shape[1];
}
inline auto getTensorShape(mlir::Value tensor) {
return mlir::cast<mlir::RankedTensorType>(tensor.getType()).getShape();
}
inline bool haveSameStaticShape(mlir::Value lhs, mlir::Value rhs) {
auto lhsType = mlir::dyn_cast<mlir::RankedTensorType>(lhs.getType());
auto rhsType = mlir::dyn_cast<mlir::RankedTensorType>(rhs.getType());
return lhsType && rhsType && lhsType.hasStaticShape() && rhsType.hasStaticShape()
&& lhsType.getShape() == rhsType.getShape();
}
/// Slices a statically shaped tensor along one axis into contiguous pieces of
/// at most `sliceSize` elements.
llvm::SmallVector<mlir::Value> sliceTensor(const mlir::Value& tensorToSlice,
size_t axis,
int64_t sliceSize,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location loc);
llvm::SmallVector<mlir::Value> sliceVector(const mlir::Value& vectorToSlice,
int64_t sliceSize,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location loc);
/// Partitions one logical vector into per-core crossbar-sized slices using the
/// current PIM target geometry.
llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>> sliceVectorPerCrossbarPerCore(
const mlir::Value& vectorToSlice, mlir::ConversionPatternRewriter& rewriter, mlir::Location loc);
/// Tiles a matrix first across output columns and then across input rows so it
/// can be assigned to crossbars grouped by core.
llvm::DenseMap<HSliceId, llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>>>
tileMatrix(mlir::Value& matrixToTile,
int64_t hSliceSize,
int64_t vSliceSize,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location& loc);
mlir::tensor::SplatOp broadcastToVector(mlir::Value scalarToBroadcast,
int64_t length,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location loc);
} // namespace onnx_mlir
@@ -0,0 +1,114 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
#include "ShapeTilingUtils.hpp"
#include "WeightMaterialization.hpp"
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
bool isWeightLikeComputeOperand(Value value) {
auto rankedType = dyn_cast<RankedTensorType>(value.getType());
if (!rankedType || !isMatrixShape(rankedType.getShape()))
return false;
llvm::SmallPtrSet<Operation*, 8> visited;
while (auto* definingOp = value.getDefiningOp()) {
if (!visited.insert(definingOp).second)
return false;
if (isa<arith::ConstantOp, ONNXConstantOp>(definingOp) || hasWeightAlways(definingOp))
return true;
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) {
value = extractSliceOp.getSource();
continue;
}
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(definingOp)) {
value = expandShapeOp.getSrc();
continue;
}
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(definingOp)) {
value = collapseShapeOp.getSrc();
continue;
}
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(definingOp)) {
value = transposeOp.getData();
continue;
}
return false;
}
return false;
}
FailureOr<Value> materializeWeightLikeValueInBlock(Value value, IRRewriter& rewriter, IRMapping& mapper) {
if (auto mapped = mapper.lookupOrNull(value))
return cast<Value>(mapped);
Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return failure();
if (isa<arith::ConstantOp, ONNXConstantOp>(definingOp)) {
auto tensorType = dyn_cast<RankedTensorType>(value.getType());
if (!tensorType || !tensorType.hasStaticShape())
return failure();
SmallVector<OpFoldResult> offsets(tensorType.getRank(), rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> sizes;
SmallVector<OpFoldResult> strides(tensorType.getRank(), rewriter.getIndexAttr(1));
sizes.reserve(tensorType.getRank());
for (int64_t dim : tensorType.getShape())
sizes.push_back(rewriter.getIndexAttr(dim));
auto referencedValue =
tensor::ExtractSliceOp::create(rewriter, value.getLoc(), tensorType, value, offsets, sizes, strides);
mapper.map(value, referencedValue.getResult());
return referencedValue.getResult();
}
if (!isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, ONNXTransposeOp>(definingOp))
return failure();
IRMapping localMapper;
for (Value operand : definingOp->getOperands()) {
if (auto mapped = mapper.lookupOrNull(operand)) {
localMapper.map(operand, cast<Value>(mapped));
continue;
}
if (isWeightLikeComputeOperand(operand)) {
auto clonedOperand = materializeWeightLikeValueInBlock(operand, rewriter, mapper);
if (failed(clonedOperand))
return failure();
localMapper.map(operand, *clonedOperand);
continue;
}
localMapper.map(operand, operand);
}
Operation* clonedOp = rewriter.clone(*definingOp, localMapper);
for (auto [oldResult, newResult] : llvm::zip(definingOp->getResults(), clonedOp->getResults()))
mapper.map(oldResult, newResult);
auto mapped = mapper.lookupOrNull(value);
if (!mapped)
return failure();
return cast<Value>(mapped);
}
} // namespace onnx_mlir
@@ -0,0 +1,18 @@
#pragma once
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
namespace onnx_mlir {
/// Returns true when a matrix-valued compute operand is ultimately backed by a
/// weight-marked constant/view chain and can be promoted into weights.
bool isWeightLikeComputeOperand(mlir::Value value);
/// Rebuilds the view/transpose chain of a promoted weight operand inside a new
/// compute body while reusing already-materialized intermediate values.
llvm::FailureOr<mlir::Value>
materializeWeightLikeValueInBlock(mlir::Value value, mlir::IRRewriter& rewriter, mlir::IRMapping& mapper);
} // namespace onnx_mlir
@@ -0,0 +1,32 @@
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc"
} // namespace
void populateConversionPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx) {
patterns.add<removeLRN>(ctx);
populateElementwisePatterns(patterns, ctx);
populateGemmPatterns(patterns, ctx);
populateConvPatterns(patterns, ctx);
populatePoolPatterns(patterns, ctx);
populateReduceMeanPatterns(patterns, ctx);
populateReluPatterns(patterns, ctx);
populateSigmoidPatterns(patterns, ctx);
populateSoftmaxPatterns(patterns, ctx);
populateConcatPatterns(patterns, ctx);
populateGatherPatterns(patterns, ctx);
populateResizePatterns(patterns, ctx);
populateReshapePatterns(patterns, ctx);
populateSplitPatterns(patterns, ctx);
}
} // namespace onnx_mlir
@@ -5,6 +5,8 @@
namespace onnx_mlir {
void populateConversionPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateConvPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateElementwisePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
@@ -0,0 +1,75 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static bool hasStaticUnitStrides(tensor::ExtractSliceOp extractSliceOp) {
return llvm::all_of(extractSliceOp.getStaticStrides(), [](int64_t stride) { return stride == 1; });
}
static bool isStaticTensorResult(Operation* op) {
return llvm::all_of(op->getResultTypes(), [](Type type) {
auto shapedType = dyn_cast<ShapedType>(type);
return shapedType && shapedType.hasStaticShape();
});
}
static bool isHostFoldableOpImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*>& visited) {
if (!op || !visited.insert(op).second)
return false;
if (isa<arith::ConstantOp, ONNXConstantOp, ONNXNoneOp>(op))
return true;
if (!isStaticTensorResult(op))
return false;
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(op))
return isHostFoldableValue(transposeOp.getData());
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(op))
return isHostFoldableValue(collapseShapeOp.getSrc());
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(op))
return isHostFoldableValue(expandShapeOp.getSrc());
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op))
return hasStaticUnitStrides(extractSliceOp) && isHostFoldableValue(extractSliceOp.getSource());
if (auto extractRowsOp = dyn_cast<spatial::SpatExtractRowsOp>(op))
return isHostFoldableValue(extractRowsOp.getInput());
if (auto concatOp = dyn_cast<spatial::SpatConcatOp>(op))
return llvm::all_of(concatOp.getInputs(), isHostFoldableValue);
return false;
}
} // namespace
bool isHostFoldableValue(Value value) {
auto* definingOp = value.getDefiningOp();
if (!definingOp)
return false;
llvm::SmallPtrSet<Operation*, 8> visited;
return isHostFoldableOpImpl(definingOp, visited);
}
bool isHostFoldableOp(Operation* op) {
llvm::SmallPtrSet<Operation*, 8> visited;
return isHostFoldableOpImpl(op, visited);
}
} // namespace onnx_mlir
@@ -0,0 +1,12 @@
#pragma once
#include "mlir/IR/Operation.h"
#include "mlir/IR/Value.h"
namespace onnx_mlir {
bool isHostFoldableValue(mlir::Value value);
bool isHostFoldableOp(mlir::Operation* op);
} // namespace onnx_mlir
@@ -0,0 +1,29 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostLegality.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
LogicalResult verifyONNXToSpatialHostLegality(func::FuncOp funcOp) {
bool hasFailure = false;
for (Operation& op : funcOp.getFunctionBody().front()) {
if (isa<func::ReturnOp, spatial::SpatCompute, spatial::SpatComputeBatch>(&op))
continue;
if (isHostFoldableOp(&op))
continue;
op.emitOpError("non-foldable top-level runtime op remains after ONNX-to-Spatial; lower it inside spat.compute");
hasFailure = true;
}
return success(!hasFailure);
}
} // namespace onnx_mlir
@@ -0,0 +1,10 @@
#pragma once
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Support/LogicalResult.h"
namespace onnx_mlir {
mlir::LogicalResult verifyONNXToSpatialHostLegality(mlir::func::FuncOp funcOp);
} // namespace onnx_mlir
@@ -1,3 +1,4 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -7,25 +8,18 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_os_ostream.h"
#include <fstream>
#include <iterator>
#include <utility>
#include "Common.hpp"
#include "Common/Common.hpp"
#include "Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostLegality.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.hpp"
#include "src/Accelerators/PIM/Pass/PIMPasses.h"
#include "src/Compiler/CompilerOptions.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -33,12 +27,8 @@ using namespace mlir;
namespace onnx_mlir {
bool haveSameStaticShape(Value lhs, Value rhs);
namespace {
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc"
struct ONNXToSpatialPass : PassWrapper<ONNXToSpatialPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ONNXToSpatialPass)
StringRef getArgument() const override { return "convert-onnx-to-spatial"; }
@@ -48,33 +38,64 @@ struct ONNXToSpatialPass : PassWrapper<ONNXToSpatialPass, OperationPass<ModuleOp
ONNXToSpatialPass(const ONNXToSpatialPass& pass) {}
void runOnOperation() override;
private:
void annotateWeightsConstants(func::FuncOp funcOp) const;
void encapsulateGlobalInstruction(func::FuncOp funcOp);
void mergeTriviallyConnectedComputes(func::FuncOp funcOp);
LogicalResult promoteConstantInputsToWeights(func::FuncOp funcOp);
};
} // namespace
static void populateEmptyFunction(func::FuncOp funcOp) {
IRRewriter rewriter(funcOp.getContext());
IRMapping mapper;
SmallVector<spatial::SpatCompute> computes(funcOp.getOps<spatial::SpatCompute>());
if (!computes.empty())
return;
auto returnOp = cast<func::ReturnOp>(funcOp.getFunctionBody().front().getTerminator());
rewriter.setInsertionPoint(returnOp);
SmallVector<Type> sourceTypes;
SmallVector<Location> sourceLocs;
sourceTypes.reserve(funcOp.getNumArguments());
sourceLocs.reserve(funcOp.getNumArguments());
for (Value source : funcOp.getArguments()) {
sourceTypes.push_back(source.getType());
sourceLocs.push_back(source.getLoc());
}
auto newCompute = spatial::SpatCompute::create(
rewriter, returnOp.getLoc(), returnOp.getOperandTypes(), funcOp.getArguments(), {}, {});
auto* newBlock = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLocs);
for (auto [blockArg, computeArg] : llvm::zip(newBlock->getArguments(), newCompute.getOperands()))
mapper.map(computeArg, blockArg);
newCompute.getProperties().setOperandSegmentSizes({0, static_cast<int>(sourceTypes.size())});
rewriter.setInsertionPointToEnd(newBlock);
for (Operation& op : funcOp.getOps())
if (!isa<spatial::SpatCompute, func::ReturnOp>(&op))
rewriter.clone(op, mapper);
auto yield = spatial::SpatYieldOp::create(rewriter, funcOp.getLoc(), returnOp.getOperands());
for (size_t i = 0; i < yield.getNumOperands(); ++i)
yield.setOperand(i, mapper.lookupOrDefault(yield.getOperand(i)));
for (Operation& op : llvm::make_early_inc_range(funcOp.getOps()))
if (!isa<spatial::SpatCompute, func::ReturnOp>(&op)) {
op.dropAllUses();
rewriter.eraseOp(&op);
}
for (auto [index, computeResult] : llvm::enumerate(newCompute.getResults()))
returnOp.setOperand(index, computeResult);
}
void ONNXToSpatialPass::runOnOperation() {
ModuleOp moduleOp = getOperation();
MLIRContext* ctx = &getContext();
RewritePatternSet mergeActivationPatterns(ctx);
mergeActivationPatterns.add<onnxToArithConstant>(ctx);
mergeActivationPatterns.add<convAddToConvWithBiasLeft>(ctx);
mergeActivationPatterns.add<convAddToConvWithBiasRight>(ctx);
mergeActivationPatterns.add<matMulAddToGemm>(ctx);
mergeActivationPatterns.add<matMulToGemm>(ctx);
mergeActivationPatterns.add<removeFlattenSameShape>(ctx);
populateMatMulRewritePatterns(mergeActivationPatterns, ctx);
RewritePatternSet prePatterns(ctx);
populatePrePatterns(prePatterns, ctx);
if (failed(applyPatternsGreedily(moduleOp, std::move(prePatterns))))
moduleOp.emitWarning("failed to apply ONNX-to-Spatial pre-patterns; continuing");
if (failed(applyPatternsGreedily(moduleOp, std::move(mergeActivationPatterns))))
llvm::dbgs() << "Failed to merge activation patterns, continuing...\n";
IRRewriter rewriter(moduleOp);
auto entryFunc = getPimEntryFunc(moduleOp);
if (failed(entryFunc)) {
signalPassFailure();
@@ -87,8 +108,7 @@ void ONNXToSpatialPass::runOnOperation() {
tensor::TensorDialect,
arith::ArithDialect,
scf::SCFDialect>();
target.addDynamicallyLegalOp<ONNXMatMulOp>(
[](ONNXMatMulOp op) { return cast<ShapedType>(op.getY().getType()).getRank() != 2; });
target.addIllegalOp<ONNXMatMulOp>();
target.addIllegalOp<ONNXAddOp>();
target.addIllegalOp<ONNXDivOp>();
target.addIllegalOp<ONNXMulOp>();
@@ -107,37 +127,29 @@ void ONNXToSpatialPass::runOnOperation() {
target.addIllegalOp<ONNXReduceMeanV13Op>();
target.addIllegalOp<ONNXSplitOp>();
RewritePatternSet patterns(ctx);
patterns.add<removeLRN>(ctx);
populateElementwisePatterns(patterns, ctx);
populateGemmPatterns(patterns, ctx);
populateConvPatterns(patterns, ctx);
populatePoolPatterns(patterns, ctx);
populateReduceMeanPatterns(patterns, ctx);
populateReluPatterns(patterns, ctx);
populateSigmoidPatterns(patterns, ctx);
populateSoftmaxPatterns(patterns, ctx);
populateConcatPatterns(patterns, ctx);
populateGatherPatterns(patterns, ctx);
populateResizePatterns(patterns, ctx);
populateReshapePatterns(patterns, ctx);
populateSplitPatterns(patterns, ctx);
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
RewritePatternSet conversionPatterns(ctx);
populateConversionPatterns(conversionPatterns, ctx);
if (failed(applyPartialConversion(moduleOp, target, std::move(conversionPatterns)))) {
signalPassFailure();
return;
}
RewritePatternSet earlyPostPatterns(ctx);
populateEarlyPostPatterns(earlyPostPatterns, ctx);
if (failed(applyPatternsGreedily(*entryFunc, std::move(earlyPostPatterns)))) {
signalPassFailure();
return;
}
// Count the number of compute ops and check they do not exceed the core count
if (coresCount != -1) {
int computeOpsCount = 0;
for (auto& op : entryFunc->getFunctionBody().front().getOperations())
for (Operation& op : entryFunc->getFunctionBody().front().getOperations())
if (isa<spatial::SpatCompute>(op))
computeOpsCount++;
if (computeOpsCount > coresCount) {
llvm::dbgs() << "Number of compute ops exceeds the core count\n";
entryFunc->emitError() << "number of compute ops (" << computeOpsCount << ") exceeds the core count ("
<< coresCount << ")";
signalPassFailure();
return;
}
@@ -146,337 +158,27 @@ void ONNXToSpatialPass::runOnOperation() {
PassManager cleanupPM(ctx);
cleanupPM.addPass(createCanonicalizerPass());
if (failed(cleanupPM.run(moduleOp)))
llvm::dbgs() << "Failed to run canonicalization cleanup, continuing...\n";
moduleOp.emitWarning("failed to run ONNX-to-Spatial canonicalization cleanup; continuing");
annotateWeightsConstants(*entryFunc);
encapsulateGlobalInstruction(*entryFunc);
if (failed(promoteConstantInputsToWeights(*entryFunc))) {
RewritePatternSet postPatterns(ctx);
populatePostPatterns(postPatterns, ctx);
if (failed(applyPatternsGreedily(*entryFunc, std::move(postPatterns)))) {
signalPassFailure();
return;
}
mergeTriviallyConnectedComputes(*entryFunc);
if (failed(verifyONNXToSpatialHostLegality(*entryFunc))) {
signalPassFailure();
return;
}
populateEmptyFunction(*entryFunc);
// Dump to file for debug
dumpModule(moduleOp, "spatial0");
}
template <typename T>
bool encapsulator(IRRewriter& rewriter, Location loc, Operation* inst, std::function<Value(T)> funcSource) {
if (T toRemoveOp = llvm::dyn_cast_if_present<T>(inst)) {
Value source = funcSource(toRemoveOp);
rewriter.setInsertionPointAfter(toRemoveOp);
if (isa_and_present<spatial::SpatCompute>(source.getDefiningOp())) {
auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), source);
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), {source.getType()}, {loc});
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) 1});
rewriter.setInsertionPointToEnd(BB);
IRMapping mapper;
mapper.map(source, BB->getArgument(0));
auto newInst = rewriter.clone(*inst, mapper);
spatial::SpatYieldOp::create(rewriter, loc, newInst->getResults());
inst->replaceAllUsesWith(newCompute->getResults());
inst->erase();
return true;
}
}
return false;
}
bool encapsulateConcat(IRRewriter& rewriter, Location loc, Operation* inst) {
if (auto toRemoveOp = llvm::dyn_cast_if_present<tensor::ConcatOp>(inst)) {
auto sources = toRemoveOp.getInputs();
rewriter.setInsertionPointAfter(toRemoveOp);
if (llvm::any_of(
sources, [](auto source) { return isa_and_present<spatial::SpatCompute>(source.getDefiningOp()); })) {
auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), sources);
SmallVector<Type> sourceTypes;
SmallVector<Location> sourceLoc;
for (auto source : sources) {
sourceTypes.push_back(source.getType());
sourceLoc.push_back(loc);
}
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLoc);
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) sources.size()});
rewriter.setInsertionPointToEnd(BB);
IRMapping mapper;
for (auto [source, bbArg] : llvm::zip(sources, BB->getArguments()))
mapper.map(source, bbArg);
auto newConcat = rewriter.clone(*inst, mapper);
spatial::SpatYieldOp::create(rewriter, loc, newConcat->getResults());
inst->replaceAllUsesWith(newCompute->getResults());
inst->erase();
return true;
}
}
return false;
}
static FailureOr<Value> materializeWeightLikeValueInBlock(Value value, IRRewriter& rewriter, IRMapping& mapper) {
if (auto mapped = mapper.lookupOrNull(value))
return cast<Value>(mapped);
Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return failure();
if (isa<arith::ConstantOp, ONNXConstantOp>(definingOp)) {
auto tensorType = dyn_cast<RankedTensorType>(value.getType());
if (!tensorType || !tensorType.hasStaticShape())
return failure();
SmallVector<OpFoldResult> offsets(tensorType.getRank(), rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> sizes;
SmallVector<OpFoldResult> strides(tensorType.getRank(), rewriter.getIndexAttr(1));
sizes.reserve(tensorType.getRank());
for (int64_t dim : tensorType.getShape())
sizes.push_back(rewriter.getIndexAttr(dim));
auto referencedValue =
tensor::ExtractSliceOp::create(rewriter, value.getLoc(), tensorType, value, offsets, sizes, strides);
mapper.map(value, referencedValue.getResult());
return referencedValue.getResult();
}
if (!isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, ONNXTransposeOp>(definingOp))
return failure();
IRMapping localMapper;
for (Value operand : definingOp->getOperands()) {
if (auto mapped = mapper.lookupOrNull(operand)) {
localMapper.map(operand, cast<Value>(mapped));
continue;
}
if (isWeightLikeComputeOperand(operand)) {
auto clonedOperand = materializeWeightLikeValueInBlock(operand, rewriter, mapper);
if (failed(clonedOperand))
return failure();
localMapper.map(operand, *clonedOperand);
continue;
}
localMapper.map(operand, operand);
}
Operation* clonedOp = rewriter.clone(*definingOp, localMapper);
for (auto [oldResult, newResult] : llvm::zip(definingOp->getResults(), clonedOp->getResults()))
mapper.map(oldResult, newResult);
auto mapped = mapper.lookupOrNull(value);
if (!mapped)
return failure();
return cast<Value>(mapped);
}
// TODO what we want to keep in global?
void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) {
Location loc = funcOp.getLoc();
IRRewriter rewriter(&getContext());
bool keep = true;
while (keep) {
keep = false;
for (auto& instruction : llvm::make_early_inc_range(funcOp.getOps())) {
keep |= encapsulator<tensor::ExtractSliceOp>(
rewriter, loc, &instruction, [](tensor::ExtractSliceOp extract) { return extract.getSource(); });
keep |= encapsulator<tensor::ExpandShapeOp>(
rewriter, loc, &instruction, [](tensor::ExpandShapeOp expand) { return expand.getSrc(); });
keep |= encapsulator<ONNXTransposeOp>(
rewriter, loc, &instruction, [](ONNXTransposeOp transpose) { return transpose.getData(); });
keep |= encapsulator<tensor::CollapseShapeOp>(
rewriter, loc, &instruction, [](tensor::CollapseShapeOp collapse) { return collapse.getSrc(); });
keep |= encapsulateConcat(rewriter, loc, &instruction);
}
}
}
void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
Location loc = funcOp.getLoc();
IRRewriter rewriter(&getContext());
SmallVector<spatial::SpatCompute> trivialComputes;
llvm::SmallSet<spatial::SpatCompute, 8> toErase;
for (auto compute : funcOp.getOps<spatial::SpatCompute>())
if (compute->hasOneUse()) {
auto& use = *compute->getUses().begin();
auto user = dyn_cast<spatial::SpatCompute>(use.getOwner());
if (user && user.getInputs().size() == 1 && use.getOperandNumber() >= user.getWeights().size())
trivialComputes.push_back(compute);
}
while (!trivialComputes.empty()) {
auto compute = trivialComputes.front();
if (compute.use_empty()) {
std::swap(trivialComputes.front(), trivialComputes.back());
trivialComputes.pop_back();
continue;
}
auto& computeUse = *compute->getUses().begin();
auto child = cast<spatial::SpatCompute>(computeUse.getOwner());
auto usedResult = cast<OpResult>(computeUse.get()).getResultNumber();
auto childArgIndex = computeUse.getOperandNumber() - child.getWeights().size();
rewriter.setInsertionPointAfter(compute.getOperation());
auto newCompute =
spatial::SpatCompute::create(rewriter, loc, child.getResultTypes(), compute.getOperands());
newCompute.getProperties().setOperandSegmentSizes(
{static_cast<int>(compute.getWeights().size()), static_cast<int>(compute.getInputs().size())});
IRMapping mapper;
auto weightMutableIter = newCompute.getWeightsMutable();
for (auto weight : child.getWeights()) {
auto founded = llvm::find(newCompute.getWeights(), weight);
if (founded == newCompute.getWeights().end()) {
weightMutableIter.append(weight);
auto last = weightMutableIter.end();
last = std::prev(last, 1);
mapper.map(weight, last->get());
}
else {
mapper.map(weight, *founded);
}
}
compute.getBodyRegion().cloneInto(&newCompute.getBodyRegion(), mapper);
auto newTerminator = newCompute.getBody().front().getTerminator();
mapper.map(child.getBody().front().getArgument(childArgIndex), newTerminator->getOperand(usedResult));
newTerminator->erase();
rewriter.setInsertionPoint(&newCompute.getBody().front(), newCompute.getBody().front().end());
for (auto& op : child.getBody().front()) {
auto newInst = rewriter.clone(op, mapper);
if (auto vmOp = llvm::dyn_cast<spatial::SpatWeightedMVMOp>(newInst)) {
auto oldIndex = vmOp.getWeightIndex();
auto newWeight = mapper.lookup(*std::next(child.getWeights().begin(), oldIndex));
auto newIndex = std::distance(newCompute.getWeights().begin(), llvm::find(newCompute.getWeights(), newWeight));
vmOp.setWeightIndex(newIndex);
}
if (auto vmOp = llvm::dyn_cast<spatial::SpatWeightedVMMOp>(newInst)) {
auto oldIndex = vmOp.getWeightIndex();
auto newWeight = mapper.lookup(*std::next(child.getWeights().begin(), oldIndex));
auto newIndex = std::distance(newCompute.getWeights().begin(), llvm::find(newCompute.getWeights(), newWeight));
vmOp.setWeightIndex(newIndex);
}
}
child.replaceAllUsesWith(newCompute);
toErase.insert(child);
std::swap(trivialComputes.front(), trivialComputes.back());
trivialComputes.pop_back();
toErase.insert(compute);
if (newCompute->hasOneUse()) {
auto& use = *newCompute->getUses().begin();
auto user = dyn_cast<spatial::SpatCompute>(use.getOwner());
if (user && user.getInputs().size() == 1 && use.getOperandNumber() >= user.getWeights().size())
trivialComputes.push_back(newCompute);
}
}
for (auto compute : toErase) {
for (Value result : compute->getResults())
result.dropAllUses();
compute.erase();
}
}
void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const {
funcOp.walk([&](arith::ConstantOp constantOp) {
bool isAlwaysWeight =
llvm::all_of(constantOp->getUsers(), [](auto user) -> bool { return isa<spatial::SpatCompute>(user); });
if (isAlwaysWeight)
markWeightAlways(constantOp);
});
}
LogicalResult ONNXToSpatialPass::promoteConstantInputsToWeights(func::FuncOp funcOp) {
IRRewriter rewriter(&getContext());
SmallVector<spatial::SpatCompute> computes(funcOp.getOps<spatial::SpatCompute>());
for (auto compute : computes) {
SmallVector<bool> promoteInput(compute.getInputs().size(), false);
bool needsRewrite = false;
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
if (!isWeightLikeComputeOperand(input))
continue;
promoteInput[inputIdx] = true;
needsRewrite = true;
}
if (!needsRewrite)
continue;
rewriter.setInsertionPointAfter(compute);
SmallVector<Value> newWeights(compute.getWeights().begin(), compute.getWeights().end());
SmallVector<Value> newInputs;
SmallVector<Type> newInputTypes;
SmallVector<Location> newInputLocs;
newWeights.reserve(compute.getWeights().size() + compute.getInputs().size());
newInputs.reserve(compute.getInputs().size());
newInputTypes.reserve(compute.getInputs().size());
newInputLocs.reserve(compute.getInputs().size());
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
if (promoteInput[inputIdx]) {
newWeights.push_back(input);
continue;
}
newInputs.push_back(input);
newInputTypes.push_back(input.getType());
newInputLocs.push_back(input.getLoc());
}
auto newCompute =
spatial::SpatCompute::create(rewriter, compute.getLoc(), compute.getResultTypes(), newWeights, newInputs);
auto* newBlock =
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), newInputTypes, newInputLocs);
newCompute.getProperties().setOperandSegmentSizes(
{static_cast<int>(newWeights.size()), static_cast<int>(newInputs.size())});
rewriter.setInsertionPointToStart(newBlock);
IRMapping mapper;
auto& oldBlock = compute.getBody().front();
size_t newInputIdx = 0;
for (auto [oldInputIdx, oldArg] : llvm::enumerate(oldBlock.getArguments())) {
if (!promoteInput[oldInputIdx]) {
mapper.map(oldArg, newBlock->getArgument(newInputIdx++));
continue;
}
auto clonedValue = materializeWeightLikeValueInBlock(compute.getInputs()[oldInputIdx], rewriter, mapper);
if (failed(clonedValue))
return compute.emitError("failed to materialize promoted weight-like operand inside compute body");
mapper.map(oldArg, *clonedValue);
}
for (auto& op : oldBlock.without_terminator())
rewriter.clone(op, mapper);
auto oldYield = cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
SmallVector<Value> newYieldOperands;
newYieldOperands.reserve(oldYield.getOutputs().size());
for (Value operand : oldYield.getOutputs()) {
auto mapped = mapper.lookupOrNull(operand);
newYieldOperands.push_back(mapped ? cast<Value>(mapped) : operand);
}
spatial::SpatYieldOp::create(rewriter, oldYield.getLoc(), newYieldOperands);
compute.replaceAllUsesWith(newCompute);
compute.erase();
}
return success();
}
std::unique_ptr<Pass> createONNXToSpatialPass() { return std::make_unique<ONNXToSpatialPass>(); }
} // namespace onnx_mlir
@@ -7,11 +7,10 @@
#include "llvm/ADT/SmallVector.h"
#include <algorithm>
#include <cassert>
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -147,162 +146,148 @@ static Value buildPackedBias(bool hasBias,
return arith::ConstantOp::create(rewriter, loc, packedBiasType, packedBiasAttr).getResult();
}
static SmallVector<Value> createIm2colRowComputes(Value x,
RankedTensorType xType,
RankedTensorType im2colType,
RankedTensorType im2colRowType,
RankedTensorType gemmInputRowType,
int64_t batchSize,
int64_t numChannelsIn,
int64_t xHeight,
int64_t xWidth,
int64_t wHeight,
int64_t wWidth,
int64_t padHeightBegin,
int64_t padHeightEnd,
int64_t padWidthBegin,
int64_t padWidthEnd,
int64_t strideHeight,
int64_t strideWidth,
int64_t dilationHeight,
int64_t dilationWidth,
int64_t outWidth,
int64_t patchSize,
int64_t numPatches,
int64_t numPatchesPerBatch,
int64_t packFactor,
ConversionPatternRewriter& rewriter,
Location loc) {
static Value createIm2colRowComputes(Value x,
RankedTensorType xType,
RankedTensorType im2colType,
RankedTensorType im2colRowType,
RankedTensorType gemmInputRowsType,
int64_t batchSize,
int64_t numChannelsIn,
int64_t xHeight,
int64_t xWidth,
int64_t wHeight,
int64_t wWidth,
int64_t padHeightBegin,
int64_t padHeightEnd,
int64_t padWidthBegin,
int64_t padWidthEnd,
int64_t strideHeight,
int64_t strideWidth,
int64_t dilationHeight,
int64_t dilationWidth,
int64_t outWidth,
int64_t patchSize,
int64_t numPatches,
int64_t numPatchesPerBatch,
int64_t packFactor,
ConversionPatternRewriter& rewriter,
Location loc) {
auto elemType = xType.getElementType();
constexpr size_t numInputs = 1;
const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor);
SmallVector<Type> resultTypes(packedNumRows, gemmInputRowType);
auto im2colComputeOp = createSpatCompute<numInputs>(rewriter, loc, resultTypes, {}, x, [&](Value xArg) {
Value paddedInput = xArg;
auto im2colComputeOp =
createSpatCompute<numInputs>(rewriter, loc, TypeRange {gemmInputRowsType}, {}, x, [&](Value xArg) {
Value paddedInput = xArg;
// Pad input with zeros if needed:
// [1, numChannelsIn, xHeight, xWidth] -> [1, numChannelsIn, xHeight+padHeight, xWidth+padWidth]
if (padHeightBegin || padHeightEnd || padWidthBegin || padWidthEnd) {
const int64_t paddedHeight = xHeight + padHeightBegin + padHeightEnd;
const int64_t paddedWidth = xWidth + padWidthBegin + padWidthEnd;
auto paddedType = RankedTensorType::get({batchSize, numChannelsIn, paddedHeight, paddedWidth}, elemType);
SmallVector<OpFoldResult> lowPads = {rewriter.getIndexAttr(0),
rewriter.getIndexAttr(0),
rewriter.getIndexAttr(padHeightBegin),
rewriter.getIndexAttr(padWidthBegin)};
SmallVector<OpFoldResult> highPads = {rewriter.getIndexAttr(0),
rewriter.getIndexAttr(0),
rewriter.getIndexAttr(padHeightEnd),
rewriter.getIndexAttr(padWidthEnd)};
auto padOp = tensor::PadOp::create(rewriter, loc, paddedType, paddedInput, lowPads, highPads);
auto* padBlock = new Block();
for (int i = 0; i < 4; i++)
padBlock->addArgument(rewriter.getIndexType(), loc);
padOp.getRegion().push_back(padBlock);
rewriter.setInsertionPointToStart(padBlock);
auto zero = arith::ConstantOp::create(rewriter, loc, elemType, rewriter.getFloatAttr(elemType, 0.0));
tensor::YieldOp::create(rewriter, loc, zero.getResult());
rewriter.setInsertionPointAfter(padOp);
paddedInput = padOp.getResult();
}
// Pad input with zeros if needed:
// [1, numChannelsIn, xHeight, xWidth] -> [1, numChannelsIn, xHeight+padHeight, xWidth+padWidth]
if (padHeightBegin || padHeightEnd || padWidthBegin || padWidthEnd) {
const int64_t paddedHeight = xHeight + padHeightBegin + padHeightEnd;
const int64_t paddedWidth = xWidth + padWidthBegin + padWidthEnd;
auto paddedType = RankedTensorType::get({batchSize, numChannelsIn, paddedHeight, paddedWidth}, elemType);
SmallVector<OpFoldResult> lowPads = {rewriter.getIndexAttr(0),
rewriter.getIndexAttr(0),
rewriter.getIndexAttr(padHeightBegin),
rewriter.getIndexAttr(padWidthBegin)};
SmallVector<OpFoldResult> highPads = {rewriter.getIndexAttr(0),
rewriter.getIndexAttr(0),
rewriter.getIndexAttr(padHeightEnd),
rewriter.getIndexAttr(padWidthEnd)};
auto padOp = tensor::PadOp::create(rewriter, loc, paddedType, paddedInput, lowPads, highPads);
auto* padBlock = new Block();
for (int i = 0; i < 4; i++)
padBlock->addArgument(rewriter.getIndexType(), loc);
padOp.getRegion().push_back(padBlock);
rewriter.setInsertionPointToStart(padBlock);
auto zero = arith::ConstantOp::create(rewriter, loc, elemType, rewriter.getFloatAttr(elemType, 0.0));
tensor::YieldOp::create(rewriter, loc, zero.getResult());
rewriter.setInsertionPointAfter(padOp);
paddedInput = padOp.getResult();
}
// Build im2col [numPatches, patchSize] incrementally to keep the IR small
// until the late PIM unrolling step.
Value im2colInit = tensor::EmptyOp::create(rewriter, loc, im2colType.getShape(), elemType);
auto c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
auto c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
auto cNumPatches = arith::ConstantIndexOp::create(rewriter, loc, numPatches);
auto cNumPatchesPerBatch = arith::ConstantIndexOp::create(rewriter, loc, numPatchesPerBatch);
auto cOutWidth = arith::ConstantIndexOp::create(rewriter, loc, outWidth);
auto cStrideHeight = arith::ConstantIndexOp::create(rewriter, loc, strideHeight);
auto cStrideWidth = arith::ConstantIndexOp::create(rewriter, loc, strideWidth);
// Build im2col [numPatches, patchSize] incrementally to keep the IR small
// until the late PIM unrolling step.
Value im2colInit = tensor::EmptyOp::create(rewriter, loc, im2colType.getShape(), elemType);
auto c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
auto c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
auto cNumPatches = arith::ConstantIndexOp::create(rewriter, loc, numPatches);
auto cNumPatchesPerBatch = arith::ConstantIndexOp::create(rewriter, loc, numPatchesPerBatch);
auto cOutWidth = arith::ConstantIndexOp::create(rewriter, loc, outWidth);
auto cStrideHeight = arith::ConstantIndexOp::create(rewriter, loc, strideHeight);
auto cStrideWidth = arith::ConstantIndexOp::create(rewriter, loc, strideWidth);
auto im2colLoop = scf::ForOp::create(rewriter, loc, c0, cNumPatches, c1, ValueRange {im2colInit});
rewriter.setInsertionPointToStart(im2colLoop.getBody());
auto im2colLoop = scf::ForOp::create(rewriter, loc, c0, cNumPatches, c1, ValueRange {im2colInit});
rewriter.setInsertionPointToStart(im2colLoop.getBody());
Value patchIndex = im2colLoop.getInductionVar();
Value im2colAcc = im2colLoop.getRegionIterArgs().front();
Value patchIndex = im2colLoop.getInductionVar();
Value im2colAcc = im2colLoop.getRegionIterArgs().front();
Value batchIndex = arith::DivUIOp::create(rewriter, loc, patchIndex, cNumPatchesPerBatch);
Value batchPatchIndex = arith::RemUIOp::create(rewriter, loc, patchIndex, cNumPatchesPerBatch);
Value outHeightIndex = arith::DivUIOp::create(rewriter, loc, batchPatchIndex, cOutWidth);
Value outWidthIndex = arith::RemUIOp::create(rewriter, loc, batchPatchIndex, cOutWidth);
Value inputHeightOffset = arith::MulIOp::create(rewriter, loc, outHeightIndex, cStrideHeight);
Value inputWidthOffset = arith::MulIOp::create(rewriter, loc, outWidthIndex, cStrideWidth);
Value batchIndex = arith::DivUIOp::create(rewriter, loc, patchIndex, cNumPatchesPerBatch);
Value batchPatchIndex = arith::RemUIOp::create(rewriter, loc, patchIndex, cNumPatchesPerBatch);
Value outHeightIndex = arith::DivUIOp::create(rewriter, loc, batchPatchIndex, cOutWidth);
Value outWidthIndex = arith::RemUIOp::create(rewriter, loc, batchPatchIndex, cOutWidth);
Value inputHeightOffset = arith::MulIOp::create(rewriter, loc, outHeightIndex, cStrideHeight);
Value inputWidthOffset = arith::MulIOp::create(rewriter, loc, outWidthIndex, cStrideWidth);
SmallVector<OpFoldResult> offsets = {batchIndex, rewriter.getIndexAttr(0), inputHeightOffset, inputWidthOffset};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(numChannelsIn),
rewriter.getIndexAttr(wHeight),
rewriter.getIndexAttr(wWidth)};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(dilationHeight),
rewriter.getIndexAttr(dilationWidth)};
auto patchType = RankedTensorType::get({1, numChannelsIn, wHeight, wWidth}, elemType);
Value patch = tensor::ExtractSliceOp::create(rewriter, loc, patchType, paddedInput, offsets, sizes, strides);
SmallVector<OpFoldResult> offsets = {batchIndex, rewriter.getIndexAttr(0), inputHeightOffset, inputWidthOffset};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(numChannelsIn),
rewriter.getIndexAttr(wHeight),
rewriter.getIndexAttr(wWidth)};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(dilationHeight),
rewriter.getIndexAttr(dilationWidth)};
auto patchType = RankedTensorType::get({1, numChannelsIn, wHeight, wWidth}, elemType);
Value patch = tensor::ExtractSliceOp::create(rewriter, loc, patchType, paddedInput, offsets, sizes, strides);
Value row = tensor::CollapseShapeOp::create(rewriter,
loc,
im2colRowType,
patch,
SmallVector<ReassociationIndices> {
{0},
{1, 2, 3}
Value row = tensor::CollapseShapeOp::create(rewriter,
loc,
im2colRowType,
patch,
SmallVector<ReassociationIndices> {
{0},
{1, 2, 3}
});
SmallVector<OpFoldResult> rowOffsets = {patchIndex, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> rowSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(patchSize)};
SmallVector<OpFoldResult> rowStrides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
Value updatedIm2col =
tensor::InsertSliceOp::create(rewriter, loc, row, im2colAcc, rowOffsets, rowSizes, rowStrides);
scf::YieldOp::create(rewriter, loc, updatedIm2col);
rewriter.setInsertionPointAfter(im2colLoop);
Value im2col = im2colLoop.getResult(0);
Value gemmInputRows = im2col;
if (packFactor != 1) {
const int64_t paddedNumPatches = packedNumRows * packFactor;
auto groupedType = RankedTensorType::get({packedNumRows, packFactor, patchSize}, elemType);
auto packedType = RankedTensorType::get({packedNumRows, packFactor * patchSize}, elemType);
Value paddedIm2col = createPaddedRows(im2col, im2colType, paddedNumPatches, rewriter, loc);
Value groupedIm2col = tensor::ExpandShapeOp::create(rewriter,
loc,
groupedType,
paddedIm2col,
SmallVector<ReassociationIndices> {
{0, 1},
{2}
});
gemmInputRows = tensor::CollapseShapeOp::create(rewriter,
loc,
packedType,
groupedIm2col,
SmallVector<ReassociationIndices> {
{0},
{1, 2}
});
}
spatial::SpatYieldOp::create(rewriter, loc, gemmInputRows);
});
SmallVector<OpFoldResult> rowOffsets = {patchIndex, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> rowSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(patchSize)};
SmallVector<OpFoldResult> rowStrides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
Value updatedIm2col =
tensor::InsertSliceOp::create(rewriter, loc, row, im2colAcc, rowOffsets, rowSizes, rowStrides);
scf::YieldOp::create(rewriter, loc, updatedIm2col);
rewriter.setInsertionPointAfter(im2colLoop);
Value im2col = im2colLoop.getResult(0);
Value gemmInputRows = im2col;
if (packFactor != 1) {
const int64_t paddedNumPatches = packedNumRows * packFactor;
auto groupedType = RankedTensorType::get({packedNumRows, packFactor, patchSize}, elemType);
auto packedType = RankedTensorType::get({packedNumRows, packFactor * patchSize}, elemType);
Value paddedIm2col = createPaddedRows(im2col, im2colType, paddedNumPatches, rewriter, loc);
Value groupedIm2col = tensor::ExpandShapeOp::create(rewriter,
loc,
groupedType,
paddedIm2col,
SmallVector<ReassociationIndices> {
{0, 1},
{2}
});
gemmInputRows = tensor::CollapseShapeOp::create(rewriter,
loc,
packedType,
groupedIm2col,
SmallVector<ReassociationIndices> {
{0},
{1, 2}
});
}
SmallVector<Value> rowResults;
rowResults.reserve(packedNumRows);
for (int64_t rowIdx = 0; rowIdx < packedNumRows; rowIdx++) {
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(packFactor * patchSize)};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
rowResults.push_back(
tensor::ExtractSliceOp::create(rewriter, loc, gemmInputRowType, gemmInputRows, offsets, sizes, strides));
}
spatial::SpatYieldOp::create(rewriter, loc, rowResults);
});
SmallVector<Value> rows;
rows.reserve(im2colComputeOp.getNumResults());
for (Value result : im2colComputeOp.getResults())
rows.push_back(result);
return rows;
return im2colComputeOp.getResult(0);
}
static Value createCollectedConvOutput(ValueRange gemmRows,
@@ -320,16 +305,12 @@ static Value createCollectedConvOutput(ValueRange gemmRows,
auto collectComputeOp = createSpatCompute(rewriter, loc, convType, {}, gemmRows, [&](ValueRange gemmRowArgs) {
Value gemmOut;
if (packFactor == 1) {
gemmOut = gemmRowArgs.size() == 1 ? gemmRowArgs.front()
: tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, gemmRowArgs).getResult();
gemmOut = createSpatConcat(rewriter, loc, /*axis=*/0, gemmRowArgs);
}
else {
auto expandedType = RankedTensorType::get({packedNumRows, packFactor, numChannelsOut}, outType.getElementType());
auto paddedType = RankedTensorType::get({paddedNumPatches, numChannelsOut}, outType.getElementType());
Value packedOutput =
gemmRowArgs.size() == 1
? gemmRowArgs.front()
: tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, gemmRowArgs).getResult();
Value packedOutput = createSpatConcat(rewriter, loc, /*axis=*/0, gemmRowArgs);
Value expandedOutput = tensor::ExpandShapeOp::create(rewriter,
loc,
expandedType,
@@ -388,11 +369,34 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
auto wType = cast<RankedTensorType>(w.getType());
auto outType = cast<RankedTensorType>(convOp.getY().getType());
assert("Only support static shapes" && xType.hasStaticShape() && wType.hasStaticShape() && outType.hasStaticShape());
assert("Only support 2D convolution" && xType.getRank() == 4);
// We need to understand what is group
assert("Only support group=1" && convOp.getGroup() == 1);
if (!xType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv input");
return failure();
}
if (!wType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv weight");
return failure();
}
if (!outType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(convOp, "conv result");
return failure();
}
if (xType.getRank() != 4) {
pim::emitUnsupportedRankDiagnostic(convOp, "conv input", xType.getRank(), {4});
return failure();
}
if (wType.getRank() != 4) {
pim::emitUnsupportedRankDiagnostic(convOp, "conv weight", wType.getRank(), {4});
return failure();
}
if (outType.getRank() != 4) {
pim::emitUnsupportedRankDiagnostic(convOp, "conv result", outType.getRank(), {4});
return failure();
}
if (convOp.getGroup() != 1) {
convOp.emitOpError("only group=1 convolution is supported for Spatial lowering");
return failure();
}
const int64_t batchSize = xType.getDimSize(0);
const int64_t numChannelsIn = xType.getDimSize(1);
@@ -409,6 +413,19 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
const auto dilationsAttr = convOp.getDilations();
const auto padsAttr = convOp.getPads();
if (stridesAttr && stridesAttr->size() != 2) {
convOp.emitOpError("requires exactly two stride values for Spatial lowering");
return failure();
}
if (dilationsAttr && dilationsAttr->size() != 2) {
convOp.emitOpError("requires exactly two dilation values for Spatial lowering");
return failure();
}
if (padsAttr && padsAttr->size() != 4) {
convOp.emitOpError("requires exactly four pad values for 2D Spatial lowering");
return failure();
}
const int64_t strideHeight = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 0) : 1;
const int64_t strideWidth = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 1) : 1;
const int64_t dilationHeight = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 0) : 1;
@@ -449,6 +466,10 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
padWidthBegin = totalPadW - padWidthEnd;
}
}
else if (autoPad != "NOTSET" && autoPad != "VALID") {
convOp.emitOpError() << "unsupported auto_pad value `" << autoPad << "` for Spatial lowering";
return failure();
}
// "NOTSET" or "VALID" -> all pads stay 0
}
@@ -505,38 +526,42 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
// B (weights): [patchSize, cOut] -- W^T, stored in crossbar columns
// and optionally repack several old rows into one GEMM row to use the available crossbar size better.
//
// The im2col compute yields each GEMM input row as a separate result so every GEMM consumes only
// the row it needs instead of receiving a full packed tensor and slicing it locally.
auto gemmInputRowType =
RankedTensorType::get({1, effectiveMaxParallelPixels * patchSize}, elemType);
auto gemmOutputRowType =
RankedTensorType::get({1, effectiveMaxParallelPixels * numChannelsOut}, outType.getElementType());
SmallVector<Value> gemmInputRows = createIm2colRowComputes(x,
xType,
im2colType,
rowType,
gemmInputRowType,
batchSize,
numChannelsIn,
xHeight,
xWidth,
wHeight,
wWidth,
padHeightBegin,
padHeightEnd,
padWidthBegin,
padWidthEnd,
strideHeight,
strideWidth,
dilationHeight,
dilationWidth,
outWidth,
patchSize,
numPatches,
numPatchesPerBatch,
effectiveMaxParallelPixels,
rewriter,
loc);
// We want to process N pixels at the same time. Instead of doing N separate operations
// of (1 x patchSize) x (patchSize x cOut), we construct a block-diagonal weight matrix
// containing N copies of W^T and concatenate N im2col rows into one longer row:
// A_packed: [ceil(numPatches / N), N * patchSize]
// B_packed: [N * patchSize, N * cOut]
// Y_packed: [ceil(numPatches / N), N * cOut]
const int64_t packedNumRows = ceilIntegerDivide(numPatches, effectiveMaxParallelPixels);
auto gemmInputRowsType = RankedTensorType::get({packedNumRows, effectiveMaxParallelPixels * patchSize}, elemType);
auto gemmOutputRowsType =
RankedTensorType::get({packedNumRows, effectiveMaxParallelPixels * numChannelsOut}, outType.getElementType());
Value gemmInputRows = createIm2colRowComputes(x,
xType,
im2colType,
rowType,
gemmInputRowsType,
batchSize,
numChannelsIn,
xHeight,
xWidth,
wHeight,
wWidth,
padHeightBegin,
padHeightEnd,
padWidthBegin,
padWidthEnd,
strideHeight,
strideWidth,
dilationHeight,
dilationWidth,
outWidth,
patchSize,
numPatches,
numPatchesPerBatch,
effectiveMaxParallelPixels,
rewriter,
loc);
Value gemmB = buildPackedWeight(wDenseAttr,
wTrans,
@@ -552,25 +577,20 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
Value gemmC = buildPackedBias(
hasB, gemmBias, biasMatrix, biasDenseAttr, outType, numChannelsOut, effectiveMaxParallelPixels, rewriter, loc);
SmallVector<Value> gemmRows;
gemmRows.reserve(gemmInputRows.size());
for (Value gemmInputRow : gemmInputRows) {
Value gemmRow = ONNXGemmOp::create(rewriter,
loc,
gemmOutputRowType,
gemmInputRow,
gemmB,
gemmC,
rewriter.getF32FloatAttr(1.0f),
rewriter.getF32FloatAttr(1.0f),
rewriter.getBoolAttr(false),
rewriter.getBoolAttr(false))
.getY();
gemmRows.push_back(gemmRow);
}
Value gemmRows = ONNXGemmOp::create(rewriter,
loc,
gemmOutputRowsType,
gemmInputRows,
gemmB,
gemmC,
rewriter.getF32FloatAttr(1.0f),
rewriter.getF32FloatAttr(1.0f),
rewriter.getBoolAttr(false),
rewriter.getBoolAttr(false))
.getY();
rewriter.replaceOp(convOp,
createCollectedConvOutput(gemmRows,
createCollectedConvOutput(ValueRange {gemmRows},
convOp.getType(),
gemmOutType,
nhwcType,
@@ -5,8 +5,9 @@
#include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -15,13 +16,6 @@ using namespace mlir;
namespace onnx_mlir {
namespace {
static SmallVector<int64_t> computeRowMajorStrides(ArrayRef<int64_t> shape) {
SmallVector<int64_t> strides(shape.size(), 1);
for (int64_t i = static_cast<int64_t>(shape.size()) - 2; i >= 0; --i)
strides[i] = strides[i + 1] * shape[i + 1];
return strides;
}
static DenseElementsAttr getDenseConstantAttr(Value value) {
if (auto constantOp = value.getDefiningOp<arith::ConstantOp>())
return dyn_cast<DenseElementsAttr>(constantOp.getValue());
@@ -1,16 +1,17 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Location.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVector.h"
#include <cassert>
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -49,6 +50,45 @@ materializeScaledConstantTensor(Value value, float factor, ConversionPatternRewr
return arith::ConstantOp::create(rewriter, loc, denseAttr.getType(), scaledAttr).getResult();
}
static Value transposeForSpatial(Value value,
RankedTensorType resultType,
ArrayRef<int64_t> permutation,
ConversionPatternRewriter& rewriter,
Location loc) {
if (isHostFoldableValue(value))
return ONNXTransposeOp::create(rewriter, loc, resultType, value, rewriter.getI64ArrayAttr(permutation));
auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {resultType}, {}, value, [&](Value input) {
Value transposed = ONNXTransposeOp::create(rewriter, loc, resultType, input, rewriter.getI64ArrayAttr(permutation));
spatial::SpatYieldOp::create(rewriter, loc, transposed);
});
return computeOp.getResult(0);
}
static Value
expandRankOneBias(Value value, RankedTensorType resultType, ConversionPatternRewriter& rewriter, Location loc) {
if (isHostFoldableValue(value))
return tensor::ExpandShapeOp::create(rewriter,
loc,
resultType,
value,
SmallVector<ReassociationIndices> {
{0, 1}
});
auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {resultType}, {}, value, [&](Value input) {
Value expanded = tensor::ExpandShapeOp::create(rewriter,
loc,
resultType,
input,
SmallVector<ReassociationIndices> {
{0, 1}
});
spatial::SpatYieldOp::create(rewriter, loc, expanded);
});
return computeOp.getResult(0);
}
struct GemmToManyGemv : OpConversionPattern<ONNXGemmOp> {
using OpConversionPattern::OpConversionPattern;
@@ -65,6 +105,72 @@ struct GemvToSpatialCompute : OpConversionPattern<ONNXGemmOp> {
ConversionPatternRewriter& rewriter) const override;
};
struct GemmToSpatialComputeBatch : OpConversionPattern<ONNXGemmOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(ONNXGemmOp gemmOp,
ONNXGemmOpAdaptor gemmOpAdaptor,
ConversionPatternRewriter& rewriter) const override;
};
static SmallVector<Value> materializeBatchRowSlices(Value matrix,
RankedTensorType matrixType,
ConversionPatternRewriter& rewriter,
Location loc) {
const int64_t numRows = matrixType.getDimSize(0);
auto rowType = RankedTensorType::get({1, matrixType.getDimSize(1)}, matrixType.getElementType());
SmallVector<Type> resultTypes(static_cast<size_t>(numRows), rowType);
if (isHostFoldableValue(matrix)) {
auto extractRowsOp = spatial::SpatExtractRowsOp::create(rewriter, loc, TypeRange(resultTypes), matrix);
return SmallVector<Value>(extractRowsOp->result_begin(), extractRowsOp->result_end());
}
auto buildRowSlices = [&](Value matrixArg) {
auto extractRowsOp = spatial::SpatExtractRowsOp::create(rewriter, loc, TypeRange(resultTypes), matrixArg);
return SmallVector<Value>(extractRowsOp->result_begin(), extractRowsOp->result_end());
};
auto cloneBatchInputChainIntoSliceCompute =
[&](Value rootInput, SmallVector<Operation*> chainOps, Value rootValue) -> SmallVector<Value> {
auto sliceCompute =
createSpatCompute<1>(rewriter, loc, TypeRange(resultTypes), {}, ValueRange {rootInput}, [&](Value input) {
Value transformedMatrix = input;
if (!chainOps.empty()) {
IRMapping mapper;
mapper.map(rootValue, input);
for (Operation* chainOp : chainOps)
rewriter.clone(*chainOp, mapper);
transformedMatrix = cast<Value>(mapper.lookup(matrix));
}
spatial::SpatYieldOp::create(rewriter, loc, buildRowSlices(transformedMatrix));
});
SmallVector<Value> rowSlices(sliceCompute->result_begin(), sliceCompute->result_end());
return rowSlices;
};
SmallVector<Operation*> chainOps;
Value rootValue = matrix;
while (Operation* definingOp = rootValue.getDefiningOp()) {
if (auto rootCompute = dyn_cast<spatial::SpatCompute>(definingOp)) {
SmallVector<Operation*> reversedChainOps(chainOps.rbegin(), chainOps.rend());
return cloneBatchInputChainIntoSliceCompute(
rootCompute.getResult(cast<OpResult>(rootValue).getResultNumber()), reversedChainOps, rootValue);
}
if (definingOp->getNumOperands() != 1)
break;
if (!isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, ONNXTransposeOp>(definingOp))
break;
chainOps.push_back(definingOp);
rootValue = definingOp->getOperand(0);
}
SmallVector<Operation*> reversedChainOps(chainOps.rbegin(), chainOps.rend());
return cloneBatchInputChainIntoSliceCompute(rootValue, reversedChainOps, rootValue);
}
} // namespace
LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
@@ -75,13 +181,23 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
Value b = gemmOpAdaptor.getB();
Value c = gemmOpAdaptor.getC();
assert("A should have been transposed already" && !gemmOpAdaptor.getTransA());
if (gemmOpAdaptor.getTransA()) {
gemmOp.emitOpError("requires transA=false before Gemm row decomposition");
return failure();
}
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
auto aType = cast<RankedTensorType>(a.getType());
auto outType = cast<RankedTensorType>(gemmOp.getY().getType());
assert("Only support static shapes" && aType.hasStaticShape() && outType.hasStaticShape());
if (!aType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input A");
return failure();
}
if (!outType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm result");
return failure();
}
const int64_t numOutRows = aType.getDimSize(0);
@@ -105,47 +221,43 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
// Expand rank-1 bias [N] to rank-2 [1, N] for uniform handling
if (cType.getRank() == 1) {
auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType());
c = tensor::ExpandShapeOp::create(rewriter,
loc,
expandedType,
c,
SmallVector<ReassociationIndices> {
{0, 1}
});
c = expandRankOneBias(c, expandedType, rewriter, loc);
cType = expandedType;
}
assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
if (!cType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm bias");
return failure();
}
if (cType.getRank() != 2) {
pim::emitUnsupportedRankDiagnostic(gemmOp, "Gemm bias", cType.getRank(), {1, 2});
return failure();
}
cHasNumOutRows = cType.getDimSize(0) == numOutRows;
}
auto outRowType = RankedTensorType::get({1, outType.getDimSize(1)}, outType.getElementType());
SmallVector<Value> aSlices = materializeBatchRowSlices(a, aType, rewriter, loc);
SmallVector<Value> cSlices;
if (hasC && cHasNumOutRows)
cSlices = materializeBatchRowSlices(c, cType, rewriter, loc);
SmallVector<Value> gemvOps;
gemvOps.reserve(numOutRows);
gemvOps.reserve(static_cast<size_t>(numOutRows));
for (int64_t rowIdx = 0; rowIdx < numOutRows; rowIdx++) {
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(aType.getDimSize(1))};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
auto aSliceType = RankedTensorType::get({1, aType.getDimSize(1)}, aType.getElementType());
auto aSlice = tensor::ExtractSliceOp::create(rewriter, loc, aSliceType, a, offsets, sizes, strides).getResult();
Value cSlice = c;
if (hasC) {
if (cHasNumOutRows) {
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(cType.getDimSize(1))};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
auto cSliceType = RankedTensorType::get({1, cType.getDimSize(1)}, cType.getElementType());
cSlice = tensor::ExtractSliceOp::create(rewriter, loc, cSliceType, c, offsets, sizes, strides).getResult();
if (cHasNumOutRows)
cSlice = cSlices[static_cast<size_t>(rowIdx)];
else if (!isVectorShape(getTensorShape(c))) {
gemmOp.emitOpError("requires Gemm bias C to be vector-like when shared across decomposed rows");
return failure();
}
else
assert("C should be a vector" && isVectorShape(getTensorShape(c)));
}
auto gemvOp = ONNXGemmOp::create(rewriter,
loc,
outRowType,
aSlice,
aSlices[static_cast<size_t>(rowIdx)],
b,
cSlice,
rewriter.getF32FloatAttr(1.0f),
@@ -156,8 +268,7 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
}
auto concatComputeOp = createSpatCompute(rewriter, loc, gemmOp.getType(), {}, gemvOps, [&](ValueRange gemvOpsArgs) {
auto concatOp = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, gemvOpsArgs);
spatial::SpatYieldOp::create(rewriter, loc, concatOp.getResult());
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, /*axis=*/0, gemvOpsArgs));
});
rewriter.replaceOp(gemmOp, concatComputeOp);
@@ -189,20 +300,31 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
// Expand rank-1 bias [N] to rank-2 [1, N] for uniform handling
if (cType.getRank() == 1) {
auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType());
c = tensor::ExpandShapeOp::create(rewriter,
gemmLoc,
expandedType,
c,
SmallVector<ReassociationIndices> {
{0, 1}
});
c = expandRankOneBias(c, expandedType, rewriter, gemmLoc);
cType = expandedType;
}
assert("Only support rank 2 tensor for C" && cType.getRank() == 2);
if (!cType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm bias");
return failure();
}
if (cType.getRank() != 2) {
pim::emitUnsupportedRankDiagnostic(gemmOp, "Gemm bias", cType.getRank(), {1, 2});
return failure();
}
}
assert("Only support static shapes" && aType.hasStaticShape() && bType.hasStaticShape()
&& (!hasC || cType.hasStaticShape()) && outType.hasStaticShape());
if (!aType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input A");
return failure();
}
if (!bType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input B");
return failure();
}
if (!outType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm result");
return failure();
}
if (!isVectorShape(aType.getShape()) || (hasC && !isVectorShape(cType.getShape())))
// Not a gemv
@@ -210,13 +332,14 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
if (transA) {
auto aShape = aType.getShape();
auto transposedType = aType.cloneWith(ArrayRef({aShape[1], aShape[0]}), aType.getElementType());
a = ONNXTransposeOp::create(rewriter, gemmLoc, transposedType, a, rewriter.getI64ArrayAttr({1, 0}));
auto transposedType = RankedTensorType::get({aShape[1], aShape[0]}, aType.getElementType());
a = transposeForSpatial(a, transposedType, {1, 0}, rewriter, gemmLoc);
aType = cast<RankedTensorType>(a.getType());
}
if (transB) {
auto bShape = bType.getShape();
auto transposedType = bType.cloneWith(ArrayRef({bShape[1], bShape[0]}), bType.getElementType());
b = ONNXTransposeOp::create(rewriter, gemmLoc, transposedType, b, rewriter.getI64ArrayAttr({1, 0}));
auto transposedType = RankedTensorType::get({bShape[1], bShape[0]}, bType.getElementType());
b = transposeForSpatial(b, transposedType, {1, 0}, rewriter, gemmLoc);
bType = cast<RankedTensorType>(b.getType());
}
@@ -240,7 +363,6 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
auto [aNumHSlices, aLastHSliceSize] = ceilIntegerDivideWithRemainder(aType.getDimSize(1), crossbarSize.getValue());
auto [bNumHSlices, bLastHSliceSize] = ceilIntegerDivideWithRemainder(bType.getDimSize(1), crossbarSize.getValue());
auto bNumVSlices = aNumHSlices;
auto bLastVSliceSize = aLastHSliceSize;
auto cNumHSlices = bNumHSlices;
auto cLastHSliceSize = bLastHSliceSize;
auto outNumHSlices = cNumHSlices;
@@ -281,19 +403,25 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
weights.push_back(bTiles[outSliceId][coreId][aSliceId]);
auto computeOp = createSpatCompute(
rewriter, gemmLoc, currOutHSliceType, weights, aHSlices[coreId], [&](ValueRange aHSlicesArgs) {
rewriter, gemmLoc, currOutHSliceType, weights, aHSlices[coreId], [&](ValueRange aHSlicesArgs) -> LogicalResult {
SmallVector<Value> vmmOutputs;
vmmOutputs.reserve(aHSlicesArgs.size());
for (auto [aHSliceId, computeArg] : llvm::enumerate(aHSlicesArgs))
vmmOutputs.push_back(
spatial::SpatWeightedVMMOp::create(rewriter, gemmLoc, currOutHSliceType, aHSliceId, computeArg));
assert(!vmmOutputs.empty() && "vmmOutputs must be non-empty");
spatial::SpatVMMOp::create(rewriter, gemmLoc, currOutHSliceType, aHSliceId, computeArg));
if (vmmOutputs.empty()) {
gemmOp.emitOpError("requires at least one non-empty slice when lowering tiled Gemm to Spatial VMMs");
return failure();
}
Value partialVmmSum = sumTensors(vmmOutputs, rewriter);
spatial::SpatYieldOp::create(rewriter, gemmLoc, partialVmmSum);
return success();
});
if (failed(computeOp))
return failure();
partialResults.push_back(computeOp.getResult(0));
partialResults.push_back(computeOp->getResult(0));
}
if (hasC) {
@@ -313,15 +441,134 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
auto concatComputeOp =
createSpatCompute(rewriter, gemmLoc, gemmOp.getType(), {}, outHSlices, [&](ValueRange blockArgs) {
auto concatOp = tensor::ConcatOp::create(rewriter, gemmLoc, /*axis=*/1, blockArgs);
spatial::SpatYieldOp::create(rewriter, gemmLoc, concatOp.getResult());
spatial::SpatYieldOp::create(rewriter, gemmLoc, createSpatConcat(rewriter, gemmLoc, /*axis=*/1, blockArgs));
});
rewriter.replaceOp(gemmOp, concatComputeOp);
return success();
}
LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp,
ONNXGemmOpAdaptor gemmOpAdaptor,
ConversionPatternRewriter& rewriter) const {
Location loc = gemmOp.getLoc();
Value a = gemmOpAdaptor.getA();
Value b = gemmOpAdaptor.getB();
Value c = gemmOpAdaptor.getC();
if (gemmOpAdaptor.getTransA()) {
gemmOp.emitOpError("requires transA=false before batch Gemm lowering");
return failure();
}
bool hasC = !isa<ONNXNoneOp>(c.getDefiningOp());
auto aType = cast<RankedTensorType>(a.getType());
auto bType = cast<RankedTensorType>(b.getType());
auto outType = cast<RankedTensorType>(gemmOp.getY().getType());
if (!aType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input A");
return failure();
}
if (!bType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input B");
return failure();
}
if (!outType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm result");
return failure();
}
const int64_t numOutRows = aType.getDimSize(0);
if (numOutRows <= 1)
return failure();
// Only handle the single-tile case: K <= crossbarSize and N <= crossbarSize
if (aType.getDimSize(1) > static_cast<int64_t>(crossbarSize.getValue())
|| outType.getDimSize(1) > static_cast<int64_t>(crossbarSize.getValue()))
return failure();
auto scaledB = materializeScaledConstantTensor(b, gemmOpAdaptor.getAlpha().convertToFloat(), rewriter, loc);
if (failed(scaledB))
return failure();
b = *scaledB;
bType = cast<RankedTensorType>(b.getType());
if (gemmOpAdaptor.getTransB()) {
auto bShape = bType.getShape();
auto transposedType = RankedTensorType::get({bShape[1], bShape[0]}, bType.getElementType());
b = transposeForSpatial(b, transposedType, {1, 0}, rewriter, loc);
bType = cast<RankedTensorType>(b.getType());
}
(void) bType;
if (!isHostFoldableValue(b))
return failure();
Value sharedBias;
if (hasC) {
auto scaledC = materializeScaledConstantTensor(c, gemmOpAdaptor.getBeta().convertToFloat(), rewriter, loc);
if (failed(scaledC))
return failure();
c = *scaledC;
auto cType = cast<RankedTensorType>(c.getType());
if (cType.getRank() == 1) {
auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType());
c = expandRankOneBias(c, expandedType, rewriter, loc);
cType = cast<RankedTensorType>(c.getType());
}
if (!cType.hasStaticShape()) {
pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm bias");
return failure();
}
if (cType.getRank() != 2) {
pim::emitUnsupportedRankDiagnostic(gemmOp, "Gemm bias", cType.getRank(), {1, 2});
return failure();
}
// Row-specific bias can't share a single template body; fall through to GemmToManyGemv
if (cType.getDimSize(0) == numOutRows && numOutRows > 1)
return failure();
if (cType.getDimSize(0) == 1 && cType.getDimSize(1) == 1)
c = broadcastToVector(c, outType.getDimSize(1), rewriter, loc);
sharedBias = c;
}
SmallVector<Value> aSlices = materializeBatchRowSlices(a, aType, rewriter, loc);
auto aSliceType = cast<RankedTensorType>(aSlices.front().getType());
auto outRowType = RankedTensorType::get({1, outType.getDimSize(1)}, outType.getElementType());
SmallVector<Type> resultTypes(static_cast<size_t>(numOutRows), outRowType);
SmallVector<Value> weights(static_cast<size_t>(numOutRows), b);
auto batchOp = spatial::SpatComputeBatch::create(rewriter,
loc,
TypeRange(resultTypes),
rewriter.getI32IntegerAttr(static_cast<int32_t>(numOutRows)),
ValueRange(weights),
ValueRange(aSlices));
Block* body = rewriter.createBlock(
&batchOp.getBody(), batchOp.getBody().end(), TypeRange {aSliceType}, SmallVector<Location>(1, loc));
rewriter.setInsertionPointToEnd(body);
Value vmmResult = spatial::SpatVMMOp::create(rewriter, loc, outRowType, 0, body->getArgument(0)).getResult();
Value laneResult = vmmResult;
if (sharedBias)
laneResult = spatial::SpatVAddOp::create(rewriter, loc, outRowType, vmmResult, sharedBias).getResult();
spatial::SpatYieldOp::create(rewriter, loc, laneResult);
rewriter.setInsertionPointAfter(batchOp);
SmallVector<Value> laneResults(batchOp->result_begin(), batchOp->result_end());
auto concatComputeOp = createSpatCompute(rewriter, loc, gemmOp.getType(), {}, laneResults, [&](ValueRange args) {
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, /*axis=*/0, args));
});
rewriter.replaceOp(gemmOp, concatComputeOp);
return success();
}
void populateGemmPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<GemmToSpatialComputeBatch>(ctx, PatternBenefit(2));
patterns.insert<GemmToManyGemv>(ctx);
patterns.insert<GemvToSpatialCompute>(ctx);
}
@@ -4,8 +4,9 @@
#include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -14,7 +15,102 @@ using namespace mlir;
namespace onnx_mlir {
namespace {
struct MatMulRank3ToGemm : OpRewritePattern<ONNXMatMulOp> {
static bool haveStaticPositiveShape(ArrayRef<int64_t> shape) {
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
}
static Value extractBatchMatrix(Value value,
int64_t batchIndex,
int64_t batchSize,
int64_t rows,
int64_t cols,
PatternRewriter& rewriter,
Location loc) {
auto type = cast<RankedTensorType>(value.getType());
if (type.getRank() == 2)
return value;
auto sliceType = RankedTensorType::get({1, rows, cols}, type.getElementType());
SmallVector<OpFoldResult> offsets = {
rewriter.getIndexAttr(batchSize == 1 ? 0 : batchIndex), rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes = {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(rows), rewriter.getIndexAttr(cols)};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
auto matrixType = RankedTensorType::get({rows, cols}, type.getElementType());
auto buildMatrix = [&](Value input) -> Value {
Value slice = tensor::ExtractSliceOp::create(rewriter, loc, sliceType, input, offsets, sizes, strides);
return tensor::CollapseShapeOp::create(rewriter,
loc,
matrixType,
slice,
SmallVector<ReassociationIndices> {
{0, 1},
{2}
});
};
if (isHostFoldableValue(value))
return buildMatrix(value);
auto batchMatrixCompute =
createSpatCompute<1>(rewriter, loc, TypeRange {matrixType}, {}, ValueRange {value}, [&](Value input) {
spatial::SpatYieldOp::create(rewriter, loc, buildMatrix(input));
});
return batchMatrixCompute.getResult(0);
}
static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Location loc) {
auto type = cast<RankedTensorType>(value.getType());
auto shape = type.getShape();
if (type.getRank() == 2) {
auto transposedType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType());
return ONNXTransposeOp::create(rewriter, loc, transposedType, value, rewriter.getI64ArrayAttr({1, 0}));
}
auto transposedType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType());
return ONNXTransposeOp::create(rewriter, loc, transposedType, value, rewriter.getI64ArrayAttr({0, 2, 1}));
}
static Value transposeLastTwoDimsInCompute(Value value, PatternRewriter& rewriter, Location loc) {
auto type = cast<RankedTensorType>(value.getType());
auto shape = type.getShape();
RankedTensorType transposedType;
SmallVector<int64_t> perm;
if (type.getRank() == 2) {
transposedType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType());
perm = {1, 0};
}
else {
transposedType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType());
perm = {0, 2, 1};
}
auto transposeCompute = createSpatCompute<1>(rewriter, loc, transposedType, {}, ValueRange {value}, [&](Value input) {
Value transposed = ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm));
spatial::SpatYieldOp::create(rewriter, loc, transposed);
});
return transposeCompute.getResult(0);
}
static Value concatValues(ValueRange inputs, int64_t axis, PatternRewriter& rewriter, Location loc) {
auto firstType = cast<RankedTensorType>(inputs.front().getType());
SmallVector<int64_t> outputShape(firstType.getShape().begin(), firstType.getShape().end());
int64_t concatDimSize = 0;
for (Value input : inputs)
concatDimSize += cast<RankedTensorType>(input.getType()).getDimSize(axis);
outputShape[axis] = concatDimSize;
auto resultType = RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
if (llvm::all_of(inputs, isHostFoldableValue))
return createSpatConcat(rewriter, loc, axis, inputs);
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {resultType}, {}, inputs, [&](ValueRange args) {
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, axis, args));
});
return concatCompute.getResult(0);
}
struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ONNXMatMulOp matmulOp, PatternRewriter& rewriter) const override {
@@ -24,80 +120,125 @@ struct MatMulRank3ToGemm : OpRewritePattern<ONNXMatMulOp> {
if (!lhsType || !rhsType || !outType || !lhsType.hasStaticShape() || !rhsType.hasStaticShape()
|| !outType.hasStaticShape())
return failure();
if (lhsType.getRank() != 2 || rhsType.getRank() != 3 || outType.getRank() != 3)
if ((lhsType.getRank() != 2 && lhsType.getRank() != 3) || (rhsType.getRank() != 2 && rhsType.getRank() != 3)
|| (outType.getRank() != 2 && outType.getRank() != 3))
return failure();
if (!haveStaticPositiveShape(lhsType.getShape()) || !haveStaticPositiveShape(rhsType.getShape())
|| !haveStaticPositiveShape(outType.getShape()))
return failure();
const int64_t batch = rhsType.getDimSize(0);
const int64_t k = rhsType.getDimSize(1);
const int64_t n = rhsType.getDimSize(2);
const int64_t m = lhsType.getDimSize(0);
if (lhsType.getDimSize(1) != k || outType.getDimSize(0) != batch || outType.getDimSize(1) != m
|| outType.getDimSize(2) != n)
const int64_t lhsBatch = lhsType.getRank() == 3 ? lhsType.getDimSize(0) : 1;
const int64_t rhsBatch = rhsType.getRank() == 3 ? rhsType.getDimSize(0) : 1;
const int64_t batch = std::max(lhsBatch, rhsBatch);
if ((lhsBatch != 1 && lhsBatch != batch) || (rhsBatch != 1 && rhsBatch != batch))
return failure();
Location loc = matmulOp.getLoc();
auto lhsTransposedType = RankedTensorType::get({k, m}, lhsType.getElementType());
auto rhsSliceType = RankedTensorType::get({1, k, 1}, rhsType.getElementType());
auto rhsRowType = RankedTensorType::get({1, k}, rhsType.getElementType());
auto gemmRowType = RankedTensorType::get({1, m}, outType.getElementType());
auto gemmOutType = RankedTensorType::get({batch * n, m}, outType.getElementType());
auto gemmExpandedType = RankedTensorType::get({batch, n, m}, outType.getElementType());
const int64_t m = lhsType.getRank() == 3 ? lhsType.getDimSize(1) : lhsType.getDimSize(0);
const int64_t k = lhsType.getRank() == 3 ? lhsType.getDimSize(2) : lhsType.getDimSize(1);
const int64_t rhsK = rhsType.getRank() == 3 ? rhsType.getDimSize(1) : rhsType.getDimSize(0);
const int64_t n = rhsType.getRank() == 3 ? rhsType.getDimSize(2) : rhsType.getDimSize(1);
if (k != rhsK)
return failure();
Value lhsTransposed =
ONNXTransposeOp::create(rewriter, loc, lhsTransposedType, matmulOp.getA(), rewriter.getI64ArrayAttr({1, 0}));
Value none = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
SmallVector<Value> gemmRows;
gemmRows.reserve(batch * n);
for (int64_t batchIdx = 0; batchIdx < batch; batchIdx++) {
for (int64_t colIdx = 0; colIdx < n; colIdx++) {
SmallVector<OpFoldResult> offsets = {
rewriter.getIndexAttr(batchIdx), rewriter.getIndexAttr(0), rewriter.getIndexAttr(colIdx)};
SmallVector<OpFoldResult> sizes = {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(k), rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> strides = {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
Value rhsSlice =
tensor::ExtractSliceOp::create(rewriter, loc, rhsSliceType, matmulOp.getB(), offsets, sizes, strides);
Value rhsRow = tensor::CollapseShapeOp::create(rewriter,
loc,
rhsRowType,
rhsSlice,
SmallVector<ReassociationIndices> {
{0},
{1, 2}
});
auto gemmOp = ONNXGemmOp::create(rewriter,
loc,
gemmRowType,
rhsRow,
lhsTransposed,
none,
rewriter.getF32FloatAttr(1.0f),
rewriter.getF32FloatAttr(1.0f),
rewriter.getBoolAttr(false),
rewriter.getBoolAttr(false));
gemmRows.push_back(gemmOp.getY());
}
if (outType.getRank() == 2) {
if (batch != 1 || outType.getDimSize(0) != m || outType.getDimSize(1) != n)
return failure();
}
else {
if (outType.getDimSize(0) != batch || outType.getDimSize(1) != m || outType.getDimSize(2) != n)
return failure();
}
auto concatComputeOp = createSpatCompute(rewriter, loc, gemmOutType, {}, gemmRows, [&](ValueRange gemmRowsArgs) {
auto concatOp = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, gemmRowsArgs);
spatial::SpatYieldOp::create(rewriter, loc, concatOp.getResult());
});
Location loc = matmulOp.getLoc();
bool useTransposedForm = isHostFoldableValue(matmulOp.getA()) && !isHostFoldableValue(matmulOp.getB());
Value gemmOut = concatComputeOp.getResult(0);
Value gemmExpanded = tensor::ExpandShapeOp::create(rewriter,
loc,
gemmExpandedType,
gemmOut,
SmallVector<ReassociationIndices> {
{0, 1},
{2}
});
Value result = ONNXTransposeOp::create(rewriter, loc, outType, gemmExpanded, rewriter.getI64ArrayAttr({0, 2, 1}));
Value lhs = matmulOp.getA();
Value rhs = matmulOp.getB();
int64_t lhsBatchForGemm = lhsBatch;
int64_t rhsBatchForGemm = rhsBatch;
int64_t gemmM = m;
int64_t gemmK = k;
int64_t gemmN = n;
if (useTransposedForm) {
lhs = transposeLastTwoDimsInCompute(matmulOp.getB(), rewriter, loc);
lhsBatchForGemm = rhsBatch;
rhs = transposeLastTwoDims(matmulOp.getA(), rewriter, loc);
rhsBatchForGemm = lhsBatch;
gemmM = n;
gemmN = m;
}
auto gemmType = RankedTensorType::get({gemmM, gemmN}, outType.getElementType());
auto batchedOutType = RankedTensorType::get({1, m, n}, outType.getElementType());
Value none = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
if (outType.getRank() == 2) {
Value lhsMatrix = extractBatchMatrix(lhs, /*batchIndex=*/0, lhsBatchForGemm, gemmM, gemmK, rewriter, loc);
Value rhsMatrix = extractBatchMatrix(rhs, /*batchIndex=*/0, rhsBatchForGemm, gemmK, gemmN, rewriter, loc);
Value gemmResult = ONNXGemmOp::create(rewriter,
loc,
gemmType,
lhsMatrix,
rhsMatrix,
none,
rewriter.getF32FloatAttr(1.0f),
rewriter.getF32FloatAttr(1.0f),
rewriter.getBoolAttr(false),
rewriter.getBoolAttr(false))
.getY();
if (useTransposedForm) {
auto transposeCompute =
createSpatCompute<1>(rewriter, loc, TypeRange {outType}, {}, gemmResult, [&](Value input) {
Value transposed = ONNXTransposeOp::create(rewriter, loc, outType, input, rewriter.getI64ArrayAttr({1, 0}));
spatial::SpatYieldOp::create(rewriter, loc, transposed);
});
gemmResult = transposeCompute.getResult(0);
}
rewriter.replaceOp(matmulOp, gemmResult);
return success();
}
SmallVector<Value> batchResults;
batchResults.reserve(batch);
for (int64_t batchIdx = 0; batchIdx < batch; batchIdx++) {
Value lhsMatrix = extractBatchMatrix(lhs, batchIdx, lhsBatchForGemm, gemmM, gemmK, rewriter, loc);
Value rhsMatrix = extractBatchMatrix(rhs, batchIdx, rhsBatchForGemm, gemmK, gemmN, rewriter, loc);
Value gemmResult = ONNXGemmOp::create(rewriter,
loc,
gemmType,
lhsMatrix,
rhsMatrix,
none,
rewriter.getF32FloatAttr(1.0f),
rewriter.getF32FloatAttr(1.0f),
rewriter.getBoolAttr(false),
rewriter.getBoolAttr(false))
.getY();
auto batchResultCompute =
createSpatCompute<1>(rewriter, loc, TypeRange {batchedOutType}, {}, gemmResult, [&](Value input) {
Value resultMatrix = input;
if (useTransposedForm) {
resultMatrix = ONNXTransposeOp::create(rewriter,
loc,
RankedTensorType::get({m, n}, outType.getElementType()),
input,
rewriter.getI64ArrayAttr({1, 0}));
}
Value expanded = tensor::ExpandShapeOp::create(rewriter,
loc,
batchedOutType,
resultMatrix,
SmallVector<ReassociationIndices> {
{0, 1},
{2}
});
spatial::SpatYieldOp::create(rewriter, loc, expanded);
});
batchResults.push_back(batchResultCompute.getResult(0));
}
Value result = concatValues(batchResults, /*axis=*/0, rewriter, loc);
rewriter.replaceOp(matmulOp, result);
return success();
}
@@ -106,7 +247,7 @@ struct MatMulRank3ToGemm : OpRewritePattern<ONNXMatMulOp> {
} // namespace
void populateMatMulRewritePatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<MatMulRank3ToGemm>(ctx);
patterns.insert<MatMulToGemm>(ctx);
}
} // namespace onnx_mlir
@@ -5,8 +5,9 @@
#include <algorithm>
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -81,6 +82,24 @@ createAverageCompute(Value input, RankedTensorType resultType, ConversionPattern
return computeOp.getResult(0);
}
static Value concatValues(ValueRange inputs, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
auto firstType = cast<RankedTensorType>(inputs.front().getType());
SmallVector<int64_t> outputShape(firstType.getShape().begin(), firstType.getShape().end());
int64_t concatDimSize = 0;
for (Value input : inputs)
concatDimSize += cast<RankedTensorType>(input.getType()).getDimSize(axis);
outputShape[axis] = concatDimSize;
auto resultType = RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
if (llvm::all_of(inputs, isHostFoldableValue))
return createSpatConcat(rewriter, loc, axis, inputs);
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {resultType}, {}, inputs, [&](ValueRange args) {
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, axis, args));
});
return concatCompute.getResult(0);
}
static Value buildReduceMeanKeepdims(Value input,
ArrayRef<bool> reducedAxes,
int64_t axis,
@@ -100,8 +119,7 @@ static Value buildReduceMeanKeepdims(Value input,
for (Value slice : slices)
reducedSlices.push_back(buildReduceMeanKeepdims(slice, reducedAxes, axis + 1, leafType, rewriter, loc));
return reducedSlices.size() == 1 ? reducedSlices.front()
: tensor::ConcatOp::create(rewriter, loc, axis, reducedSlices).getResult();
return concatValues(reducedSlices, axis, rewriter, loc);
}
static Value squeezeReducedAxes(Value keepdimsValue,
@@ -116,9 +134,16 @@ static Value squeezeReducedAxes(Value keepdimsValue,
return tensor::FromElementsOp::create(rewriter, loc, resultType, ValueRange {element});
}
return tensor::CollapseShapeOp::create(
rewriter, loc, resultType, keepdimsValue, buildCollapseReassociation(reducedAxes))
.getResult();
auto reassociation = buildCollapseReassociation(reducedAxes);
if (isHostFoldableValue(keepdimsValue))
return tensor::CollapseShapeOp::create(rewriter, loc, resultType, keepdimsValue, reassociation).getResult();
auto squeezeCompute =
createSpatCompute<1>(rewriter, loc, TypeRange {resultType}, {}, ValueRange {keepdimsValue}, [&](Value input) {
Value collapsed = tensor::CollapseShapeOp::create(rewriter, loc, resultType, input, reassociation);
spatial::SpatYieldOp::create(rewriter, loc, collapsed);
});
return squeezeCompute.getResult(0);
}
struct ReduceMeanToSpatialCompute : OpConversionPattern<ONNXReduceMeanV13Op> {
@@ -1,18 +1,20 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/SmallVector.h"
#include <algorithm>
#include <cassert>
#include <optional>
#include <type_traits>
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -31,13 +33,6 @@ static int64_t getOptionalI64(std::optional<ArrayAttrT> arrayAttr, size_t index,
return arrayAttr ? getI64(*arrayAttr, index) : defaultValue;
}
static Value concatAlongAxis(ConversionPatternRewriter& rewriter, Location loc, int64_t axis, ArrayRef<Value> values) {
assert(!values.empty() && "Expected at least one value to concatenate.");
if (values.size() == 1)
return values.front();
return tensor::ConcatOp::create(rewriter, loc, axis, values);
}
static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Location loc, Value tile) {
auto tileType = cast<RankedTensorType>(tile.getType());
Value empty = tensor::EmptyOp::create(rewriter, loc, tileType.getShape(), tileType.getElementType());
@@ -52,27 +47,126 @@ static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Loca
return tensor::InsertSliceOp::create(rewriter, loc, tile, empty, offsets, sizes, strides);
}
template <typename ReduceOp>
static Value reduceWindowValues(ConversionPatternRewriter& rewriter, Location loc, ArrayRef<Value> windowValues) {
assert(!windowValues.empty() && "Expected at least one pool window value.");
static Value
createPoolFillElement(ConversionPatternRewriter& rewriter, Location loc, Type elementType, bool useMinimumValue) {
if (!useMinimumValue)
return arith::ConstantOp::create(rewriter, loc, elementType, rewriter.getZeroAttr(elementType));
Value reduced = windowValues.front();
for (Value value : windowValues.drop_front())
reduced = ReduceOp::create(rewriter, loc, reduced.getType(), reduced, value);
return reduced;
if (auto floatType = dyn_cast<FloatType>(elementType)) {
auto minValue = llvm::APFloat::getInf(floatType.getFloatSemantics(), /*Negative=*/true);
return arith::ConstantOp::create(rewriter, loc, elementType, rewriter.getFloatAttr(floatType, minValue));
}
if (auto integerType = dyn_cast<IntegerType>(elementType)) {
auto minValue = llvm::APInt::getSignedMinValue(integerType.getWidth());
return arith::ConstantOp::create(rewriter, loc, elementType, rewriter.getIntegerAttr(integerType, minValue));
}
llvm_unreachable("unsupported pool element type");
}
static Value
scaleAverageWindow(ConversionPatternRewriter& rewriter, Location loc, Value reducedWindow, int64_t divisor) {
assert(divisor > 0 && "AveragePool divisor must be positive.");
if (divisor == 1)
return reducedWindow;
static Value createPoolFillTensor(ConversionPatternRewriter& rewriter,
Location loc,
RankedTensorType tensorType,
bool useMinimumValue) {
auto fillElement = createPoolFillElement(rewriter, loc, tensorType.getElementType(), useMinimumValue);
return tensor::SplatOp::create(rewriter, loc, tensorType, fillElement);
}
auto tileType = cast<RankedTensorType>(reducedWindow.getType());
double scale = 1.0 / static_cast<double>(divisor);
auto scaleAttr = DenseElementsAttr::get(tileType, rewriter.getFloatAttr(tileType.getElementType(), scale));
Value scaleTensor = arith::ConstantOp::create(rewriter, loc, tileType, scaleAttr);
return spatial::SpatVMulOp::create(rewriter, loc, tileType, reducedWindow, scaleTensor);
template <typename PoolOp>
static Value createPaddedPoolInput(ConversionPatternRewriter& rewriter,
Location loc,
PoolOp poolOp,
Value input,
RankedTensorType inputType,
int64_t padTop,
int64_t padLeft,
int64_t padBottom,
int64_t padRight) {
if (padTop == 0 && padLeft == 0 && padBottom == 0 && padRight == 0)
return input;
auto paddedType = RankedTensorType::get({inputType.getDimSize(0),
inputType.getDimSize(1),
inputType.getDimSize(2) + padTop + padBottom,
inputType.getDimSize(3) + padLeft + padRight},
inputType.getElementType(),
inputType.getEncoding());
SmallVector<OpFoldResult> lowPads = {
rewriter.getIndexAttr(0), rewriter.getIndexAttr(0), rewriter.getIndexAttr(padTop), rewriter.getIndexAttr(padLeft)};
SmallVector<OpFoldResult> highPads = {rewriter.getIndexAttr(0),
rewriter.getIndexAttr(0),
rewriter.getIndexAttr(padBottom),
rewriter.getIndexAttr(padRight)};
auto padOp = tensor::PadOp::create(rewriter, loc, paddedType, input, lowPads, highPads);
auto* padBlock = new Block();
for (int index = 0; index < paddedType.getRank(); ++index)
padBlock->addArgument(rewriter.getIndexType(), loc);
padOp.getRegion().push_back(padBlock);
rewriter.setInsertionPointToStart(padBlock);
Value padValue =
createPoolFillElement(rewriter, loc, inputType.getElementType(), std::is_same_v<PoolOp, ONNXMaxPoolSingleOutOp>);
tensor::YieldOp::create(rewriter, loc, padValue);
rewriter.setInsertionPointAfter(padOp);
return padOp.getResult();
}
static FailureOr<Value> createAverageScaleTensor(ConversionPatternRewriter& rewriter,
Location loc,
Operation* op,
RankedTensorType outType,
int64_t channels,
int64_t inputHeight,
int64_t inputWidth,
int64_t outputHeight,
int64_t outputWidth,
int64_t kernelHeight,
int64_t kernelWidth,
int64_t strideHeight,
int64_t strideWidth,
int64_t dilationHeight,
int64_t dilationWidth,
int64_t padTop,
int64_t padLeft,
bool countIncludePad) {
auto elemType = dyn_cast<FloatType>(outType.getElementType());
if (!elemType) {
op->emitOpError("AveragePool lowering requires a floating-point element type");
return failure();
}
auto scaleType = RankedTensorType::get({1, channels, outputHeight, outputWidth}, elemType, outType.getEncoding());
SmallVector<Attribute> scaleValues;
scaleValues.reserve(static_cast<size_t>(channels * outputHeight * outputWidth));
for (int64_t channel = 0; channel < channels; ++channel) {
(void) channel;
for (int64_t outH = 0; outH < outputHeight; ++outH) {
for (int64_t outW = 0; outW < outputWidth; ++outW) {
int64_t validCount = 0;
for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) {
const int64_t inH = outH * strideHeight + kernelH * dilationHeight - padTop;
if (inH < 0 || inH >= inputHeight)
continue;
for (int64_t kernelW = 0; kernelW < kernelWidth; ++kernelW) {
const int64_t inW = outW * strideWidth + kernelW * dilationWidth - padLeft;
if (inW < 0 || inW >= inputWidth)
continue;
++validCount;
}
}
const int64_t divisor = countIncludePad ? kernelHeight * kernelWidth : validCount;
if (divisor <= 0) {
op->emitOpError("AveragePool divisor must be positive");
return failure();
}
scaleValues.push_back(rewriter.getFloatAttr(elemType, 1.0 / static_cast<double>(divisor)));
}
}
}
auto scaleAttr = DenseElementsAttr::get(scaleType, scaleValues);
return arith::ConstantOp::create(rewriter, loc, scaleType, scaleAttr).getResult();
}
template <typename PoolOp>
@@ -150,89 +244,133 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
}
}
(void) padBottom;
(void) padRight;
const int64_t xbarSize = static_cast<int64_t>(crossbarSize.getValue());
const int64_t channelTileCount = (channels + xbarSize - 1) / xbarSize;
const int64_t outputPatchCount = batchSize * outputHeight * outputWidth;
const bool countIncludePad = [&]() {
if constexpr (std::is_same_v<PoolOp, ONNXAveragePoolOp>)
return poolOp.getCountIncludePad() == 1;
return true;
}();
Value averageScaleTensor;
if constexpr (std::is_same_v<PoolOp, ONNXAveragePoolOp>) {
auto maybeAverageScaleTensor = createAverageScaleTensor(rewriter,
loc,
poolOp,
outType,
channels,
inputHeight,
inputWidth,
outputHeight,
outputWidth,
kernelHeight,
kernelWidth,
strideHeight,
strideWidth,
dilationHeight,
dilationWidth,
padTop,
padLeft,
countIncludePad);
if (failed(maybeAverageScaleTensor))
return failure();
averageScaleTensor = *maybeAverageScaleTensor;
}
constexpr size_t numInputs = 1;
auto computeOp =
createSpatCompute<numInputs>(rewriter, loc, outType, {}, ValueRange {x}, [&](Value xArg) -> LogicalResult {
SmallVector<Value> batchResults;
batchResults.reserve(batchSize);
Value paddedInput =
createPaddedPoolInput(rewriter, loc, poolOp, xArg, xType, padTop, padLeft, padBottom, padRight);
Value pooledOutputInit = tensor::EmptyOp::create(rewriter, loc, outType.getShape(), outType.getElementType());
for (int64_t batch = 0; batch < batchSize; ++batch) {
SmallVector<Value> rows;
rows.reserve(outputHeight);
Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
Value c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
Value cOutputPatchCount = arith::ConstantIndexOp::create(rewriter, loc, outputPatchCount);
Value cOutputPixelsPerBatch = arith::ConstantIndexOp::create(rewriter, loc, outputHeight * outputWidth);
Value cOutputWidth = arith::ConstantIndexOp::create(rewriter, loc, outputWidth);
Value cStrideHeight = arith::ConstantIndexOp::create(rewriter, loc, strideHeight);
Value cStrideWidth = arith::ConstantIndexOp::create(rewriter, loc, strideWidth);
for (int64_t outH = 0; outH < outputHeight; ++outH) {
SmallVector<Value> rowPixels;
rowPixels.reserve(outputWidth);
auto outputLoop = scf::ForOp::create(rewriter, loc, c0, cOutputPatchCount, c1, ValueRange {pooledOutputInit});
rewriter.setInsertionPointToStart(outputLoop.getBody());
for (int64_t outW = 0; outW < outputWidth; ++outW) {
SmallVector<Value> outputChannelTiles;
outputChannelTiles.reserve(channelTileCount);
Value outputPatchIndex = outputLoop.getInductionVar();
Value pooledOutputAcc = outputLoop.getRegionIterArgs().front();
for (int64_t channelTile = 0; channelTile < channelTileCount; ++channelTile) {
const int64_t tileChannels = std::min<int64_t>(xbarSize, channels - channelTile * xbarSize);
auto tileType = RankedTensorType::get({1, tileChannels, 1, 1}, outType.getElementType());
Value batchIndex = arith::DivUIOp::create(rewriter, loc, outputPatchIndex, cOutputPixelsPerBatch);
Value batchPatchIndex = arith::RemUIOp::create(rewriter, loc, outputPatchIndex, cOutputPixelsPerBatch);
Value outHeightIndex = arith::DivUIOp::create(rewriter, loc, batchPatchIndex, cOutputWidth);
Value outWidthIndex = arith::RemUIOp::create(rewriter, loc, batchPatchIndex, cOutputWidth);
Value windowBaseH = arith::MulIOp::create(rewriter, loc, outHeightIndex, cStrideHeight);
Value windowBaseW = arith::MulIOp::create(rewriter, loc, outWidthIndex, cStrideWidth);
SmallVector<Value> windowValues;
windowValues.reserve(kernelHeight * kernelWidth);
for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) {
const int64_t inH = outH * strideHeight + kernelH * dilationHeight - padTop;
if (inH < 0 || inH >= inputHeight)
continue;
Value updatedOutput = pooledOutputAcc;
for (int64_t channelTile = 0; channelTile < channelTileCount; ++channelTile) {
const int64_t tileChannels = std::min<int64_t>(xbarSize, channels - channelTile * xbarSize);
auto tileType = RankedTensorType::get({1, tileChannels, 1, 1}, outType.getElementType());
Value reducedWindow =
createPoolFillTensor(rewriter, loc, tileType, std::is_same_v<PoolOp, ONNXMaxPoolSingleOutOp>);
for (int64_t kernelW = 0; kernelW < kernelWidth; ++kernelW) {
const int64_t inW = outW * strideWidth + kernelW * dilationWidth - padLeft;
if (inW < 0 || inW >= inputWidth)
continue;
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(batch),
rewriter.getIndexAttr(channelTile * xbarSize),
rewriter.getIndexAttr(inH),
rewriter.getIndexAttr(inW)};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(tileChannels),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1)};
Value windowValue =
tensor::ExtractSliceOp::create(rewriter, loc, tileType, xArg, offsets, sizes, strides);
windowValue = materializeContiguousTile(rewriter, loc, windowValue);
windowValues.push_back(windowValue);
}
}
if (windowValues.empty())
return rewriter.notifyMatchFailure(poolOp, "pool window resolved to zero valid elements.");
Value reducedWindow = reduceWindowValues<ReduceOp>(rewriter, loc, windowValues);
if constexpr (std::is_same_v<PoolOp, ONNXAveragePoolOp>) {
const bool countIncludePad = poolOp.getCountIncludePad() == 1;
const int64_t divisor =
countIncludePad ? kernelHeight * kernelWidth : static_cast<int64_t>(windowValues.size());
reducedWindow = scaleAverageWindow(rewriter, loc, reducedWindow, divisor);
}
outputChannelTiles.push_back(reducedWindow);
}
rowPixels.push_back(concatAlongAxis(rewriter, loc, /*axis=*/1, outputChannelTiles));
for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) {
Value paddedInH = windowBaseH;
if (kernelH * dilationHeight != 0) {
Value kernelHOffset = arith::ConstantIndexOp::create(rewriter, loc, kernelH * dilationHeight);
paddedInH = arith::AddIOp::create(rewriter, loc, paddedInH, kernelHOffset);
}
rows.push_back(concatAlongAxis(rewriter, loc, /*axis=*/3, rowPixels));
for (int64_t kernelW = 0; kernelW < kernelWidth; ++kernelW) {
Value paddedInW = windowBaseW;
if (kernelW * dilationWidth != 0) {
Value kernelWOffset = arith::ConstantIndexOp::create(rewriter, loc, kernelW * dilationWidth);
paddedInW = arith::AddIOp::create(rewriter, loc, paddedInW, kernelWOffset);
}
SmallVector<OpFoldResult> offsets = {
batchIndex, rewriter.getIndexAttr(channelTile * xbarSize), paddedInH, paddedInW};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(tileChannels),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> strides = {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
Value windowValue =
tensor::ExtractSliceOp::create(rewriter, loc, tileType, paddedInput, offsets, sizes, strides);
windowValue = materializeContiguousTile(rewriter, loc, windowValue);
reducedWindow = ReduceOp::create(rewriter, loc, tileType, reducedWindow, windowValue);
}
}
batchResults.push_back(concatAlongAxis(rewriter, loc, /*axis=*/2, rows));
if constexpr (std::is_same_v<PoolOp, ONNXAveragePoolOp>) {
SmallVector<OpFoldResult> scaleOffsets = {
rewriter.getIndexAttr(0), rewriter.getIndexAttr(channelTile * xbarSize), outHeightIndex, outWidthIndex};
SmallVector<OpFoldResult> scaleSizes = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(tileChannels),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> scaleStrides = {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
Value scaleSlice = tensor::ExtractSliceOp::create(
rewriter, loc, tileType, averageScaleTensor, scaleOffsets, scaleSizes, scaleStrides);
scaleSlice = materializeContiguousTile(rewriter, loc, scaleSlice);
reducedWindow = spatial::SpatVMulOp::create(rewriter, loc, tileType, reducedWindow, scaleSlice);
}
SmallVector<OpFoldResult> outputOffsets = {
batchIndex, rewriter.getIndexAttr(channelTile * xbarSize), outHeightIndex, outWidthIndex};
SmallVector<OpFoldResult> outputSizes = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(tileChannels),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> outputStrides = {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
updatedOutput = tensor::InsertSliceOp::create(
rewriter, loc, reducedWindow, updatedOutput, outputOffsets, outputSizes, outputStrides);
}
Value pooledOutput = concatAlongAxis(rewriter, loc, /*axis=*/0, batchResults);
spatial::SpatYieldOp::create(rewriter, loc, pooledOutput);
scf::YieldOp::create(rewriter, loc, updatedOutput);
rewriter.setInsertionPointAfter(outputLoop);
spatial::SpatYieldOp::create(rewriter, loc, outputLoop.getResult(0));
return success();
});
if (failed(computeOp))
@@ -1,6 +1,6 @@
#include "mlir/Transforms/DialectConversion.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -1,6 +1,6 @@
#include "mlir/Transforms/DialectConversion.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -1,8 +1,9 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/DialectConversion.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -32,6 +33,24 @@ static Value createSoftmaxCompute(Value input, ConversionPatternRewriter& rewrit
return computeOp.getResult(0);
}
static Value concatValues(ValueRange inputs, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
auto firstType = cast<RankedTensorType>(inputs.front().getType());
SmallVector<int64_t> outputShape(firstType.getShape().begin(), firstType.getShape().end());
int64_t concatDimSize = 0;
for (Value input : inputs)
concatDimSize += cast<RankedTensorType>(input.getType()).getDimSize(axis);
outputShape[axis] = concatDimSize;
auto resultType = RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
if (llvm::all_of(inputs, isHostFoldableValue))
return createSpatConcat(rewriter, loc, axis, inputs);
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {resultType}, {}, inputs, [&](ValueRange args) {
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, axis, args));
});
return concatCompute.getResult(0);
}
static Value
buildSoftmax(Value input, int64_t softmaxAxis, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
auto inputType = cast<RankedTensorType>(input.getType());
@@ -47,8 +66,7 @@ buildSoftmax(Value input, int64_t softmaxAxis, int64_t axis, ConversionPatternRe
for (Value slice : slices)
rebuiltSlices.push_back(buildSoftmax(slice, softmaxAxis, axis + 1, rewriter, loc));
return rebuiltSlices.size() == 1 ? rebuiltSlices.front()
: tensor::ConcatOp::create(rewriter, loc, axis, rebuiltSlices).getResult();
return concatValues(rebuiltSlices, axis, rewriter, loc);
}
struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
@@ -93,8 +111,13 @@ struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
Value transposedInput = preTransposeCompute.getResult(0);
Value transposedResult = buildSoftmax(
transposedInput, /*softmaxAxis=*/inputType.getRank() - 1, /*axis=*/0, rewriter, softmaxOp.getLoc());
result = ONNXTransposeOp::create(
rewriter, softmaxOp.getLoc(), inputType, transposedResult, rewriter.getI64ArrayAttr(inversePermutation));
auto postTransposeCompute =
createSpatCompute<1>(rewriter, softmaxOp.getLoc(), TypeRange {inputType}, {}, transposedResult, [&](Value x) {
Value transposed = ONNXTransposeOp::create(
rewriter, softmaxOp.getLoc(), inputType, x, rewriter.getI64ArrayAttr(inversePermutation));
spatial::SpatYieldOp::create(rewriter, softmaxOp.getLoc(), transposed);
});
result = postTransposeCompute.getResult(0);
}
rewriter.replaceOp(softmaxOp, result);
@@ -1,7 +1,10 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/PatternMatch.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/ComputeRegionBuilder.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
@@ -17,7 +20,17 @@ struct Concat : public OpConversionPattern<ONNXConcatOp> {
auto inputs = adaptor.getInputs();
int64_t axis = adaptor.getAxis();
rewriter.replaceOpWithNewOp<tensor::ConcatOp>(maxpoolOp, axis, inputs);
if (llvm::all_of(inputs, isHostFoldableValue)) {
rewriter.replaceOp(maxpoolOp, createSpatConcat(rewriter, maxpoolOp.getLoc(), axis, inputs));
return success();
}
auto computeOp = createSpatCompute(
rewriter, maxpoolOp.getLoc(), TypeRange {maxpoolOp.getResult().getType()}, {}, inputs, [&](ValueRange args) {
spatial::SpatYieldOp::create(
rewriter, maxpoolOp.getLoc(), createSpatConcat(rewriter, maxpoolOp.getLoc(), axis, args));
});
rewriter.replaceOp(maxpoolOp, computeOp.getResults());
return success();
}
@@ -5,8 +5,8 @@
#include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -49,7 +49,7 @@ static Value concatGatherSlices(Value data,
}
if (slices.empty())
return {};
return slices.size() == 1 ? slices.front() : tensor::ConcatOp::create(rewriter, loc, axis, slices).getResult();
return createSpatConcat(rewriter, loc, axis, slices);
}
static Value addLeadingGatherDim(Value value, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
@@ -130,9 +130,7 @@ struct Gather : OpConversionPattern<ONNXGatherOp> {
return failure();
rows.push_back(addLeadingGatherDim(gatheredRow, axis, rewriter, loc));
}
result = rows.size() == 1
? rows.front()
: tensor::ConcatOp::create(rewriter, loc, /*axis=*/axis, rows).getResult();
result = createSpatConcat(rewriter, loc, /*axis=*/axis, rows);
}
else {
return failure();
@@ -3,7 +3,10 @@
#include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
@@ -95,18 +98,33 @@ struct Reshape : OpConversionPattern<ONNXReshapeOp> {
return success();
}
auto replaceWithReshape = [&](auto buildReshape) -> LogicalResult {
if (isHostFoldableValue(adaptor.getData())) {
rewriter.replaceOp(reshapeOp, buildReshape(adaptor.getData()));
return success();
}
auto computeOp = createSpatCompute<1>(
rewriter, reshapeOp.getLoc(), TypeRange {resultType}, {}, adaptor.getData(), [&](Value data) {
Value reshaped = buildReshape(data);
spatial::SpatYieldOp::create(rewriter, reshapeOp.getLoc(), reshaped);
});
rewriter.replaceOp(reshapeOp, computeOp.getResults());
return success();
};
SmallVector<ReassociationIndices> reassociation;
if (sourceType.getRank() > resultType.getRank()
&& inferCollapseReassociation(sourceType.getShape(), resultType.getShape(), reassociation)) {
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(reshapeOp, resultType, adaptor.getData(), reassociation);
return success();
}
&& inferCollapseReassociation(sourceType.getShape(), resultType.getShape(), reassociation))
return replaceWithReshape([&](Value data) {
return tensor::CollapseShapeOp::create(rewriter, reshapeOp.getLoc(), resultType, data, reassociation);
});
if (sourceType.getRank() < resultType.getRank()
&& inferExpandReassociation(sourceType.getShape(), resultType.getShape(), reassociation)) {
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(reshapeOp, resultType, adaptor.getData(), reassociation);
return success();
}
&& inferExpandReassociation(sourceType.getShape(), resultType.getShape(), reassociation))
return replaceWithReshape([&](Value data) {
return tensor::ExpandShapeOp::create(rewriter, reshapeOp.getLoc(), resultType, data, reassociation);
});
return failure();
}
@@ -5,8 +5,8 @@
#include <algorithm>
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -50,7 +50,7 @@ static Value buildNearestResize(Value input,
slices.push_back(buildNearestResize(slice, inputShape, outputShape, axis + 1, rewriter, loc));
}
return slices.size() == 1 ? slices.front() : tensor::ConcatOp::create(rewriter, loc, axis, slices).getResult();
return createSpatConcat(rewriter, loc, axis, slices);
}
struct Resize : OpConversionPattern<ONNXResizeOp> {
@@ -1,8 +1,10 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/DialectConversion.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
@@ -23,7 +25,10 @@ static Value extractSliceAt(
sizes.push_back(rewriter.getIndexAttr(dim));
offsets[axis] = rewriter.getIndexAttr(offset);
sizes[axis] = rewriter.getIndexAttr(size);
return tensor::ExtractSliceOp::create(rewriter, loc, input, offsets, sizes, strides);
SmallVector<int64_t> resultShape(inputType.getShape());
resultShape[axis] = size;
auto resultType = RankedTensorType::get(resultShape, inputType.getElementType());
return tensor::ExtractSliceOp::create(rewriter, loc, resultType, input, offsets, sizes, strides);
}
struct Split : OpConversionPattern<ONNXSplitOp> {
@@ -44,21 +49,40 @@ struct Split : OpConversionPattern<ONNXSplitOp> {
outputs.reserve(splitOp.getNumResults());
int64_t offset = 0;
SmallVector<RankedTensorType> resultTypes;
resultTypes.reserve(splitOp.getNumResults());
SmallVector<int64_t> sliceSizes;
sliceSizes.reserve(splitOp.getNumResults());
for (Value result : splitOp.getResults()) {
auto resultType = dyn_cast<RankedTensorType>(result.getType());
if (!resultType || !resultType.hasStaticShape())
return failure();
int64_t sliceSize = resultType.getShape()[axis];
auto computeOp =
createSpatCompute<1>(rewriter, splitOp.getLoc(), TypeRange {resultType}, {}, adaptor.getInput(), [&](Value x) {
Value output = extractSliceAt(x, axis, offset, sliceSize, rewriter, splitOp.getLoc());
spatial::SpatYieldOp::create(rewriter, splitOp.getLoc(), output);
});
outputs.push_back(computeOp.getResult(0));
offset += sliceSize;
resultTypes.push_back(resultType);
sliceSizes.push_back(resultType.getShape()[axis]);
}
rewriter.replaceOp(splitOp, outputs);
if (isHostFoldableValue(adaptor.getInput())) {
for (int64_t sliceSize : sliceSizes) {
outputs.push_back(extractSliceAt(adaptor.getInput(), axis, offset, sliceSize, rewriter, splitOp.getLoc()));
offset += sliceSize;
}
rewriter.replaceOp(splitOp, outputs);
return success();
}
auto computeOp = createSpatCompute<1>(
rewriter, splitOp.getLoc(), TypeRange(splitOp.getResultTypes()), {}, adaptor.getInput(), [&](Value input) {
SmallVector<Value> runtimeOutputs;
runtimeOutputs.reserve(resultTypes.size());
int64_t runtimeOffset = 0;
for (int64_t sliceSize : sliceSizes) {
runtimeOutputs.push_back(extractSliceAt(input, axis, runtimeOffset, sliceSize, rewriter, splitOp.getLoc()));
runtimeOffset += sliceSize;
}
spatial::SpatYieldOp::create(rewriter, splitOp.getLoc(), runtimeOutputs);
});
rewriter.replaceOp(splitOp, computeOp.getResults());
return success();
}
};
@@ -0,0 +1,265 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/WeightMaterialization.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static bool isWeightMaterializationHelperUser(Operation* op) {
return isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, ONNXTransposeOp>(op);
}
static bool canPromoteInputBlockArgument(BlockArgument arg) {
return !arg.use_empty() && llvm::all_of(arg.getUsers(), isWeightMaterializationHelperUser);
}
static bool isDirectConstantValue(Value value) {
return isa_and_nonnull<arith::ConstantOp, ONNXConstantOp>(value.getDefiningOp());
}
// Collapses one-lane batches so later phases do not carry batch-only structure unnecessarily.
struct FoldSingleLaneComputeBatchPattern : OpRewritePattern<spatial::SpatComputeBatch> {
using OpRewritePattern<spatial::SpatComputeBatch>::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatComputeBatch batchOp, PatternRewriter& rewriter) const override {
if (batchOp.getLaneCount() != 1)
return rewriter.notifyMatchFailure(batchOp, "requires a single lane");
auto loc = batchOp.getLoc();
rewriter.setInsertionPoint(batchOp);
auto computeOp =
spatial::SpatCompute::create(rewriter, loc, batchOp.getResultTypes(), batchOp.getWeights(), batchOp.getInputs());
computeOp.getProperties().setOperandSegmentSizes(
{static_cast<int>(batchOp.getWeights().size()), static_cast<int>(batchOp.getInputs().size())});
Block& templateBlock = batchOp.getBody().front();
SmallVector<Type> blockArgTypes;
SmallVector<Location> blockArgLocs;
blockArgTypes.reserve(templateBlock.getNumArguments());
blockArgLocs.reserve(templateBlock.getNumArguments());
for (BlockArgument arg : templateBlock.getArguments()) {
blockArgTypes.push_back(arg.getType());
blockArgLocs.push_back(loc);
}
auto* newBlock =
rewriter.createBlock(&computeOp.getBody(), computeOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
IRMapping mapper;
for (auto [oldArg, newArg] : llvm::zip(templateBlock.getArguments(), newBlock->getArguments()))
mapper.map(oldArg, newArg);
rewriter.setInsertionPointToEnd(newBlock);
for (Operation& op : templateBlock)
rewriter.clone(op, mapper);
batchOp->replaceAllUsesWith(computeOp->getResults());
rewriter.eraseOp(batchOp);
return success();
}
};
// Promotes foldable helper chains from runtime inputs to weights to avoid artificial compute inputs.
struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCompute> {
using OpRewritePattern<spatial::SpatCompute>::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatCompute compute, PatternRewriter& rewriter) const override {
SmallVector<bool> promoteInput(compute.getInputs().size(), false);
bool needsRewrite = false;
Block& oldBlock = compute.getBody().front();
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
if (inputIdx >= oldBlock.getNumArguments())
continue;
if (!isWeightLikeComputeOperand(input))
continue;
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(oldBlock.getArgument(inputIdx)))
continue;
promoteInput[inputIdx] = true;
needsRewrite = true;
}
if (!needsRewrite)
return rewriter.notifyMatchFailure(compute, "no weight-like inputs to promote");
rewriter.setInsertionPointAfter(compute);
SmallVector<Value> newWeights(compute.getWeights().begin(), compute.getWeights().end());
SmallVector<Value> newInputs;
SmallVector<Type> newInputTypes;
SmallVector<Location> newInputLocs;
newWeights.reserve(compute.getWeights().size() + compute.getInputs().size());
newInputs.reserve(compute.getInputs().size());
newInputTypes.reserve(compute.getInputs().size());
newInputLocs.reserve(compute.getInputs().size());
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
if (promoteInput[inputIdx]) {
newWeights.push_back(input);
continue;
}
newInputs.push_back(input);
newInputTypes.push_back(input.getType());
newInputLocs.push_back(input.getLoc());
}
auto newCompute =
spatial::SpatCompute::create(rewriter, compute.getLoc(), compute.getResultTypes(), newWeights, newInputs);
auto* newBlock =
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), newInputTypes, newInputLocs);
newCompute.getProperties().setOperandSegmentSizes(
{static_cast<int>(newWeights.size()), static_cast<int>(newInputs.size())});
rewriter.setInsertionPointToStart(newBlock);
IRRewriter bodyRewriter(rewriter.getContext());
bodyRewriter.setInsertionPointToStart(newBlock);
IRMapping mapper;
size_t newInputIdx = 0;
for (auto [oldInputIdx, oldArg] : llvm::enumerate(oldBlock.getArguments())) {
if (!promoteInput[oldInputIdx]) {
mapper.map(oldArg, newBlock->getArgument(newInputIdx++));
continue;
}
auto clonedValue = materializeWeightLikeValueInBlock(compute.getInputs()[oldInputIdx], bodyRewriter, mapper);
if (failed(clonedValue))
return rewriter.notifyMatchFailure(compute, "failed to materialize promoted weight-like operand");
mapper.map(oldArg, *clonedValue);
}
for (Operation& op : oldBlock.without_terminator())
rewriter.clone(op, mapper);
auto oldYield = cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
SmallVector<Value> newYieldOperands;
newYieldOperands.reserve(oldYield.getOutputs().size());
for (Value operand : oldYield.getOutputs()) {
auto mapped = mapper.lookupOrNull(operand);
newYieldOperands.push_back(mapped ? cast<Value>(mapped) : operand);
}
spatial::SpatYieldOp::create(rewriter, oldYield.getLoc(), newYieldOperands);
rewriter.replaceOp(compute, newCompute.getResults());
return success();
}
};
// Promotes foldable batch helper chains to weights while preserving compact compute_batch IR.
struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::SpatComputeBatch> {
using OpRewritePattern<spatial::SpatComputeBatch>::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatComputeBatch compute, PatternRewriter& rewriter) const override {
SmallVector<bool> promoteInput(compute.getInputs().size(), false);
bool needsRewrite = false;
Block& oldBlock = compute.getBody().front();
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
if (inputIdx >= oldBlock.getNumArguments())
continue;
if (!isWeightLikeComputeOperand(input))
continue;
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(oldBlock.getArgument(inputIdx)))
continue;
promoteInput[inputIdx] = true;
needsRewrite = true;
}
if (!needsRewrite)
return rewriter.notifyMatchFailure(compute, "no weight-like batch inputs to promote");
rewriter.setInsertionPointAfter(compute);
SmallVector<Value> newWeights(compute.getWeights().begin(), compute.getWeights().end());
SmallVector<Value> newInputs;
SmallVector<Type> newInputTypes;
SmallVector<Location> newInputLocs;
newWeights.reserve(compute.getWeights().size() + compute.getInputs().size());
newInputs.reserve(compute.getInputs().size());
newInputTypes.reserve(compute.getInputs().size());
newInputLocs.reserve(compute.getInputs().size());
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
if (promoteInput[inputIdx]) {
newWeights.push_back(input);
continue;
}
newInputs.push_back(input);
newInputTypes.push_back(input.getType());
newInputLocs.push_back(input.getLoc());
}
auto newCompute =
spatial::SpatComputeBatch::create(rewriter,
compute.getLoc(),
compute.getResultTypes(),
rewriter.getI32IntegerAttr(static_cast<int32_t>(compute.getLaneCount())),
newWeights,
newInputs);
auto* newBlock =
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), newInputTypes, newInputLocs);
newCompute.getProperties().setOperandSegmentSizes(
{static_cast<int>(newWeights.size()), static_cast<int>(newInputs.size())});
rewriter.setInsertionPointToStart(newBlock);
IRRewriter bodyRewriter(rewriter.getContext());
bodyRewriter.setInsertionPointToStart(newBlock);
IRMapping mapper;
size_t newInputIdx = 0;
for (auto [oldInputIdx, oldArg] : llvm::enumerate(oldBlock.getArguments())) {
if (!promoteInput[oldInputIdx]) {
mapper.map(oldArg, newBlock->getArgument(newInputIdx++));
continue;
}
auto clonedValue = materializeWeightLikeValueInBlock(compute.getInputs()[oldInputIdx], bodyRewriter, mapper);
if (failed(clonedValue))
return rewriter.notifyMatchFailure(compute, "failed to materialize promoted batch weight-like operand");
mapper.map(oldArg, *clonedValue);
}
for (Operation& op : oldBlock.without_terminator())
rewriter.clone(op, mapper);
auto oldYield = cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
SmallVector<Value> newYieldOperands;
newYieldOperands.reserve(oldYield.getOutputs().size());
for (Value operand : oldYield.getOutputs()) {
auto mapped = mapper.lookupOrNull(operand);
newYieldOperands.push_back(mapped ? cast<Value>(mapped) : operand);
}
spatial::SpatYieldOp::create(rewriter, oldYield.getLoc(), newYieldOperands);
rewriter.replaceOp(compute, newCompute.getResults());
return success();
}
};
} // namespace
void populateEarlyPostPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.add<FoldSingleLaneComputeBatchPattern>(ctx);
}
void populatePostPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.add<PromoteWeightLikeComputeInputsPattern, PromoteWeightLikeComputeBatchInputsPattern>(ctx);
}
void annotateWeightsConstants(func::FuncOp funcOp) {
funcOp.walk([&](arith::ConstantOp constantOp) {
if (hasOnlySpatialMvmVmmWeightUses(constantOp.getResult()))
markWeightAlways(constantOp);
});
}
} // namespace onnx_mlir
@@ -0,0 +1,14 @@
#pragma once
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/MLIRContext.h"
namespace onnx_mlir {
void populateEarlyPostPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populatePostPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void annotateWeightsConstants(mlir::func::FuncOp funcOp);
} // namespace onnx_mlir
@@ -0,0 +1,25 @@
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc"
} // namespace
void populatePrePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx) {
patterns.add<onnxToArithConstant>(ctx);
patterns.add<convAddToConvWithBiasLeft>(ctx);
patterns.add<convAddToConvWithBiasRight>(ctx);
patterns.add<matMulAddToGemm>(ctx);
patterns.add<matMulToGemm>(ctx);
patterns.add<removeFlattenSameShape>(ctx);
populateMatMulRewritePatterns(patterns, ctx);
}
} // namespace onnx_mlir
@@ -0,0 +1,10 @@
#pragma once
#include "mlir/IR/MLIRContext.h"
#include "mlir/Transforms/DialectConversion.h"
namespace onnx_mlir {
void populatePrePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
} // namespace onnx_mlir
@@ -0,0 +1,218 @@
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/IRMapping.h"
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir;
using namespace onnx_mlir::pim;
namespace onnx_mlir {
namespace {
static int32_t translateSpatialCoreIdToPimCoreId(size_t spatialCoreId) { return static_cast<int32_t>(spatialCoreId); }
static SmallVector<int32_t> getPimCoreIdsForBatchOp(spatial::SpatComputeBatch computeBatchOp, size_t& fallbackCoreId) {
if (auto coreIdsAttr = computeBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
SmallVector<int32_t> coreIds;
coreIds.reserve(static_cast<size_t>(computeBatchOp.getLaneCount()));
for (uint32_t lane = 0; lane < computeBatchOp.getLaneCount(); ++lane)
coreIds.push_back(static_cast<int32_t>(fallbackCoreId++));
return coreIds;
}
static void lowerChannelSendTensorBatch(spatial::SpatChannelSendTensorBatchOp sendTensorBatchOp,
IRMapping& mapper,
IRRewriter& rewriter) {
SmallVector<int32_t> targetCoreIds;
targetCoreIds.reserve(sendTensorBatchOp.getTargetCoreIds().size());
for (int32_t targetCoreId : sendTensorBatchOp.getTargetCoreIds())
targetCoreIds.push_back(translateSpatialCoreIdToPimCoreId(targetCoreId));
pim::PimSendTensorBatchOp::create(rewriter,
sendTensorBatchOp.getLoc(),
mapper.lookup(sendTensorBatchOp.getInput()),
rewriter.getDenseI32ArrayAttr(targetCoreIds));
}
static void lowerChannelReceiveTensorBatch(spatial::SpatChannelReceiveTensorBatchOp receiveTensorBatchOp,
IRMapping& mapper,
IRRewriter& rewriter) {
SmallVector<int32_t> sourceCoreIds;
sourceCoreIds.reserve(receiveTensorBatchOp.getSourceCoreIds().size());
for (int32_t sourceCoreId : receiveTensorBatchOp.getSourceCoreIds())
sourceCoreIds.push_back(translateSpatialCoreIdToPimCoreId(sourceCoreId));
auto outputType = cast<ShapedType>(receiveTensorBatchOp.getOutput().getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveTensorBatchOp.getLoc(), outputType);
Value received = pim::PimReceiveTensorBatchOp::create(rewriter,
receiveTensorBatchOp.getLoc(),
outputBuffer.getType(),
outputBuffer,
rewriter.getDenseI32ArrayAttr(sourceCoreIds))
.getOutput();
mapper.map(receiveTensorBatchOp.getOutput(), received);
}
} // namespace
LogicalResult
lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState& state, IRRewriter& rewriter) {
if (computeBatchOp.getNumResults() != 0)
return computeBatchOp.emitOpError(
"batched Spatial-to-PIM lowering currently requires channelized compute_batch with no results");
Location loc = computeBatchOp.getLoc();
Block& oldBlock = computeBatchOp.getBody().front();
auto oldYield = cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
if (oldYield.getNumOperands() != 0)
return computeBatchOp.emitOpError("batched Spatial-to-PIM lowering currently requires empty spat.yield");
SmallVector<int32_t> coreIds = getPimCoreIdsForBatchOp(computeBatchOp, state.nextCoreId);
SmallVector<Value> batchWeights(computeBatchOp.getWeights().begin(), computeBatchOp.getWeights().end());
SmallVector<Value> batchInputs;
if (!computeBatchOp.getInputs().empty())
batchInputs.append(computeBatchOp.getInputs().begin(), computeBatchOp.getInputs().end());
rewriter.setInsertionPointAfter(computeBatchOp);
auto coreBatchOp = pim::PimCoreBatchOp::create(rewriter,
loc,
rewriter.getI32IntegerAttr(computeBatchOp.getLaneCount()),
ValueRange(batchWeights),
ValueRange(batchInputs));
coreBatchOp.getProperties().setOperandSegmentSizes(
{static_cast<int>(batchWeights.size()), static_cast<int>(batchInputs.size())});
coreBatchOp->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds));
SmallVector<Type> blockArgTypes;
SmallVector<Location> blockArgLocs;
for (BlockArgument arg : oldBlock.getArguments()) {
blockArgTypes.push_back(arg.getType());
blockArgLocs.push_back(arg.getLoc());
}
Block* newBlock =
rewriter.createBlock(&coreBatchOp.getBody(), coreBatchOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
IRMapping mapper;
rewriter.setInsertionPointToStart(newBlock);
for (auto [oldArg, newArg] : llvm::zip(oldBlock.getArguments(), newBlock->getArguments())) {
auto newArgType = cast<ShapedType>(newArg.getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, newArgType);
auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter,
loc,
outputBuffer.getType(),
outputBuffer,
newArg,
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(0),
getTensorSizeInBytesAttr(rewriter, newArg))
.getOutput();
mapper.map(oldArg, copied);
}
auto materializeCapturedTensor = [&](Value capturedTensor) -> Value {
if (auto mapped = mapper.lookupOrNull(capturedTensor))
return mapped;
auto capturedType = cast<ShapedType>(capturedTensor.getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, capturedType);
auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter,
loc,
outputBuffer.getType(),
outputBuffer,
capturedTensor,
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(0),
getTensorSizeInBytesAttr(rewriter, capturedTensor))
.getOutput();
mapper.map(capturedTensor, copied);
return copied;
};
rewriter.setInsertionPointToEnd(newBlock);
for (Operation& op : oldBlock) {
if (isa<spatial::SpatYieldOp>(op))
continue;
if (auto sendBatchOp = dyn_cast<spatial::SpatChannelSendBatchOp>(op)) {
pim::PimSendBatchOp::create(rewriter,
loc,
mapper.lookup(sendBatchOp.getInput()),
getTensorSizeInBytesAttr(rewriter, mapper.lookup(sendBatchOp.getInput())),
sendBatchOp.getTargetCoreIdsAttr());
continue;
}
if (auto sendTensorBatchOp = dyn_cast<spatial::SpatChannelSendTensorBatchOp>(op)) {
lowerChannelSendTensorBatch(sendTensorBatchOp, mapper, rewriter);
continue;
}
if (auto receiveBatchOp = dyn_cast<spatial::SpatChannelReceiveBatchOp>(op)) {
auto outputType = cast<ShapedType>(receiveBatchOp.getOutput().getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, outputType);
auto received = pim::PimReceiveBatchOp::create(rewriter,
loc,
outputBuffer.getType(),
outputBuffer,
getTensorSizeInBytesAttr(rewriter, receiveBatchOp.getOutput()),
receiveBatchOp.getSourceCoreIdsAttr())
.getOutput();
mapper.map(receiveBatchOp.getOutput(), received);
continue;
}
if (auto receiveTensorBatchOp = dyn_cast<spatial::SpatChannelReceiveTensorBatchOp>(op)) {
lowerChannelReceiveTensorBatch(receiveTensorBatchOp, mapper, rewriter);
continue;
}
if (auto toTensorOp = dyn_cast<bufferization::ToTensorOp>(op)) {
if (isa_and_present<memref::GetGlobalOp>(toTensorOp.getBuffer().getDefiningOp())) {
Operation* cloned = rewriter.clone(op, mapper);
auto clonedTensor = cloned->getResult(0);
auto clonedType = cast<ShapedType>(clonedTensor.getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, clonedType);
auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter,
loc,
outputBuffer.getType(),
outputBuffer,
clonedTensor,
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(0),
getTensorSizeInBytesAttr(rewriter, clonedTensor))
.getOutput();
mapper.map(toTensorOp.getResult(), copied);
continue;
}
}
for (Value operand : op.getOperands()) {
if (!isa<TensorType>(operand.getType()) || mapper.contains(operand))
continue;
Operation* definingOp = operand.getDefiningOp();
if (definingOp && definingOp->getBlock() == &oldBlock)
continue;
materializeCapturedTensor(operand);
}
Operation* cloned = rewriter.clone(op, mapper);
for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults()))
mapper.map(originalResult, clonedResult);
}
rewriter.setInsertionPointToEnd(newBlock);
PimHaltOp::create(rewriter, loc);
return success();
}
} // namespace onnx_mlir
@@ -0,0 +1,10 @@
#pragma once
#include "src/Accelerators/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.hpp"
namespace onnx_mlir {
mlir::LogicalResult
lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState& state, mlir::IRRewriter& rewriter);
} // namespace onnx_mlir
@@ -4,7 +4,16 @@ add_public_tablegen_target(SpatialToPimIncGen)
add_pim_library(OMSpatialToPim
SpatialToPimPass.cpp
BatchCoreLoweringPatterns.cpp
ChannelLoweringPatterns.cpp
Cleanup.cpp
Common.cpp
ComputeLikeRegionUtils.cpp
CoreLoweringPatterns.cpp
GlobalTensorMaterialization.cpp
PhaseVerification.cpp
ReturnPathNormalization.cpp
TensorPackingPatterns.cpp
EXCLUDE_FROM_OM_LIBS
@@ -0,0 +1,136 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static int32_t toPimCoreId(int32_t spatialCoreId) { return spatialCoreId; }
struct ChannelSendLowering : OpRewritePattern<spatial::SpatChannelSendOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatChannelSendOp op, PatternRewriter& rewriter) const override {
pim::PimSendOp::create(rewriter,
op.getLoc(),
op.getInput(),
getTensorSizeInBytesAttr(rewriter, op.getInput()),
rewriter.getI32IntegerAttr(toPimCoreId(op.getTargetCoreId())));
rewriter.eraseOp(op);
return success();
}
};
struct ChannelReceiveLowering : OpRewritePattern<spatial::SpatChannelReceiveOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatChannelReceiveOp op, PatternRewriter& rewriter) const override {
if (op->use_empty()) {
rewriter.eraseOp(op);
return success();
}
auto outputType = cast<ShapedType>(op.getResult().getType());
Value outputBuffer =
tensor::EmptyOp::create(rewriter, op.getLoc(), outputType.getShape(), outputType.getElementType()).getResult();
Value received = pim::PimReceiveOp::create(rewriter,
op.getLoc(),
op.getResult().getType(),
outputBuffer,
getTensorSizeInBytesAttr(rewriter, op.getResult()),
rewriter.getI32IntegerAttr(toPimCoreId(op.getSourceCoreId())))
.getOutput();
rewriter.replaceOp(op, received);
return success();
}
};
struct ChannelSendTensorLowering : OpRewritePattern<spatial::SpatChannelSendTensorOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatChannelSendTensorOp op, PatternRewriter& rewriter) const override {
SmallVector<int32_t> targetCoreIds;
targetCoreIds.reserve(op.getTargetCoreIds().size());
for (int32_t targetCoreId : op.getTargetCoreIds())
targetCoreIds.push_back(toPimCoreId(targetCoreId));
pim::PimSendTensorOp::create(rewriter, op.getLoc(), op.getInput(), rewriter.getDenseI32ArrayAttr(targetCoreIds));
rewriter.eraseOp(op);
return success();
}
};
struct ChannelReceiveTensorLowering : OpRewritePattern<spatial::SpatChannelReceiveTensorOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatChannelReceiveTensorOp op, PatternRewriter& rewriter) const override {
SmallVector<int32_t> sourceCoreIds;
sourceCoreIds.reserve(op.getSourceCoreIds().size());
for (int32_t sourceCoreId : op.getSourceCoreIds())
sourceCoreIds.push_back(toPimCoreId(sourceCoreId));
auto outputType = cast<ShapedType>(op.getOutput().getType());
Value outputBuffer =
tensor::EmptyOp::create(rewriter, op.getLoc(), outputType.getShape(), outputType.getElementType()).getResult();
Value received =
pim::PimReceiveTensorOp::create(
rewriter, op.getLoc(), op.getOutput().getType(), outputBuffer, rewriter.getDenseI32ArrayAttr(sourceCoreIds))
.getOutput();
rewriter.replaceOp(op, received);
return success();
}
};
struct ExtractRowsLowering : OpRewritePattern<spatial::SpatExtractRowsOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatExtractRowsOp op, PatternRewriter& rewriter) const override {
auto inputType = cast<RankedTensorType>(op.getInput().getType());
SmallVector<Value> replacements;
replacements.reserve(op.getNumResults());
for (auto [rowIndex, output] : llvm::enumerate(op.getOutputs())) {
auto outputType = cast<RankedTensorType>(output.getType());
SmallVector<OpFoldResult> offsets = {
rewriter.getIndexAttr(static_cast<int64_t>(rowIndex) * outputType.getDimSize(0)), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(outputType.getDimSize(0)),
rewriter.getIndexAttr(inputType.getDimSize(1))};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
replacements.push_back(
tensor::ExtractSliceOp::create(rewriter, op.getLoc(), outputType, op.getInput(), offsets, sizes, strides)
.getResult());
}
rewriter.replaceOp(op, replacements);
return success();
}
};
struct ConcatLowering : OpRewritePattern<spatial::SpatConcatOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatConcatOp op, PatternRewriter& rewriter) const override {
auto outputType = cast<ShapedType>(op.getOutput().getType());
Value outputBuffer =
tensor::EmptyOp::create(rewriter, op.getLoc(), outputType.getShape(), outputType.getElementType()).getResult();
Value concatenated =
pim::PimConcatOp::create(
rewriter, op.getLoc(), op.getOutput().getType(), op.getAxisAttr(), op.getInputs(), outputBuffer)
.getOutput();
rewriter.replaceOp(op, concatenated);
return success();
}
};
} // namespace
void populateChannelLoweringPatterns(RewritePatternSet& patterns) {
patterns.add<ChannelSendLowering,
ChannelReceiveLowering,
ChannelSendTensorLowering,
ChannelReceiveTensorLowering,
ExtractRowsLowering,
ConcatLowering>(patterns.getContext());
}
} // namespace onnx_mlir
@@ -0,0 +1,9 @@
#pragma once
#include "mlir/IR/PatternMatch.h"
namespace onnx_mlir {
void populateChannelLoweringPatterns(mlir::RewritePatternSet& patterns);
} // namespace onnx_mlir
@@ -0,0 +1,42 @@
#include "llvm/ADT/STLExtras.h"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Cleanup.hpp"
using namespace mlir;
namespace onnx_mlir {
LogicalResult erasePendingOps(SmallVectorImpl<Operation*>& pendingOps, IRRewriter& rewriter) {
while (!pendingOps.empty()) {
bool erasedAnyOp = false;
for (auto it = pendingOps.begin(); it != pendingOps.end();) {
Operation* opToRemove = *it;
if (!opToRemove->use_empty()) {
++it;
continue;
}
rewriter.eraseOp(opToRemove);
it = pendingOps.erase(it);
erasedAnyOp = true;
}
if (erasedAnyOp)
continue;
for (Operation* opToRemove : pendingOps) {
InFlightDiagnostic diag = opToRemove->emitError("pending Spatial-to-PIM cleanup could not erase operation");
diag << "; op has " << llvm::range_size(opToRemove->getUsers()) << " remaining user(s)";
for (Operation* user : opToRemove->getUsers()) {
bool userPendingRemoval = llvm::is_contained(pendingOps, user);
opToRemove->emitRemark() << "remaining user `" << user->getName() << "`"
<< (userPendingRemoval ? " is also pending removal" : " is not pending removal");
}
}
return failure();
}
return success();
}
} // namespace onnx_mlir
@@ -0,0 +1,11 @@
#pragma once
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
namespace onnx_mlir {
mlir::LogicalResult erasePendingOps(llvm::SmallVectorImpl<mlir::Operation*>& pendingOps, mlir::IRRewriter& rewriter);
} // namespace onnx_mlir
+14 -45
View File
@@ -7,23 +7,12 @@
#include <cstddef>
#include "Common.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace llvm;
using namespace mlir;
namespace onnx_mlir {
namespace {
IntegerAttr getRequiredI32Attr(Builder& builder, Operation* op, llvm::StringRef attrName) {
auto attr = op->getAttrOfType<IntegerAttr>(attrName);
assert(attr && "required precomputed channel attr is missing");
return IntegerAttr::get(builder.getI32Type(), attr.getInt());
}
} // namespace
size_t getSliceActualOffset(tensor::ExtractSliceOp& sliceOp, ShapedType& inputShape) {
/*
EXAMPLE RUN:
@@ -74,37 +63,6 @@ IntegerAttr getTensorSizeInBytesAttr(Builder& builder, mlir::Value value) {
return builder.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(value.getType()))));
}
IntegerAttr getSpatialChannelSourceCoreIdAttr(Builder& builder, mlir::Value channel) {
auto channelNewOp = channel.getDefiningOp<spatial::SpatChannelNewOp>();
assert(channelNewOp && "spatial channel value must come from spat.channel_new");
return getRequiredI32Attr(builder, channelNewOp, kChannelSourceCoreIdAttrName);
}
IntegerAttr getSpatialChannelTargetCoreIdAttr(Builder& builder, mlir::Value channel) {
auto channelNewOp = channel.getDefiningOp<spatial::SpatChannelNewOp>();
assert(channelNewOp && "spatial channel value must come from spat.channel_new");
return getRequiredI32Attr(builder, channelNewOp, kChannelTargetCoreIdAttrName);
}
bool hasSpatialChannelSourceCoreIdAttr(mlir::Value channel) {
auto channelNewOp = channel.getDefiningOp<spatial::SpatChannelNewOp>();
return channelNewOp && channelNewOp->hasAttr(kChannelSourceCoreIdAttrName);
}
bool hasSpatialChannelTargetCoreIdAttr(mlir::Value channel) {
auto channelNewOp = channel.getDefiningOp<spatial::SpatChannelNewOp>();
return channelNewOp && channelNewOp->hasAttr(kChannelTargetCoreIdAttrName);
}
mlir::Value createPimReceiveFromSpatialChannel(
PatternRewriter& rewriter, Location loc, mlir::Value output, mlir::Value channel) {
mlir::Value outputBuffer = getBestOutputTensorFromOperandsOrAllocate(rewriter, output.getDefiningOp());
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, output);
auto sourceCoreIdAttr = getSpatialChannelSourceCoreIdAttr(rewriter, channel);
return pim::PimReceiveOp::create(rewriter, loc, outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr)
.getOutput();
}
Operation* getEarliestUserWithinBlock(mlir::Value value) {
auto users = value.getUsers();
@@ -127,15 +85,26 @@ SmallVector<mlir::Value> getOpOperandsSortedByUses(Operation* operation) {
return map_to_vector(operandsAndUses, [](auto operandAndUse) { return operandAndUse.first; });
}
mlir::Value getBestOutputTensorFromOperandsOrAllocate(PatternRewriter& rewriter, Operation* operation) {
bool hasLaterUserInBlock(mlir::Value value, Operation* operation) {
for (Operation* user : value.getUsers()) {
if (user->getBlock() != operation->getBlock())
return true;
if (operation->isBeforeInBlock(user))
return true;
}
return false;
}
mlir::Value getBestOutputTensorFromOperandsOrAllocate(RewriterBase& rewriter, Operation* operation) {
assert("Only support operations with a single result" && operation->getNumResults() == 1);
mlir::Value result = operation->getResult(0);
auto resultType = result.getType();
assert("Only support result ShapedType as result type" && isa<ShapedType>(resultType));
SmallVector<mlir::Value> operands = getOpOperandsSortedByUses(operation);
auto validOperands =
make_filter_range(operands, [resultType](mlir::Value operand) { return operand.getType() == resultType; });
auto validOperands = make_filter_range(operands, [operation, resultType](mlir::Value operand) {
return operand.getType() == resultType && !hasLaterUserInBlock(operand, operation);
});
auto bestOperand = validOperands.begin();
if (bestOperand != validOperands.end())
+1 -18
View File
@@ -2,16 +2,10 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "llvm/ADT/StringRef.h"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
namespace onnx_mlir {
inline constexpr llvm::StringLiteral kChannelSourceCoreIdAttrName = "precomp_source_core_id";
inline constexpr llvm::StringLiteral kChannelTargetCoreIdAttrName = "precomp_target_core_id";
/**
* \brief Get the offset of the ExtractSliceOp based on its static offsets and
* its static tensor input.
@@ -30,17 +24,6 @@ size_t getShapedTypeSizeInBytes(mlir::ShapedType shapedType);
mlir::IntegerAttr getTensorSizeInBytesAttr(mlir::Builder& builder, mlir::Value value);
mlir::IntegerAttr getSpatialChannelSourceCoreIdAttr(mlir::Builder& builder, mlir::Value channel);
mlir::IntegerAttr getSpatialChannelTargetCoreIdAttr(mlir::Builder& builder, mlir::Value channel);
bool hasSpatialChannelSourceCoreIdAttr(mlir::Value channel);
bool hasSpatialChannelTargetCoreIdAttr(mlir::Value channel);
mlir::Value createPimReceiveFromSpatialChannel(
mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value output, mlir::Value channel);
template <class T>
size_t rangeLength(const mlir::iterator_range<T> range) {
return std::distance(range.begin(), range.end());
@@ -58,7 +41,7 @@ mlir::Operation* getEarliestUserWithinBlock(mlir::Value value);
mlir::SmallVector<mlir::Value> getOpOperandsSortedByUses(mlir::Operation* operation);
mlir::Value getBestOutputTensorFromOperandsOrAllocate(mlir::PatternRewriter& rewriter, mlir::Operation* operation);
mlir::Value getBestOutputTensorFromOperandsOrAllocate(mlir::RewriterBase& rewriter, mlir::Operation* operation);
inline mlir::tensor::EmptyOp
createEmptyTensorFromShaped(mlir::IRRewriter& rewriter, mlir::Location loc, mlir::ShapedType shapedType) {
@@ -0,0 +1,44 @@
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ComputeLikeRegionUtils.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
std::optional<unsigned> getDirectComputeLikeInputIndex(Operation* owner, unsigned operandNumber) {
auto getInputIndex = [operandNumber](Operation* op, unsigned inputCount) -> std::optional<unsigned> {
if (inputCount == 0)
return std::nullopt;
unsigned inputBegin = op->getNumOperands() - inputCount;
if (operandNumber < inputBegin)
return std::nullopt;
return operandNumber - inputBegin;
};
if (auto compute = dyn_cast<spatial::SpatCompute>(owner))
return getInputIndex(owner, compute.getInputs().size());
if (auto computeBatch = dyn_cast<spatial::SpatComputeBatch>(owner))
return getInputIndex(owner, computeBatch.getInputs().size());
return std::nullopt;
}
void replaceAndEraseDirectComputeLikeInput(PatternRewriter& rewriter,
Operation* owner,
unsigned inputIndex,
Value replacement) {
Block& body = owner->getRegion(0).front();
BlockArgument bodyArgument = body.getArgument(inputIndex);
rewriter.startOpModification(owner);
bodyArgument.replaceAllUsesWith(replacement);
if (auto compute = dyn_cast<spatial::SpatCompute>(owner))
compute.getInputsMutable().erase(inputIndex);
else
cast<spatial::SpatComputeBatch>(owner).getInputsMutable().erase(inputIndex);
body.eraseArgument(inputIndex);
rewriter.finalizeOpModification(owner);
}
} // namespace onnx_mlir
@@ -0,0 +1,17 @@
#pragma once
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include <optional>
namespace onnx_mlir {
std::optional<unsigned> getDirectComputeLikeInputIndex(mlir::Operation* owner, unsigned operandNumber);
void replaceAndEraseDirectComputeLikeInput(mlir::PatternRewriter& rewriter,
mlir::Operation* owner,
unsigned inputIndex,
mlir::Value replacement);
} // namespace onnx_mlir
@@ -0,0 +1,213 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/IRMapping.h"
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
using namespace onnx_mlir::pim;
namespace onnx_mlir {
namespace {
static bool isChannelUseChainOp(Operation* op) {
return isa<tensor::ExtractSliceOp,
tensor::CollapseShapeOp,
tensor::ExpandShapeOp,
tensor::CastOp,
tosa::ReshapeOp,
ONNXTransposeOp,
pim::PimTransposeOp>(op);
}
static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter) {
for (Value operand : op->getOperands()) {
if (mapping.lookupOrNull(operand))
continue;
Operation* definingOp = operand.getDefiningOp();
if (!definingOp)
continue;
if (!isa<tensor::EmptyOp, arith::ConstantOp>(definingOp))
continue;
Operation* clonedOp = rewriter.clone(*definingOp, mapping);
for (auto [originalResult, newResult] : llvm::zip(definingOp->getResults(), clonedOp->getResults()))
mapping.map(originalResult, newResult);
rewriter.setInsertionPointAfter(clonedOp);
}
}
static int32_t translateSpatialCoreIdToPimCoreId(size_t spatialCoreId) { return static_cast<int32_t>(spatialCoreId); }
static int32_t getPimCoreIdForComputeOp(spatial::SpatCompute computeOp, size_t& fallbackCoreId) {
if (auto spatialCoreIdAttr = computeOp->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
return static_cast<int32_t>(spatialCoreIdAttr.getInt());
return static_cast<int32_t>(fallbackCoreId++);
}
static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp,
SmallVectorImpl<Operation*>& helperChain,
bool requireReturnUse = true) {
if (computeOp.getInputs().size() != 1 || computeOp.getNumResults() != 1)
return failure();
if (requireReturnUse
&& (!computeOp.getResult(0).hasOneUse() || !isa<func::ReturnOp>(*computeOp.getResult(0).getUsers().begin())))
return failure();
Block& block = computeOp.getBody().front();
if (block.getNumArguments() != 1)
return failure();
auto yieldOp = dyn_cast<spatial::SpatYieldOp>(block.getTerminator());
if (!yieldOp || yieldOp.getNumOperands() != 1)
return failure();
SmallVector<Operation*> reverseChain;
Value currentValue = yieldOp.getOperands().front();
Value blockArg = block.getArgument(0);
while (currentValue != blockArg) {
Operation* definingOp = currentValue.getDefiningOp();
if (!definingOp || definingOp->getBlock() != &block || !isChannelUseChainOp(definingOp))
return failure();
reverseChain.push_back(definingOp);
currentValue = definingOp->getOperand(0);
}
SmallPtrSet<Operation*, 8> chainSet(reverseChain.begin(), reverseChain.end());
for (Operation& op : llvm::make_early_inc_range(block.without_terminator()))
if (!chainSet.contains(&op) && !isa<tensor::EmptyOp, arith::ConstantOp>(op))
return failure();
helperChain.assign(reverseChain.rbegin(), reverseChain.rend());
return success();
}
static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute computeOp, IRRewriter& rewriter) {
if (!computeOp.getInputs().empty() || computeOp.getNumResults() != 1)
return false;
if (!llvm::all_of(computeOp.getResult(0).getUsers(), [](Operation* user) {
return isa<spatial::SpatCompute, spatial::SpatComputeBatch, pim::PimCoreOp, pim::PimCoreBatchOp>(user);
}))
return false;
Block& block = computeOp.getBody().front();
if (block.getNumArguments() != 0)
return false;
auto yieldOp = dyn_cast<spatial::SpatYieldOp>(block.getTerminator());
if (!yieldOp || yieldOp.getNumOperands() != 1)
return false;
rewriter.setInsertionPoint(computeOp);
IRMapping mapping;
for (Operation& op : block.without_terminator()) {
cloneMappedHelperOperands(&op, mapping, rewriter);
Operation* clonedOp = rewriter.clone(op, mapping);
for (auto [originalResult, newResult] : llvm::zip(op.getResults(), clonedOp->getResults()))
mapping.map(originalResult, newResult);
rewriter.setInsertionPointAfter(clonedOp);
}
Value replacement = mapping.lookupOrDefault(yieldOp.getOperand(0));
computeOp.getResult(0).replaceAllUsesWith(replacement);
return true;
}
} // namespace
void markOpToRemove(CoreLoweringState& state, Operation* op) {
if (!llvm::is_contained(state.operationsToRemove, op))
state.operationsToRemove.push_back(op);
}
LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState& state, IRRewriter& rewriter) {
Location loc = computeOp->getLoc();
if (inlineInputlessHelperComputeForWeightLikeUsers(computeOp, rewriter))
return success();
SmallVector<Operation*> helperChain;
if (succeeded(collectHelperComputeChain(computeOp, helperChain)))
return success();
auto& block = computeOp.getRegion().front();
auto yieldOp = cast<spatial::SpatYieldOp>(block.getTerminator());
for (auto [argIndex, blockArg] : llvm::enumerate(block.getArguments())) {
auto receiveOp = dyn_cast_or_null<spatial::SpatChannelReceiveOp>(computeOp.getInputs()[argIndex].getDefiningOp());
if (!receiveOp || blockArg.use_empty())
continue;
rewriter.setInsertionPoint(getEarliestUserWithinBlock(blockArg));
auto outputType = cast<ShapedType>(blockArg.getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveOp.getLoc(), outputType);
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, blockArg);
auto sourceCoreIdAttr = rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(receiveOp.getSourceCoreId()));
Value received = PimReceiveOp::create(
rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr)
.getOutput();
blockArg.replaceAllUsesWith(received);
markOpToRemove(state, receiveOp);
}
if (computeOp.getNumResults() != yieldOp.getNumOperands())
llvm_unreachable("ComputeOp must have same number of results as yieldOp operands");
for (auto [result, yieldValue] : llvm::zip(computeOp.getResults(), yieldOp.getOperands())) {
if (result.use_empty())
continue;
ReturnPathState returnPathState {state.outputTensors, state.operationsToRemove};
ReturnPathLoweringResult returnPathResult =
lowerComputeResultReturnPath(computeOp, cast<OpResult>(result), yieldValue, returnPathState, rewriter);
if (returnPathResult == ReturnPathLoweringResult::Failure)
return failure();
if (returnPathResult == ReturnPathLoweringResult::Handled)
continue;
auto resultUses = result.getUses();
if (rangeLength(resultUses) == 1) {
OpOperand& resultUse = *resultUses.begin();
Operation* resultUser = resultUse.getOwner();
if (isa<spatial::SpatChannelSendOp>(resultUser))
continue;
}
return computeOp.emitOpError("has an unsupported remaining result use during Spatial-to-PIM lowering");
}
rewriter.setInsertionPoint(yieldOp);
rewriter.replaceOpWithNewOp<PimHaltOp>(yieldOp);
SmallVector<Value> computeWeights;
if (!computeOp.getWeights().empty())
computeWeights.append(computeOp.getWeights().begin(), computeOp.getWeights().end());
rewriter.setInsertionPointAfter(computeOp);
auto coreOp = PimCoreOp::create(rewriter,
loc,
ValueRange(computeWeights),
rewriter.getI32IntegerAttr(getPimCoreIdForComputeOp(computeOp, state.nextCoreId)));
auto& coreOpBlocks = coreOp.getBody().getBlocks();
for (auto [argIndex, blockArg] : llvm::enumerate(block.getArguments()))
if (!blockArg.use_empty())
blockArg.replaceAllUsesWith(computeOp.getInputs()[argIndex]);
block.eraseArguments(0, block.getNumArguments());
coreOpBlocks.splice(coreOpBlocks.begin(), computeOp.getBody().getBlocks());
Block* tempComputeBlock = new Block();
computeOp.getBody().push_back(tempComputeBlock);
rewriter.setInsertionPointToEnd(tempComputeBlock);
PimHaltOp::create(rewriter, computeOp.getLoc());
return success();
}
} // namespace onnx_mlir
@@ -0,0 +1,21 @@
#pragma once
#include "mlir/IR/PatternMatch.h"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
namespace onnx_mlir {
struct CoreLoweringState {
size_t& nextCoreId;
llvm::SmallVectorImpl<OutputTensorFactory>& outputTensors;
llvm::SmallVectorImpl<mlir::Operation*>& operationsToRemove;
};
void markOpToRemove(CoreLoweringState& state, mlir::Operation* op);
mlir::LogicalResult
lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState& state, mlir::IRRewriter& rewriter);
} // namespace onnx_mlir

Some files were not shown because too many files have changed in this diff Show More